├── README.md ├── ResNet ├── README.md ├── ResNet_ImageNet.py ├── main.py └── test_bottleneck.py ├── Swin-Transformer ├── README.md ├── build.py ├── config.py ├── configs │ ├── acmix_swin_small_patch4_window7_224.yaml │ ├── acmix_swin_tiny_patch4_window7_224.yaml │ ├── swin_base_patch4_window12_384.yaml │ ├── swin_base_patch4_window7_224.yaml │ ├── swin_large_patch4_window12_384.yaml │ ├── swin_large_patch4_window7_224.yaml │ ├── swin_small_patch4_window7_224.yaml │ └── swin_tiny_patch4_window7_224.yaml ├── data │ ├── __init__.py │ ├── build.py │ ├── cached_image_folder.py │ ├── samplers.py │ └── zipreader.py ├── logger.py ├── lr_scheduler.py ├── main.py ├── models │ ├── __init__.py │ ├── build.py │ ├── swin_transformer.py │ └── swin_transformer_acmix.py ├── optimizer.py └── utils.py └── figure ├── main.png ├── result.png └── shift.png /README.md: -------------------------------------------------------------------------------- 1 | # ACmix 2 | This repo contains the official **PyTorch** code and pre-trained models for ACmix. 3 | 4 | + [On the Integration of Self-Attention and Convolution](https://arxiv.org/pdf/2111.14556v1.pdf) 5 | 6 | ## Update 7 | 8 | + **2022.4.13 Update ResNet training code.** 9 | 10 | **Notice:** Self-attention in ResNet is adopted following [Stand-Alone Self-Attention in Vision Models, NeurIPS 2019](https://proceedings.neurips.cc/paper/2019/file/3416a75f4cea9109507cacd8e2f2aefc-Paper.pdf). The sliding window pattern is extremely inefficient unless with carefully designed CUDA implementations. Therefore, **it is highly recommended** to use ACmix on SAN (with more efficient self-attention pattern) or Transformer-based models instead of vanilla ResNet. 11 | 12 | ## Introduction 13 | 14 | ![main](figure/main.png) 15 | 16 | We explore a closer relationship between convolution and self-attention in the sense of sharing the 17 | same computation overhead (1×1 convolutions), and combining with the remaining lightweight aggregation operations. 18 | 19 | ## Results 20 | 21 | + Top-1 accuracy on ImageNet v.s. Multiply-Adds 22 | 23 | ![image-20211208195403247](figure/result.png) 24 | 25 | ## Pretrained Models 26 | 27 | | Backbone Models | Params | FLOPs | Top-1 Acc | Links | 28 | | --------------- | ------ | ----- | ----------- | ------------------------------------------------------------ | 29 | | ResNet-26 | 10.6M | 2.3G | 76.1 (+2.5) | In process | 30 | | ResNet-38 | 14.6M | 2.9G | 77.4 (+1.4) | In process | 31 | | ResNet-50 | 18.6M | 3.6G | 77.8 (+0.9) | In process | 32 | | SAN-10 | 12.1M | 1.9G | 77.6 (+0.5) | In process | 33 | | SAN-15 | 16.6M | 2.7G | 78.4 (+0.4) | In process | 34 | | SAN-19 | 21.2M | 3.4G | 78.7 (+0.5) | In process | 35 | | PVT-T | 13M | 2.0G | 78.0 (+2.9) | In process | 36 | | PVT-S | 25M | 3.9G | 81.7 (+1.9) | In process | 37 | | Swin-T | 30M | 4.6G | 81.9 (+0.6) | [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/210c89a6f9eb4bd0beb6/) / [Google Drive](https://drive.google.com/file/d/1qJnYhtQ65rWd0zUuV9eZQW5Wxj-TiiSg/view?usp=sharing) | 38 | | Swin-S | 51M | 9.0G | 83.5 (+0.5) | [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/8bed555e75c840f8a00d/) / [Google Drive](https://drive.google.com/file/d/12PhN5YOEtWAgO8eSLhopfCVZ7JozMf9T/view?usp=sharing) | 39 | 40 | ## Get Started 41 | 42 | Please go to the folder [ResNet](https://github.com/LeapLabTHU/ACmix/tree/main/ResNet), [Swin-Transformer](https://github.com/LeapLabTHU/ACmix/tree/main/Swin-Transformer) for specific docs. 43 | 44 | ## Contact 45 | 46 | If you have any question, please feel free to contact the authors. Xuran Pan: [pxr18@mails.tsinghua.edu.cn](mailto:pxr18@mails.tsinghua.edu.cn). 47 | 48 | ## Acknowledgment 49 | 50 | Our code is based on [SAN](https://github.com/hszhao/SAN), [PVT](https://github.com/whai362/PVT), and [Swin Transformer](https://github.com/microsoft/Swin-Transformer). 51 | 52 | ## Citation 53 | 54 | If you find our work is useful in your research, please consider citing: 55 | 56 | ```bibtex 57 | @misc{pan2021integration, 58 | title={On the Integration of Self-Attention and Convolution}, 59 | author={Xuran Pan and Chunjiang Ge and Rui Lu and Shiji Song and Guanfu Chen and Zeyi Huang and Gao Huang}, 60 | year={2021}, 61 | eprint={2111.14556}, 62 | archivePrefix={arXiv}, 63 | primaryClass={cs.CV} 64 | } 65 | ``` 66 | 67 | -------------------------------------------------------------------------------- /ResNet/README.md: -------------------------------------------------------------------------------- 1 | # ResNet 2 | 3 | This folder contains the implementation of the ACmix based on ResNet models for image classification. 4 | 5 | ### Requirements 6 | 7 | + Python 3.7 8 | + PyTorch==1.8.0 9 | + torchvision==0.9.0 10 | 11 | We use standard ImageNet dataset, you can download it from http://image-net.org/. The file structure should look like: 12 | 13 | ``` 14 | $ tree data 15 | imagenet 16 | ├── train 17 | │ ├── class1 18 | │ │ ├── img1.jpeg 19 | │ │ ├── img2.jpeg 20 | │ │ └── ... 21 | │ ├── class2 22 | │ │ ├── img3.jpeg 23 | │ │ └── ... 24 | │ └── ... 25 | └── val 26 | ├── class1 27 | │ ├── img4.jpeg 28 | │ ├── img5.jpeg 29 | │ └── ... 30 | ├── class2 31 | │ ├── img6.jpeg 32 | │ └── ... 33 | └── ... 34 | ``` 35 | 36 | ### Run 37 | 38 | Train ResNet + ACmix on ImageNet 39 | 40 | ```python 41 | python main.py --dist-url 'tcp://127.0.0.1:12345' --dist-backend 'nccl' --world-size 1 --rank 0 --data_url --batch-size 128 42 | ``` 43 | -------------------------------------------------------------------------------- /ResNet/ResNet_ImageNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.utils import load_state_dict_from_url 4 | from test_bottleneck import ACmix 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=dilation, groups=groups, bias=False, dilation=dilation) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class Bottleneck(nn.Module): 19 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 20 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 21 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 22 | # This variant is also known as ResNet V1.5 and improves accuracy according to 23 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 24 | 25 | expansion = 4 26 | 27 | def __init__(self, inplanes, planes, k_att, head, k_conv, stride=1, downsample=None, groups=1, 28 | base_width=64, dilation=1, norm_layer=None): 29 | super(Bottleneck, self).__init__() 30 | if norm_layer is None: 31 | norm_layer = nn.BatchNorm2d 32 | width = int(planes * (base_width / 64.)) * groups 33 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 34 | self.conv1 = conv1x1(inplanes, width) 35 | self.bn1 = norm_layer(width) 36 | self.conv2 = ACmix(width, width, k_att, head, k_conv, stride=stride, dilation=dilation) 37 | self.bn2 = norm_layer(width) 38 | self.conv3 = conv1x1(width, planes * self.expansion) 39 | self.bn3 = norm_layer(planes * self.expansion) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv3(out) 56 | out = self.bn3(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | 69 | def __init__(self, block, layers, k_att=7, head=4, k_conv=3, num_classes=1000, zero_init_residual=False, 70 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 71 | norm_layer=None): 72 | super(ResNet, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | self._norm_layer = norm_layer 76 | 77 | self.inplanes = 64 78 | self.dilation = 1 79 | if replace_stride_with_dilation is None: 80 | # each element in the tuple indicates if we should replace 81 | # the 2x2 stride with a dilated convolution instead 82 | replace_stride_with_dilation = [False, False, False] 83 | if len(replace_stride_with_dilation) != 3: 84 | raise ValueError("replace_stride_with_dilation should be None " 85 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 86 | self.groups = groups 87 | self.base_width = width_per_group 88 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = norm_layer(self.inplanes) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 93 | self.layer1 = self._make_layer(block, 64, layers[0], k_att, head, k_conv) 94 | self.layer2 = self._make_layer(block, 128, layers[1], k_att, head, k_conv, stride=2, 95 | dilate=replace_stride_with_dilation[0]) 96 | self.layer3 = self._make_layer(block, 256, layers[2], k_att, head, k_conv, stride=2, 97 | dilate=replace_stride_with_dilation[1]) 98 | self.layer4 = self._make_layer(block, 512, layers[3], k_att, head, k_conv, stride=2, 99 | dilate=replace_stride_with_dilation[2]) 100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 106 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | 110 | # Zero-initialize the last BN in each residual branch, 111 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 112 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 113 | if zero_init_residual: 114 | for m in self.modules(): 115 | if isinstance(m, Bottleneck): 116 | nn.init.constant_(m.bn3.weight, 0) 117 | 118 | def _make_layer(self, block, planes, blocks, rate, k, head, stride=1, dilate=False): 119 | norm_layer = self._norm_layer 120 | downsample = None 121 | previous_dilation = self.dilation 122 | if dilate: 123 | self.dilation *= stride 124 | stride = 1 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | norm_layer(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, rate, k, head, stride, downsample, self.groups, 133 | self.base_width, previous_dilation, norm_layer)) 134 | self.inplanes = planes * block.expansion 135 | for _ in range(1, blocks): 136 | layers.append(block(self.inplanes, planes, rate, k, head, groups=self.groups, 137 | base_width=self.base_width, dilation=self.dilation, 138 | norm_layer=norm_layer)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def _forward_impl(self, x): 143 | # See note [TorchScript super()] 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | x = self.avgpool(x) 155 | x = torch.flatten(x, 1) 156 | x = self.fc(x) 157 | 158 | return x 159 | 160 | def forward(self, x): 161 | return self._forward_impl(x) 162 | 163 | 164 | def _resnet(block, layers, **kwargs): 165 | model = ResNet(block, layers, **kwargs) 166 | return model 167 | 168 | 169 | def ACmix_ResNet(layers=[3,4,6,3], **kwargs): 170 | return _resnet(Bottleneck, layers, **kwargs) 171 | 172 | 173 | if __name__ == '__main__': 174 | model = ACmix_ResNet().cuda() 175 | input = torch.randn([2,3,224,224]).cuda() 176 | total_params = sum(p.numel() for p in model.parameters()) 177 | print(f'{total_params:,} total parameters.') 178 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 179 | print(f'{total_trainable_params:,} training parameters.') 180 | print(model(input).shape) 181 | # print(summary(model, torch.zeros((1, 3, 224, 224)).cuda())) -------------------------------------------------------------------------------- /ResNet/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | from enum import Enum 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | from torch.optim.lr_scheduler import StepLR 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | from ResNet_ImageNet import ACmix_ResNet 23 | import torch.nn.functional as F 24 | 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('--data_url', default = '/home/data/ImageNet/', type=str) 28 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 29 | help='number of data loading workers (default: 4)') 30 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 31 | help='number of total epochs to run') 32 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 33 | help='manual epoch number (useful on restarts)') 34 | parser.add_argument('-b', '--batch-size', default=32, type=int, 35 | metavar='N', 36 | help='mini-batch size (default: 256), this is the total ' 37 | 'batch size of all GPUs on the current node when ' 38 | 'using Data Parallel or Distributed Data Parallel') 39 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 40 | metavar='LR', help='initial learning rate', dest='lr') 41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 42 | help='momentum') 43 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 44 | metavar='W', help='weight decay (default: 1e-4)', 45 | dest='weight_decay') 46 | parser.add_argument('-p', '--print-freq', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--world-size', default=-1, type=int, 55 | help='number of nodes for distributed training') 56 | parser.add_argument('--rank', default=-1, type=int, 57 | help='node rank for distributed training') 58 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 59 | help='url used to set up distributed training') 60 | parser.add_argument('--dist-backend', default='nccl', type=str, 61 | help='distributed backend') 62 | parser.add_argument('--seed', default=None, type=int, 63 | help='seed for initializing training. ') 64 | parser.add_argument('--gpu', default=None, type=int, 65 | help='GPU id to use.') 66 | parser.add_argument('--multiprocessing-distributed', action='store_true', 67 | help='Use multi-processing distributed training to launch ' 68 | 'N processes per node, which has N GPUs. This is the ' 69 | 'fastest way to use PyTorch for either single node or ' 70 | 'multi node data parallel training') 71 | 72 | parser.add_argument('--lr_scheduler', default = 'cosine', type = str) 73 | parser.add_argument('--train_url', default='./output', type=str, metavar='Logdir', help='log dir') 74 | parser.add_argument('--layers', default='3,4,6,3', type=str, help='Res50: 3,4,6,3, Res38: 2,3,5,2, Res26: 1,2,4,1') 75 | parser.add_argument('--k_att', default=7, type=int, help='kernel size for attention') 76 | parser.add_argument('--k_conv', default=3, type=int, help='kernel size for convolution') 77 | parser.add_argument('--head', default=4, type=int, help='number of heads of attention') 78 | parser.add_argument('--ls', default=0.1, type=float, help='url used to set up distributed training') 79 | best_acc1 = 0 80 | 81 | 82 | def main(): 83 | args = parser.parse_args() 84 | 85 | save_path = f'{args.train_url}/log/' 86 | args.train_url = save_path 87 | 88 | if args.seed is not None: 89 | random.seed(args.seed) 90 | torch.manual_seed(args.seed) 91 | cudnn.deterministic = True 92 | warnings.warn('You have chosen to seed training. ' 93 | 'This will turn on the CUDNN deterministic setting, ' 94 | 'which can slow down your training considerably! ' 95 | 'You may see unexpected behavior when restarting ' 96 | 'from checkpoints.') 97 | 98 | if args.gpu is not None: 99 | warnings.warn('You have chosen a specific GPU. This will completely ' 100 | 'disable data parallelism.') 101 | 102 | if args.dist_url == "env://" and args.world_size == -1: 103 | args.world_size = int(os.environ["WORLD_SIZE"]) 104 | 105 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 106 | 107 | ngpus_per_node = torch.cuda.device_count() 108 | if args.multiprocessing_distributed: 109 | # Since we have ngpus_per_node processes per node, the total world_size 110 | # needs to be adjusted accordingly 111 | args.world_size = ngpus_per_node * args.world_size 112 | # Use torch.multiprocessing.spawn to launch distributed processes: the 113 | # main_worker process function 114 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 115 | else: 116 | # Simply call main_worker function 117 | main_worker(args.gpu, ngpus_per_node, args) 118 | 119 | 120 | def main_worker(gpu, ngpus_per_node, args): 121 | global best_acc1 122 | args.gpu = gpu 123 | 124 | os.makedirs(args.train_url, exist_ok=True) 125 | with open(args.train_url+'train_configs.txt', "w") as f: 126 | f.write(str(args)) 127 | 128 | if args.gpu is not None: 129 | print("Use GPU: {} for training".format(args.gpu)) 130 | 131 | if args.distributed: 132 | if args.dist_url == "env://" and args.rank == -1: 133 | args.rank = int(os.environ["RANK"]) 134 | if args.multiprocessing_distributed: 135 | # For multiprocessing distributed training, rank needs to be the 136 | # global rank among all the processes 137 | args.rank = args.rank * ngpus_per_node + gpu 138 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 139 | world_size=args.world_size, rank=args.rank) 140 | # create model 141 | model = ACmix_ResNet(layers=[int(i) for i in args.layers.split(',')], k_att=args.k_att, head=args.head, k_conv=args.k_conv) 142 | 143 | with open(args.train_url+'model_arch.txt', "w") as f: 144 | f.write(str(model)) 145 | 146 | if not torch.cuda.is_available(): 147 | print('using CPU, this will be slow') 148 | elif args.distributed: 149 | # For multiprocessing distributed, DistributedDataParallel constructor 150 | # should always set the single device scope, otherwise, 151 | # DistributedDataParallel will use all available devices. 152 | if args.gpu is not None: 153 | torch.cuda.set_device(args.gpu) 154 | model.cuda(args.gpu) 155 | # When using a single GPU per process and per 156 | # DistributedDataParallel, we need to divide the batch size 157 | # ourselves based on the total number of GPUs of the current node. 158 | args.batch_size = int(args.batch_size / ngpus_per_node) 159 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 160 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 161 | else: 162 | model.cuda() 163 | # DistributedDataParallel will divide and allocate batch_size to all 164 | # available GPUs if device_ids are not set 165 | model = torch.nn.parallel.DistributedDataParallel(model) 166 | elif args.gpu is not None: 167 | torch.cuda.set_device(args.gpu) 168 | model = model.cuda(args.gpu) 169 | else: 170 | model = torch.nn.DataParallel(model).cuda() 171 | 172 | # define loss function (criterion), optimizer, and learning rate scheduler 173 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 174 | 175 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 176 | momentum=args.momentum, 177 | weight_decay=args.weight_decay) 178 | 179 | 180 | # optionally resume from a checkpoint 181 | if args.resume: 182 | if os.path.isfile(args.resume): 183 | print("=> loading checkpoint '{}'".format(args.resume)) 184 | if args.gpu is None: 185 | checkpoint = torch.load(args.resume) 186 | else: 187 | # Map model to be loaded to specified single gpu. 188 | loc = 'cuda:{}'.format(args.gpu) 189 | checkpoint = torch.load(args.resume, map_location=loc) 190 | args.start_epoch = checkpoint['epoch'] 191 | best_acc1 = checkpoint['best_acc1'] 192 | if args.gpu is not None: 193 | # best_acc1 may be from a checkpoint from a different GPU 194 | best_acc1 = best_acc1.to(args.gpu) 195 | model.load_state_dict(checkpoint['state_dict']) 196 | optimizer.load_state_dict(checkpoint['optimizer']) 197 | scheduler.load_state_dict(checkpoint['scheduler']) 198 | print("=> loaded checkpoint '{}' (epoch {})" 199 | .format(args.resume, checkpoint['epoch'])) 200 | else: 201 | print("=> no checkpoint found at '{}'".format(args.resume)) 202 | 203 | cudnn.benchmark = True 204 | 205 | # Data loading code 206 | traindir = os.path.join(args.data_url, 'train') 207 | valdir = os.path.join(args.data_url, 'val') 208 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 209 | std=[0.229, 0.224, 0.225]) 210 | 211 | train_dataset = datasets.ImageFolder( 212 | traindir, 213 | transforms.Compose([ 214 | transforms.RandomResizedCrop(224), 215 | transforms.RandomHorizontalFlip(), 216 | transforms.ToTensor(), 217 | normalize, 218 | ])) 219 | 220 | if args.distributed: 221 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 222 | else: 223 | train_sampler = None 224 | 225 | train_loader = torch.utils.data.DataLoader( 226 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 227 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 228 | 229 | val_loader = torch.utils.data.DataLoader( 230 | datasets.ImageFolder(valdir, transforms.Compose([ 231 | transforms.Resize(256), 232 | transforms.CenterCrop(224), 233 | transforms.ToTensor(), 234 | normalize, 235 | ])), 236 | batch_size=args.batch_size, shuffle=False, 237 | num_workers=args.workers, pin_memory=True) 238 | 239 | if args.evaluate: 240 | validate(val_loader, model, criterion, args) 241 | return 242 | 243 | for epoch in range(args.start_epoch, args.epochs): 244 | if args.distributed: 245 | train_sampler.set_epoch(epoch) 246 | 247 | # train for one epoch 248 | train(train_loader, model, criterion, optimizer, epoch, args) 249 | 250 | # evaluate on validation set 251 | acc1 = validate(val_loader, model, criterion, args) 252 | 253 | scheduler.step() 254 | 255 | 256 | # remember best acc@1 and save checkpoint 257 | is_best = acc1 > best_acc1 258 | best_acc1 = max(acc1, best_acc1) 259 | 260 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 261 | and args.rank % ngpus_per_node == 0): 262 | save_checkpoint({ 263 | 'epoch': epoch + 1, 264 | 'arch': args.arch, 265 | 'state_dict': model.state_dict(), 266 | 'best_acc1': best_acc1, 267 | 'optimizer' : optimizer.state_dict(), 268 | 'scheduler' : scheduler.state_dict() 269 | }, is_best) 270 | 271 | 272 | def train(train_loader, model, criterion, optimizer, epoch, args): 273 | batch_time = AverageMeter('Time', ':6.3f') 274 | data_time = AverageMeter('Data', ':6.3f') 275 | losses = AverageMeter('Loss', ':.4e') 276 | top1 = AverageMeter('Acc@1', ':6.2f') 277 | top5 = AverageMeter('Acc@5', ':6.2f') 278 | train_batches_num = len(train_loader) 279 | progress = ProgressMeter( 280 | len(train_loader), 281 | [batch_time, data_time, losses, top1, top5], 282 | prefix="Epoch: [{}]".format(epoch)) 283 | 284 | # switch to train mode 285 | model.train() 286 | 287 | end = time.time() 288 | for i, (images, target) in enumerate(train_loader): 289 | ### Adjust learning rate 290 | lr = adjust_learning_rate_iter_warmup(optimizer, epoch, i, train_batches_num, args) 291 | 292 | # measure data loading time 293 | data_time.update(time.time() - end) 294 | 295 | if args.gpu is not None: 296 | images = images.cuda(args.gpu, non_blocking=True) 297 | if torch.cuda.is_available(): 298 | target = target.cuda(args.gpu, non_blocking=True) 299 | 300 | # compute output 301 | output = model(images) 302 | if args.ls > 0: 303 | # enable label smoothing 304 | loss = smooth_loss(output, target, args.ls) 305 | else: 306 | loss = criterion(output, target) 307 | 308 | # measure accuracy and record loss 309 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 310 | losses.update(loss.item(), images.size(0)) 311 | top1.update(acc1[0], images.size(0)) 312 | top5.update(acc5[0], images.size(0)) 313 | 314 | # compute gradient and do SGD step 315 | optimizer.zero_grad() 316 | loss.backward() 317 | optimizer.step() 318 | 319 | # measure elapsed time 320 | batch_time.update(time.time() - end) 321 | end = time.time() 322 | 323 | if i % args.print_freq == 0: 324 | progress.display(i) 325 | 326 | 327 | def validate(val_loader, model, criterion, args): 328 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 329 | losses = AverageMeter('Loss', ':.4e', Summary.NONE) 330 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 331 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 332 | progress = ProgressMeter( 333 | len(val_loader), 334 | [batch_time, losses, top1, top5], 335 | prefix='Test: ') 336 | 337 | # switch to evaluate mode 338 | model.eval() 339 | 340 | with torch.no_grad(): 341 | end = time.time() 342 | for i, (images, target) in enumerate(val_loader): 343 | if args.gpu is not None: 344 | images = images.cuda(args.gpu, non_blocking=True) 345 | if torch.cuda.is_available(): 346 | target = target.cuda(args.gpu, non_blocking=True) 347 | 348 | # compute output 349 | output = model(images) 350 | loss = criterion(output, target) 351 | 352 | # measure accuracy and record loss 353 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 354 | losses.update(loss.item(), images.size(0)) 355 | top1.update(acc1[0], images.size(0)) 356 | top5.update(acc5[0], images.size(0)) 357 | 358 | # measure elapsed time 359 | batch_time.update(time.time() - end) 360 | end = time.time() 361 | 362 | if i % args.print_freq == 0: 363 | progress.display(i) 364 | 365 | progress.display_summary() 366 | 367 | return top1.avg 368 | 369 | 370 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 371 | torch.save(state, filename) 372 | if is_best: 373 | shutil.copyfile(filename, 'model_best.pth.tar') 374 | 375 | class Summary(Enum): 376 | NONE = 0 377 | AVERAGE = 1 378 | SUM = 2 379 | COUNT = 3 380 | 381 | class AverageMeter(object): 382 | """Computes and stores the average and current value""" 383 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 384 | self.name = name 385 | self.fmt = fmt 386 | self.summary_type = summary_type 387 | self.reset() 388 | 389 | def reset(self): 390 | self.val = 0 391 | self.avg = 0 392 | self.sum = 0 393 | self.count = 0 394 | 395 | def update(self, val, n=1): 396 | self.val = val 397 | self.sum += val * n 398 | self.count += n 399 | self.avg = self.sum / self.count 400 | 401 | def __str__(self): 402 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 403 | return fmtstr.format(**self.__dict__) 404 | 405 | def summary(self): 406 | fmtstr = '' 407 | if self.summary_type is Summary.NONE: 408 | fmtstr = '' 409 | elif self.summary_type is Summary.AVERAGE: 410 | fmtstr = '{name} {avg:.3f}' 411 | elif self.summary_type is Summary.SUM: 412 | fmtstr = '{name} {sum:.3f}' 413 | elif self.summary_type is Summary.COUNT: 414 | fmtstr = '{name} {count:.3f}' 415 | else: 416 | raise ValueError('invalid summary type %r' % self.summary_type) 417 | 418 | return fmtstr.format(**self.__dict__) 419 | 420 | 421 | class ProgressMeter(object): 422 | def __init__(self, num_batches, meters, prefix=""): 423 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 424 | self.meters = meters 425 | self.prefix = prefix 426 | 427 | def display(self, batch): 428 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 429 | entries += [str(meter) for meter in self.meters] 430 | print('\t'.join(entries)) 431 | 432 | def display_summary(self): 433 | entries = [" *"] 434 | entries += [meter.summary() for meter in self.meters] 435 | print(' '.join(entries)) 436 | 437 | def _get_batch_fmtstr(self, num_batches): 438 | num_digits = len(str(num_batches // 1)) 439 | fmt = '{:' + str(num_digits) + 'd}' 440 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 441 | 442 | 443 | def adjust_learning_rate_iter_warmup(optimizer, epoch, step, len_epoch, args): 444 | if args.lr_scheduler == 'multiStep': 445 | factor = epoch // 30 446 | if epoch >= 80: 447 | factor = factor + 1 448 | lr = args.lr*(0.1**factor) 449 | """Warmup""" 450 | if epoch < 5: 451 | lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) 452 | 453 | elif args.lr_scheduler == 'cosine': 454 | if epoch < 5: 455 | lr = args.lr/4 + 3*args.lr/4 * float(1 + step + epoch*len_epoch)/(5.*len_epoch) 456 | else: 457 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * float((epoch-5)*len_epoch+step) / float((args.epochs-5)*len_epoch))) 458 | 459 | for param_group in optimizer.param_groups: 460 | param_group['lr'] = lr 461 | 462 | return lr 463 | 464 | 465 | def adjust_learning_rate(optimizer, epoch, args): 466 | if args.lr_scheduler == 'multiStep': 467 | lr = args.lr * (0.1 ** (epoch // 30)) 468 | for param_group in optimizer.param_groups: 469 | param_group['lr'] = lr 470 | print('lr:',param_group['lr']) 471 | elif args.lr_scheduler == 'cosine': 472 | for param_group in optimizer.param_groups: 473 | param_group['lr'] = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs)) 474 | print('lr:',param_group['lr']) 475 | else: 476 | assert('scheduler not defined') 477 | 478 | 479 | def accuracy(output, target, topk=(1,)): 480 | """Computes the accuracy over the k top predictions for the specified values of k""" 481 | with torch.no_grad(): 482 | maxk = max(topk) 483 | batch_size = target.size(0) 484 | 485 | _, pred = output.topk(maxk, 1, True, True) 486 | pred = pred.t() 487 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 488 | 489 | res = [] 490 | for k in topk: 491 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 492 | res.append(correct_k.mul_(100.0 / batch_size)) 493 | return res 494 | 495 | 496 | def smooth_loss(output, target, eps=0.1): 497 | w = torch.zeros_like(output).scatter(1, target.unsqueeze(1), 1) 498 | w = w * (1 - eps) + (1 - w) * eps / (output.shape[1] - 1) 499 | log_prob = F.log_softmax(output, dim=1) 500 | loss = (-w * log_prob).sum(dim=1).mean() 501 | return loss 502 | 503 | 504 | if __name__ == '__main__': 505 | main() -------------------------------------------------------------------------------- /ResNet/test_bottleneck.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | 6 | def position(H, W, is_cuda=True): 7 | if is_cuda: 8 | loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1) 9 | loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W) 10 | else: 11 | loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1) 12 | loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W) 13 | loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0) 14 | return loc 15 | 16 | 17 | def stride(x, stride): 18 | b, c, h, w = x.shape 19 | return x[:, :, ::stride, ::stride] 20 | 21 | def init_rate_half(tensor): 22 | if tensor is not None: 23 | tensor.data.fill_(0.5) 24 | 25 | def init_rate_0(tensor): 26 | if tensor is not None: 27 | tensor.data.fill_(0.) 28 | 29 | 30 | class ACmix(nn.Module): 31 | def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1): 32 | super(ACmix, self).__init__() 33 | self.in_planes = in_planes 34 | self.out_planes = out_planes 35 | self.head = head 36 | self.kernel_att = kernel_att 37 | self.kernel_conv = kernel_conv 38 | self.stride = stride 39 | self.dilation = dilation 40 | self.rate1 = torch.nn.Parameter(torch.Tensor(1)) 41 | self.rate2 = torch.nn.Parameter(torch.Tensor(1)) 42 | self.head_dim = self.out_planes // self.head 43 | 44 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1) 45 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1) 46 | self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1) 47 | self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1) 48 | 49 | self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2 50 | self.pad_att = torch.nn.ReflectionPad2d(self.padding_att) 51 | self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride) 52 | self.softmax = torch.nn.Softmax(dim=1) 53 | 54 | self.fc = nn.Conv2d(3*self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False) 55 | self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes, kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1, stride=stride) 56 | 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | init_rate_half(self.rate1) 61 | init_rate_half(self.rate2) 62 | kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv) 63 | for i in range(self.kernel_conv * self.kernel_conv): 64 | kernel[i, i//self.kernel_conv, i%self.kernel_conv] = 1. 65 | kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1) 66 | self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True) 67 | self.dep_conv.bias = init_rate_0(self.dep_conv.bias) 68 | 69 | def forward(self, x): 70 | q, k, v = self.conv1(x), self.conv2(x), self.conv3(x) 71 | scaling = float(self.head_dim) ** -0.5 72 | b, c, h, w = q.shape 73 | h_out, w_out = h//self.stride, w//self.stride 74 | 75 | 76 | # ### att 77 | # ## positional encoding 78 | pe = self.conv_p(position(h, w, x.is_cuda)) 79 | 80 | q_att = q.view(b*self.head, self.head_dim, h, w) * scaling 81 | k_att = k.view(b*self.head, self.head_dim, h, w) 82 | v_att = v.view(b*self.head, self.head_dim, h, w) 83 | 84 | if self.stride > 1: 85 | q_att = stride(q_att, self.stride) 86 | q_pe = stride(pe, self.stride) 87 | else: 88 | q_pe = pe 89 | 90 | unfold_k = self.unfold(self.pad_att(k_att)).view(b*self.head, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) # b*head, head_dim, k_att^2, h_out, w_out 91 | unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) # 1, head_dim, k_att^2, h_out, w_out 92 | 93 | att = (q_att.unsqueeze(2)*(unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out) 94 | att = self.softmax(att) 95 | 96 | out_att = self.unfold(self.pad_att(v_att)).view(b*self.head, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) 97 | out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out) 98 | 99 | ## conv 100 | f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h*w), k.view(b, self.head, self.head_dim, h*w), v.view(b, self.head, self.head_dim, h*w)], 1)) 101 | f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1]) 102 | 103 | out_conv = self.dep_conv(f_conv) 104 | 105 | return self.rate1 * out_att + self.rate2 * out_conv -------------------------------------------------------------------------------- /Swin-Transformer/README.md: -------------------------------------------------------------------------------- 1 | # Swin Transformer 2 | 3 | This folder contains the implementation of the ACmix based on Swin Transformer models for image classification. 4 | 5 | ### Requirements 6 | 7 | + Python 3.7 8 | 9 | + PyTorch==1.8.0 10 | 11 | + torchvision==0.9.0 12 | 13 | + timm==0.3.2 14 | 15 | + opencv-python==4.4.0.46 16 | 17 | + termcolor==1.1.0 18 | 19 | + yacs==0.1.8 20 | 21 | + Install Apex: 22 | 23 | ```python 24 | git clone https://github.com/NVIDIA/apex 25 | cd apex 26 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 27 | ``` 28 | 29 | ### Data preparation 30 | 31 | We use standard ImageNet dataset, you can download it from http://image-net.org/. The file structure should look like: 32 | 33 | ``` 34 | $ tree data 35 | imagenet 36 | ├── train 37 | │ ├── class1 38 | │ │ ├── img1.jpeg 39 | │ │ ├── img2.jpeg 40 | │ │ └── ... 41 | │ ├── class2 42 | │ │ ├── img3.jpeg 43 | │ │ └── ... 44 | │ └── ... 45 | └── val 46 | ├── class1 47 | │ ├── img4.jpeg 48 | │ ├── img5.jpeg 49 | │ └── ... 50 | ├── class2 51 | │ ├── img6.jpeg 52 | │ └── ... 53 | └── ... 54 | ``` 55 | 56 | ### Run 57 | 58 | Train Swin-T + ACmix on ImageNet 59 | 60 | ```python 61 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py --cfg configs/acmix_swin_tiny_patch4_window7_224.yaml --data-path --batch-size 128 62 | ``` 63 | 64 | Train Swin-S + ACmix on ImageNet 65 | 66 | ```python 67 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py --cfg configs/acmix_swin_small_patch4_window7_224.yaml --data-path --batch-size 128 68 | ``` 69 | 70 | -------------------------------------------------------------------------------- /Swin-Transformer/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | from .swin_transformer_hac import SwinTransformer_hac 10 | from .swin_transformer_hac_v3 import SwinTransformer_hac_v3 11 | from .swin_transformer_hac_v4 import SwinTransformer_hac_v4 12 | from .swin_transformer_hac_v5 import SwinTransformer_hac_v5 13 | from .swin_transformer_hac_v6 import SwinTransformer_hac_v6 14 | from .swin_transformer_hac_v7 import SwinTransformer_hac_v7 15 | from .swin_transformer_hac_v8 import SwinTransformer_hac_v8 16 | from .swin_transformer_hac_v9 import SwinTransformer_hac_v9 17 | from .swin_transformer_hac_v10 import SwinTransformer_hac_v10 18 | from .swin_transformer_hac_v11 import SwinTransformer_hac_v11 19 | from .swin_transformer_hac_v12 import SwinTransformer_hac_v12 20 | from .swin_transformer_hac_v13 import SwinTransformer_hac_v13 21 | from .swin_transformer_hac_v14 import SwinTransformer_hac_v14 22 | from .swin_transformer_hac_v15 import SwinTransformer_hac_v15 23 | from .swin_transformer_hac_v16 import SwinTransformer_hac_v16 24 | 25 | 26 | def build_model(config): 27 | model_type = config.MODEL.TYPE 28 | if model_type == 'swin': 29 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 30 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 31 | in_chans=config.MODEL.SWIN.IN_CHANS, 32 | num_classes=config.MODEL.NUM_CLASSES, 33 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 34 | depths=config.MODEL.SWIN.DEPTHS, 35 | num_heads=config.MODEL.SWIN.NUM_HEADS, 36 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 37 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 38 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 39 | qk_scale=config.MODEL.SWIN.QK_SCALE, 40 | drop_rate=config.MODEL.DROP_RATE, 41 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 42 | ape=config.MODEL.SWIN.APE, 43 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 44 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 45 | elif model_type == 'swin_hac': 46 | model = SwinTransformer_hac(img_size=config.DATA.IMG_SIZE, 47 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 48 | in_chans=config.MODEL.SWIN.IN_CHANS, 49 | num_classes=config.MODEL.NUM_CLASSES, 50 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 51 | depths=config.MODEL.SWIN.DEPTHS, 52 | num_heads=config.MODEL.SWIN.NUM_HEADS, 53 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 54 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 55 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 56 | qk_scale=config.MODEL.SWIN.QK_SCALE, 57 | drop_rate=config.MODEL.DROP_RATE, 58 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 59 | ape=config.MODEL.SWIN.APE, 60 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 61 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 62 | elif model_type == 'swin_hac_v3': 63 | model = SwinTransformer_hac_v3(img_size=config.DATA.IMG_SIZE, 64 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 65 | in_chans=config.MODEL.SWIN.IN_CHANS, 66 | num_classes=config.MODEL.NUM_CLASSES, 67 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 68 | depths=config.MODEL.SWIN.DEPTHS, 69 | num_heads=config.MODEL.SWIN.NUM_HEADS, 70 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 71 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 72 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 73 | qk_scale=config.MODEL.SWIN.QK_SCALE, 74 | drop_rate=config.MODEL.DROP_RATE, 75 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 76 | ape=config.MODEL.SWIN.APE, 77 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 78 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 79 | elif model_type == 'swin_hac_v4': 80 | model = SwinTransformer_hac_v4(img_size=config.DATA.IMG_SIZE, 81 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 82 | in_chans=config.MODEL.SWIN.IN_CHANS, 83 | num_classes=config.MODEL.NUM_CLASSES, 84 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 85 | depths=config.MODEL.SWIN.DEPTHS, 86 | num_heads=config.MODEL.SWIN.NUM_HEADS, 87 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 88 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 89 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 90 | qk_scale=config.MODEL.SWIN.QK_SCALE, 91 | drop_rate=config.MODEL.DROP_RATE, 92 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 93 | ape=config.MODEL.SWIN.APE, 94 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 95 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 96 | elif model_type == 'swin_hac_v5': 97 | model = SwinTransformer_hac_v5(img_size=config.DATA.IMG_SIZE, 98 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 99 | in_chans=config.MODEL.SWIN.IN_CHANS, 100 | num_classes=config.MODEL.NUM_CLASSES, 101 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 102 | depths=config.MODEL.SWIN.DEPTHS, 103 | num_heads=config.MODEL.SWIN.NUM_HEADS, 104 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 105 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 106 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 107 | qk_scale=config.MODEL.SWIN.QK_SCALE, 108 | drop_rate=config.MODEL.DROP_RATE, 109 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 110 | ape=config.MODEL.SWIN.APE, 111 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 112 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 113 | elif model_type == 'swin_hac_v6': 114 | model = SwinTransformer_hac_v6(img_size=config.DATA.IMG_SIZE, 115 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 116 | in_chans=config.MODEL.SWIN.IN_CHANS, 117 | num_classes=config.MODEL.NUM_CLASSES, 118 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 119 | depths=config.MODEL.SWIN.DEPTHS, 120 | num_heads=config.MODEL.SWIN.NUM_HEADS, 121 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 122 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 123 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 124 | qk_scale=config.MODEL.SWIN.QK_SCALE, 125 | drop_rate=config.MODEL.DROP_RATE, 126 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 127 | ape=config.MODEL.SWIN.APE, 128 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 129 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 130 | elif model_type == 'swin_hac_v7': 131 | model = SwinTransformer_hac_v7(img_size=config.DATA.IMG_SIZE, 132 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 133 | in_chans=config.MODEL.SWIN.IN_CHANS, 134 | num_classes=config.MODEL.NUM_CLASSES, 135 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 136 | depths=config.MODEL.SWIN.DEPTHS, 137 | num_heads=config.MODEL.SWIN.NUM_HEADS, 138 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 139 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 140 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 141 | qk_scale=config.MODEL.SWIN.QK_SCALE, 142 | drop_rate=config.MODEL.DROP_RATE, 143 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 144 | ape=config.MODEL.SWIN.APE, 145 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 146 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 147 | elif model_type == 'swin_hac_v8': 148 | model = SwinTransformer_hac_v8(img_size=config.DATA.IMG_SIZE, 149 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 150 | in_chans=config.MODEL.SWIN.IN_CHANS, 151 | num_classes=config.MODEL.NUM_CLASSES, 152 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 153 | depths=config.MODEL.SWIN.DEPTHS, 154 | num_heads=config.MODEL.SWIN.NUM_HEADS, 155 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 156 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 157 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 158 | qk_scale=config.MODEL.SWIN.QK_SCALE, 159 | drop_rate=config.MODEL.DROP_RATE, 160 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 161 | ape=config.MODEL.SWIN.APE, 162 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 163 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 164 | elif model_type == 'swin_hac_v9': 165 | model = SwinTransformer_hac_v9(img_size=config.DATA.IMG_SIZE, 166 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 167 | in_chans=config.MODEL.SWIN.IN_CHANS, 168 | num_classes=config.MODEL.NUM_CLASSES, 169 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 170 | depths=config.MODEL.SWIN.DEPTHS, 171 | num_heads=config.MODEL.SWIN.NUM_HEADS, 172 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 173 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 174 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 175 | qk_scale=config.MODEL.SWIN.QK_SCALE, 176 | drop_rate=config.MODEL.DROP_RATE, 177 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 178 | ape=config.MODEL.SWIN.APE, 179 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 180 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 181 | elif model_type == 'swin_hac_v10': 182 | model = SwinTransformer_hac_v10(img_size=config.DATA.IMG_SIZE, 183 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 184 | in_chans=config.MODEL.SWIN.IN_CHANS, 185 | num_classes=config.MODEL.NUM_CLASSES, 186 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 187 | depths=config.MODEL.SWIN.DEPTHS, 188 | num_heads=config.MODEL.SWIN.NUM_HEADS, 189 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 190 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 191 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 192 | qk_scale=config.MODEL.SWIN.QK_SCALE, 193 | drop_rate=config.MODEL.DROP_RATE, 194 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 195 | ape=config.MODEL.SWIN.APE, 196 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 197 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 198 | elif model_type == 'swin_hac_v11': 199 | model = SwinTransformer_hac_v11(img_size=config.DATA.IMG_SIZE, 200 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 201 | in_chans=config.MODEL.SWIN.IN_CHANS, 202 | num_classes=config.MODEL.NUM_CLASSES, 203 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 204 | depths=config.MODEL.SWIN.DEPTHS, 205 | num_heads=config.MODEL.SWIN.NUM_HEADS, 206 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 207 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 208 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 209 | qk_scale=config.MODEL.SWIN.QK_SCALE, 210 | drop_rate=config.MODEL.DROP_RATE, 211 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 212 | ape=config.MODEL.SWIN.APE, 213 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 214 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 215 | elif model_type == 'swin_hac_v12': 216 | model = SwinTransformer_hac_v12(img_size=config.DATA.IMG_SIZE, 217 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 218 | in_chans=config.MODEL.SWIN.IN_CHANS, 219 | num_classes=config.MODEL.NUM_CLASSES, 220 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 221 | depths=config.MODEL.SWIN.DEPTHS, 222 | num_heads=config.MODEL.SWIN.NUM_HEADS, 223 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 224 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 225 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 226 | qk_scale=config.MODEL.SWIN.QK_SCALE, 227 | drop_rate=config.MODEL.DROP_RATE, 228 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 229 | ape=config.MODEL.SWIN.APE, 230 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 231 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 232 | elif model_type == 'swin_hac_v13': 233 | model = SwinTransformer_hac_v13(img_size=config.DATA.IMG_SIZE, 234 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 235 | in_chans=config.MODEL.SWIN.IN_CHANS, 236 | num_classes=config.MODEL.NUM_CLASSES, 237 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 238 | depths=config.MODEL.SWIN.DEPTHS, 239 | num_heads=config.MODEL.SWIN.NUM_HEADS, 240 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 241 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 242 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 243 | qk_scale=config.MODEL.SWIN.QK_SCALE, 244 | drop_rate=config.MODEL.DROP_RATE, 245 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 246 | ape=config.MODEL.SWIN.APE, 247 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 248 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 249 | elif model_type == 'swin_hac_v14': 250 | model = SwinTransformer_hac_v14(img_size=config.DATA.IMG_SIZE, 251 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 252 | in_chans=config.MODEL.SWIN.IN_CHANS, 253 | num_classes=config.MODEL.NUM_CLASSES, 254 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 255 | depths=config.MODEL.SWIN.DEPTHS, 256 | num_heads=config.MODEL.SWIN.NUM_HEADS, 257 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 258 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 259 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 260 | qk_scale=config.MODEL.SWIN.QK_SCALE, 261 | drop_rate=config.MODEL.DROP_RATE, 262 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 263 | ape=config.MODEL.SWIN.APE, 264 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 265 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 266 | elif model_type == 'swin_hac_v15': 267 | model = SwinTransformer_hac_v15(img_size=config.DATA.IMG_SIZE, 268 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 269 | in_chans=config.MODEL.SWIN.IN_CHANS, 270 | num_classes=config.MODEL.NUM_CLASSES, 271 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 272 | depths=config.MODEL.SWIN.DEPTHS, 273 | num_heads=config.MODEL.SWIN.NUM_HEADS, 274 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 275 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 276 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 277 | qk_scale=config.MODEL.SWIN.QK_SCALE, 278 | drop_rate=config.MODEL.DROP_RATE, 279 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 280 | ape=config.MODEL.SWIN.APE, 281 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 282 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 283 | elif model_type == 'swin_hac_v16': 284 | model = SwinTransformer_hac_v16(img_size=config.DATA.IMG_SIZE, 285 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 286 | in_chans=config.MODEL.SWIN.IN_CHANS, 287 | num_classes=config.MODEL.NUM_CLASSES, 288 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 289 | depths=config.MODEL.SWIN.DEPTHS, 290 | num_heads=config.MODEL.SWIN.NUM_HEADS, 291 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 292 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 293 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 294 | qk_scale=config.MODEL.SWIN.QK_SCALE, 295 | drop_rate=config.MODEL.DROP_RATE, 296 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 297 | ape=config.MODEL.SWIN.APE, 298 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 299 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 300 | else: 301 | raise NotImplementedError(f"Unkown model: {model_type}") 302 | 303 | return model 304 | -------------------------------------------------------------------------------- /Swin-Transformer/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.RESUME = '' 51 | # Number of classes, overwritten in data preparation 52 | _C.MODEL.NUM_CLASSES = 1000 53 | # Dropout rate 54 | _C.MODEL.DROP_RATE = 0.0 55 | # Drop path rate 56 | _C.MODEL.DROP_PATH_RATE = 0.1 57 | # Label Smoothing 58 | _C.MODEL.LABEL_SMOOTHING = 0.1 59 | 60 | # Swin Transformer parameters 61 | _C.MODEL.SWIN = CN() 62 | _C.MODEL.SWIN.PATCH_SIZE = 4 63 | _C.MODEL.SWIN.IN_CHANS = 3 64 | _C.MODEL.SWIN.EMBED_DIM = 96 65 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 66 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 67 | _C.MODEL.SWIN.WINDOW_SIZE = 7 68 | _C.MODEL.SWIN.MLP_RATIO = 4. 69 | _C.MODEL.SWIN.QKV_BIAS = True 70 | _C.MODEL.SWIN.QK_SCALE = None 71 | _C.MODEL.SWIN.APE = False 72 | _C.MODEL.SWIN.PATCH_NORM = True 73 | 74 | # ----------------------------------------------------------------------------- 75 | # Training settings 76 | # ----------------------------------------------------------------------------- 77 | _C.TRAIN = CN() 78 | _C.TRAIN.START_EPOCH = 0 79 | _C.TRAIN.EPOCHS = 300 80 | _C.TRAIN.WARMUP_EPOCHS = 20 81 | _C.TRAIN.WEIGHT_DECAY = 0.05 82 | _C.TRAIN.BASE_LR = 5e-4 83 | _C.TRAIN.WARMUP_LR = 5e-7 84 | _C.TRAIN.MIN_LR = 5e-6 85 | # Clip gradient norm 86 | _C.TRAIN.CLIP_GRAD = 5.0 87 | # Auto resume from latest checkpoint 88 | _C.TRAIN.AUTO_RESUME = True 89 | # Gradient accumulation steps 90 | # could be overwritten by command line argument 91 | _C.TRAIN.ACCUMULATION_STEPS = 0 92 | # Whether to use gradient checkpointing to save memory 93 | # could be overwritten by command line argument 94 | _C.TRAIN.USE_CHECKPOINT = False 95 | 96 | # LR scheduler 97 | _C.TRAIN.LR_SCHEDULER = CN() 98 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 99 | # Epoch interval to decay LR, used in StepLRScheduler 100 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 101 | # LR decay rate, used in StepLRScheduler 102 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 103 | 104 | # Optimizer 105 | _C.TRAIN.OPTIMIZER = CN() 106 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 107 | # Optimizer Epsilon 108 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 109 | # Optimizer Betas 110 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 111 | # SGD momentum 112 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 113 | 114 | # ----------------------------------------------------------------------------- 115 | # Augmentation settings 116 | # ----------------------------------------------------------------------------- 117 | _C.AUG = CN() 118 | # Color jitter factor 119 | _C.AUG.COLOR_JITTER = 0.4 120 | # Use AutoAugment policy. "v0" or "original" 121 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 122 | # Random erase prob 123 | _C.AUG.REPROB = 0.25 124 | # Random erase mode 125 | _C.AUG.REMODE = 'pixel' 126 | # Random erase count 127 | _C.AUG.RECOUNT = 1 128 | # Mixup alpha, mixup enabled if > 0 129 | _C.AUG.MIXUP = 0.8 130 | # Cutmix alpha, cutmix enabled if > 0 131 | _C.AUG.CUTMIX = 1.0 132 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 133 | _C.AUG.CUTMIX_MINMAX = None 134 | # Probability of performing mixup or cutmix when either/both is enabled 135 | _C.AUG.MIXUP_PROB = 1.0 136 | # Probability of switching to cutmix when both mixup and cutmix enabled 137 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 138 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 139 | _C.AUG.MIXUP_MODE = 'batch' 140 | 141 | # ----------------------------------------------------------------------------- 142 | # Testing settings 143 | # ----------------------------------------------------------------------------- 144 | _C.TEST = CN() 145 | # Whether to use center crop when testing 146 | _C.TEST.CROP = True 147 | 148 | # ----------------------------------------------------------------------------- 149 | # Misc 150 | # ----------------------------------------------------------------------------- 151 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 152 | # overwritten by command line argument 153 | _C.AMP_OPT_LEVEL = '' 154 | # Path to output folder, overwritten by command line argument 155 | _C.OUTPUT = '' 156 | # Tag of experiment, overwritten by command line argument 157 | _C.TAG = 'default' 158 | # Frequency to save checkpoint 159 | _C.SAVE_FREQ = 1 160 | # Frequency to logging info 161 | _C.PRINT_FREQ = 10 162 | # Fixed random seed 163 | _C.SEED = 0 164 | # Perform evaluation only, overwritten by command line argument 165 | _C.EVAL_MODE = False 166 | # Test throughput only, overwritten by command line argument 167 | _C.THROUGHPUT_MODE = False 168 | # local rank for DistributedDataParallel, given by command line argument 169 | _C.LOCAL_RANK = 0 170 | 171 | 172 | def _update_config_from_file(config, cfg_file): 173 | config.defrost() 174 | with open(cfg_file, 'r') as f: 175 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 176 | 177 | for cfg in yaml_cfg.setdefault('BASE', ['']): 178 | if cfg: 179 | _update_config_from_file( 180 | config, os.path.join(os.path.dirname(cfg_file), cfg) 181 | ) 182 | print('=> merge config from {}'.format(cfg_file)) 183 | config.merge_from_file(cfg_file) 184 | config.freeze() 185 | 186 | 187 | def update_config(config, args): 188 | _update_config_from_file(config, args.cfg) 189 | 190 | config.defrost() 191 | if args.opts: 192 | config.merge_from_list(args.opts) 193 | 194 | # merge from specific arguments 195 | if args.batch_size: 196 | config.DATA.BATCH_SIZE = args.batch_size 197 | if args.data_path: 198 | config.DATA.DATA_PATH = args.data_path 199 | if args.zip: 200 | config.DATA.ZIP_MODE = True 201 | if args.cache_mode: 202 | config.DATA.CACHE_MODE = args.cache_mode 203 | if args.resume: 204 | config.MODEL.RESUME = args.resume 205 | if args.accumulation_steps: 206 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 207 | if args.use_checkpoint: 208 | config.TRAIN.USE_CHECKPOINT = True 209 | if args.amp_opt_level: 210 | config.AMP_OPT_LEVEL = args.amp_opt_level 211 | if args.output: 212 | config.OUTPUT = args.output 213 | if args.tag: 214 | config.TAG = args.tag 215 | if args.eval: 216 | config.EVAL_MODE = True 217 | if args.throughput: 218 | config.THROUGHPUT_MODE = True 219 | 220 | # set local rank for distributed training 221 | config.LOCAL_RANK = args.local_rank 222 | 223 | # output folder 224 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 225 | 226 | config.freeze() 227 | 228 | 229 | def get_config(args): 230 | """Get a yacs CfgNode object with default values.""" 231 | # Return a clone so that the defaults will not be altered 232 | # This is for the "local variable" use pattern 233 | config = _C.clone() 234 | update_config(config, args) 235 | 236 | return config 237 | -------------------------------------------------------------------------------- /Swin-Transformer/configs/acmix_swin_small_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | EPOCHS: 300 3 | 4 | MODEL: 5 | TYPE: swin_acmix 6 | NAME: acmix_swin_small_patch4_window7_224 7 | DROP_PATH_RATE: 0.3 8 | SWIN: 9 | EMBED_DIM: 96 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /Swin-Transformer/configs/acmix_swin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | EPOCHS: 300 3 | 4 | MODEL: 5 | TYPE: swin_acmix 6 | NAME: acmix_swin_tiny_patch4_window7_224 7 | DROP_PATH_RATE: 0.2 8 | SWIN: 9 | EMBED_DIM: 96 10 | DEPTHS: [ 2, 2, 6, 2 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /Swin-Transformer/configs/swin_base_patch4_window12_384.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | DATA: 3 | IMG_SIZE: 384 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_base_patch4_window12_384 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TEST: 13 | CROP: False -------------------------------------------------------------------------------- /Swin-Transformer/configs/swin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /Swin-Transformer/configs/swin_large_patch4_window12_384.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | DATA: 3 | IMG_SIZE: 384 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_large_patch4_window12_384 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 12 12 | TEST: 13 | CROP: False -------------------------------------------------------------------------------- /Swin-Transformer/configs/swin_large_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | MODEL: 3 | TYPE: swin 4 | NAME: swin_large_patch4_window7_224 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 6, 12, 24, 48 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /Swin-Transformer/configs/swin_small_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /Swin-Transformer/configs/swin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /Swin-Transformer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader -------------------------------------------------------------------------------- /Swin-Transformer/data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | import torch.distributed as dist 12 | from torchvision import datasets, transforms 13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.data import Mixup 15 | from timm.data import create_transform 16 | from timm.data.transforms import _pil_interp 17 | 18 | from .cached_image_folder import CachedImageFolder 19 | from .samplers import SubsetRandomSampler 20 | 21 | 22 | def build_loader(config): 23 | config.defrost() 24 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 25 | config.freeze() 26 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 27 | dataset_val, _ = build_dataset(is_train=False, config=config) 28 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 29 | 30 | num_tasks = dist.get_world_size() 31 | global_rank = dist.get_rank() 32 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 33 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 34 | sampler_train = SubsetRandomSampler(indices) 35 | else: 36 | sampler_train = torch.utils.data.DistributedSampler( 37 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 38 | ) 39 | 40 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 41 | sampler_val = SubsetRandomSampler(indices) 42 | 43 | data_loader_train = torch.utils.data.DataLoader( 44 | dataset_train, sampler=sampler_train, 45 | batch_size=config.DATA.BATCH_SIZE, 46 | num_workers=config.DATA.NUM_WORKERS, 47 | pin_memory=config.DATA.PIN_MEMORY, 48 | drop_last=True, 49 | ) 50 | 51 | data_loader_val = torch.utils.data.DataLoader( 52 | dataset_val, sampler=sampler_val, 53 | batch_size=config.DATA.BATCH_SIZE, 54 | shuffle=False, 55 | num_workers=config.DATA.NUM_WORKERS, 56 | pin_memory=config.DATA.PIN_MEMORY, 57 | drop_last=False 58 | ) 59 | 60 | # setup mixup / cutmix 61 | mixup_fn = None 62 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 63 | if mixup_active: 64 | mixup_fn = Mixup( 65 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 66 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 67 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 68 | 69 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 70 | 71 | 72 | def build_dataset(is_train, config): 73 | transform = build_transform(is_train, config) 74 | if config.DATA.DATASET == 'imagenet': 75 | prefix = 'train' if is_train else 'val' 76 | if config.DATA.ZIP_MODE: 77 | ann_file = prefix + "_map.txt" 78 | prefix = prefix + ".zip@/" 79 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 80 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 81 | else: 82 | root = os.path.join(config.DATA.DATA_PATH, prefix) 83 | dataset = datasets.ImageFolder(root, transform=transform) 84 | nb_classes = 1000 85 | else: 86 | raise NotImplementedError("We only support ImageNet Now.") 87 | 88 | return dataset, nb_classes 89 | 90 | 91 | def build_transform(is_train, config): 92 | resize_im = config.DATA.IMG_SIZE > 32 93 | if is_train: 94 | # this should always dispatch to transforms_imagenet_train 95 | transform = create_transform( 96 | input_size=config.DATA.IMG_SIZE, 97 | is_training=True, 98 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 99 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 100 | re_prob=config.AUG.REPROB, 101 | re_mode=config.AUG.REMODE, 102 | re_count=config.AUG.RECOUNT, 103 | interpolation=config.DATA.INTERPOLATION, 104 | ) 105 | if not resize_im: 106 | # replace RandomResizedCropAndInterpolation with 107 | # RandomCrop 108 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 109 | return transform 110 | 111 | t = [] 112 | if resize_im: 113 | if config.TEST.CROP: 114 | size = int((256 / 224) * config.DATA.IMG_SIZE) 115 | t.append( 116 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 117 | # to maintain same ratio w.r.t. 224 images 118 | ) 119 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 120 | else: 121 | t.append( 122 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 123 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 124 | ) 125 | 126 | t.append(transforms.ToTensor()) 127 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 128 | return transforms.Compose(t) 129 | -------------------------------------------------------------------------------- /Swin-Transformer/data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .zipreader import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | 190 | 191 | def accimage_loader(path): 192 | import accimage 193 | try: 194 | return accimage.Image(path) 195 | except IOError: 196 | # Potentially a decoding problem, fall back to PIL.Image 197 | return pil_loader(path) 198 | 199 | 200 | def default_img_loader(path): 201 | from torchvision import get_image_backend 202 | if get_image_backend() == 'accimage': 203 | return accimage_loader(path) 204 | else: 205 | return pil_loader(path) 206 | 207 | 208 | class CachedImageFolder(DatasetFolder): 209 | """A generic data loader where the images are arranged in this way: :: 210 | root/dog/xxx.png 211 | root/dog/xxy.png 212 | root/dog/xxz.png 213 | root/cat/123.png 214 | root/cat/nsdf3.png 215 | root/cat/asd932_.png 216 | Args: 217 | root (string): Root directory path. 218 | transform (callable, optional): A function/transform that takes in an PIL image 219 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 220 | target_transform (callable, optional): A function/transform that takes in the 221 | target and transforms it. 222 | loader (callable, optional): A function to load an image given its path. 223 | Attributes: 224 | imgs (list): List of (image path, class_index) tuples 225 | """ 226 | 227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 228 | loader=default_img_loader, cache_mode="no"): 229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 230 | ann_file=ann_file, img_prefix=img_prefix, 231 | transform=transform, target_transform=target_transform, 232 | cache_mode=cache_mode) 233 | self.imgs = self.samples 234 | 235 | def __getitem__(self, index): 236 | """ 237 | Args: 238 | index (int): Index 239 | Returns: 240 | tuple: (image, target) where target is class_index of the target class. 241 | """ 242 | path, target = self.samples[index] 243 | image = self.loader(path) 244 | if self.transform is not None: 245 | img = self.transform(image) 246 | else: 247 | img = image 248 | if self.target_transform is not None: 249 | target = self.target_transform(target) 250 | 251 | return img, target 252 | -------------------------------------------------------------------------------- /Swin-Transformer/data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /Swin-Transformer/data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import zipfile 10 | import io 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageFile 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def is_zip_path(img_or_path): 19 | """judge if this is a zip path""" 20 | return '.zip@' in img_or_path 21 | 22 | 23 | class ZipReader(object): 24 | """A class to read zipped files""" 25 | zip_bank = dict() 26 | 27 | def __init__(self): 28 | super(ZipReader, self).__init__() 29 | 30 | @staticmethod 31 | def get_zipfile(path): 32 | zip_bank = ZipReader.zip_bank 33 | if path not in zip_bank: 34 | zfile = zipfile.ZipFile(path, 'r') 35 | zip_bank[path] = zfile 36 | return zip_bank[path] 37 | 38 | @staticmethod 39 | def split_zip_style_path(path): 40 | pos_at = path.index('@') 41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 42 | 43 | zip_path = path[0: pos_at] 44 | folder_path = path[pos_at + 1:] 45 | folder_path = str.strip(folder_path, '/') 46 | return zip_path, folder_path 47 | 48 | @staticmethod 49 | def list_folder(path): 50 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 51 | 52 | zfile = ZipReader.get_zipfile(zip_path) 53 | folder_list = [] 54 | for file_foler_name in zfile.namelist(): 55 | file_foler_name = str.strip(file_foler_name, '/') 56 | if file_foler_name.startswith(folder_path) and \ 57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 58 | file_foler_name != folder_path: 59 | if len(folder_path) == 0: 60 | folder_list.append(file_foler_name) 61 | else: 62 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 63 | 64 | return folder_list 65 | 66 | @staticmethod 67 | def list_files(path, extension=None): 68 | if extension is None: 69 | extension = ['.*'] 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and \ 77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 78 | if len(folder_path) == 0: 79 | file_lists.append(file_foler_name) 80 | else: 81 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 82 | 83 | return file_lists 84 | 85 | @staticmethod 86 | def read(path): 87 | zip_path, path_img = ZipReader.split_zip_style_path(path) 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | data = zfile.read(path_img) 90 | return data 91 | 92 | @staticmethod 93 | def imread(path): 94 | zip_path, path_img = ZipReader.split_zip_style_path(path) 95 | zfile = ZipReader.get_zipfile(zip_path) 96 | data = zfile.read(path_img) 97 | try: 98 | im = Image.open(io.BytesIO(data)) 99 | except: 100 | print("ERROR IMG LOADED: ", path_img) 101 | random_img = np.random.rand(224, 224, 3) * 255 102 | im = Image.fromarray(np.uint8(random_img)) 103 | return im 104 | -------------------------------------------------------------------------------- /Swin-Transformer/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /Swin-Transformer/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from timm.scheduler.cosine_lr import CosineLRScheduler 10 | from timm.scheduler.step_lr import StepLRScheduler 11 | from timm.scheduler.scheduler import Scheduler 12 | 13 | 14 | def build_scheduler(config, optimizer, n_iter_per_epoch): 15 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 16 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 17 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 18 | 19 | lr_scheduler = None 20 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 21 | lr_scheduler = CosineLRScheduler( 22 | optimizer, 23 | t_initial=num_steps, 24 | t_mul=1., 25 | lr_min=config.TRAIN.MIN_LR, 26 | warmup_lr_init=config.TRAIN.WARMUP_LR, 27 | warmup_t=warmup_steps, 28 | cycle_limit=1, 29 | t_in_epochs=False, 30 | ) 31 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 32 | lr_scheduler = LinearLRScheduler( 33 | optimizer, 34 | t_initial=num_steps, 35 | lr_min_rate=0.01, 36 | warmup_lr_init=config.TRAIN.WARMUP_LR, 37 | warmup_t=warmup_steps, 38 | t_in_epochs=False, 39 | ) 40 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 41 | lr_scheduler = StepLRScheduler( 42 | optimizer, 43 | decay_t=decay_steps, 44 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 45 | warmup_lr_init=config.TRAIN.WARMUP_LR, 46 | warmup_t=warmup_steps, 47 | t_in_epochs=False, 48 | ) 49 | 50 | return lr_scheduler 51 | 52 | 53 | class LinearLRScheduler(Scheduler): 54 | def __init__(self, 55 | optimizer: torch.optim.Optimizer, 56 | t_initial: int, 57 | lr_min_rate: float, 58 | warmup_t=0, 59 | warmup_lr_init=0., 60 | t_in_epochs=True, 61 | noise_range_t=None, 62 | noise_pct=0.67, 63 | noise_std=1.0, 64 | noise_seed=42, 65 | initialize=True, 66 | ) -> None: 67 | super().__init__( 68 | optimizer, param_group_field="lr", 69 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 70 | initialize=initialize) 71 | 72 | self.t_initial = t_initial 73 | self.lr_min_rate = lr_min_rate 74 | self.warmup_t = warmup_t 75 | self.warmup_lr_init = warmup_lr_init 76 | self.t_in_epochs = t_in_epochs 77 | if self.warmup_t: 78 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 79 | super().update_groups(self.warmup_lr_init) 80 | else: 81 | self.warmup_steps = [1 for _ in self.base_values] 82 | 83 | def _get_lr(self, t): 84 | if t < self.warmup_t: 85 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 86 | else: 87 | t = t - self.warmup_t 88 | total_t = self.t_initial - self.warmup_t 89 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 90 | return lrs 91 | 92 | def get_epoch_values(self, epoch: int): 93 | if self.t_in_epochs: 94 | return self._get_lr(epoch) 95 | else: 96 | return None 97 | 98 | def get_update_values(self, num_updates: int): 99 | if not self.t_in_epochs: 100 | return self._get_lr(num_updates) 101 | else: 102 | return None 103 | -------------------------------------------------------------------------------- /Swin-Transformer/main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import argparse 11 | import datetime 12 | import numpy as np 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | 18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 19 | from timm.utils import accuracy, AverageMeter 20 | 21 | from config import get_config 22 | from models import build_model 23 | from data import build_loader 24 | from lr_scheduler import build_scheduler 25 | from optimizer import build_optimizer 26 | from logger import create_logger 27 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 28 | 29 | try: 30 | # noinspection PyUnresolvedReferences 31 | from apex import amp 32 | except ImportError: 33 | amp = None 34 | 35 | 36 | def parse_option(): 37 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 38 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 39 | parser.add_argument( 40 | "--opts", 41 | help="Modify config options by adding 'KEY VALUE' pairs. ", 42 | default=None, 43 | nargs='+', 44 | ) 45 | 46 | # easy config modification 47 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 48 | parser.add_argument('--data-path', type=str, help='path to dataset') 49 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 50 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 51 | help='no: no cache, ' 52 | 'full: cache all data, ' 53 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 54 | parser.add_argument('--resume', help='resume from checkpoint') 55 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 56 | parser.add_argument('--use-checkpoint', action='store_true', 57 | help="whether to use gradient checkpointing to save memory") 58 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 59 | help='mixed precision opt level, if O0, no amp is used') 60 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 61 | help='root of output folder, the full path is // (default: output)') 62 | parser.add_argument('--tag', help='tag of experiment') 63 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 64 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 65 | 66 | # distributed training 67 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 68 | 69 | args, unparsed = parser.parse_known_args() 70 | 71 | config = get_config(args) 72 | 73 | return args, config 74 | 75 | 76 | def main(config): 77 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 78 | 79 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 80 | model = build_model(config) 81 | model.cuda() 82 | logger.info(str(model)) 83 | 84 | optimizer = build_optimizer(config, model) 85 | if config.AMP_OPT_LEVEL != "O0": 86 | model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) 87 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 88 | model_without_ddp = model.module 89 | 90 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 91 | logger.info(f"number of params: {n_parameters}") 92 | if hasattr(model_without_ddp, 'flops'): 93 | flops = model_without_ddp.flops() 94 | logger.info(f"number of GFLOPs: {flops / 1e9}") 95 | 96 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 97 | 98 | if config.AUG.MIXUP > 0.: 99 | # smoothing is handled with mixup label transform 100 | criterion = SoftTargetCrossEntropy() 101 | elif config.MODEL.LABEL_SMOOTHING > 0.: 102 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 103 | else: 104 | criterion = torch.nn.CrossEntropyLoss() 105 | 106 | max_accuracy = 0.0 107 | 108 | if config.TRAIN.AUTO_RESUME: 109 | resume_file = auto_resume_helper(config.OUTPUT) 110 | if resume_file: 111 | if config.MODEL.RESUME: 112 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 113 | config.defrost() 114 | config.MODEL.RESUME = resume_file 115 | config.freeze() 116 | logger.info(f'auto resuming from {resume_file}') 117 | else: 118 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 119 | 120 | if config.MODEL.RESUME: 121 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 122 | acc1, acc5, loss = validate(config, data_loader_val, model) 123 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 124 | if config.EVAL_MODE: 125 | return 126 | 127 | if config.THROUGHPUT_MODE: 128 | throughput(data_loader_val, model, logger) 129 | return 130 | 131 | logger.info("Start training") 132 | start_time = time.time() 133 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 134 | data_loader_train.sampler.set_epoch(epoch) 135 | 136 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 137 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 138 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 139 | 140 | acc1, acc5, loss = validate(config, data_loader_val, model) 141 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 142 | max_accuracy = max(max_accuracy, acc1) 143 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 144 | 145 | total_time = time.time() - start_time 146 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 147 | logger.info('Training time {}'.format(total_time_str)) 148 | 149 | 150 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 151 | model.train() 152 | optimizer.zero_grad() 153 | 154 | num_steps = len(data_loader) 155 | batch_time = AverageMeter() 156 | loss_meter = AverageMeter() 157 | norm_meter = AverageMeter() 158 | 159 | start = time.time() 160 | end = time.time() 161 | for idx, (samples, targets) in enumerate(data_loader): 162 | samples = samples.cuda(non_blocking=True) 163 | targets = targets.cuda(non_blocking=True) 164 | 165 | if mixup_fn is not None: 166 | samples, targets = mixup_fn(samples, targets) 167 | 168 | outputs = model(samples) 169 | 170 | if config.TRAIN.ACCUMULATION_STEPS > 1: 171 | loss = criterion(outputs, targets) 172 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 173 | if config.AMP_OPT_LEVEL != "O0": 174 | with amp.scale_loss(loss, optimizer) as scaled_loss: 175 | scaled_loss.backward() 176 | if config.TRAIN.CLIP_GRAD: 177 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 178 | else: 179 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 180 | else: 181 | loss.backward() 182 | if config.TRAIN.CLIP_GRAD: 183 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 184 | else: 185 | grad_norm = get_grad_norm(model.parameters()) 186 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 187 | optimizer.step() 188 | optimizer.zero_grad() 189 | lr_scheduler.step_update(epoch * num_steps + idx) 190 | else: 191 | loss = criterion(outputs, targets) 192 | optimizer.zero_grad() 193 | if config.AMP_OPT_LEVEL != "O0": 194 | with amp.scale_loss(loss, optimizer) as scaled_loss: 195 | scaled_loss.backward() 196 | if config.TRAIN.CLIP_GRAD: 197 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 198 | else: 199 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 200 | else: 201 | loss.backward() 202 | if config.TRAIN.CLIP_GRAD: 203 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 204 | else: 205 | grad_norm = get_grad_norm(model.parameters()) 206 | optimizer.step() 207 | lr_scheduler.step_update(epoch * num_steps + idx) 208 | 209 | torch.cuda.synchronize() 210 | 211 | loss_meter.update(loss.item(), targets.size(0)) 212 | norm_meter.update(grad_norm) 213 | batch_time.update(time.time() - end) 214 | end = time.time() 215 | 216 | if idx % config.PRINT_FREQ == 0: 217 | lr = optimizer.param_groups[0]['lr'] 218 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 219 | etas = batch_time.avg * (num_steps - idx) 220 | logger.info( 221 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 222 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 223 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 224 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 225 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 226 | f'mem {memory_used:.0f}MB') 227 | epoch_time = time.time() - start 228 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 229 | 230 | 231 | @torch.no_grad() 232 | def validate(config, data_loader, model): 233 | criterion = torch.nn.CrossEntropyLoss() 234 | model.eval() 235 | 236 | batch_time = AverageMeter() 237 | loss_meter = AverageMeter() 238 | acc1_meter = AverageMeter() 239 | acc5_meter = AverageMeter() 240 | 241 | end = time.time() 242 | for idx, (images, target) in enumerate(data_loader): 243 | images = images.cuda(non_blocking=True) 244 | target = target.cuda(non_blocking=True) 245 | 246 | # compute output 247 | output = model(images) 248 | 249 | # measure accuracy and record loss 250 | loss = criterion(output, target) 251 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 252 | 253 | acc1 = reduce_tensor(acc1) 254 | acc5 = reduce_tensor(acc5) 255 | loss = reduce_tensor(loss) 256 | 257 | loss_meter.update(loss.item(), target.size(0)) 258 | acc1_meter.update(acc1.item(), target.size(0)) 259 | acc5_meter.update(acc5.item(), target.size(0)) 260 | 261 | # measure elapsed time 262 | batch_time.update(time.time() - end) 263 | end = time.time() 264 | 265 | if idx % config.PRINT_FREQ == 0: 266 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 267 | logger.info( 268 | f'Test: [{idx}/{len(data_loader)}]\t' 269 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 270 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 271 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 272 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 273 | f'Mem {memory_used:.0f}MB') 274 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 275 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 276 | 277 | 278 | @torch.no_grad() 279 | def throughput(data_loader, model, logger): 280 | model.eval() 281 | 282 | for idx, (images, _) in enumerate(data_loader): 283 | images = images.cuda(non_blocking=True) 284 | batch_size = images.shape[0] 285 | for i in range(50): 286 | model(images) 287 | torch.cuda.synchronize() 288 | logger.info(f"throughput averaged with 30 times") 289 | tic1 = time.time() 290 | for i in range(30): 291 | model(images) 292 | torch.cuda.synchronize() 293 | tic2 = time.time() 294 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 295 | return 296 | 297 | 298 | if __name__ == '__main__': 299 | _, config = parse_option() 300 | 301 | if config.AMP_OPT_LEVEL != "O0": 302 | assert amp is not None, "amp not installed!" 303 | 304 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 305 | rank = int(os.environ["RANK"]) 306 | world_size = int(os.environ['WORLD_SIZE']) 307 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 308 | else: 309 | rank = -1 310 | world_size = -1 311 | torch.cuda.set_device(config.LOCAL_RANK) 312 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 313 | torch.distributed.barrier() 314 | 315 | seed = config.SEED + dist.get_rank() 316 | torch.manual_seed(seed) 317 | np.random.seed(seed) 318 | cudnn.benchmark = True 319 | 320 | # linear scale the learning rate according to total batch size, may not be optimal 321 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 322 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 323 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 324 | # gradient accumulation also need to scale the learning rate 325 | if config.TRAIN.ACCUMULATION_STEPS > 1: 326 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 327 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 328 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 329 | config.defrost() 330 | config.TRAIN.BASE_LR = linear_scaled_lr 331 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 332 | config.TRAIN.MIN_LR = linear_scaled_min_lr 333 | config.freeze() 334 | 335 | os.makedirs(config.OUTPUT, exist_ok=True) 336 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 337 | 338 | if dist.get_rank() == 0: 339 | path = os.path.join(config.OUTPUT, "config.json") 340 | with open(path, "w") as f: 341 | f.write(config.dump()) 342 | logger.info(f"Full config saved to {path}") 343 | 344 | # print config 345 | logger.info(config.dump()) 346 | 347 | main(config) 348 | -------------------------------------------------------------------------------- /Swin-Transformer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /Swin-Transformer/models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | from .swin_transformer_acmix import SwinTransformer_acmix 10 | 11 | 12 | def build_model(config): 13 | model_type = config.MODEL.TYPE 14 | if model_type == 'swin': 15 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 16 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 17 | in_chans=config.MODEL.SWIN.IN_CHANS, 18 | num_classes=config.MODEL.NUM_CLASSES, 19 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 20 | depths=config.MODEL.SWIN.DEPTHS, 21 | num_heads=config.MODEL.SWIN.NUM_HEADS, 22 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 23 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 24 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 25 | qk_scale=config.MODEL.SWIN.QK_SCALE, 26 | drop_rate=config.MODEL.DROP_RATE, 27 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 28 | ape=config.MODEL.SWIN.APE, 29 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 30 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 31 | elif model_type == 'swin_acmix': 32 | model = SwinTransformer_acmix(img_size=config.DATA.IMG_SIZE, 33 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 34 | in_chans=config.MODEL.SWIN.IN_CHANS, 35 | num_classes=config.MODEL.NUM_CLASSES, 36 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 37 | depths=config.MODEL.SWIN.DEPTHS, 38 | num_heads=config.MODEL.SWIN.NUM_HEADS, 39 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 40 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 41 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 42 | qk_scale=config.MODEL.SWIN.QK_SCALE, 43 | drop_rate=config.MODEL.DROP_RATE, 44 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 45 | ape=config.MODEL.SWIN.APE, 46 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 47 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 48 | else: 49 | raise NotImplementedError(f"Unkown model: {model_type}") 50 | 51 | return model 52 | -------------------------------------------------------------------------------- /Swin-Transformer/models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | 14 | class GELU(nn.Module): 15 | 16 | def forward(self, x): 17 | return F.gelu(x) 18 | 19 | class Mlp(nn.Module): 20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.): 21 | super().__init__() 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | self.fc1 = nn.Linear(in_features, hidden_features) 25 | self.act = act_layer() 26 | self.fc2 = nn.Linear(hidden_features, out_features) 27 | self.drop = nn.Dropout(drop) 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | x = self.act(x) 32 | x = self.drop(x) 33 | x = self.fc2(x) 34 | x = self.drop(x) 35 | return x 36 | 37 | 38 | def window_partition(x, window_size): 39 | """ 40 | Args: 41 | x: (B, H, W, C) 42 | window_size (int): window size 43 | 44 | Returns: 45 | windows: (num_windows*B, window_size, window_size, C) 46 | """ 47 | B, H, W, C = x.shape 48 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 49 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 50 | return windows 51 | 52 | 53 | def window_reverse(windows, window_size, H, W): 54 | """ 55 | Args: 56 | windows: (num_windows*B, window_size, window_size, C) 57 | window_size (int): Window size 58 | H (int): Height of image 59 | W (int): Width of image 60 | 61 | Returns: 62 | x: (B, H, W, C) 63 | """ 64 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 65 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 66 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 67 | return x 68 | 69 | 70 | class WindowAttention(nn.Module): 71 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 72 | It supports both of shifted and non-shifted window. 73 | 74 | Args: 75 | dim (int): Number of input channels. 76 | window_size (tuple[int]): The height and width of the window. 77 | num_heads (int): Number of attention heads. 78 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 79 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 80 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 81 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 82 | """ 83 | 84 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 85 | 86 | super().__init__() 87 | self.dim = dim 88 | self.window_size = window_size # Wh, Ww 89 | self.num_heads = num_heads 90 | head_dim = dim // num_heads 91 | self.scale = qk_scale or head_dim ** -0.5 92 | 93 | # define a parameter table of relative position bias 94 | self.relative_position_bias_table = nn.Parameter( 95 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 96 | 97 | # get pair-wise relative position index for each token inside the window 98 | coords_h = torch.arange(self.window_size[0]) 99 | coords_w = torch.arange(self.window_size[1]) 100 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 101 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 102 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 103 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 104 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 105 | relative_coords[:, :, 1] += self.window_size[1] - 1 106 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 107 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 108 | self.register_buffer("relative_position_index", relative_position_index) 109 | 110 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 111 | self.attn_drop = nn.Dropout(attn_drop) 112 | self.proj = nn.Linear(dim, dim) 113 | self.proj_drop = nn.Dropout(proj_drop) 114 | 115 | trunc_normal_(self.relative_position_bias_table, std=.02) 116 | self.softmax = nn.Softmax(dim=-1) 117 | 118 | def forward(self, x, mask=None): 119 | """ 120 | Args: 121 | x: input features with shape of (num_windows*B, N, C) 122 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 123 | """ 124 | B_, N, C = x.shape 125 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 126 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 127 | 128 | q = q * self.scale 129 | attn = (q @ k.transpose(-2, -1)) 130 | 131 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 132 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 133 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 134 | attn = attn + relative_position_bias.unsqueeze(0) 135 | 136 | if mask is not None: 137 | nW = mask.shape[0] 138 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 139 | attn = attn.view(-1, self.num_heads, N, N) 140 | attn = self.softmax(attn) 141 | else: 142 | attn = self.softmax(attn) 143 | 144 | attn = self.attn_drop(attn) 145 | 146 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 147 | x = self.proj(x) 148 | x = self.proj_drop(x) 149 | return x 150 | 151 | def extra_repr(self) -> str: 152 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 153 | 154 | def flops(self, N): 155 | # calculate flops for 1 window with token length of N 156 | flops = 0 157 | # qkv = self.qkv(x) 158 | flops += N * self.dim * 3 * self.dim 159 | # attn = (q @ k.transpose(-2, -1)) 160 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 161 | # x = (attn @ v) 162 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 163 | # x = self.proj(x) 164 | flops += N * self.dim * self.dim 165 | return flops 166 | 167 | 168 | class SwinTransformerBlock(nn.Module): 169 | r""" Swin Transformer Block. 170 | 171 | Args: 172 | dim (int): Number of input channels. 173 | input_resolution (tuple[int]): Input resulotion. 174 | num_heads (int): Number of attention heads. 175 | window_size (int): Window size. 176 | shift_size (int): Shift size for SW-MSA. 177 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 178 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 179 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 180 | drop (float, optional): Dropout rate. Default: 0.0 181 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 182 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 183 | act_layer (nn.Module, optional): Activation layer. Default: GELU 184 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 185 | """ 186 | 187 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 188 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 189 | act_layer=GELU, norm_layer=nn.LayerNorm): 190 | super().__init__() 191 | self.dim = dim 192 | self.input_resolution = input_resolution 193 | self.num_heads = num_heads 194 | self.window_size = window_size 195 | self.shift_size = shift_size 196 | self.mlp_ratio = mlp_ratio 197 | if min(self.input_resolution) <= self.window_size: 198 | # if window size is larger than input resolution, we don't partition windows 199 | self.shift_size = 0 200 | self.window_size = min(self.input_resolution) 201 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 202 | 203 | self.norm1 = norm_layer(dim) 204 | self.attn = WindowAttention( 205 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 206 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 207 | 208 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 209 | self.norm2 = norm_layer(dim) 210 | mlp_hidden_dim = int(dim * mlp_ratio) 211 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 212 | 213 | if self.shift_size > 0: 214 | # calculate attention mask for SW-MSA 215 | H, W = self.input_resolution 216 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 217 | h_slices = (slice(0, -self.window_size), 218 | slice(-self.window_size, -self.shift_size), 219 | slice(-self.shift_size, None)) 220 | w_slices = (slice(0, -self.window_size), 221 | slice(-self.window_size, -self.shift_size), 222 | slice(-self.shift_size, None)) 223 | cnt = 0 224 | for h in h_slices: 225 | for w in w_slices: 226 | img_mask[:, h, w, :] = cnt 227 | cnt += 1 228 | 229 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 230 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 231 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 232 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 233 | else: 234 | attn_mask = None 235 | 236 | self.register_buffer("attn_mask", attn_mask) 237 | 238 | def forward(self, x): 239 | H, W = self.input_resolution 240 | B, L, C = x.shape 241 | assert L == H * W, "input feature has wrong size" 242 | 243 | shortcut = x 244 | x = self.norm1(x) 245 | x = x.view(B, H, W, C) 246 | 247 | # cyclic shift 248 | if self.shift_size > 0: 249 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 250 | else: 251 | shifted_x = x 252 | 253 | # partition windows 254 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 255 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 256 | 257 | # W-MSA/SW-MSA 258 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 259 | 260 | # merge windows 261 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 262 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 263 | 264 | # reverse cyclic shift 265 | if self.shift_size > 0: 266 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 267 | else: 268 | x = shifted_x 269 | x = x.view(B, H * W, C) 270 | 271 | # FFN 272 | x = shortcut + self.drop_path(x) 273 | x = x + self.drop_path(self.mlp(self.norm2(x))) 274 | 275 | return x 276 | 277 | def extra_repr(self) -> str: 278 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 279 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 280 | 281 | def flops(self): 282 | flops = 0 283 | H, W = self.input_resolution 284 | # norm1 285 | flops += self.dim * H * W 286 | # W-MSA/SW-MSA 287 | nW = H * W / self.window_size / self.window_size 288 | flops += nW * self.attn.flops(self.window_size * self.window_size) 289 | # mlp 290 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 291 | # norm2 292 | flops += self.dim * H * W 293 | return flops 294 | 295 | 296 | class PatchMerging(nn.Module): 297 | r""" Patch Merging Layer. 298 | 299 | Args: 300 | input_resolution (tuple[int]): Resolution of input feature. 301 | dim (int): Number of input channels. 302 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 303 | """ 304 | 305 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 306 | super().__init__() 307 | self.input_resolution = input_resolution 308 | self.dim = dim 309 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 310 | self.norm = norm_layer(4 * dim) 311 | 312 | def forward(self, x): 313 | """ 314 | x: B, H*W, C 315 | """ 316 | H, W = self.input_resolution 317 | B, L, C = x.shape 318 | assert L == H * W, "input feature has wrong size" 319 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 320 | 321 | x = x.view(B, H, W, C) 322 | 323 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 324 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 325 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 326 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 327 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 328 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 329 | 330 | x = self.norm(x) 331 | x = self.reduction(x) 332 | 333 | return x 334 | 335 | def extra_repr(self) -> str: 336 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 337 | 338 | def flops(self): 339 | H, W = self.input_resolution 340 | flops = H * W * self.dim 341 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 342 | return flops 343 | 344 | 345 | class BasicLayer(nn.Module): 346 | """ A basic Swin Transformer layer for one stage. 347 | 348 | Args: 349 | dim (int): Number of input channels. 350 | input_resolution (tuple[int]): Input resolution. 351 | depth (int): Number of blocks. 352 | num_heads (int): Number of attention heads. 353 | window_size (int): Local window size. 354 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 355 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 356 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 357 | drop (float, optional): Dropout rate. Default: 0.0 358 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 359 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 360 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 361 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 362 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 363 | """ 364 | 365 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 366 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 367 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 368 | 369 | super().__init__() 370 | self.dim = dim 371 | self.input_resolution = input_resolution 372 | self.depth = depth 373 | self.use_checkpoint = use_checkpoint 374 | 375 | # build blocks 376 | self.blocks = nn.ModuleList([ 377 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 378 | num_heads=num_heads, window_size=window_size, 379 | shift_size=0 if (i % 2 == 0) else window_size // 2, 380 | mlp_ratio=mlp_ratio, 381 | qkv_bias=qkv_bias, qk_scale=qk_scale, 382 | drop=drop, attn_drop=attn_drop, 383 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 384 | norm_layer=norm_layer) 385 | for i in range(depth)]) 386 | 387 | # patch merging layer 388 | if downsample is not None: 389 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 390 | else: 391 | self.downsample = None 392 | 393 | def forward(self, x): 394 | for blk in self.blocks: 395 | if self.use_checkpoint: 396 | x = checkpoint.checkpoint(blk, x) 397 | else: 398 | x = blk(x) 399 | if self.downsample is not None: 400 | x = self.downsample(x) 401 | return x 402 | 403 | def extra_repr(self) -> str: 404 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 405 | 406 | def flops(self): 407 | flops = 0 408 | for blk in self.blocks: 409 | flops += blk.flops() 410 | if self.downsample is not None: 411 | flops += self.downsample.flops() 412 | return flops 413 | 414 | 415 | class PatchEmbed(nn.Module): 416 | r""" Image to Patch Embedding 417 | 418 | Args: 419 | img_size (int): Image size. Default: 224. 420 | patch_size (int): Patch token size. Default: 4. 421 | in_chans (int): Number of input image channels. Default: 3. 422 | embed_dim (int): Number of linear projection output channels. Default: 96. 423 | norm_layer (nn.Module, optional): Normalization layer. Default: None 424 | """ 425 | 426 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 427 | super().__init__() 428 | img_size = to_2tuple(img_size) 429 | patch_size = to_2tuple(patch_size) 430 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 431 | self.img_size = img_size 432 | self.patch_size = patch_size 433 | self.patches_resolution = patches_resolution 434 | self.num_patches = patches_resolution[0] * patches_resolution[1] 435 | 436 | self.in_chans = in_chans 437 | self.embed_dim = embed_dim 438 | 439 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 440 | if norm_layer is not None: 441 | self.norm = norm_layer(embed_dim) 442 | else: 443 | self.norm = None 444 | 445 | def forward(self, x): 446 | B, C, H, W = x.shape 447 | # FIXME look at relaxing size constraints 448 | assert H == self.img_size[0] and W == self.img_size[1], \ 449 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 450 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 451 | if self.norm is not None: 452 | x = self.norm(x) 453 | return x 454 | 455 | def flops(self): 456 | Ho, Wo = self.patches_resolution 457 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 458 | if self.norm is not None: 459 | flops += Ho * Wo * self.embed_dim 460 | return flops 461 | 462 | 463 | class SwinTransformer(nn.Module): 464 | r""" Swin Transformer 465 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 466 | https://arxiv.org/pdf/2103.14030 467 | 468 | Args: 469 | img_size (int | tuple(int)): Input image size. Default 224 470 | patch_size (int | tuple(int)): Patch size. Default: 4 471 | in_chans (int): Number of input image channels. Default: 3 472 | num_classes (int): Number of classes for classification head. Default: 1000 473 | embed_dim (int): Patch embedding dimension. Default: 96 474 | depths (tuple(int)): Depth of each Swin Transformer layer. 475 | num_heads (tuple(int)): Number of attention heads in different layers. 476 | window_size (int): Window size. Default: 7 477 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 478 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 479 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 480 | drop_rate (float): Dropout rate. Default: 0 481 | attn_drop_rate (float): Attention dropout rate. Default: 0 482 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 483 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 484 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 485 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 486 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 487 | """ 488 | 489 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 490 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 491 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 492 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 493 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 494 | use_checkpoint=False, **kwargs): 495 | super().__init__() 496 | 497 | self.num_classes = num_classes 498 | self.num_layers = len(depths) 499 | self.embed_dim = embed_dim 500 | self.ape = ape 501 | self.patch_norm = patch_norm 502 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 503 | self.mlp_ratio = mlp_ratio 504 | 505 | # split image into non-overlapping patches 506 | self.patch_embed = PatchEmbed( 507 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 508 | norm_layer=norm_layer if self.patch_norm else None) 509 | num_patches = self.patch_embed.num_patches 510 | patches_resolution = self.patch_embed.patches_resolution 511 | self.patches_resolution = patches_resolution 512 | 513 | # absolute position embedding 514 | if self.ape: 515 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 516 | trunc_normal_(self.absolute_pos_embed, std=.02) 517 | 518 | self.pos_drop = nn.Dropout(p=drop_rate) 519 | 520 | # stochastic depth 521 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 522 | 523 | # build layers 524 | self.layers = nn.ModuleList() 525 | for i_layer in range(self.num_layers): 526 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 527 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 528 | patches_resolution[1] // (2 ** i_layer)), 529 | depth=depths[i_layer], 530 | num_heads=num_heads[i_layer], 531 | window_size=window_size, 532 | mlp_ratio=self.mlp_ratio, 533 | qkv_bias=qkv_bias, qk_scale=qk_scale, 534 | drop=drop_rate, attn_drop=attn_drop_rate, 535 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 536 | norm_layer=norm_layer, 537 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 538 | use_checkpoint=use_checkpoint) 539 | self.layers.append(layer) 540 | 541 | self.norm = norm_layer(self.num_features) 542 | self.avgpool = nn.AdaptiveAvgPool1d(1) 543 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 544 | 545 | self.apply(self._init_weights) 546 | 547 | def _init_weights(self, m): 548 | if isinstance(m, nn.Linear): 549 | trunc_normal_(m.weight, std=.02) 550 | if isinstance(m, nn.Linear) and m.bias is not None: 551 | nn.init.constant_(m.bias, 0) 552 | elif isinstance(m, nn.LayerNorm): 553 | nn.init.constant_(m.bias, 0) 554 | nn.init.constant_(m.weight, 1.0) 555 | 556 | @torch.jit.ignore 557 | def no_weight_decay(self): 558 | return {'absolute_pos_embed'} 559 | 560 | @torch.jit.ignore 561 | def no_weight_decay_keywords(self): 562 | return {'relative_position_bias_table'} 563 | 564 | def forward_features(self, x): 565 | x = self.patch_embed(x) 566 | if self.ape: 567 | x = x + self.absolute_pos_embed 568 | x = self.pos_drop(x) 569 | 570 | for layer in self.layers: 571 | x = layer(x) 572 | 573 | x = self.norm(x) # B L C 574 | x = self.avgpool(x.transpose(1, 2)) # B C 1 575 | x = torch.flatten(x, 1) 576 | return x 577 | 578 | def forward(self, x): 579 | x = self.forward_features(x) 580 | x = self.head(x) 581 | return x 582 | 583 | def flops(self): 584 | flops = 0 585 | flops += self.patch_embed.flops() 586 | for i, layer in enumerate(self.layers): 587 | flops += layer.flops() 588 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 589 | flops += self.num_features * self.num_classes 590 | return flops 591 | -------------------------------------------------------------------------------- /Swin-Transformer/models/swin_transformer_acmix.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer + ACmix 3 | # Copyright (c) 2021 Xuran Pan 4 | # Written by Xuran Pan 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint as checkpoint 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | 13 | class GELU(nn.Module): 14 | 15 | def forward(self, x): 16 | return F.gelu(x) 17 | 18 | class Mlp(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | 37 | def window_partition(x, window_size): 38 | """ 39 | Args: 40 | x: (B, H, W, C) 41 | window_size (int): window size 42 | 43 | Returns: 44 | windows: (num_windows*B, window_size, window_size, C) 45 | """ 46 | B, H, W, C = x.shape 47 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 48 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 49 | return windows 50 | 51 | 52 | def window_reverse(windows, window_size, H, W): 53 | """ 54 | Args: 55 | windows: (num_windows*B, window_size, window_size, C) 56 | window_size (int): Window size 57 | H (int): Height of image 58 | W (int): Width of image 59 | 60 | Returns: 61 | x: (B, H, W, C) 62 | """ 63 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 64 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 65 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 66 | return x 67 | 68 | 69 | 70 | def ones(tensor): 71 | if tensor is not None: 72 | tensor.data.fill_(0.5) 73 | 74 | def zeros(tensor): 75 | if tensor is not None: 76 | tensor.data.fill_(0.0) 77 | 78 | 79 | class WindowAttention_acmix(nn.Module): 80 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 81 | It supports both of shifted and non-shifted window. 82 | 83 | Args: 84 | dim (int): Number of input channels. 85 | window_size (tuple[int]): The height and width of the window. 86 | num_heads (int): Number of attention heads. 87 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 88 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 89 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 90 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 91 | """ 92 | 93 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 94 | 95 | super().__init__() 96 | self.dim = dim 97 | self.window_size = window_size # Wh, Ww 98 | self.num_heads = num_heads 99 | head_dim = dim // num_heads 100 | self.scale = qk_scale or head_dim ** -0.5 101 | 102 | # define a parameter table of relative position bias 103 | self.relative_position_bias_table = nn.Parameter( 104 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 105 | 106 | # get pair-wise relative position index for each token inside the window 107 | coords_h = torch.arange(self.window_size[0]) 108 | coords_w = torch.arange(self.window_size[1]) 109 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 110 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 111 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 112 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 113 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 114 | relative_coords[:, :, 1] += self.window_size[1] - 1 115 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 116 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 117 | self.register_buffer("relative_position_index", relative_position_index) 118 | 119 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 120 | self.attn_drop = nn.Dropout(attn_drop) 121 | self.proj = nn.Linear(dim, dim) 122 | self.proj_drop = nn.Dropout(proj_drop) 123 | 124 | trunc_normal_(self.relative_position_bias_table, std=.02) 125 | self.softmax = nn.Softmax(dim=-1) 126 | 127 | # fully connected layer in Fig.2 128 | self.fc = nn.Conv2d(3*self.num_heads, 9, kernel_size=1, bias=True) 129 | # group convolution layer in Fig.3 130 | self.dep_conv = nn.Conv2d(9*dim//self.num_heads, dim, kernel_size=3, bias=True, groups=dim//self.num_heads, padding=1) 131 | # rates for both paths 132 | self.rate1 = torch.nn.Parameter(torch.Tensor(1)) 133 | self.rate2 = torch.nn.Parameter(torch.Tensor(1)) 134 | self.reset_parameters() 135 | 136 | def reset_parameters(self): 137 | ones(self.rate1) 138 | ones(self.rate2) 139 | # shift initialization for group convolution 140 | kernel = torch.zeros(9, 3, 3) 141 | for i in range(9): 142 | kernel[i, i//3, i%3] = 1. 143 | kernel = kernel.squeeze(0).repeat(self.dim, 1, 1, 1) 144 | self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True) 145 | self.dep_conv.bias = zeros(self.dep_conv.bias) 146 | 147 | 148 | def forward(self, x, H, W, mask=None): 149 | """ 150 | Args: 151 | x: input features with shape of (B, H, W, C) 152 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 153 | """ 154 | 155 | qkv = self.qkv(x) 156 | 157 | # fully connected layer 158 | f_all = qkv.reshape(x.shape[0], H*W, 3*self.num_heads, -1).permute(0, 2, 1, 3) # B, 3*nhead, H*W, C//nhead 159 | f_conv = self.fc(f_all).permute(0, 3, 1, 2).reshape(x.shape[0], 9*x.shape[-1]//self.num_heads, H, W) # B, 9*C//nhead, H, W 160 | # group conovlution 161 | out_conv = self.dep_conv(f_conv).permute(0, 2, 3, 1) # B, H, W, C 162 | 163 | # partition windows 164 | qkv = window_partition(qkv, self.window_size[0]) # nW*B, window_size, window_size, C 165 | 166 | B_, _, _, C = qkv.shape 167 | 168 | qkv = qkv.view(-1, self.window_size[0] * self.window_size[1], C) # nW*B, window_size*window_size, C 169 | 170 | N = self.window_size[0] * self.window_size[1] 171 | C = C // 3 172 | 173 | qkv = qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 174 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 175 | 176 | q = q * self.scale 177 | attn = (q @ k.transpose(-2, -1)) 178 | 179 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 180 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 181 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 182 | attn = attn + relative_position_bias.unsqueeze(0) 183 | 184 | if mask is not None: 185 | nW = mask.shape[0] 186 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 187 | attn = attn.view(-1, self.num_heads, N, N) 188 | attn = self.softmax(attn) 189 | else: 190 | attn = self.softmax(attn) 191 | 192 | attn = self.attn_drop(attn) 193 | 194 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 195 | x = self.proj(x) 196 | 197 | # merge windows 198 | x = x.view(-1, self.window_size[0], self.window_size[1], C) 199 | x = window_reverse(x, self.window_size[0], H, W) # B H' W' C 200 | 201 | x = self.rate1 * x + self.rate2 * out_conv 202 | 203 | x = self.proj_drop(x) 204 | return x 205 | 206 | def extra_repr(self) -> str: 207 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 208 | 209 | def flops(self, N): 210 | # calculate flops for 1 window with token length of N 211 | flops = 0 212 | # qkv = self.qkv(x) 213 | flops += N * self.dim * 3 * self.dim 214 | # attn = (q @ k.transpose(-2, -1)) 215 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 216 | # x = (attn @ v) 217 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 218 | # x = self.proj(x) 219 | flops += N * self.dim * self.dim 220 | return flops 221 | 222 | 223 | class SwinTransformerBlock(nn.Module): 224 | r""" Swin Transformer Block. 225 | 226 | Args: 227 | dim (int): Number of input channels. 228 | input_resolution (tuple[int]): Input resulotion. 229 | num_heads (int): Number of attention heads. 230 | window_size (int): Window size. 231 | shift_size (int): Shift size for SW-MSA. 232 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 233 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 234 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 235 | drop (float, optional): Dropout rate. Default: 0.0 236 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 237 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 238 | act_layer (nn.Module, optional): Activation layer. Default: GELU 239 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 240 | """ 241 | 242 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 243 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 244 | act_layer=GELU, norm_layer=nn.LayerNorm): 245 | super().__init__() 246 | self.dim = dim 247 | self.input_resolution = input_resolution 248 | self.num_heads = num_heads 249 | self.window_size = window_size 250 | self.shift_size = shift_size 251 | self.mlp_ratio = mlp_ratio 252 | if min(self.input_resolution) <= self.window_size: 253 | # if window size is larger than input resolution, we don't partition windows 254 | self.shift_size = 0 255 | self.window_size = min(self.input_resolution) 256 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 257 | 258 | self.norm1 = norm_layer(dim) 259 | self.attn = WindowAttention_acmix( 260 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 261 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 262 | 263 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 264 | self.norm2 = norm_layer(dim) 265 | mlp_hidden_dim = int(dim * mlp_ratio) 266 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 267 | 268 | if self.shift_size > 0: 269 | # calculate attention mask for SW-MSA 270 | H, W = self.input_resolution 271 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 272 | h_slices = (slice(0, -self.window_size), 273 | slice(-self.window_size, -self.shift_size), 274 | slice(-self.shift_size, None)) 275 | w_slices = (slice(0, -self.window_size), 276 | slice(-self.window_size, -self.shift_size), 277 | slice(-self.shift_size, None)) 278 | cnt = 0 279 | for h in h_slices: 280 | for w in w_slices: 281 | img_mask[:, h, w, :] = cnt 282 | cnt += 1 283 | 284 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 285 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 286 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 287 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 288 | else: 289 | attn_mask = None 290 | 291 | self.register_buffer("attn_mask", attn_mask) 292 | 293 | def forward(self, x): 294 | H, W = self.input_resolution 295 | B, L, C = x.shape 296 | assert L == H * W, "input feature has wrong size" 297 | 298 | shortcut = x 299 | x = self.norm1(x) 300 | x = x.view(B, H, W, C) 301 | 302 | # cyclic shift 303 | if self.shift_size > 0: 304 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 305 | else: 306 | shifted_x = x 307 | 308 | shifted_x = self.attn(shifted_x, H, W, mask=self.attn_mask) 309 | 310 | # reverse cyclic shift 311 | if self.shift_size > 0: 312 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 313 | else: 314 | x = shifted_x 315 | x = x.view(B, H * W, C) 316 | 317 | # FFN 318 | x = shortcut + self.drop_path(x) 319 | x = x + self.drop_path(self.mlp(self.norm2(x))) 320 | 321 | return x 322 | 323 | def extra_repr(self) -> str: 324 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 325 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 326 | 327 | def flops(self): 328 | flops = 0 329 | H, W = self.input_resolution 330 | # norm1 331 | flops += self.dim * H * W 332 | # W-MSA/SW-MSA 333 | nW = H * W / self.window_size / self.window_size 334 | flops += nW * self.attn.flops(self.window_size * self.window_size) 335 | # mlp 336 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 337 | # norm2 338 | flops += self.dim * H * W 339 | 340 | # FC for ACmix 341 | flops += 3 * self.dim * H * W * 9 342 | # Group convolution for ACmix 343 | flops += 9 * self.dim * 9 * H * W 344 | 345 | return flops 346 | 347 | 348 | class PatchMerging(nn.Module): 349 | r""" Patch Merging Layer. 350 | 351 | Args: 352 | input_resolution (tuple[int]): Resolution of input feature. 353 | dim (int): Number of input channels. 354 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 355 | """ 356 | 357 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 358 | super().__init__() 359 | self.input_resolution = input_resolution 360 | self.dim = dim 361 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 362 | self.norm = norm_layer(4 * dim) 363 | 364 | def forward(self, x): 365 | """ 366 | x: B, H*W, C 367 | """ 368 | H, W = self.input_resolution 369 | B, L, C = x.shape 370 | assert L == H * W, "input feature has wrong size" 371 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 372 | 373 | x = x.view(B, H, W, C) 374 | 375 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 376 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 377 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 378 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 379 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 380 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 381 | 382 | x = self.norm(x) 383 | x = self.reduction(x) 384 | 385 | return x 386 | 387 | def extra_repr(self) -> str: 388 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 389 | 390 | def flops(self): 391 | H, W = self.input_resolution 392 | flops = H * W * self.dim 393 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 394 | return flops 395 | 396 | 397 | class BasicLayer(nn.Module): 398 | """ A basic Swin Transformer layer for one stage. 399 | 400 | Args: 401 | dim (int): Number of input channels. 402 | input_resolution (tuple[int]): Input resolution. 403 | depth (int): Number of blocks. 404 | num_heads (int): Number of attention heads. 405 | window_size (int): Local window size. 406 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 407 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 408 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 409 | drop (float, optional): Dropout rate. Default: 0.0 410 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 411 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 412 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 413 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 414 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 415 | """ 416 | 417 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 418 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 419 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 420 | 421 | super().__init__() 422 | self.dim = dim 423 | self.input_resolution = input_resolution 424 | self.depth = depth 425 | self.use_checkpoint = use_checkpoint 426 | 427 | # build blocks 428 | self.blocks = nn.ModuleList([ 429 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 430 | num_heads=num_heads, window_size=window_size, 431 | shift_size=0 if (i % 2 == 0) else window_size // 2, 432 | mlp_ratio=mlp_ratio, 433 | qkv_bias=qkv_bias, qk_scale=qk_scale, 434 | drop=drop, attn_drop=attn_drop, 435 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 436 | norm_layer=norm_layer) 437 | for i in range(depth)]) 438 | 439 | # patch merging layer 440 | if downsample is not None: 441 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 442 | else: 443 | self.downsample = None 444 | 445 | def forward(self, x): 446 | for blk in self.blocks: 447 | if self.use_checkpoint: 448 | x = checkpoint.checkpoint(blk, x) 449 | else: 450 | x = blk(x) 451 | if self.downsample is not None: 452 | x = self.downsample(x) 453 | return x 454 | 455 | def extra_repr(self) -> str: 456 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 457 | 458 | def flops(self): 459 | flops = 0 460 | for blk in self.blocks: 461 | flops += blk.flops() 462 | if self.downsample is not None: 463 | flops += self.downsample.flops() 464 | return flops 465 | 466 | 467 | class PatchEmbed(nn.Module): 468 | r""" Image to Patch Embedding 469 | 470 | Args: 471 | img_size (int): Image size. Default: 224. 472 | patch_size (int): Patch token size. Default: 4. 473 | in_chans (int): Number of input image channels. Default: 3. 474 | embed_dim (int): Number of linear projection output channels. Default: 96. 475 | norm_layer (nn.Module, optional): Normalization layer. Default: None 476 | """ 477 | 478 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 479 | super().__init__() 480 | img_size = to_2tuple(img_size) 481 | patch_size = to_2tuple(patch_size) 482 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 483 | self.img_size = img_size 484 | self.patch_size = patch_size 485 | self.patches_resolution = patches_resolution 486 | self.num_patches = patches_resolution[0] * patches_resolution[1] 487 | 488 | self.in_chans = in_chans 489 | self.embed_dim = embed_dim 490 | 491 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 492 | if norm_layer is not None: 493 | self.norm = norm_layer(embed_dim) 494 | else: 495 | self.norm = None 496 | 497 | def forward(self, x): 498 | B, C, H, W = x.shape 499 | # FIXME look at relaxing size constraints 500 | assert H == self.img_size[0] and W == self.img_size[1], \ 501 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 502 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 503 | if self.norm is not None: 504 | x = self.norm(x) 505 | return x 506 | 507 | def flops(self): 508 | Ho, Wo = self.patches_resolution 509 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 510 | if self.norm is not None: 511 | flops += Ho * Wo * self.embed_dim 512 | return flops 513 | 514 | 515 | class SwinTransformer_acmix(nn.Module): 516 | r""" Swin Transformer 517 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 518 | https://arxiv.org/pdf/2103.14030 519 | 520 | Args: 521 | img_size (int | tuple(int)): Input image size. Default 224 522 | patch_size (int | tuple(int)): Patch size. Default: 4 523 | in_chans (int): Number of input image channels. Default: 3 524 | num_classes (int): Number of classes for classification head. Default: 1000 525 | embed_dim (int): Patch embedding dimension. Default: 96 526 | depths (tuple(int)): Depth of each Swin Transformer layer. 527 | num_heads (tuple(int)): Number of attention heads in different layers. 528 | window_size (int): Window size. Default: 7 529 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 530 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 531 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 532 | drop_rate (float): Dropout rate. Default: 0 533 | attn_drop_rate (float): Attention dropout rate. Default: 0 534 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 535 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 536 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 537 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 538 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 539 | """ 540 | 541 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 542 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 543 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 544 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 545 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 546 | use_checkpoint=False, **kwargs): 547 | super().__init__() 548 | 549 | self.num_classes = num_classes 550 | self.num_layers = len(depths) 551 | self.embed_dim = embed_dim 552 | self.ape = ape 553 | self.patch_norm = patch_norm 554 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 555 | self.mlp_ratio = mlp_ratio 556 | 557 | # split image into non-overlapping patches 558 | self.patch_embed = PatchEmbed( 559 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 560 | norm_layer=norm_layer if self.patch_norm else None) 561 | num_patches = self.patch_embed.num_patches 562 | patches_resolution = self.patch_embed.patches_resolution 563 | self.patches_resolution = patches_resolution 564 | 565 | # absolute position embedding 566 | if self.ape: 567 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 568 | trunc_normal_(self.absolute_pos_embed, std=.02) 569 | 570 | self.pos_drop = nn.Dropout(p=drop_rate) 571 | 572 | # stochastic depth 573 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 574 | 575 | # build layers 576 | self.layers = nn.ModuleList() 577 | for i_layer in range(self.num_layers): 578 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 579 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 580 | patches_resolution[1] // (2 ** i_layer)), 581 | depth=depths[i_layer], 582 | num_heads=num_heads[i_layer], 583 | window_size=window_size, 584 | mlp_ratio=self.mlp_ratio, 585 | qkv_bias=qkv_bias, qk_scale=qk_scale, 586 | drop=drop_rate, attn_drop=attn_drop_rate, 587 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 588 | norm_layer=norm_layer, 589 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 590 | use_checkpoint=use_checkpoint) 591 | self.layers.append(layer) 592 | 593 | self.norm = norm_layer(self.num_features) 594 | self.avgpool = nn.AdaptiveAvgPool1d(1) 595 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 596 | 597 | self.apply(self._init_weights) 598 | 599 | def _init_weights(self, m): 600 | if isinstance(m, nn.Linear): 601 | trunc_normal_(m.weight, std=.02) 602 | if isinstance(m, nn.Linear) and m.bias is not None: 603 | nn.init.constant_(m.bias, 0) 604 | elif isinstance(m, nn.LayerNorm): 605 | nn.init.constant_(m.bias, 0) 606 | nn.init.constant_(m.weight, 1.0) 607 | 608 | @torch.jit.ignore 609 | def no_weight_decay(self): 610 | return {'absolute_pos_embed'} 611 | 612 | @torch.jit.ignore 613 | def no_weight_decay_keywords(self): 614 | return {'relative_position_bias_table'} 615 | 616 | def forward_features(self, x): 617 | x = self.patch_embed(x) 618 | if self.ape: 619 | x = x + self.absolute_pos_embed 620 | x = self.pos_drop(x) 621 | 622 | for layer in self.layers: 623 | x = layer(x) 624 | 625 | x = self.norm(x) # B L C 626 | x = self.avgpool(x.transpose(1, 2)) # B C 1 627 | x = torch.flatten(x, 1) 628 | return x 629 | 630 | def forward(self, x): 631 | x = self.forward_features(x) 632 | x = self.head(x) 633 | return x 634 | 635 | def flops(self): 636 | flops = 0 637 | flops += self.patch_embed.flops() 638 | for i, layer in enumerate(self.layers): 639 | flops += layer.flops() 640 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 641 | flops += self.num_features * self.num_classes 642 | return flops 643 | 644 | 645 | if __name__ == '__main__': 646 | 647 | model = SwinTransformer_acmix(depths=[2,2,6,2], drop_path_rate=0.3) 648 | print(model.flops()) 649 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 650 | print(f'{total_trainable_params:,} training parameters.') -------------------------------------------------------------------------------- /Swin-Transformer/optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from torch import optim as optim 9 | 10 | 11 | def build_optimizer(config, model): 12 | """ 13 | Build optimizer, set weight decay of normalization to 0 by default. 14 | """ 15 | skip = {} 16 | skip_keywords = {} 17 | if hasattr(model, 'no_weight_decay'): 18 | skip = model.no_weight_decay() 19 | if hasattr(model, 'no_weight_decay_keywords'): 20 | skip_keywords = model.no_weight_decay_keywords() 21 | parameters = set_weight_decay(model, skip, skip_keywords) 22 | 23 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 24 | optimizer = None 25 | if opt_lower == 'sgd': 26 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 27 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 28 | elif opt_lower == 'adamw': 29 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 30 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 31 | 32 | return optimizer 33 | 34 | 35 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 36 | has_decay = [] 37 | no_decay = [] 38 | 39 | for name, param in model.named_parameters(): 40 | if not param.requires_grad: 41 | continue # frozen weights 42 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 43 | check_keywords_in_name(name, skip_keywords): 44 | no_decay.append(param) 45 | # print(f"{name} has no weight decay") 46 | else: 47 | has_decay.append(param) 48 | return [{'params': has_decay}, 49 | {'params': no_decay, 'weight_decay': 0.}] 50 | 51 | 52 | def check_keywords_in_name(name, keywords=()): 53 | isin = False 54 | for keyword in keywords: 55 | if keyword in name: 56 | isin = True 57 | return isin 58 | -------------------------------------------------------------------------------- /Swin-Transformer/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.distributed as dist 11 | 12 | try: 13 | # noinspection PyUnresolvedReferences 14 | from apex import amp 15 | except ImportError: 16 | amp = None 17 | 18 | 19 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 20 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 21 | if config.MODEL.RESUME.startswith('https'): 22 | checkpoint = torch.hub.load_state_dict_from_url( 23 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 24 | else: 25 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 26 | msg = model.load_state_dict(checkpoint['model'], strict=False) 27 | logger.info(msg) 28 | max_accuracy = 0.0 29 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 30 | optimizer.load_state_dict(checkpoint['optimizer']) 31 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 32 | config.defrost() 33 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 34 | config.freeze() 35 | if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": 36 | amp.load_state_dict(checkpoint['amp']) 37 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 38 | if 'max_accuracy' in checkpoint: 39 | max_accuracy = checkpoint['max_accuracy'] 40 | 41 | del checkpoint 42 | torch.cuda.empty_cache() 43 | return max_accuracy 44 | 45 | 46 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger): 47 | save_state = {'model': model.state_dict(), 48 | 'optimizer': optimizer.state_dict(), 49 | 'lr_scheduler': lr_scheduler.state_dict(), 50 | 'max_accuracy': max_accuracy, 51 | 'epoch': epoch, 52 | 'config': config} 53 | if config.AMP_OPT_LEVEL != "O0": 54 | save_state['amp'] = amp.state_dict() 55 | 56 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 57 | logger.info(f"{save_path} saving......") 58 | torch.save(save_state, save_path) 59 | logger.info(f"{save_path} saved !!!") 60 | 61 | 62 | def get_grad_norm(parameters, norm_type=2): 63 | if isinstance(parameters, torch.Tensor): 64 | parameters = [parameters] 65 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 66 | norm_type = float(norm_type) 67 | total_norm = 0 68 | for p in parameters: 69 | param_norm = p.grad.data.norm(norm_type) 70 | total_norm += param_norm.item() ** norm_type 71 | total_norm = total_norm ** (1. / norm_type) 72 | return total_norm 73 | 74 | 75 | def auto_resume_helper(output_dir): 76 | checkpoints = os.listdir(output_dir) 77 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 78 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 79 | if len(checkpoints) > 0: 80 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 81 | print(f"The latest checkpoint founded: {latest_checkpoint}") 82 | resume_file = latest_checkpoint 83 | else: 84 | resume_file = None 85 | return resume_file 86 | 87 | 88 | def reduce_tensor(tensor): 89 | rt = tensor.clone() 90 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 91 | rt /= dist.get_world_size() 92 | return rt 93 | -------------------------------------------------------------------------------- /figure/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/ACmix/81dddb6dff98f5e238a7fb6ab174e256489c07fa/figure/main.png -------------------------------------------------------------------------------- /figure/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/ACmix/81dddb6dff98f5e238a7fb6ab174e256489c07fa/figure/result.png -------------------------------------------------------------------------------- /figure/shift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/ACmix/81dddb6dff98f5e238a7fb6ab174e256489c07fa/figure/shift.png --------------------------------------------------------------------------------