├── BOAT-CSWin ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── __pycache__ │ ├── checkpoint_saver.cpython-38.pyc │ └── labeled_memcached_dataset.cpython-38.pyc ├── checkpoint_saver.py ├── finetune.py ├── finetune.sh ├── install_req.sh ├── labeled_memcached_dataset.py ├── main.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── cswin.cpython-38.pyc │ │ ├── cswin_boat.cpython-38.pyc │ │ ├── cswin_boat.cpython-39.pyc │ │ ├── cswin_conv.cpython-38.pyc │ │ ├── cswin_full.cpython-38.pyc │ │ ├── cswin_kmeans.cpython-38.pyc │ │ └── cswinmlpplus.cpython-38.pyc │ └── cswin_boat.py ├── output │ └── train │ │ └── 20220328-162150-CSWin_64_12211_tiny_224-224 │ │ └── args.yaml ├── segmentation │ ├── README.md │ ├── backbone │ │ └── cswin_transformer.py │ ├── configs │ │ ├── _base │ │ │ └── upernet_cswin.py │ │ └── cswin │ │ │ ├── upernet_cswin_base.py │ │ │ ├── upernet_cswin_small.py │ │ │ └── upernet_cswin_tiny.py │ ├── install_req.sh │ └── mmcv_custom │ │ └── checkpoint.py ├── teaser.png └── train.sh ├── BOAT-Swin ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── config.py ├── configs │ ├── swin_base_patch4_window12_384_22kto1k_finetune.yaml │ ├── swin_base_patch4_window12_384_finetune.yaml │ ├── swin_base_patch4_window7_224.yaml │ ├── swin_base_patch4_window7_224_22k.yaml │ ├── swin_base_patch4_window7_224_22kto1k_finetune.yaml │ ├── swin_large_patch4_window12_384_22kto1k_finetune.yaml │ ├── swin_large_patch4_window7_224_22k.yaml │ ├── swin_large_patch4_window7_224_22kto1k_finetune.yaml │ ├── swin_mlp_base_patch4_window7_224.yaml │ ├── swin_mlp_tiny_c12_patch4_window8_256.yaml │ ├── swin_mlp_tiny_c24_patch4_window8_256.yaml │ ├── swin_mlp_tiny_c6_patch4_window8_256.yaml │ ├── swin_small_patch4_window7_224.yaml │ ├── swin_tiny_c24_patch4_window8_256.yaml │ └── swin_tiny_patch4_window7_224.yaml ├── data │ ├── __init__.py │ ├── build.py │ ├── cached_image_folder.py │ ├── map22kto1k.txt │ ├── samplers.py │ └── zipreader.py ├── figures │ └── teaser.png ├── get_started.md ├── logger.py ├── lr_scheduler.py ├── main.py ├── models │ ├── __init__.py │ ├── boat_swin_transformer.py │ ├── build.py │ ├── swin_mlp.py │ └── swin_transformer.py ├── optimizer.py └── utils.py └── README.md /BOAT-CSWin/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /BOAT-CSWin/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /BOAT-CSWin/README.md: -------------------------------------------------------------------------------- 1 | # CSWin-BOAT 2 | 3 | This implementation is based on the official implementation of ["CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows"](https://arxiv.org/pdf/2107.00652.pdf). 4 | 5 | ## Requirements 6 | 7 | timm==0.3.4, pytorch>=1.4, opencv, ... , run: 8 | 9 | ``` 10 | bash install_req.sh 11 | ``` 12 | 13 | Apex for mixed precision training is used for finetuning. To install apex, run: 14 | 15 | ``` 16 | git clone https://github.com/NVIDIA/apex 17 | cd apex 18 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 19 | ``` 20 | 21 | Data prepare: ImageNet with the following folder structure, you can extract imagenet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4). 22 | Please follow the train-test splits of CSwin. 23 | 24 | ``` 25 | │imagenet/ 26 | ├──train/ 27 | │ ├── n01440764 28 | │ │ ├── n01440764_10026.JPEG 29 | │ │ ├── n01440764_10027.JPEG 30 | │ │ ├── ...... 31 | │ ├── ...... 32 | ├──val/ 33 | │ ├── n01440764 34 | │ │ ├── ILSVRC2012_val_00000293.JPEG 35 | │ │ ├── ILSVRC2012_val_00002138.JPEG 36 | │ │ ├── ...... 37 | │ ├── ...... 38 | ``` 39 | 40 | ## Train 41 | 42 | Train the three lite variants: CSWin-Tiny, CSWin-Small and CSWin-Base: 43 | ``` 44 | bash train.sh 8 --data --model CSWin_64_12211_tiny_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.2 45 | ``` 46 | ``` 47 | bash train.sh 8 --data --model CSWin_64_24322_small_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.4 48 | ``` 49 | ``` 50 | bash train.sh 8 --data --model CSWin_96_24322_base_224 -b 128 --lr 1e-3 --weight-decay .1 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99992 --drop-path 0.5 51 | ``` 52 | 53 | ## Pre-trained models 54 | 55 | [BOAT-CSwin-Tiny](https://www.dropbox.com/s/rsmtu6r0v2lt0y5/cswin_tiny.pth.tar?dl=0) 56 | 57 | [BOAT-CSwin-Small](https://www.dropbox.com/s/cnl00d1faxxoi19/cswin_small.pth.tar?dl=0) 58 | 59 | [BOAT-CSwin-Base](https://www.dropbox.com/s/92sr8r8zhng1mqg/cswin_base.pth.tar?dl=0) 60 | 61 | ## Acknowledgement 62 | This is developped based on CSWin Transformer 63 | -------------------------------------------------------------------------------- /BOAT-CSWin/SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /BOAT-CSWin/SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /BOAT-CSWin/__pycache__/checkpoint_saver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/__pycache__/checkpoint_saver.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/__pycache__/labeled_memcached_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/__pycache__/labeled_memcached_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/checkpoint_saver.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # CSWin Transformer 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | # written By Xiaoyi Dong 6 | # ------------------------------------------ 7 | 8 | import glob 9 | import operator 10 | import os 11 | import logging 12 | 13 | import torch 14 | 15 | from timm.utils.model import unwrap_model, get_state_dict 16 | import shutil 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | 21 | class CheckpointSaver: 22 | def __init__( 23 | self, 24 | model, 25 | optimizer, 26 | args=None, 27 | model_ema=None, 28 | amp_scaler=None, 29 | checkpoint_prefix='checkpoint', 30 | recovery_prefix='recovery', 31 | checkpoint_dir='', 32 | recovery_dir='', 33 | decreasing=False, 34 | max_history=10, 35 | unwrap_fn=unwrap_model): 36 | 37 | # objects to save state_dicts of 38 | self.model = model 39 | self.optimizer = optimizer 40 | self.args = args 41 | self.model_ema = model_ema 42 | self.amp_scaler = amp_scaler 43 | 44 | # state 45 | self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness 46 | self.best_epoch = None 47 | self.best_metric = None 48 | self.curr_recovery_file = '' 49 | self.last_recovery_file = '' 50 | 51 | # config 52 | self.checkpoint_dir = checkpoint_dir 53 | self.recovery_dir = recovery_dir 54 | self.save_prefix = checkpoint_prefix 55 | self.recovery_prefix = recovery_prefix 56 | self.extension = '.pth.tar' 57 | self.decreasing = decreasing # a lower metric is better if True 58 | self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs 59 | self.max_history = max_history 60 | self.unwrap_fn = unwrap_fn 61 | assert self.max_history >= 1 62 | 63 | def save_checkpoint(self, epoch, metric=None): 64 | assert epoch >= 0 65 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) 66 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) 67 | self._save(tmp_save_path, epoch, metric) 68 | if os.path.exists(last_save_path): 69 | #os.unlink(last_save_path) # required for Windows support. 70 | os.remove(last_save_path) 71 | os.rename(tmp_save_path, last_save_path) 72 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None 73 | if (len(self.checkpoint_files) < self.max_history 74 | or metric is None or self.cmp(metric, worst_file[1])): 75 | if len(self.checkpoint_files) >= self.max_history: 76 | self._cleanup_checkpoints(1) 77 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension 78 | save_path = os.path.join(self.checkpoint_dir, filename) 79 | #os.link(last_save_path, save_path) 80 | shutil.copyfile(last_save_path, save_path) 81 | self.checkpoint_files.append((save_path, metric)) 82 | self.checkpoint_files = sorted( 83 | self.checkpoint_files, key=lambda x: x[1], 84 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better 85 | 86 | checkpoints_str = "Current checkpoints:\n" 87 | for c in self.checkpoint_files: 88 | checkpoints_str += ' {}\n'.format(c) 89 | _logger.info(checkpoints_str) 90 | 91 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): 92 | self.best_epoch = epoch 93 | self.best_metric = metric 94 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) 95 | if os.path.exists(best_save_path): 96 | os.unlink(best_save_path) 97 | #os.link(last_save_path, best_save_path) 98 | shutil.copyfile(last_save_path, best_save_path) 99 | 100 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) 101 | 102 | def _save(self, save_path, epoch, metric=None): 103 | save_state = { 104 | 'epoch': epoch, 105 | 'arch': type(self.model).__name__.lower(), 106 | 'state_dict': get_state_dict(self.model, self.unwrap_fn), 107 | 'optimizer': self.optimizer.state_dict(), 108 | 'version': 2, # version < 2 increments epoch before save 109 | } 110 | if self.args is not None: 111 | save_state['arch'] = self.args.model 112 | save_state['args'] = self.args 113 | if self.amp_scaler is not None: 114 | save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() 115 | if self.model_ema is not None: 116 | save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) 117 | if metric is not None: 118 | save_state['metric'] = metric 119 | torch.save(save_state, save_path) 120 | 121 | def _cleanup_checkpoints(self, trim=0): 122 | trim = min(len(self.checkpoint_files), trim) 123 | delete_index = self.max_history - trim 124 | if delete_index <= 0 or len(self.checkpoint_files) <= delete_index: 125 | return 126 | to_delete = self.checkpoint_files[delete_index:] 127 | for d in to_delete: 128 | try: 129 | _logger.debug("Cleaning checkpoint: {}".format(d)) 130 | os.remove(d[0]) 131 | except Exception as e: 132 | _logger.error("Exception '{}' while deleting checkpoint".format(e)) 133 | self.checkpoint_files = self.checkpoint_files[:delete_index] 134 | 135 | def save_recovery(self, epoch, batch_idx=0): 136 | assert epoch >= 0 137 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension 138 | save_path = os.path.join(self.recovery_dir, filename) 139 | self._save(save_path, epoch) 140 | if os.path.exists(self.last_recovery_file): 141 | try: 142 | _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) 143 | os.remove(self.last_recovery_file) 144 | except Exception as e: 145 | _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) 146 | self.last_recovery_file = self.curr_recovery_file 147 | self.curr_recovery_file = save_path 148 | 149 | def find_recovery(self): 150 | recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) 151 | files = glob.glob(recovery_path + '*' + self.extension) 152 | files = sorted(files) 153 | if len(files): 154 | return files[0] 155 | else: 156 | return '' 157 | -------------------------------------------------------------------------------- /BOAT-CSWin/finetune.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # CSWin Transformer 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | # written By Xiaoyi Dong 6 | # ------------------------------------------ 7 | 8 | NUM_PROC=$1 9 | shift 10 | python -m torch.distributed.launch --nproc_per_node=$NUM_PROC finetune.py "$@" 11 | 12 | -------------------------------------------------------------------------------- /BOAT-CSWin/install_req.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | pip install --user bcolz mxnet tensorboardX matplotlib easydict opencv-python einops --no-cache-dir -U | cat 4 | pip install --user scikit-image imgaug PyTurboJPEG --no-cache-dir -U | cat 5 | pip install --user scikit-learn --no-cache-dir -U | cat 6 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir -U | cat 7 | pip install --user termcolor imgaug prettytable --no-cache-dir -U | cat 8 | pip install --user timm==0.3.4 --no-cache-dir -U | cat 9 | 10 | -------------------------------------------------------------------------------- /BOAT-CSWin/labeled_memcached_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # CSWin Transformer 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | # written By Xiaoyi Dong 6 | # ------------------------------------------ 7 | 8 | from torch.utils.data import Dataset 9 | import numpy as np 10 | import io 11 | from PIL import Image 12 | import os 13 | import json 14 | import random 15 | def load_img(filepath): 16 | img = Image.open(filepath).convert('RGB') 17 | return img 18 | 19 | class McDataset(Dataset): 20 | def __init__(self, data_root, file_list, phase = 'train', transform=None): 21 | self.transform = transform 22 | self.root = os.path.join(data_root, phase) 23 | 24 | temp_label = json.load(open('./dataset/imagenet_class_index.json', 'r')) 25 | self.labels = {} 26 | for i in range(1000): 27 | self.labels[temp_label[str(i)][0]] = i 28 | self.A_paths = [] 29 | self.A_labels = [] 30 | with open(file_list, 'r') as f: 31 | temp_path = f.readlines() 32 | for path in temp_path: 33 | label = self.labels[path.split('/')[0]] 34 | self.A_paths.append(os.path.join(self.root, path.strip())) 35 | self.A_labels.append(label) 36 | 37 | self.num = len(self.A_paths) 38 | self.A_size = len(self.A_paths) 39 | 40 | def __len__(self): 41 | return self.num 42 | 43 | def __getitem__(self, index): 44 | try: 45 | return self.load_img(index) 46 | except: 47 | return self.__getitem__(random.randint(0, self.__len__()-1)) 48 | 49 | def load_img(self, index): 50 | A_path = self.A_paths[index % self.A_size] 51 | A = load_img(A_path) 52 | if self.transform is not None: 53 | A = self.transform(A) 54 | A_label = self.A_labels[index % self.A_size] 55 | return A, A_label 56 | -------------------------------------------------------------------------------- /BOAT-CSWin/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # CSWin Transformer 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | # written By Xiaoyi Dong 6 | # ------------------------------------------ 7 | 8 | from .cswin_boat import * 9 | -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/cswin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/cswin.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/cswin_boat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/cswin_boat.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/cswin_boat.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/cswin_boat.cpython-39.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/cswin_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/cswin_conv.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/cswin_full.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/cswin_full.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/cswin_kmeans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/cswin_kmeans.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/__pycache__/cswinmlpplus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/models/__pycache__/cswinmlpplus.cpython-38.pyc -------------------------------------------------------------------------------- /BOAT-CSWin/models/cswin_boat.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # CSWin Transformer 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | # written By Xiaoyi Dong 6 | # ------------------------------------------ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from functools import partial 13 | 14 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 15 | from timm.models.helpers import load_pretrained 16 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 17 | from timm.models.registry import register_model 18 | from einops.layers.torch import Rearrange 19 | import torch.utils.checkpoint as checkpoint 20 | import numpy as np 21 | import time 22 | from einops import rearrange, repeat 23 | import math 24 | 25 | 26 | 27 | def _cfg(url='', **kwargs): 28 | return { 29 | 'url': url, 30 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 31 | 'crop_pct': .9, 'interpolation': 'bicubic', 32 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 33 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 34 | **kwargs 35 | } 36 | 37 | 38 | default_cfgs = { 39 | 'cswin_224': _cfg(), 40 | 'cswin_384': _cfg( 41 | crop_pct=1.0 42 | ), 43 | 44 | } 45 | 46 | 47 | class Mlp(nn.Module): 48 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 49 | super().__init__() 50 | out_features = out_features or in_features 51 | hidden_features = hidden_features or in_features 52 | self.fc1 = nn.Linear(in_features, hidden_features) 53 | self.act = act_layer() 54 | self.fc2 = nn.Linear(hidden_features, out_features) 55 | self.drop = nn.Dropout(drop) 56 | 57 | def forward(self, x): 58 | x = self.fc1(x) 59 | x = self.act(x) 60 | x = self.drop(x) 61 | x = self.fc2(x) 62 | x = self.drop(x) 63 | return x 64 | 65 | class LePEAttention(nn.Module): 66 | def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0., qk_scale=None): 67 | super().__init__() 68 | self.dim = dim 69 | self.dim_out = dim_out or dim 70 | self.resolution = resolution 71 | self.split_size = split_size 72 | self.num_heads = num_heads 73 | head_dim = dim // num_heads 74 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 75 | self.scale = qk_scale or head_dim ** -0.5 76 | if idx == -1: 77 | H_sp, W_sp = self.resolution, self.resolution 78 | elif idx == 0: 79 | H_sp, W_sp = self.resolution, self.split_size 80 | elif idx == 1: 81 | W_sp, H_sp = self.resolution, self.split_size 82 | else: 83 | print ("ERROR MODE", idx) 84 | exit(0) 85 | self.H_sp = H_sp 86 | self.W_sp = W_sp 87 | stride = 1 88 | self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim) 89 | 90 | self.attn_drop = nn.Dropout(attn_drop) 91 | 92 | def im2cswin(self, x): 93 | B, N, C = x.shape 94 | H = W = int(np.sqrt(N)) 95 | x = x.transpose(-2,-1).contiguous().view(B, C, H, W) 96 | x = img2windows(x, self.H_sp, self.W_sp) 97 | x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() 98 | return x 99 | 100 | def get_lepe(self, x, func): 101 | B, N, C = x.shape 102 | H = W = int(np.sqrt(N)) 103 | x = x.transpose(-2,-1).contiguous().view(B, C, H, W) 104 | 105 | H_sp, W_sp = self.H_sp, self.W_sp 106 | x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 107 | x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W' 108 | 109 | lepe = func(x) ### B', C, H', W' 110 | lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3, 2).contiguous() 111 | 112 | x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp* self.W_sp).permute(0, 1, 3, 2).contiguous() 113 | return x, lepe 114 | 115 | def forward(self, qkv): 116 | """ 117 | x: B L C 118 | """ 119 | q,k,v = qkv[0], qkv[1], qkv[2] 120 | 121 | ### Img2Window 122 | H = W = self.resolution 123 | B, L, C = q.shape 124 | assert L == H * W, "flatten img_tokens has wrong size" 125 | 126 | q = self.im2cswin(q) 127 | k = self.im2cswin(k) 128 | v, lepe = self.get_lepe(v, self.get_v) 129 | 130 | q = q * self.scale 131 | attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N 132 | attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype) 133 | attn = self.attn_drop(attn) 134 | 135 | x = (attn @ v) + lepe 136 | x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C) # B head N N @ B head N C 137 | 138 | ### Window2Img 139 | x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C) # B H' W' C 140 | 141 | return x 142 | 143 | 144 | class ContentAttention(nn.Module): 145 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 146 | super().__init__() 147 | self.dim = dim 148 | self.window_size = window_size # Wh, Ww 149 | self.ws = window_size 150 | self.num_heads = num_heads 151 | head_dim = dim // num_heads 152 | self.scale = qk_scale or head_dim ** -0.5 153 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 154 | self.attn_drop = nn.Dropout(attn_drop) 155 | self.proj = nn.Linear(dim, dim) 156 | self.proj_drop = nn.Dropout(proj_drop) 157 | self.softmax = nn.Softmax(dim=-1) 158 | self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim) 159 | def forward(self, x, mask=None): 160 | #B_, W, H, C = x.shape 161 | #x = x.view(B_,W*H,C) 162 | B_, N, C = x.shape 163 | 164 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# 3, B_, self.num_heads,N,D 165 | if True: 166 | q_pre = qkv[0].reshape(B_*self.num_heads,N, C // self.num_heads).permute(0,2,1)#qkv_pre[:,0].reshape(b*self.num_heads,qkvhd//3//self.num_heads,hh*ww) 167 | ntimes = int(math.log(N//49,2)) 168 | q_idx_last = torch.arange(N).cuda().unsqueeze(0).expand(B_*self.num_heads,N) 169 | for i in range(ntimes): 170 | bh,d,n = q_pre.shape 171 | q_pre_new = q_pre.reshape(bh,d,2,n//2) 172 | q_avg = q_pre_new.mean(dim=-1)#.reshape(b*self.num_heads,qkvhd//3//self.num_heads,) 173 | q_avg = torch.nn.functional.normalize(q_avg,dim=-2) 174 | iters = 2 175 | for i in range(iters): 176 | q_scores = torch.nn.functional.normalize(q_pre.permute(0,2,1),dim=-1).bmm(q_avg) 177 | soft_assign = torch.nn.functional.softmax(q_scores*100, dim=-1).detach() 178 | q_avg = q_pre.bmm(soft_assign) 179 | q_avg = torch.nn.functional.normalize(q_avg,dim=-2) 180 | q_scores = torch.nn.functional.normalize(q_pre.permute(0,2,1),dim=-1).bmm(q_avg).reshape(bh,n,2)#.unsqueeze(2) 181 | q_idx = (q_scores[:,:,0]+1)/(q_scores[:,:,1]+1) 182 | _,q_idx = torch.sort(q_idx,dim=-1) 183 | q_idx_last = q_idx_last.gather(dim=-1,index=q_idx).reshape(bh*2,n//2) 184 | q_idx = q_idx.unsqueeze(1).expand(q_pre.size()) 185 | q_pre = q_pre.gather(dim=-1,index=q_idx).reshape(bh,d,2,n//2).permute(0,2,1,3).reshape(bh*2,d,n//2) 186 | 187 | q_idx = q_idx_last.view(B_,self.num_heads,N) 188 | _,q_idx_rev = torch.sort(q_idx,dim=-1) 189 | q_idx = q_idx.unsqueeze(0).unsqueeze(4).expand(qkv.size()) 190 | qkv_pre = qkv.gather(dim=-2,index=q_idx) 191 | q, k, v = rearrange(qkv_pre, 'qkv b h (nw ws) c -> qkv (b nw) h ws c', ws=49) 192 | 193 | k = k.view(B_*((N//49))//2,2,self.num_heads,49,-1) 194 | k_over1 = k[:,1,:,:20].unsqueeze(1)#.expand(-1,2,-1,-1,-1) 195 | k_over2 = k[:,0,:,29:].unsqueeze(1)#.expand(-1,2,-1,-1,-1) 196 | k_over = torch.cat([k_over1,k_over2],1) 197 | k = torch.cat([k,k_over],3).contiguous().view(B_*((N//49)),self.num_heads,49+20,-1) 198 | 199 | v = v.view(B_*((N//49))//2,2,self.num_heads,49,-1) 200 | v_over1 = v[:,1,:,:20].unsqueeze(1)#.expand(-1,2,-1,-1,-1) 201 | v_over2 = v[:,0,:,29:].unsqueeze(1)#.expand(-1,2,-1,-1,-1) 202 | v_over = torch.cat([v_over1,v_over2],1) 203 | v = torch.cat([v,v_over],3).contiguous().view(B_*((N//49)),self.num_heads,49+20,-1) 204 | 205 | #v = rearrange(v[:,:,:49,:], '(b nw) h ws d -> b h d (nw ws)', h=self.num_heads, b=B_) 206 | #W = int(math.sqrt(N)) 207 | 208 | 209 | attn = (q @ k.transpose(-2, -1))*self.scale 210 | attn = self.softmax(attn) 211 | 212 | attn = self.attn_drop(attn) 213 | out = attn @ v 214 | 215 | out = rearrange(out, '(b nw) h ws d -> b (h d) nw ws', h=self.num_heads, b=B_) 216 | out = out.reshape(B_,self.num_heads,C//self.num_heads,-1) 217 | q_idx_rev = q_idx_rev.unsqueeze(2).expand(out.size()) 218 | x = out.gather(dim=-1,index=q_idx_rev).reshape(B_,C,N).permute(0,2,1) 219 | 220 | 221 | v = rearrange(v[:,:,:49,:], '(b nw) h ws d -> b h d (nw ws)', h=self.num_heads, b=B_) 222 | W = int(math.sqrt(N)) 223 | v = v.gather(dim=-1,index=q_idx_rev).reshape(B_,C,W,W) 224 | v = self.get_v(v) 225 | v = v.reshape(B_,C,N).permute(0,2,1) 226 | x = x + v 227 | 228 | 229 | 230 | x = self.proj(x) 231 | x = self.proj_drop(x) 232 | return x 233 | 234 | 235 | 236 | 237 | class CSWinBlock(nn.Module): 238 | 239 | def __init__(self, dim, reso, num_heads, 240 | split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None, 241 | drop=0., attn_drop=0., drop_path=0., 242 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, 243 | last_stage=False, content=False): 244 | super().__init__() 245 | self.dim = dim 246 | self.num_heads = num_heads 247 | self.patches_resolution = reso 248 | self.split_size = split_size 249 | self.mlp_ratio = mlp_ratio 250 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 251 | self.norm1 = norm_layer(dim) 252 | 253 | if self.patches_resolution == split_size: 254 | last_stage = True 255 | if last_stage: 256 | self.branch_num = 1 257 | else: 258 | self.branch_num = 2 259 | self.proj = nn.Linear(dim, dim) 260 | self.proj_drop = nn.Dropout(drop) 261 | self.content = content 262 | if last_stage: 263 | self.attns = nn.ModuleList([ 264 | LePEAttention( 265 | dim, resolution=self.patches_resolution, idx = -1, 266 | split_size=split_size, num_heads=num_heads, dim_out=dim, 267 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 268 | for i in range(self.branch_num)]) 269 | else: 270 | self.attns = nn.ModuleList([ 271 | LePEAttention( 272 | dim//2, resolution=self.patches_resolution, idx = i, 273 | split_size=split_size, num_heads=num_heads//2, dim_out=dim//2, 274 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 275 | for i in range(self.branch_num)]) 276 | if self.content: 277 | self.content_attn = ContentAttention(dim=dim, window_size=split_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qkv_bias, attn_drop=attn_drop, proj_drop=attn_drop) 278 | self.norm3 = norm_layer(dim) 279 | mlp_hidden_dim = int(dim * mlp_ratio) 280 | 281 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 282 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop) 283 | self.norm2 = norm_layer(dim) 284 | 285 | def forward(self, x): 286 | """ 287 | x: B, H*W, C 288 | """ 289 | 290 | H = W = self.patches_resolution 291 | B, L, C = x.shape 292 | assert L == H * W, "flatten img_tokens has wrong size" 293 | img = self.norm1(x) 294 | qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3) 295 | 296 | if self.branch_num == 2: 297 | x1 = self.attns[0](qkv[:,:,:,:C//2]) 298 | x2 = self.attns[1](qkv[:,:,:,C//2:]) 299 | attened_x = torch.cat([x1,x2], dim=2) 300 | else: 301 | attened_x = self.attns[0](qkv) 302 | attened_x = self.proj(attened_x) 303 | x = x + self.drop_path(attened_x) 304 | if self.content: 305 | x = x + self.drop_path(self.content_attn(self.norm3(x))) 306 | 307 | x = x + self.drop_path(self.mlp(self.norm2(x))) 308 | 309 | return x 310 | 311 | def img2windows(img, H_sp, W_sp): 312 | """ 313 | img: B C H W 314 | """ 315 | B, C, H, W = img.shape 316 | img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 317 | img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) 318 | return img_perm 319 | 320 | def windows2img(img_splits_hw, H_sp, W_sp, H, W): 321 | """ 322 | img_splits_hw: B' H W C 323 | """ 324 | B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) 325 | 326 | img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) 327 | img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 328 | return img 329 | 330 | class Merge_Block(nn.Module): 331 | def __init__(self, dim, dim_out, norm_layer=nn.LayerNorm): 332 | super().__init__() 333 | self.conv = nn.Conv2d(dim, dim_out, 3, 2, 1) 334 | self.norm = norm_layer(dim_out) 335 | 336 | def forward(self, x): 337 | B, new_HW, C = x.shape 338 | H = W = int(np.sqrt(new_HW)) 339 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 340 | x = self.conv(x) 341 | B, C = x.shape[:2] 342 | x = x.view(B, C, -1).transpose(-2, -1).contiguous() 343 | x = self.norm(x) 344 | 345 | return x 346 | 347 | class CSWinTransformer(nn.Module): 348 | """ Vision Transformer with support for patch or hybrid CNN input stage 349 | """ 350 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=96, depth=[2,2,6,2], split_size = [3,5,7], 351 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., 352 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, use_chk=False): 353 | super().__init__() 354 | self.use_chk = use_chk 355 | self.num_classes = num_classes 356 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 357 | heads=num_heads 358 | 359 | self.stage1_conv_embed = nn.Sequential( 360 | nn.Conv2d(in_chans, embed_dim, 7, 4, 2), 361 | Rearrange('b c h w -> b (h w) c', h = img_size//4, w = img_size//4), 362 | nn.LayerNorm(embed_dim) 363 | ) 364 | 365 | curr_dim = embed_dim 366 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))] # stochastic depth decay rule 367 | self.stage1 = nn.ModuleList([ 368 | CSWinBlock( 369 | dim=curr_dim, num_heads=heads[0], reso=img_size//4, mlp_ratio=mlp_ratio, 370 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[0], 371 | drop=drop_rate, attn_drop=attn_drop_rate, 372 | drop_path=dpr[i], norm_layer=norm_layer, content=(i%2==0)) 373 | for i in range(depth[0])]) 374 | 375 | self.merge1 = Merge_Block(curr_dim, curr_dim*2) 376 | curr_dim = curr_dim*2 377 | self.stage2 = nn.ModuleList( 378 | [CSWinBlock( 379 | dim=curr_dim, num_heads=heads[1], reso=img_size//8, mlp_ratio=mlp_ratio, 380 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[1], 381 | drop=drop_rate, attn_drop=attn_drop_rate, 382 | drop_path=dpr[np.sum(depth[:1])+i], norm_layer=norm_layer,content=(i%2==0)) 383 | for i in range(depth[1])]) 384 | 385 | self.merge2 = Merge_Block(curr_dim, curr_dim*2) 386 | curr_dim = curr_dim*2 387 | temp_stage3 = [] 388 | temp_stage3.extend( 389 | [CSWinBlock( 390 | dim=curr_dim, num_heads=heads[2], reso=img_size//16, mlp_ratio=mlp_ratio, 391 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[2], 392 | drop=drop_rate, attn_drop=attn_drop_rate, 393 | drop_path=dpr[np.sum(depth[:2])+i], norm_layer=norm_layer,content=(i%2==0)) 394 | for i in range(depth[2])]) 395 | 396 | self.stage3 = nn.ModuleList(temp_stage3) 397 | 398 | self.merge3 = Merge_Block(curr_dim, curr_dim*2) 399 | curr_dim = curr_dim*2 400 | self.stage4 = nn.ModuleList( 401 | [CSWinBlock( 402 | dim=curr_dim, num_heads=heads[3], reso=img_size//32, mlp_ratio=mlp_ratio, 403 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[-1], 404 | drop=drop_rate, attn_drop=attn_drop_rate, 405 | drop_path=dpr[np.sum(depth[:-1])+i], norm_layer=norm_layer, last_stage=True,content=(i%2==0)) 406 | for i in range(depth[-1])]) 407 | 408 | self.norm = norm_layer(curr_dim) 409 | # Classifier head 410 | self.head = nn.Linear(curr_dim, num_classes) if num_classes > 0 else nn.Identity() 411 | 412 | trunc_normal_(self.head.weight, std=0.02) 413 | self.apply(self._init_weights) 414 | def _init_weights(self, m): 415 | if isinstance(m, nn.Linear): 416 | trunc_normal_(m.weight, std=.02) 417 | if isinstance(m, nn.Linear) and m.bias is not None: 418 | nn.init.constant_(m.bias, 0) 419 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 420 | nn.init.constant_(m.bias, 0) 421 | nn.init.constant_(m.weight, 1.0) 422 | 423 | @torch.jit.ignore 424 | def no_weight_decay(self): 425 | return {'pos_embed', 'cls_token'} 426 | 427 | def get_classifier(self): 428 | return self.head 429 | 430 | def reset_classifier(self, num_classes, global_pool=''): 431 | if self.num_classes != num_classes: 432 | print ('reset head to', num_classes) 433 | self.num_classes = num_classes 434 | self.head = nn.Linear(self.out_dim, num_classes) if num_classes > 0 else nn.Identity() 435 | self.head = self.head.cuda() 436 | trunc_normal_(self.head.weight, std=.02) 437 | if self.head.bias is not None: 438 | nn.init.constant_(self.head.bias, 0) 439 | 440 | def forward_features(self, x): 441 | B = x.shape[0] 442 | x = self.stage1_conv_embed(x) 443 | for blk in self.stage1: 444 | if self.use_chk: 445 | x = checkpoint.checkpoint(blk, x) 446 | else: 447 | x = blk(x) 448 | for pre, blocks in zip([self.merge1, self.merge2, self.merge3], 449 | [self.stage2, self.stage3, self.stage4]): 450 | x = pre(x) 451 | for blk in blocks: 452 | if self.use_chk: 453 | x = checkpoint.checkpoint(blk, x) 454 | else: 455 | x = blk(x) 456 | x = self.norm(x) 457 | return torch.mean(x, dim=1) 458 | 459 | def forward(self, x): 460 | x = self.forward_features(x) 461 | x = self.head(x) 462 | return x 463 | 464 | 465 | def _conv_filter(state_dict, patch_size=16): 466 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 467 | out_dict = {} 468 | for k, v in state_dict.items(): 469 | if 'patch_embed.proj.weight' in k: 470 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 471 | out_dict[k] = v 472 | return out_dict 473 | 474 | ### 224 models 475 | 476 | @register_model 477 | def CSWin_64_12211_tiny_224(pretrained=False, **kwargs): 478 | model = CSWinTransformer(patch_size=4, embed_dim=64, depth=[1,2,21,1], 479 | split_size=[1,2,7,7], num_heads=[2,4,8,16], mlp_ratio=4., **kwargs) 480 | model.default_cfg = default_cfgs['cswin_224'] 481 | return model 482 | 483 | @register_model 484 | def CSWin_64_24322_small_224(pretrained=False, **kwargs): 485 | model = CSWinTransformer(patch_size=4, embed_dim=64, depth=[2,4,32,2], 486 | split_size=[1,2,7,7], num_heads=[2,4,8,16], mlp_ratio=4., **kwargs) 487 | model.default_cfg = default_cfgs['cswin_224'] 488 | return model 489 | 490 | @register_model 491 | def CSWin_96_24322_base_224(pretrained=False, **kwargs): 492 | model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[2,4,32,2], 493 | split_size=[1,2,7,7], num_heads=[4,8,16,32], mlp_ratio=4., **kwargs) 494 | model.default_cfg = default_cfgs['cswin_224'] 495 | return model 496 | 497 | @register_model 498 | def CSWin_144_24322_large_224(pretrained=False, **kwargs): 499 | model = CSWinTransformer(patch_size=4, embed_dim=144, depth=[2,4,32,2], 500 | split_size=[1,2,7,7], num_heads=[6,12,24,24], mlp_ratio=4., **kwargs) 501 | model.default_cfg = default_cfgs['cswin_224'] 502 | return model 503 | 504 | ### 384 models 505 | 506 | @register_model 507 | def CSWin_96_24322_base_384(pretrained=False, **kwargs): 508 | model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[2,4,32,2], 509 | split_size=[1,2,12,12], num_heads=[4,8,16,32], mlp_ratio=4., **kwargs) 510 | model.default_cfg = default_cfgs['cswin_384'] 511 | return model 512 | 513 | @register_model 514 | def CSWin_144_24322_large_384(pretrained=False, **kwargs): 515 | model = CSWinTransformer(patch_size=4, embed_dim=144, depth=[2,4,32,2], 516 | split_size=[1,2,12,12], num_heads=[6,12,24,24], mlp_ratio=4., **kwargs) 517 | model.default_cfg = default_cfgs['cswin_384'] 518 | return model 519 | 520 | -------------------------------------------------------------------------------- /BOAT-CSWin/output/train/20220328-162150-CSWin_64_12211_tiny_224-224/args.yaml: -------------------------------------------------------------------------------- 1 | aa: rand-m9-mstd0.5-inc1 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 2 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | clip_grad: null 11 | color_jitter: 0.4 12 | cooldown_epochs: 10 13 | crop_pct: null 14 | cutmix: 1.0 15 | cutmix_minmax: null 16 | data: /home/yutan/ILSVRC2012_w/tmp 17 | decay_epochs: 30 18 | decay_rate: 0.1 19 | dist_bn: '' 20 | drop: 0.0 21 | drop_block: null 22 | drop_connect: null 23 | drop_path: 0.2 24 | epochs: 300 25 | eval_checkpoint: '' 26 | eval_metric: top1 27 | gp: null 28 | hflip: 0.5 29 | img_size: 224 30 | initial_checkpoint: '' 31 | interpolation: '' 32 | jsd: false 33 | local_rank: 0 34 | log_interval: 50 35 | lr: 0.002 36 | lr_cycle_limit: 1 37 | lr_cycle_mul: 1.0 38 | lr_noise: null 39 | lr_noise_pct: 0.67 40 | lr_noise_std: 1.0 41 | mean: null 42 | min_lr: 1.0e-05 43 | mixup: 0.8 44 | mixup_mode: batch 45 | mixup_off_epoch: 0 46 | mixup_prob: 1.0 47 | mixup_switch_prob: 0.5 48 | model: CSWin_64_12211_tiny_224 49 | model_ema: true 50 | model_ema_decay: 0.99984 51 | model_ema_force_cpu: false 52 | momentum: 0.9 53 | native_amp: false 54 | no_aug: false 55 | no_prefetcher: false 56 | no_resume_opt: false 57 | num_classes: 1000 58 | num_gpu: 1 59 | opt: adamw 60 | opt_betas: null 61 | opt_eps: null 62 | output: '' 63 | patience_epochs: 10 64 | pin_mem: false 65 | pretrained: false 66 | ratio: 67 | - 0.75 68 | - 1.3333333333333333 69 | recount: 1 70 | recovery_interval: 0 71 | remode: pixel 72 | reprob: 0.25 73 | resplit: false 74 | resume: '' 75 | save_images: false 76 | scale: 77 | - 0.08 78 | - 1.0 79 | sched: cosine 80 | seed: 42 81 | smoothing: 0.1 82 | split_bn: false 83 | start_epoch: null 84 | std: null 85 | sync_bn: false 86 | train_interpolation: random 87 | tta: 0 88 | use_chk: false 89 | use_multi_epochs_loader: false 90 | validation_batch_size_multiplier: 1 91 | vflip: 0.0 92 | warmup_epochs: 20 93 | warmup_lr: 1.0e-06 94 | weight_decay: 0.05 95 | workers: 8 96 | -------------------------------------------------------------------------------- /BOAT-CSWin/segmentation/README.md: -------------------------------------------------------------------------------- 1 | # ADE20k Semantic segmentation with CSWin 2 | 3 | 4 | ## Results and Models 5 | 6 | | Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs | config | model | log | 7 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 8 | | CSWin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 49.3 | 50.7 | 60M | 959G | [`config`](configs/cswin/upernet_cswin_tiny.py) | [model](https://github.com/microsoft/CSWin-Transformer/releases/download/v0.2.0/upernet_cswin_tiny.pth) | [log](https://github.com/microsoft/CSWin-Transformer/releases/download/v0.2.0/upernet_cswin_tiny.log.json) | 9 | | CSWin-S | UperNet | ImageNet-1K | 512x512 | 160K | 50.4 | 51.5 | 65M | 1027G | [`config`](configs/cswin/upernet_cswin_small.py) |[model](https://github.com/microsoft/CSWin-Transformer/releases/download/v0.2.0/upernet_cswin_small.pth) | [log](https://github.com/microsoft/CSWin-Transformer/releases/download/v0.2.0/upernet_cswin_small.log.json) | 10 | | CSWin-B | UperNet | ImageNet-1K | 512x512 | 160K | 51.1 | 52.2 | 109M | 1222G | [`config`](configs/cswin/upernet_cswin_base.py) |[model](https://github.com/microsoft/CSWin-Transformer/releases/download/v0.2.0/upernet_cswin_base.pth) | [log](https://github.com/microsoft/CSWin-Transformer/releases/download/v0.2.0/upernet_cswin_base.log.json) | 11 | 12 | 13 | ## Getting started 14 | 15 | 1. Install the [Swin_Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation) repository and some required packages. 16 | 17 | ```bash 18 | git clone https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation 19 | bash install_req.sh 20 | ``` 21 | 22 | 2. Move the CSWin configs and backbone file to the corresponding folder. 23 | 24 | ```bash 25 | cp -r configs/cswin /configs/ 26 | cp config/_base/upernet_cswin.py /config/_base_/models 27 | cp backbone/cswin_transformer.py /mmseg/models/backbones/ 28 | cp mmcv_custom/checkpoint.py /mmcv_custom/ 29 | ``` 30 | 31 | 3. Install [apex](https://github.com/NVIDIA/apex) for mixed-precision training 32 | 33 | ```bash 34 | git clone https://github.com/NVIDIA/apex 35 | cd apex 36 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 37 | ``` 38 | 39 | 4. Follow the guide in [mmseg](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/dataset_prepare.md) to prepare the ADE20k dataset. 40 | 41 | ## Fine-tuning 42 | 43 | Command format: 44 | ``` 45 | tools/dist_train.sh --options model.pretrained= 46 | ``` 47 | 48 | For example, using a CSWin-T backbone with UperNet: 49 | ```bash 50 | bash tools/dist_train.sh \ 51 | configs/cswin/upernet_cswin_tiny.py 8 \ 52 | --options model.pretrained= 53 | ``` 54 | 55 | pretrained models could be found at [main page](https://github.com/microsoft/CSWin-Transformer). 56 | 57 | More config files can be found at [`configs/cswin`](configs/cswin). 58 | 59 | 60 | ## Evaluation 61 | 62 | Command format: 63 | ``` 64 | tools/dist_test.sh --eval mIoU 65 | tools/dist_test.sh --eval mIoU --aug-test 66 | ``` 67 | 68 | For example, evaluate a CSWin-T backbone with UperNet: 69 | ```bash 70 | bash tools/dist_test.sh configs/cswin/upernet_cswin_tiny.py \ 71 | 8 --eval mIoU 72 | ``` 73 | 74 | 75 | --- 76 | 77 | ## Acknowledgment 78 | 79 | This code is built using the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library, [Timm](https://github.com/rwightman/pytorch-image-models) library, the [Swin](https://github.com/microsoft/Swin-Transformer) repository. 80 | -------------------------------------------------------------------------------- /BOAT-CSWin/segmentation/backbone/cswin_transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # CSWin Transformer 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | # written By Xiaoyi Dong 6 | # ------------------------------------------ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from functools import partial 13 | 14 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 15 | from timm.models.helpers import load_pretrained 16 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 17 | from timm.models.resnet import resnet26d, resnet50d 18 | from timm.models.registry import register_model 19 | from einops.layers.torch import Rearrange 20 | import numpy as np 21 | import time 22 | 23 | from mmcv_custom import load_checkpoint 24 | from mmseg.utils import get_root_logger 25 | from ..builder import BACKBONES 26 | 27 | import torch.utils.checkpoint as checkpoint 28 | 29 | 30 | class Mlp(nn.Module): 31 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 32 | super().__init__() 33 | out_features = out_features or in_features 34 | hidden_features = hidden_features or in_features 35 | self.fc1 = nn.Linear(in_features, hidden_features) 36 | self.act = act_layer() 37 | self.fc2 = nn.Linear(hidden_features, out_features) 38 | self.drop = nn.Dropout(drop) 39 | 40 | def forward(self, x): 41 | x = self.fc1(x) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | 49 | class LePEAttention(nn.Module): 50 | def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 51 | """Not supported now, since we have cls_tokens now..... 52 | """ 53 | super().__init__() 54 | self.dim = dim 55 | self.dim_out = dim_out or dim 56 | self.resolution = resolution 57 | self.split_size = split_size 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 61 | self.scale = qk_scale or head_dim ** -0.5 62 | self.idx = idx 63 | if idx == -1: 64 | H_sp, W_sp = self.resolution, self.resolution 65 | elif idx == 0: 66 | H_sp, W_sp = self.resolution, self.split_size 67 | elif idx == 1: 68 | W_sp, H_sp = self.resolution, self.split_size 69 | else: 70 | print ("ERROR MODE", idx) 71 | exit(0) 72 | self.H_sp = H_sp 73 | self.W_sp = W_sp 74 | 75 | self.H_sp_ = self.H_sp 76 | self.W_sp_ = self.W_sp 77 | 78 | stride = 1 79 | self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim) 80 | 81 | self.attn_drop = nn.Dropout(attn_drop) 82 | 83 | def im2cswin(self, x): 84 | B, C, H, W = x.shape 85 | x = img2windows(x, self.H_sp, self.W_sp) 86 | x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() 87 | return x 88 | 89 | def get_rpe(self, x, func): 90 | B, C, H, W = x.shape 91 | H_sp, W_sp = self.H_sp, self.W_sp 92 | x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 93 | x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W' 94 | 95 | rpe = func(x) ### B', C, H', W' 96 | rpe = rpe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3, 2).contiguous() 97 | 98 | x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp* self.W_sp).permute(0, 1, 3, 2).contiguous() 99 | return x, rpe 100 | 101 | def forward(self, temp): 102 | """ 103 | x: B N C 104 | mask: B N N 105 | """ 106 | B, _, C, H, W = temp.shape 107 | 108 | 109 | idx = self.idx 110 | if idx == -1: 111 | H_sp, W_sp = H, W 112 | elif idx == 0: 113 | H_sp, W_sp = H, self.split_size 114 | elif idx == 1: 115 | H_sp, W_sp = self.split_size, W 116 | else: 117 | print ("ERROR MODE in forward", idx) 118 | exit(0) 119 | self.H_sp = H_sp 120 | self.W_sp = W_sp 121 | 122 | ### padding for split window 123 | H_pad = (self.H_sp - H % self.H_sp) % self.H_sp 124 | W_pad = (self.W_sp - W % self.W_sp) % self.W_sp 125 | top_pad = H_pad//2 126 | down_pad = H_pad - top_pad 127 | left_pad = W_pad//2 128 | right_pad = W_pad - left_pad 129 | H_ = H + H_pad 130 | W_ = W + W_pad 131 | 132 | qkv = F.pad(temp, (left_pad, right_pad, top_pad, down_pad)) ### B,3,C,H',W' 133 | qkv = qkv.permute(1, 0, 2, 3, 4) 134 | q,k,v = qkv[0], qkv[1], qkv[2] 135 | 136 | q = self.im2cswin(q) 137 | k = self.im2cswin(k) 138 | v, rpe = self.get_rpe(v, self.get_v) 139 | 140 | ### Local attention 141 | q = q * self.scale 142 | attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N 143 | 144 | attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype) 145 | 146 | attn = self.attn_drop(attn) 147 | 148 | x = (attn @ v) + rpe 149 | x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C) # B head N N @ B head N C 150 | 151 | ### Window2Img 152 | x = windows2img(x, self.H_sp, self.W_sp, H_, W_) # B H_ W_ C 153 | x = x[:, top_pad:H+top_pad, left_pad:W+left_pad, :] 154 | x = x.reshape(B, -1, C) 155 | 156 | return x 157 | 158 | class CSWinBlock(nn.Module): 159 | 160 | def __init__(self, dim, patches_resolution, num_heads, 161 | split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None, 162 | drop=0., attn_drop=0., drop_path=0., 163 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, 164 | last_stage=False): 165 | super().__init__() 166 | self.dim = dim 167 | self.num_heads = num_heads 168 | self.patches_resolution = patches_resolution 169 | self.split_size = split_size 170 | self.mlp_ratio = mlp_ratio 171 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 172 | self.norm1 = norm_layer(dim) 173 | 174 | if last_stage: 175 | self.branch_num = 1 176 | else: 177 | self.branch_num = 2 178 | self.proj = nn.Linear(dim, dim) 179 | self.proj_drop = nn.Dropout(drop) 180 | 181 | if last_stage: 182 | self.attns = nn.ModuleList([ 183 | LePEAttention( 184 | dim, resolution=self.patches_resolution, idx = -1, 185 | split_size=split_size, num_heads=num_heads, dim_out=dim, 186 | qkv_bias=qkv_bias, qk_scale=qk_scale, 187 | attn_drop=attn_drop, proj_drop=drop) 188 | for i in range(self.branch_num)]) 189 | else: 190 | self.attns = nn.ModuleList([ 191 | LePEAttention( 192 | dim//2, resolution=self.patches_resolution, idx = i, 193 | split_size=split_size, num_heads=num_heads//2, dim_out=dim//2, 194 | qkv_bias=qkv_bias, qk_scale=qk_scale, 195 | attn_drop=attn_drop, proj_drop=drop) 196 | for i in range(self.branch_num)]) 197 | mlp_hidden_dim = int(dim * mlp_ratio) 198 | 199 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 200 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop) 201 | self.norm2 = norm_layer(dim) 202 | 203 | atten_mask_matrix = None 204 | 205 | self.register_buffer("atten_mask_matrix", atten_mask_matrix) 206 | self.H = None 207 | self.W = None 208 | 209 | def forward(self, x): 210 | """ 211 | x: B, H*W, C 212 | """ 213 | B, L, C = x.shape 214 | H = self.H 215 | W = self.W 216 | assert L == H * W, "flatten img_tokens has wrong size" 217 | img = self.norm1(x) 218 | temp = self.qkv(img).reshape(B, H, W, 3, C).permute(0, 3, 4, 1, 2) 219 | 220 | if self.branch_num == 2: 221 | x1 = self.attns[0](temp[:,:,:C//2,:,:]) 222 | x2 = self.attns[1](temp[:,:,C//2:,:,:]) 223 | attened_x = torch.cat([x1,x2], dim=2) 224 | else: 225 | attened_x = self.attns[0](temp) 226 | attened_x = self.proj(attened_x) 227 | x = x + self.drop_path(attened_x) 228 | x = x + self.drop_path(self.mlp(self.norm2(x))) 229 | 230 | return x 231 | 232 | def img2windows(img, H_sp, W_sp): 233 | """ 234 | img: B C H W 235 | """ 236 | B, C, H, W = img.shape 237 | img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 238 | img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) 239 | return img_perm 240 | 241 | def windows2img(img_splits_hw, H_sp, W_sp, H, W): 242 | """ 243 | img_splits_hw: B' H W C 244 | """ 245 | B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) 246 | 247 | img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) 248 | img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 249 | return img 250 | 251 | 252 | class Merge_Block(nn.Module): 253 | def __init__(self, dim, dim_out, norm_layer=nn.LayerNorm): 254 | super().__init__() 255 | self.conv = nn.Conv2d(dim, dim_out, 3, 2, 1) 256 | self.norm = norm_layer(dim_out) 257 | 258 | def forward(self, x, H, W): 259 | B, new_HW, C = x.shape 260 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 261 | x = self.conv(x) 262 | B, C, H, W = x.shape 263 | x = x.view(B, C, -1).transpose(-2, -1).contiguous() 264 | x = self.norm(x) 265 | 266 | return x, H, W 267 | 268 | @BACKBONES.register_module() 269 | class CSWin(nn.Module): 270 | """ Vision Transformer with support for patch or hybrid CNN input stage 271 | """ 272 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=64, depth=[1,2,21,1], split_size = 7, 273 | num_heads=[1,2,4,8], mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 274 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, use_chk=False): 275 | super().__init__() 276 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 277 | 278 | heads=num_heads 279 | self.use_chk = use_chk 280 | self.stage1_conv_embed = nn.Sequential( 281 | nn.Conv2d(in_chans, embed_dim, 7, 4, 2), 282 | Rearrange('b c h w -> b (h w) c', h = img_size//4, w = img_size//4), 283 | nn.LayerNorm(embed_dim) 284 | ) 285 | 286 | self.norm1 = nn.LayerNorm(embed_dim) 287 | 288 | curr_dim = embed_dim 289 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))] # stochastic depth decay rule 290 | self.stage1 = nn.ModuleList([ 291 | CSWinBlock( 292 | dim=curr_dim, num_heads=heads[0], patches_resolution=224//4, mlp_ratio=mlp_ratio, 293 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[0], 294 | drop=drop_rate, attn_drop=attn_drop_rate, 295 | drop_path=dpr[i], norm_layer=norm_layer) 296 | for i in range(depth[0])]) 297 | 298 | self.merge1 = Merge_Block(curr_dim, curr_dim*(heads[1]//heads[0])) 299 | curr_dim = curr_dim*(heads[1]//heads[0]) 300 | self.norm2 = nn.LayerNorm(curr_dim) 301 | self.stage2 = nn.ModuleList( 302 | [CSWinBlock( 303 | dim=curr_dim, num_heads=heads[1], patches_resolution=224//8, mlp_ratio=mlp_ratio, 304 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[1], 305 | drop=drop_rate, attn_drop=attn_drop_rate, 306 | drop_path=dpr[np.sum(depth[:1])+i], norm_layer=norm_layer) 307 | for i in range(depth[1])]) 308 | 309 | self.merge2 = Merge_Block(curr_dim, curr_dim*(heads[2]//heads[1])) 310 | curr_dim = curr_dim*(heads[2]//heads[1]) 311 | self.norm3 = nn.LayerNorm(curr_dim) 312 | temp_stage3 = [] 313 | temp_stage3.extend( 314 | [CSWinBlock( 315 | dim=curr_dim, num_heads=heads[2], patches_resolution=224//16, mlp_ratio=mlp_ratio, 316 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[2], 317 | drop=drop_rate, attn_drop=attn_drop_rate, 318 | drop_path=dpr[np.sum(depth[:2])+i], norm_layer=norm_layer) 319 | for i in range(depth[2])]) 320 | 321 | self.stage3 = nn.ModuleList(temp_stage3) 322 | 323 | self.merge3 = Merge_Block(curr_dim, curr_dim*(heads[3]//heads[2])) 324 | curr_dim = curr_dim*(heads[3]//heads[2]) 325 | self.stage4 = nn.ModuleList( 326 | [CSWinBlock( 327 | dim=curr_dim, num_heads=heads[3], patches_resolution=224//32, mlp_ratio=mlp_ratio, 328 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[-1], 329 | drop=drop_rate, attn_drop=attn_drop_rate, 330 | drop_path=dpr[np.sum(depth[:-1])+i], norm_layer=norm_layer, last_stage=True) 331 | for i in range(depth[-1])]) 332 | 333 | self.norm4 = norm_layer(curr_dim) 334 | 335 | 336 | def init_weights(self, pretrained=None): 337 | def _init_weights(m): 338 | if isinstance(m, nn.Linear): 339 | trunc_normal_(m.weight, std=.02) 340 | if isinstance(m, nn.Linear) and m.bias is not None: 341 | nn.init.constant_(m.bias, 0) 342 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 343 | nn.init.constant_(m.bias, 0) 344 | nn.init.constant_(m.weight, 1.0) 345 | if isinstance(pretrained, str): 346 | self.apply(_init_weights) 347 | logger = get_root_logger() 348 | load_checkpoint(self, pretrained, strict=False, logger=logger) 349 | elif pretrained is None: 350 | self.apply(_init_weights) 351 | else: 352 | raise TypeError('pretrained must be a str or None') 353 | 354 | def save_out(self, x, norm, H, W): 355 | x = norm(x) 356 | B, N, C = x.shape 357 | x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() 358 | return x 359 | 360 | def forward_features(self, x): 361 | B = x.shape[0] 362 | x = self.stage1_conv_embed[0](x) ### B, C, H, W 363 | B, C, H, W = x.size() 364 | x = x.reshape(B, C, -1).transpose(-1,-2).contiguous() 365 | x = self.stage1_conv_embed[2](x) 366 | 367 | out = [] 368 | for blk in self.stage1: 369 | blk.H = H 370 | blk.W = W 371 | if self.use_chk: 372 | x = checkpoint.checkpoint(blk, x) 373 | else: 374 | x = blk(x) 375 | 376 | out.append(self.save_out(x, self.norm1, H, W)) 377 | 378 | for pre, blocks, norm in zip([self.merge1, self.merge2, self.merge3], 379 | [self.stage2, self.stage3, self.stage4], 380 | [self.norm2 , self.norm3 , self.norm4 ]): 381 | 382 | x, H, W = pre(x, H, W) 383 | for blk in blocks: 384 | blk.H = H 385 | blk.W = W 386 | if self.use_chk: 387 | x = checkpoint.checkpoint(blk, x) 388 | else: 389 | x = blk(x) 390 | 391 | out.append(self.save_out(x, norm, H, W)) 392 | 393 | return tuple(out) 394 | 395 | def forward(self, x): 396 | x = self.forward_features(x) 397 | return x 398 | 399 | 400 | def _conv_filter(state_dict, patch_size=16): 401 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 402 | out_dict = {} 403 | for k, v in state_dict.items(): 404 | if 'patch_embed.proj.weight' in k: 405 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 406 | out_dict[k] = v 407 | return out_dict 408 | 409 | 410 | -------------------------------------------------------------------------------- /BOAT-CSWin/segmentation/configs/_base/upernet_cswin.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='CSWin', 8 | embed_dim=64, 9 | patch_size=4, 10 | depth=[1, 2, 21, 1], 11 | num_heads=[2,4,8,16], 12 | split_size=[1,2,7,7], 13 | mlp_ratio=4., 14 | qkv_bias=True, 15 | qk_scale=None, 16 | drop_rate=0., 17 | attn_drop_rate=0., 18 | drop_path_rate=0.1), 19 | decode_head=dict( 20 | type='UPerHead', 21 | in_channels=[96, 192, 384, 768], 22 | in_index=[0, 1, 2, 3], 23 | pool_scales=(1, 2, 3, 6), 24 | channels=512, 25 | dropout_ratio=0.1, 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=384, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /BOAT-CSWin/segmentation/configs/cswin/upernet_cswin_base.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/upernet_cswin.py', '../_base_/datasets/ade20k.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | backbone=dict( 7 | type='CSWin', 8 | embed_dim=96, 9 | depth=[2,4,32,2], 10 | num_heads=[4,8,16,32], 11 | split_size=[1,2,7,7], 12 | drop_path_rate=0.6, 13 | use_chk=False, 14 | ), 15 | decode_head=dict( 16 | in_channels=[96,192,384,768], 17 | num_classes=150 18 | ), 19 | auxiliary_head=dict( 20 | in_channels=384, 21 | num_classes=150 22 | )) 23 | 24 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 25 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 26 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 27 | 'relative_position_bias_table': dict(decay_mult=0.), 28 | 'norm': dict(decay_mult=0.)})) 29 | 30 | lr_config = dict(_delete_=True, policy='poly', 31 | warmup='linear', 32 | warmup_iters=1500, 33 | warmup_ratio=1e-6, 34 | power=1.0, min_lr=0.0, by_epoch=False) 35 | 36 | data=dict(samples_per_gpu=2) 37 | 38 | -------------------------------------------------------------------------------- /BOAT-CSWin/segmentation/configs/cswin/upernet_cswin_small.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/upernet_cswin.py', '../_base_/datasets/ade20k.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | backbone=dict( 7 | type='CSWin', 8 | embed_dim=64, 9 | depth=[2,4,32,2], 10 | num_heads=[2,4,8,16], 11 | split_size=[1,2,7,7], 12 | drop_path_rate=0.4, 13 | use_chk=False, 14 | ), 15 | decode_head=dict( 16 | in_channels=[64,128,256,512], 17 | num_classes=150 18 | ), 19 | auxiliary_head=dict( 20 | in_channels=256, 21 | num_classes=150 22 | )) 23 | 24 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 25 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 26 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 27 | 'relative_position_bias_table': dict(decay_mult=0.), 28 | 'norm': dict(decay_mult=0.)})) 29 | 30 | lr_config = dict(_delete_=True, policy='poly', 31 | warmup='linear', 32 | warmup_iters=1500, 33 | warmup_ratio=1e-6, 34 | power=1.0, min_lr=0.0, by_epoch=False) 35 | 36 | data=dict(samples_per_gpu=2) 37 | -------------------------------------------------------------------------------- /BOAT-CSWin/segmentation/configs/cswin/upernet_cswin_tiny.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/upernet_cswin.py', '../_base_/datasets/ade20k.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | model = dict( 6 | backbone=dict( 7 | type='CSWin', 8 | embed_dim=64, 9 | depth=[1,2,21,1], 10 | num_heads=[2,4,8,16], 11 | split_size=[1,2,7,7], 12 | drop_path_rate=0.3, 13 | use_chk=False, 14 | ), 15 | decode_head=dict( 16 | in_channels=[64,128,256,512], 17 | num_classes=150 18 | ), 19 | auxiliary_head=dict( 20 | in_channels=256, 21 | num_classes=150 22 | )) 23 | 24 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 25 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 26 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 27 | 'relative_position_bias_table': dict(decay_mult=0.), 28 | 'norm': dict(decay_mult=0.)})) 29 | 30 | lr_config = dict(_delete_=True, policy='poly', 31 | warmup='linear', 32 | warmup_iters=1500, 33 | warmup_ratio=1e-6, 34 | power=1.0, min_lr=0.0, by_epoch=False) 35 | 36 | data=dict(samples_per_gpu=2) 37 | 38 | -------------------------------------------------------------------------------- /BOAT-CSWin/segmentation/install_req.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | pip install --user bcolz mxnet tensorboardX matplotlib easydict opencv-python einops --no-cache-dir -U | cat 4 | pip install --user scikit-image imgaug PyTurboJPEG --no-cache-dir -U | cat 5 | pip install --user scikit-learn --no-cache-dir -U | cat 6 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir -U | cat 7 | pip install --user termcolor imgaug prettytable --no-cache-dir -U | cat 8 | pip install --user timm==0.3.4 --no-cache-dir -U | cat 9 | pip install mmcv-full==1.3.0 --user --no-cache-dir -U | cat 10 | -------------------------------------------------------------------------------- /BOAT-CSWin/segmentation/mmcv_custom/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import io 3 | import os 4 | import os.path as osp 5 | import pkgutil 6 | import time 7 | import warnings 8 | from collections import OrderedDict 9 | from importlib import import_module 10 | from tempfile import TemporaryDirectory 11 | 12 | import torch 13 | import torchvision 14 | from torch.optim import Optimizer 15 | from torch.utils import model_zoo 16 | from torch.nn import functional as F 17 | 18 | import mmcv 19 | from mmcv.fileio import FileClient 20 | from mmcv.fileio import load as load_file 21 | from mmcv.parallel import is_module_wrapper 22 | from mmcv.utils import mkdir_or_exist 23 | from mmcv.runner import get_dist_info 24 | 25 | ENV_MMCV_HOME = 'MMCV_HOME' 26 | ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' 27 | DEFAULT_CACHE_DIR = '~/.cache' 28 | 29 | 30 | def _get_mmcv_home(): 31 | mmcv_home = os.path.expanduser( 32 | os.getenv( 33 | ENV_MMCV_HOME, 34 | os.path.join( 35 | os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) 36 | 37 | mkdir_or_exist(mmcv_home) 38 | return mmcv_home 39 | 40 | 41 | def load_state_dict(module, state_dict, strict=False, logger=None): 42 | """Load state_dict to a module. 43 | 44 | This method is modified from :meth:`torch.nn.Module.load_state_dict`. 45 | Default value for ``strict`` is set to ``False`` and the message for 46 | param mismatch will be shown even if strict is False. 47 | 48 | Args: 49 | module (Module): Module that receives the state_dict. 50 | state_dict (OrderedDict): Weights. 51 | strict (bool): whether to strictly enforce that the keys 52 | in :attr:`state_dict` match the keys returned by this module's 53 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. 54 | logger (:obj:`logging.Logger`, optional): Logger to log the error 55 | message. If not specified, print function will be used. 56 | """ 57 | unexpected_keys = [] 58 | all_missing_keys = [] 59 | err_msg = [] 60 | 61 | metadata = getattr(state_dict, '_metadata', None) 62 | state_dict = state_dict.copy() 63 | if metadata is not None: 64 | state_dict._metadata = metadata 65 | 66 | # use _load_from_state_dict to enable checkpoint version control 67 | def load(module, prefix=''): 68 | # recursively check parallel module in case that the model has a 69 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 70 | if is_module_wrapper(module): 71 | module = module.module 72 | local_metadata = {} if metadata is None else metadata.get( 73 | prefix[:-1], {}) 74 | module._load_from_state_dict(state_dict, prefix, local_metadata, True, 75 | all_missing_keys, unexpected_keys, 76 | err_msg) 77 | for name, child in module._modules.items(): 78 | if child is not None: 79 | load(child, prefix + name + '.') 80 | 81 | load(module) 82 | load = None # break load->load reference cycle 83 | 84 | # ignore "num_batches_tracked" of BN layers 85 | missing_keys = [ 86 | key for key in all_missing_keys if 'num_batches_tracked' not in key 87 | ] 88 | 89 | if unexpected_keys: 90 | err_msg.append('unexpected key in source ' 91 | f'state_dict: {", ".join(unexpected_keys)}\n') 92 | if missing_keys: 93 | err_msg.append( 94 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n') 95 | 96 | rank, _ = get_dist_info() 97 | if len(err_msg) > 0 and rank == 0: 98 | err_msg.insert( 99 | 0, 'The model and loaded state dict do not match exactly\n') 100 | err_msg = '\n'.join(err_msg) 101 | if strict: 102 | raise RuntimeError(err_msg) 103 | elif logger is not None: 104 | logger.warning(err_msg) 105 | else: 106 | print(err_msg) 107 | 108 | 109 | def load_url_dist(url, model_dir=None): 110 | """In distributed setting, this function only download checkpoint at local 111 | rank 0.""" 112 | rank, world_size = get_dist_info() 113 | rank = int(os.environ.get('LOCAL_RANK', rank)) 114 | if rank == 0: 115 | checkpoint = model_zoo.load_url(url, model_dir=model_dir) 116 | if world_size > 1: 117 | torch.distributed.barrier() 118 | if rank > 0: 119 | checkpoint = model_zoo.load_url(url, model_dir=model_dir) 120 | return checkpoint 121 | 122 | 123 | def load_pavimodel_dist(model_path, map_location=None): 124 | """In distributed setting, this function only download checkpoint at local 125 | rank 0.""" 126 | try: 127 | from pavi import modelcloud 128 | except ImportError: 129 | raise ImportError( 130 | 'Please install pavi to load checkpoint from modelcloud.') 131 | rank, world_size = get_dist_info() 132 | rank = int(os.environ.get('LOCAL_RANK', rank)) 133 | if rank == 0: 134 | model = modelcloud.get(model_path) 135 | with TemporaryDirectory() as tmp_dir: 136 | downloaded_file = osp.join(tmp_dir, model.name) 137 | model.download(downloaded_file) 138 | checkpoint = torch.load(downloaded_file, map_location=map_location) 139 | if world_size > 1: 140 | torch.distributed.barrier() 141 | if rank > 0: 142 | model = modelcloud.get(model_path) 143 | with TemporaryDirectory() as tmp_dir: 144 | downloaded_file = osp.join(tmp_dir, model.name) 145 | model.download(downloaded_file) 146 | checkpoint = torch.load( 147 | downloaded_file, map_location=map_location) 148 | return checkpoint 149 | 150 | 151 | def load_fileclient_dist(filename, backend, map_location): 152 | """In distributed setting, this function only download checkpoint at local 153 | rank 0.""" 154 | rank, world_size = get_dist_info() 155 | rank = int(os.environ.get('LOCAL_RANK', rank)) 156 | allowed_backends = ['ceph'] 157 | if backend not in allowed_backends: 158 | raise ValueError(f'Load from Backend {backend} is not supported.') 159 | if rank == 0: 160 | fileclient = FileClient(backend=backend) 161 | buffer = io.BytesIO(fileclient.get(filename)) 162 | checkpoint = torch.load(buffer, map_location=map_location) 163 | if world_size > 1: 164 | torch.distributed.barrier() 165 | if rank > 0: 166 | fileclient = FileClient(backend=backend) 167 | buffer = io.BytesIO(fileclient.get(filename)) 168 | checkpoint = torch.load(buffer, map_location=map_location) 169 | return checkpoint 170 | 171 | 172 | def get_torchvision_models(): 173 | model_urls = dict() 174 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): 175 | if ispkg: 176 | continue 177 | _zoo = import_module(f'torchvision.models.{name}') 178 | if hasattr(_zoo, 'model_urls'): 179 | _urls = getattr(_zoo, 'model_urls') 180 | model_urls.update(_urls) 181 | return model_urls 182 | 183 | 184 | def get_external_models(): 185 | mmcv_home = _get_mmcv_home() 186 | default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') 187 | default_urls = load_file(default_json_path) 188 | assert isinstance(default_urls, dict) 189 | external_json_path = osp.join(mmcv_home, 'open_mmlab.json') 190 | if osp.exists(external_json_path): 191 | external_urls = load_file(external_json_path) 192 | assert isinstance(external_urls, dict) 193 | default_urls.update(external_urls) 194 | 195 | return default_urls 196 | 197 | 198 | def get_mmcls_models(): 199 | mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') 200 | mmcls_urls = load_file(mmcls_json_path) 201 | 202 | return mmcls_urls 203 | 204 | 205 | def get_deprecated_model_names(): 206 | deprecate_json_path = osp.join(mmcv.__path__[0], 207 | 'model_zoo/deprecated.json') 208 | deprecate_urls = load_file(deprecate_json_path) 209 | assert isinstance(deprecate_urls, dict) 210 | 211 | return deprecate_urls 212 | 213 | 214 | def _process_mmcls_checkpoint(checkpoint): 215 | state_dict = checkpoint['state_dict'] 216 | new_state_dict = OrderedDict() 217 | for k, v in state_dict.items(): 218 | if k.startswith('backbone.'): 219 | new_state_dict[k[9:]] = v 220 | new_checkpoint = dict(state_dict=new_state_dict) 221 | 222 | return new_checkpoint 223 | 224 | 225 | def _load_checkpoint(filename, map_location=None): 226 | """Load checkpoint from somewhere (modelzoo, file, url). 227 | 228 | Args: 229 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 230 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 231 | details. 232 | map_location (str | None): Same as :func:`torch.load`. Default: None. 233 | 234 | Returns: 235 | dict | OrderedDict: The loaded checkpoint. It can be either an 236 | OrderedDict storing model weights or a dict containing other 237 | information, which depends on the checkpoint. 238 | """ 239 | if filename.startswith('modelzoo://'): 240 | warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 241 | 'use "torchvision://" instead') 242 | model_urls = get_torchvision_models() 243 | model_name = filename[11:] 244 | checkpoint = load_url_dist(model_urls[model_name]) 245 | elif filename.startswith('torchvision://'): 246 | model_urls = get_torchvision_models() 247 | model_name = filename[14:] 248 | checkpoint = load_url_dist(model_urls[model_name]) 249 | elif filename.startswith('open-mmlab://'): 250 | model_urls = get_external_models() 251 | model_name = filename[13:] 252 | deprecated_urls = get_deprecated_model_names() 253 | if model_name in deprecated_urls: 254 | warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' 255 | f'of open-mmlab://{deprecated_urls[model_name]}') 256 | model_name = deprecated_urls[model_name] 257 | model_url = model_urls[model_name] 258 | # check if is url 259 | if model_url.startswith(('http://', 'https://')): 260 | checkpoint = load_url_dist(model_url) 261 | else: 262 | filename = osp.join(_get_mmcv_home(), model_url) 263 | if not osp.isfile(filename): 264 | raise IOError(f'{filename} is not a checkpoint file') 265 | checkpoint = torch.load(filename, map_location=map_location) 266 | elif filename.startswith('mmcls://'): 267 | model_urls = get_mmcls_models() 268 | model_name = filename[8:] 269 | checkpoint = load_url_dist(model_urls[model_name]) 270 | checkpoint = _process_mmcls_checkpoint(checkpoint) 271 | elif filename.startswith(('http://', 'https://')): 272 | checkpoint = load_url_dist(filename) 273 | elif filename.startswith('pavi://'): 274 | model_path = filename[7:] 275 | checkpoint = load_pavimodel_dist(model_path, map_location=map_location) 276 | elif filename.startswith('s3://'): 277 | checkpoint = load_fileclient_dist( 278 | filename, backend='ceph', map_location=map_location) 279 | else: 280 | if not osp.isfile(filename): 281 | raise IOError(f'{filename} is not a checkpoint file') 282 | checkpoint = torch.load(filename, map_location=map_location) 283 | return checkpoint 284 | 285 | 286 | def load_checkpoint(model, 287 | filename, 288 | map_location='cpu', 289 | strict=False, 290 | logger=None): 291 | """Load checkpoint from a file or URI. 292 | 293 | Args: 294 | model (Module): Module to load checkpoint. 295 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 296 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 297 | details. 298 | map_location (str): Same as :func:`torch.load`. 299 | strict (bool): Whether to allow different params for the model and 300 | checkpoint. 301 | logger (:mod:`logging.Logger` or None): The logger for error message. 302 | 303 | Returns: 304 | dict or OrderedDict: The loaded checkpoint. 305 | """ 306 | checkpoint = _load_checkpoint(filename, map_location) 307 | # OrderedDict is a subclass of dict 308 | if not isinstance(checkpoint, dict): 309 | raise RuntimeError( 310 | f'No state_dict found in checkpoint file {filename}') 311 | # get state_dict from checkpoint 312 | if 'state_dict' in checkpoint: 313 | state_dict = checkpoint['state_dict'] 314 | elif 'state_dict_ema' in checkpoint: 315 | state_dict = checkpoint['state_dict_ema'] 316 | elif 'model' in checkpoint: 317 | state_dict = checkpoint['model'] 318 | else: 319 | state_dict = checkpoint 320 | # strip prefix of state_dict 321 | if list(state_dict.keys())[0].startswith('module.'): 322 | state_dict = {k[7:]: v for k, v in state_dict.items()} 323 | 324 | # for MoBY, load model of online branch 325 | if sorted(list(state_dict.keys()))[0].startswith('encoder'): 326 | state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} 327 | 328 | # reshape absolute position embedding 329 | if state_dict.get('absolute_pos_embed') is not None: 330 | absolute_pos_embed = state_dict['absolute_pos_embed'] 331 | N1, L, C1 = absolute_pos_embed.size() 332 | N2, C2, H, W = model.absolute_pos_embed.size() 333 | if N1 != N2 or C1 != C2 or L != H*W: 334 | logger.warning("Error in loading absolute_pos_embed, pass") 335 | else: 336 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) 337 | 338 | # interpolate position bias table if needed 339 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 340 | for table_key in relative_position_bias_table_keys: 341 | table_pretrained = state_dict[table_key] 342 | table_current = model.state_dict()[table_key] 343 | L1, nH1 = table_pretrained.size() 344 | L2, nH2 = table_current.size() 345 | if nH1 != nH2: 346 | logger.warning(f"Error in loading {table_key}, pass") 347 | else: 348 | if L1 != L2: 349 | S1 = int(L1 ** 0.5) 350 | S2 = int(L2 ** 0.5) 351 | table_pretrained_resized = F.interpolate( 352 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1), 353 | size=(S2, S2), mode='bicubic') 354 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) 355 | 356 | # load state_dict 357 | load_state_dict(model, state_dict, strict, logger) 358 | return checkpoint 359 | 360 | 361 | def weights_to_cpu(state_dict): 362 | """Copy a model state_dict to cpu. 363 | 364 | Args: 365 | state_dict (OrderedDict): Model weights on GPU. 366 | 367 | Returns: 368 | OrderedDict: Model weights on GPU. 369 | """ 370 | state_dict_cpu = OrderedDict() 371 | for key, val in state_dict.items(): 372 | state_dict_cpu[key] = val.cpu() 373 | return state_dict_cpu 374 | 375 | 376 | def _save_to_state_dict(module, destination, prefix, keep_vars): 377 | """Saves module state to `destination` dictionary. 378 | 379 | This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. 380 | 381 | Args: 382 | module (nn.Module): The module to generate state_dict. 383 | destination (dict): A dict where state will be stored. 384 | prefix (str): The prefix for parameters and buffers used in this 385 | module. 386 | """ 387 | for name, param in module._parameters.items(): 388 | if param is not None: 389 | destination[prefix + name] = param if keep_vars else param.detach() 390 | for name, buf in module._buffers.items(): 391 | # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d 392 | if buf is not None: 393 | destination[prefix + name] = buf if keep_vars else buf.detach() 394 | 395 | 396 | def get_state_dict(module, destination=None, prefix='', keep_vars=False): 397 | """Returns a dictionary containing a whole state of the module. 398 | 399 | Both parameters and persistent buffers (e.g. running averages) are 400 | included. Keys are corresponding parameter and buffer names. 401 | 402 | This method is modified from :meth:`torch.nn.Module.state_dict` to 403 | recursively check parallel module in case that the model has a complicated 404 | structure, e.g., nn.Module(nn.Module(DDP)). 405 | 406 | Args: 407 | module (nn.Module): The module to generate state_dict. 408 | destination (OrderedDict): Returned dict for the state of the 409 | module. 410 | prefix (str): Prefix of the key. 411 | keep_vars (bool): Whether to keep the variable property of the 412 | parameters. Default: False. 413 | 414 | Returns: 415 | dict: A dictionary containing a whole state of the module. 416 | """ 417 | # recursively check parallel module in case that the model has a 418 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 419 | if is_module_wrapper(module): 420 | module = module.module 421 | 422 | # below is the same as torch.nn.Module.state_dict() 423 | if destination is None: 424 | destination = OrderedDict() 425 | destination._metadata = OrderedDict() 426 | destination._metadata[prefix[:-1]] = local_metadata = dict( 427 | version=module._version) 428 | _save_to_state_dict(module, destination, prefix, keep_vars) 429 | for name, child in module._modules.items(): 430 | if child is not None: 431 | get_state_dict( 432 | child, destination, prefix + name + '.', keep_vars=keep_vars) 433 | for hook in module._state_dict_hooks.values(): 434 | hook_result = hook(module, destination, prefix, local_metadata) 435 | if hook_result is not None: 436 | destination = hook_result 437 | return destination 438 | 439 | 440 | def save_checkpoint(model, filename, optimizer=None, meta=None): 441 | """Save checkpoint to file. 442 | 443 | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and 444 | ``optimizer``. By default ``meta`` will contain version and time info. 445 | 446 | Args: 447 | model (Module): Module whose params are to be saved. 448 | filename (str): Checkpoint filename. 449 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 450 | meta (dict, optional): Metadata to be saved in checkpoint. 451 | """ 452 | if meta is None: 453 | meta = {} 454 | elif not isinstance(meta, dict): 455 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') 456 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 457 | 458 | if is_module_wrapper(model): 459 | model = model.module 460 | 461 | if hasattr(model, 'CLASSES') and model.CLASSES is not None: 462 | # save class name to the meta 463 | meta.update(CLASSES=model.CLASSES) 464 | 465 | checkpoint = { 466 | 'meta': meta, 467 | 'state_dict': weights_to_cpu(get_state_dict(model)) 468 | } 469 | # save optimizer state dict in the checkpoint 470 | if isinstance(optimizer, Optimizer): 471 | checkpoint['optimizer'] = optimizer.state_dict() 472 | elif isinstance(optimizer, dict): 473 | checkpoint['optimizer'] = {} 474 | for name, optim in optimizer.items(): 475 | checkpoint['optimizer'][name] = optim.state_dict() 476 | 477 | if filename.startswith('pavi://'): 478 | try: 479 | from pavi import modelcloud 480 | from pavi.exception import NodeNotFoundError 481 | except ImportError: 482 | raise ImportError( 483 | 'Please install pavi to load checkpoint from modelcloud.') 484 | model_path = filename[7:] 485 | root = modelcloud.Folder() 486 | model_dir, model_name = osp.split(model_path) 487 | try: 488 | model = modelcloud.get(model_dir) 489 | except NodeNotFoundError: 490 | model = root.create_training_model(model_dir) 491 | with TemporaryDirectory() as tmp_dir: 492 | checkpoint_file = osp.join(tmp_dir, model_name) 493 | with open(checkpoint_file, 'wb') as f: 494 | torch.save(checkpoint, f) 495 | f.flush() 496 | model.create_file(checkpoint_file, name=model_name) 497 | else: 498 | mmcv.mkdir_or_exist(osp.dirname(filename)) 499 | # immediately flush buffer 500 | with open(filename, 'wb') as f: 501 | torch.save(checkpoint, f) 502 | f.flush() 503 | -------------------------------------------------------------------------------- /BOAT-CSWin/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-CSWin/teaser.png -------------------------------------------------------------------------------- /BOAT-CSWin/train.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # CSWin Transformer 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | # written By Xiaoyi Dong 6 | # ------------------------------------------ 7 | 8 | NUM_PROC=$1 9 | shift 10 | python -m torch.distributed.launch --nproc_per_node=$NUM_PROC --master_port 20028 main.py "$@" 11 | 12 | -------------------------------------------------------------------------------- /BOAT-Swin/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /BOAT-Swin/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /BOAT-Swin/README.md: -------------------------------------------------------------------------------- 1 | # Swin-BOAT 2 | 3 | This is developed based on the official version of [Swin Transformer](https://github.com/microsoft/Swin-Transformer) 4 | We only change ./model/swin_transformer.py to ./model/boat_swin_transformer.py and keep other codes unchanged. 5 | 6 | 7 | ## Start 8 | 9 | Please refer to [Start for Swin](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md) for installing the prerequisite. 10 | 11 | ## Training 12 | 13 | `BOAT-Swin-T`: 14 | 15 | ```bash 16 | python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 main.py \ 17 | --cfg configs/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 256 18 | ``` 19 | 20 | `BOAT-Swin-S`: 21 | 22 | ```bash 23 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 24 | --cfg configs/swin_small_patch4_window7_224.yaml --data-path --batch-size 128 25 | ``` 26 | 27 | `BOAT-Swin-B`: 28 | 29 | ```bash 30 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 31 | --cfg configs/swin_base_patch4_window7_224.yaml --data-path --batch-size 128 \ 32 | ``` 33 | 34 | ## Evaluation 35 | 36 | To evaluate a pre-trained `BOAT-Swin Transformer` on ImageNet val, run: 37 | 38 | ```bash 39 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \ 40 | --cfg --resume --data-path 41 | ``` 42 | 43 | ## Pre-trained models 44 | 45 | [BOAT-Swin-Tiny](https://www.dropbox.com/s/xa94uewsrvjglnn/tiny.pth?dl=0) 46 | 47 | [BOAT-Swin-Small](https://www.dropbox.com/s/7ih1zvii3bvdcgd/small.pth?dl=0) 48 | 49 | [BOAT-Swin-Base](https://www.dropbox.com/s/70hr7h0smcr0gr9/base.pth?dl=0) 50 | 51 | ## Acknowledgement 52 | This is developed based on the official version of [Swin Transformer](https://github.com/microsoft/Swin-Transformer) 53 | -------------------------------------------------------------------------------- /BOAT-Swin/SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /BOAT-Swin/SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /BOAT-Swin/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Pretrained weight from checkpoint, could be imagenet22k pretrained weight 50 | # could be overwritten by command line argument 51 | _C.MODEL.PRETRAINED = '' 52 | # Checkpoint to resume, could be overwritten by command line argument 53 | _C.MODEL.RESUME = '' 54 | # Number of classes, overwritten in data preparation 55 | _C.MODEL.NUM_CLASSES = 1000 56 | # Dropout rate 57 | _C.MODEL.DROP_RATE = 0.0 58 | # Drop path rate 59 | _C.MODEL.DROP_PATH_RATE = 0.1 60 | # Label Smoothing 61 | _C.MODEL.LABEL_SMOOTHING = 0.1 62 | 63 | # Swin Transformer parameters 64 | _C.MODEL.SWIN = CN() 65 | _C.MODEL.SWIN.PATCH_SIZE = 4 66 | _C.MODEL.SWIN.IN_CHANS = 3 67 | _C.MODEL.SWIN.EMBED_DIM = 96 68 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 69 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 70 | _C.MODEL.SWIN.WINDOW_SIZE = 7 71 | _C.MODEL.SWIN.MLP_RATIO = 4. 72 | _C.MODEL.SWIN.QKV_BIAS = True 73 | _C.MODEL.SWIN.QK_SCALE = None 74 | _C.MODEL.SWIN.APE = False 75 | _C.MODEL.SWIN.PATCH_NORM = True 76 | 77 | # Swin MLP parameters 78 | _C.MODEL.SWIN_MLP = CN() 79 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 80 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 81 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 82 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 83 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 84 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 85 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 86 | _C.MODEL.SWIN_MLP.APE = False 87 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 88 | 89 | # ----------------------------------------------------------------------------- 90 | # Training settings 91 | # ----------------------------------------------------------------------------- 92 | _C.TRAIN = CN() 93 | _C.TRAIN.START_EPOCH = 0 94 | _C.TRAIN.EPOCHS = 300 95 | _C.TRAIN.WARMUP_EPOCHS = 20 96 | _C.TRAIN.WEIGHT_DECAY = 0.05 97 | _C.TRAIN.BASE_LR = 5e-4 98 | _C.TRAIN.WARMUP_LR = 5e-7 99 | _C.TRAIN.MIN_LR = 5e-6 100 | # Clip gradient norm 101 | _C.TRAIN.CLIP_GRAD = 5.0 102 | # Auto resume from latest checkpoint 103 | _C.TRAIN.AUTO_RESUME = True 104 | # Gradient accumulation steps 105 | # could be overwritten by command line argument 106 | _C.TRAIN.ACCUMULATION_STEPS = 0 107 | # Whether to use gradient checkpointing to save memory 108 | # could be overwritten by command line argument 109 | _C.TRAIN.USE_CHECKPOINT = False 110 | 111 | # LR scheduler 112 | _C.TRAIN.LR_SCHEDULER = CN() 113 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 114 | # Epoch interval to decay LR, used in StepLRScheduler 115 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 116 | # LR decay rate, used in StepLRScheduler 117 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 118 | 119 | # Optimizer 120 | _C.TRAIN.OPTIMIZER = CN() 121 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 122 | # Optimizer Epsilon 123 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 124 | # Optimizer Betas 125 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 126 | # SGD momentum 127 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 128 | 129 | # ----------------------------------------------------------------------------- 130 | # Augmentation settings 131 | # ----------------------------------------------------------------------------- 132 | _C.AUG = CN() 133 | # Color jitter factor 134 | _C.AUG.COLOR_JITTER = 0.4 135 | # Use AutoAugment policy. "v0" or "original" 136 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 137 | # Random erase prob 138 | _C.AUG.REPROB = 0.25 139 | # Random erase mode 140 | _C.AUG.REMODE = 'pixel' 141 | # Random erase count 142 | _C.AUG.RECOUNT = 1 143 | # Mixup alpha, mixup enabled if > 0 144 | _C.AUG.MIXUP = 0.8 145 | # Cutmix alpha, cutmix enabled if > 0 146 | _C.AUG.CUTMIX = 1.0 147 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 148 | _C.AUG.CUTMIX_MINMAX = None 149 | # Probability of performing mixup or cutmix when either/both is enabled 150 | _C.AUG.MIXUP_PROB = 1.0 151 | # Probability of switching to cutmix when both mixup and cutmix enabled 152 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 153 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 154 | _C.AUG.MIXUP_MODE = 'batch' 155 | 156 | # ----------------------------------------------------------------------------- 157 | # Testing settings 158 | # ----------------------------------------------------------------------------- 159 | _C.TEST = CN() 160 | # Whether to use center crop when testing 161 | _C.TEST.CROP = True 162 | # Whether to use SequentialSampler as validation sampler 163 | _C.TEST.SEQUENTIAL = False 164 | 165 | # ----------------------------------------------------------------------------- 166 | # Misc 167 | # ----------------------------------------------------------------------------- 168 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 169 | # overwritten by command line argument 170 | _C.AMP_OPT_LEVEL = '' 171 | # Path to output folder, overwritten by command line argument 172 | _C.OUTPUT = '' 173 | # Tag of experiment, overwritten by command line argument 174 | _C.TAG = 'default' 175 | # Frequency to save checkpoint 176 | _C.SAVE_FREQ = 1 177 | # Frequency to logging info 178 | _C.PRINT_FREQ = 10 179 | # Fixed random seed 180 | _C.SEED = 0 181 | # Perform evaluation only, overwritten by command line argument 182 | _C.EVAL_MODE = False 183 | # Test throughput only, overwritten by command line argument 184 | _C.THROUGHPUT_MODE = False 185 | # local rank for DistributedDataParallel, given by command line argument 186 | _C.LOCAL_RANK = 0 187 | 188 | 189 | def _update_config_from_file(config, cfg_file): 190 | config.defrost() 191 | with open(cfg_file, 'r') as f: 192 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 193 | 194 | for cfg in yaml_cfg.setdefault('BASE', ['']): 195 | if cfg: 196 | _update_config_from_file( 197 | config, os.path.join(os.path.dirname(cfg_file), cfg) 198 | ) 199 | print('=> merge config from {}'.format(cfg_file)) 200 | config.merge_from_file(cfg_file) 201 | config.freeze() 202 | 203 | 204 | def update_config(config, args): 205 | _update_config_from_file(config, args.cfg) 206 | 207 | config.defrost() 208 | if args.opts: 209 | config.merge_from_list(args.opts) 210 | 211 | # merge from specific arguments 212 | if args.batch_size: 213 | config.DATA.BATCH_SIZE = args.batch_size 214 | if args.data_path: 215 | config.DATA.DATA_PATH = args.data_path 216 | if args.zip: 217 | config.DATA.ZIP_MODE = True 218 | if args.cache_mode: 219 | config.DATA.CACHE_MODE = args.cache_mode 220 | if args.pretrained: 221 | config.MODEL.PRETRAINED = args.pretrained 222 | if args.resume: 223 | config.MODEL.RESUME = args.resume 224 | if args.accumulation_steps: 225 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 226 | if args.use_checkpoint: 227 | config.TRAIN.USE_CHECKPOINT = True 228 | if args.amp_opt_level: 229 | config.AMP_OPT_LEVEL = args.amp_opt_level 230 | if args.output: 231 | config.OUTPUT = args.output 232 | if args.tag: 233 | config.TAG = args.tag 234 | if args.eval: 235 | config.EVAL_MODE = True 236 | if args.throughput: 237 | config.THROUGHPUT_MODE = True 238 | 239 | # set local rank for distributed training 240 | config.LOCAL_RANK = args.local_rank 241 | 242 | # output folder 243 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 244 | 245 | config.freeze() 246 | 247 | 248 | def get_config(args): 249 | """Get a yacs CfgNode object with default values.""" 250 | # Return a clone so that the defaults will not be altered 251 | # This is for the "local variable" use pattern 252 | config = _C.clone() 253 | update_config(config, args) 254 | 255 | return config 256 | -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_base_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window12_384_22kto1k_finetune 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_base_patch4_window12_384_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window12_384_finetune 6 | DROP_PATH_RATE: 0.5 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_base_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_base_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_large_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_large_patch4_window12_384_22kto1k_finetune 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_large_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_large_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_large_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_large_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 6, 12, 24, 48 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_mlp_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin_mlp 3 | NAME: swin_mlp_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN_MLP: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_mlp_tiny_c12_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c12_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 8, 16, 32, 64 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_mlp_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_mlp_tiny_c6_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c6_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 16, 32, 64, 128 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_small_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /BOAT-Swin/configs/swin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /BOAT-Swin/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader -------------------------------------------------------------------------------- /BOAT-Swin/data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | import torch.distributed as dist 12 | from torchvision import datasets, transforms 13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.data import Mixup 15 | from timm.data import create_transform 16 | 17 | from .cached_image_folder import CachedImageFolder 18 | from .samplers import SubsetRandomSampler 19 | 20 | try: 21 | from torchvision.transforms import InterpolationMode 22 | 23 | 24 | def _pil_interp(method): 25 | if method == 'bicubic': 26 | return InterpolationMode.BICUBIC 27 | elif method == 'lanczos': 28 | return InterpolationMode.LANCZOS 29 | elif method == 'hamming': 30 | return InterpolationMode.HAMMING 31 | else: 32 | # default bilinear, do we want to allow nearest? 33 | return InterpolationMode.BILINEAR 34 | except: 35 | from timm.data.transforms import _pil_interp 36 | 37 | 38 | def build_loader(config): 39 | config.defrost() 40 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 41 | config.freeze() 42 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 43 | dataset_val, _ = build_dataset(is_train=False, config=config) 44 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 45 | 46 | num_tasks = dist.get_world_size() 47 | global_rank = dist.get_rank() 48 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 49 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 50 | sampler_train = SubsetRandomSampler(indices) 51 | else: 52 | sampler_train = torch.utils.data.DistributedSampler( 53 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 54 | ) 55 | 56 | if config.TEST.SEQUENTIAL: 57 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 58 | else: 59 | sampler_val = torch.utils.data.distributed.DistributedSampler( 60 | dataset_val, shuffle=False 61 | ) 62 | 63 | data_loader_train = torch.utils.data.DataLoader( 64 | dataset_train, sampler=sampler_train, 65 | batch_size=config.DATA.BATCH_SIZE, 66 | num_workers=config.DATA.NUM_WORKERS, 67 | pin_memory=config.DATA.PIN_MEMORY, 68 | drop_last=True, 69 | ) 70 | 71 | data_loader_val = torch.utils.data.DataLoader( 72 | dataset_val, sampler=sampler_val, 73 | batch_size=config.DATA.BATCH_SIZE, 74 | shuffle=False, 75 | num_workers=config.DATA.NUM_WORKERS, 76 | pin_memory=config.DATA.PIN_MEMORY, 77 | drop_last=False 78 | ) 79 | 80 | # setup mixup / cutmix 81 | mixup_fn = None 82 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 83 | if mixup_active: 84 | mixup_fn = Mixup( 85 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 86 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 87 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 88 | 89 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 90 | 91 | 92 | def build_dataset(is_train, config): 93 | transform = build_transform(is_train, config) 94 | if config.DATA.DATASET == 'imagenet': 95 | prefix = 'train' if is_train else 'val' 96 | if config.DATA.ZIP_MODE: 97 | ann_file = prefix + "_map.txt" 98 | prefix = prefix + ".zip@/" 99 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 100 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 101 | else: 102 | root = os.path.join(config.DATA.DATA_PATH, prefix) 103 | dataset = datasets.ImageFolder(root, transform=transform) 104 | nb_classes = 1000 105 | elif config.DATA.DATASET == 'imagenet22K': 106 | raise NotImplementedError("Imagenet-22K will come soon.") 107 | else: 108 | raise NotImplementedError("We only support ImageNet Now.") 109 | 110 | return dataset, nb_classes 111 | 112 | 113 | def build_transform(is_train, config): 114 | resize_im = config.DATA.IMG_SIZE > 32 115 | if is_train: 116 | # this should always dispatch to transforms_imagenet_train 117 | transform = create_transform( 118 | input_size=config.DATA.IMG_SIZE, 119 | is_training=True, 120 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 121 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 122 | re_prob=config.AUG.REPROB, 123 | re_mode=config.AUG.REMODE, 124 | re_count=config.AUG.RECOUNT, 125 | interpolation=config.DATA.INTERPOLATION, 126 | ) 127 | if not resize_im: 128 | # replace RandomResizedCropAndInterpolation with 129 | # RandomCrop 130 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 131 | return transform 132 | 133 | t = [] 134 | if resize_im: 135 | if config.TEST.CROP: 136 | size = int((256 / 224) * config.DATA.IMG_SIZE) 137 | t.append( 138 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 139 | # to maintain same ratio w.r.t. 224 images 140 | ) 141 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 142 | else: 143 | t.append( 144 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 145 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 146 | ) 147 | 148 | t.append(transforms.ToTensor()) 149 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 150 | return transforms.Compose(t) 151 | -------------------------------------------------------------------------------- /BOAT-Swin/data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .zipreader import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | return img.convert('RGB') 190 | 191 | 192 | def accimage_loader(path): 193 | import accimage 194 | try: 195 | return accimage.Image(path) 196 | except IOError: 197 | # Potentially a decoding problem, fall back to PIL.Image 198 | return pil_loader(path) 199 | 200 | 201 | def default_img_loader(path): 202 | from torchvision import get_image_backend 203 | if get_image_backend() == 'accimage': 204 | return accimage_loader(path) 205 | else: 206 | return pil_loader(path) 207 | 208 | 209 | class CachedImageFolder(DatasetFolder): 210 | """A generic data loader where the images are arranged in this way: :: 211 | root/dog/xxx.png 212 | root/dog/xxy.png 213 | root/dog/xxz.png 214 | root/cat/123.png 215 | root/cat/nsdf3.png 216 | root/cat/asd932_.png 217 | Args: 218 | root (string): Root directory path. 219 | transform (callable, optional): A function/transform that takes in an PIL image 220 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 221 | target_transform (callable, optional): A function/transform that takes in the 222 | target and transforms it. 223 | loader (callable, optional): A function to load an image given its path. 224 | Attributes: 225 | imgs (list): List of (image path, class_index) tuples 226 | """ 227 | 228 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 229 | loader=default_img_loader, cache_mode="no"): 230 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 231 | ann_file=ann_file, img_prefix=img_prefix, 232 | transform=transform, target_transform=target_transform, 233 | cache_mode=cache_mode) 234 | self.imgs = self.samples 235 | 236 | def __getitem__(self, index): 237 | """ 238 | Args: 239 | index (int): Index 240 | Returns: 241 | tuple: (image, target) where target is class_index of the target class. 242 | """ 243 | path, target = self.samples[index] 244 | image = self.loader(path) 245 | if self.transform is not None: 246 | img = self.transform(image) 247 | else: 248 | img = image 249 | if self.target_transform is not None: 250 | target = self.target_transform(target) 251 | 252 | return img, target 253 | -------------------------------------------------------------------------------- /BOAT-Swin/data/map22kto1k.txt: -------------------------------------------------------------------------------- 1 | 359 2 | 368 3 | 460 4 | 475 5 | 486 6 | 492 7 | 496 8 | 514 9 | 516 10 | 525 11 | 547 12 | 548 13 | 556 14 | 563 15 | 575 16 | 641 17 | 648 18 | 723 19 | 733 20 | 765 21 | 801 22 | 826 23 | 852 24 | 858 25 | 878 26 | 896 27 | 900 28 | 905 29 | 908 30 | 910 31 | 935 32 | 946 33 | 947 34 | 994 35 | 999 36 | 1003 37 | 1005 38 | 1010 39 | 1027 40 | 1029 41 | 1048 42 | 1055 43 | 1064 44 | 1065 45 | 1069 46 | 1075 47 | 1079 48 | 1081 49 | 1085 50 | 1088 51 | 1093 52 | 1106 53 | 1143 54 | 1144 55 | 1145 56 | 1147 57 | 1168 58 | 1171 59 | 1178 60 | 1187 61 | 1190 62 | 1197 63 | 1205 64 | 1216 65 | 1223 66 | 1230 67 | 1236 68 | 1241 69 | 1245 70 | 1257 71 | 1259 72 | 1260 73 | 1267 74 | 1268 75 | 1269 76 | 1271 77 | 1272 78 | 1273 79 | 1277 80 | 1303 81 | 1344 82 | 1349 83 | 1355 84 | 1357 85 | 1384 86 | 1388 87 | 1391 88 | 1427 89 | 1429 90 | 1432 91 | 1437 92 | 1450 93 | 1461 94 | 1462 95 | 1474 96 | 1502 97 | 1503 98 | 1512 99 | 1552 100 | 1555 101 | 1577 102 | 1584 103 | 1587 104 | 1589 105 | 1599 106 | 1615 107 | 1616 108 | 1681 109 | 1692 110 | 1701 111 | 1716 112 | 1729 113 | 1757 114 | 1759 115 | 1764 116 | 1777 117 | 1786 118 | 1822 119 | 1841 120 | 1842 121 | 1848 122 | 1850 123 | 1856 124 | 1860 125 | 1861 126 | 1864 127 | 1876 128 | 1897 129 | 1898 130 | 1910 131 | 1913 132 | 1918 133 | 1922 134 | 1928 135 | 1932 136 | 1935 137 | 1947 138 | 1951 139 | 1953 140 | 1970 141 | 1977 142 | 1979 143 | 2001 144 | 2017 145 | 2067 146 | 2081 147 | 2087 148 | 2112 149 | 2128 150 | 2135 151 | 2147 152 | 2174 153 | 2175 154 | 2176 155 | 2177 156 | 2178 157 | 2181 158 | 2183 159 | 2184 160 | 2187 161 | 2189 162 | 2190 163 | 2191 164 | 2192 165 | 2193 166 | 2197 167 | 2202 168 | 2203 169 | 2206 170 | 2208 171 | 2209 172 | 2211 173 | 2212 174 | 2213 175 | 2214 176 | 2215 177 | 2216 178 | 2217 179 | 2219 180 | 2222 181 | 2223 182 | 2224 183 | 2225 184 | 2226 185 | 2227 186 | 2228 187 | 2229 188 | 2230 189 | 2236 190 | 2238 191 | 2240 192 | 2241 193 | 2242 194 | 2243 195 | 2244 196 | 2245 197 | 2247 198 | 2248 199 | 2249 200 | 2250 201 | 2251 202 | 2252 203 | 2255 204 | 2256 205 | 2257 206 | 2262 207 | 2263 208 | 2264 209 | 2265 210 | 2266 211 | 2268 212 | 2270 213 | 2271 214 | 2272 215 | 2273 216 | 2275 217 | 2276 218 | 2279 219 | 2280 220 | 2281 221 | 2282 222 | 2285 223 | 2289 224 | 2292 225 | 2295 226 | 2296 227 | 2297 228 | 2298 229 | 2299 230 | 2300 231 | 2301 232 | 2302 233 | 2303 234 | 2304 235 | 2305 236 | 2306 237 | 2309 238 | 2310 239 | 2312 240 | 2313 241 | 2314 242 | 2315 243 | 2316 244 | 2318 245 | 2319 246 | 2321 247 | 2322 248 | 2326 249 | 2329 250 | 2330 251 | 2331 252 | 2332 253 | 2334 254 | 2335 255 | 2336 256 | 2337 257 | 2338 258 | 2339 259 | 2341 260 | 2342 261 | 2343 262 | 2344 263 | 2346 264 | 2348 265 | 2349 266 | 2351 267 | 2352 268 | 2353 269 | 2355 270 | 2357 271 | 2358 272 | 2359 273 | 2360 274 | 2364 275 | 2365 276 | 2368 277 | 2369 278 | 2377 279 | 2382 280 | 2383 281 | 2385 282 | 2397 283 | 2398 284 | 2400 285 | 2402 286 | 2405 287 | 2412 288 | 2421 289 | 2428 290 | 2431 291 | 2432 292 | 2433 293 | 2436 294 | 2441 295 | 2445 296 | 2450 297 | 2453 298 | 2454 299 | 2465 300 | 2469 301 | 2532 302 | 2533 303 | 2538 304 | 2544 305 | 2547 306 | 2557 307 | 2565 308 | 2578 309 | 2612 310 | 2658 311 | 2702 312 | 2722 313 | 2731 314 | 2738 315 | 2741 316 | 2747 317 | 2810 318 | 2818 319 | 2833 320 | 2844 321 | 2845 322 | 2867 323 | 2874 324 | 2882 325 | 2884 326 | 2888 327 | 2889 328 | 3008 329 | 3012 330 | 3019 331 | 3029 332 | 3033 333 | 3042 334 | 3091 335 | 3106 336 | 3138 337 | 3159 338 | 3164 339 | 3169 340 | 3280 341 | 3296 342 | 3311 343 | 3318 344 | 3320 345 | 3324 346 | 3330 347 | 3366 348 | 3375 349 | 3381 350 | 3406 351 | 3419 352 | 3432 353 | 3434 354 | 3435 355 | 3493 356 | 3495 357 | 3503 358 | 3509 359 | 3511 360 | 3513 361 | 3517 362 | 3521 363 | 3526 364 | 3546 365 | 3554 366 | 3600 367 | 3601 368 | 3606 369 | 3612 370 | 3613 371 | 3616 372 | 3622 373 | 3623 374 | 3627 375 | 3632 376 | 3634 377 | 3636 378 | 3638 379 | 3644 380 | 3646 381 | 3649 382 | 3650 383 | 3651 384 | 3656 385 | 3663 386 | 3673 387 | 3674 388 | 3689 389 | 3690 390 | 3702 391 | 3733 392 | 3769 393 | 3971 394 | 3974 395 | 4065 396 | 4068 397 | 4073 398 | 4102 399 | 4136 400 | 4140 401 | 4151 402 | 4159 403 | 4165 404 | 4207 405 | 4219 406 | 4226 407 | 4249 408 | 4256 409 | 4263 410 | 4270 411 | 4313 412 | 4321 413 | 4378 414 | 4386 415 | 4478 416 | 4508 417 | 4512 418 | 4536 419 | 4542 420 | 4550 421 | 4560 422 | 4562 423 | 4570 424 | 4571 425 | 4572 426 | 4583 427 | 4588 428 | 4594 429 | 4604 430 | 4608 431 | 4623 432 | 4634 433 | 4636 434 | 4646 435 | 4651 436 | 4652 437 | 4686 438 | 4688 439 | 4691 440 | 4699 441 | 4724 442 | 4727 443 | 4737 444 | 4770 445 | 4774 446 | 4789 447 | 4802 448 | 4807 449 | 4819 450 | 4880 451 | 4886 452 | 4908 453 | 4927 454 | 4931 455 | 4936 456 | 4964 457 | 4976 458 | 4993 459 | 5028 460 | 5033 461 | 5043 462 | 5046 463 | 5096 464 | 5111 465 | 5114 466 | 5131 467 | 5132 468 | 5183 469 | 5199 470 | 5235 471 | 5275 472 | 5291 473 | 5293 474 | 5294 475 | 5343 476 | 5360 477 | 5362 478 | 5364 479 | 5390 480 | 5402 481 | 5418 482 | 5428 483 | 5430 484 | 5437 485 | 5443 486 | 5473 487 | 5484 488 | 5486 489 | 5505 490 | 5507 491 | 5508 492 | 5510 493 | 5567 494 | 5578 495 | 5580 496 | 5584 497 | 5606 498 | 5613 499 | 5629 500 | 5672 501 | 5676 502 | 5692 503 | 5701 504 | 5760 505 | 5769 506 | 5770 507 | 5779 508 | 5814 509 | 5850 510 | 5871 511 | 5893 512 | 5911 513 | 5949 514 | 5954 515 | 6005 516 | 6006 517 | 6012 518 | 6017 519 | 6023 520 | 6024 521 | 6040 522 | 6050 523 | 6054 524 | 6087 525 | 6105 526 | 6157 527 | 6235 528 | 6237 529 | 6256 530 | 6259 531 | 6286 532 | 6291 533 | 6306 534 | 6339 535 | 6341 536 | 6343 537 | 6379 538 | 6383 539 | 6393 540 | 6405 541 | 6479 542 | 6511 543 | 6517 544 | 6541 545 | 6561 546 | 6608 547 | 6611 548 | 6615 549 | 6678 550 | 6682 551 | 6707 552 | 6752 553 | 6798 554 | 6850 555 | 6880 556 | 6885 557 | 6890 558 | 6920 559 | 6981 560 | 7000 561 | 7009 562 | 7038 563 | 7049 564 | 7050 565 | 7052 566 | 7073 567 | 7078 568 | 7098 569 | 7111 570 | 7165 571 | 7198 572 | 7204 573 | 7280 574 | 7283 575 | 7286 576 | 7287 577 | 7293 578 | 7294 579 | 7305 580 | 7318 581 | 7341 582 | 7346 583 | 7354 584 | 7382 585 | 7427 586 | 7428 587 | 7435 588 | 7445 589 | 7450 590 | 7455 591 | 7467 592 | 7469 593 | 7497 594 | 7502 595 | 7506 596 | 7514 597 | 7523 598 | 7651 599 | 7661 600 | 7664 601 | 7672 602 | 7679 603 | 7685 604 | 7696 605 | 7730 606 | 7871 607 | 7873 608 | 7895 609 | 7914 610 | 7915 611 | 7920 612 | 7934 613 | 7935 614 | 7949 615 | 8009 616 | 8036 617 | 8051 618 | 8065 619 | 8074 620 | 8090 621 | 8112 622 | 8140 623 | 8164 624 | 8168 625 | 8178 626 | 8182 627 | 8198 628 | 8212 629 | 8216 630 | 8230 631 | 8242 632 | 8288 633 | 8289 634 | 8295 635 | 8318 636 | 8352 637 | 8368 638 | 8371 639 | 8375 640 | 8376 641 | 8401 642 | 8416 643 | 8419 644 | 8436 645 | 8460 646 | 8477 647 | 8478 648 | 8482 649 | 8498 650 | 8500 651 | 8539 652 | 8543 653 | 8552 654 | 8555 655 | 8580 656 | 8584 657 | 8586 658 | 8594 659 | 8598 660 | 8601 661 | 8606 662 | 8610 663 | 8611 664 | 8622 665 | 8627 666 | 8639 667 | 8649 668 | 8650 669 | 8653 670 | 8654 671 | 8667 672 | 8672 673 | 8673 674 | 8674 675 | 8676 676 | 8684 677 | 8720 678 | 8723 679 | 8750 680 | 8753 681 | 8801 682 | 8815 683 | 8831 684 | 8835 685 | 8842 686 | 8845 687 | 8858 688 | 8897 689 | 8916 690 | 8951 691 | 8954 692 | 8959 693 | 8970 694 | 8976 695 | 8981 696 | 8983 697 | 8989 698 | 8991 699 | 8993 700 | 9019 701 | 9039 702 | 9042 703 | 9043 704 | 9056 705 | 9057 706 | 9070 707 | 9087 708 | 9098 709 | 9106 710 | 9130 711 | 9131 712 | 9155 713 | 9171 714 | 9183 715 | 9198 716 | 9199 717 | 9201 718 | 9204 719 | 9212 720 | 9221 721 | 9225 722 | 9229 723 | 9250 724 | 9260 725 | 9271 726 | 9279 727 | 9295 728 | 9300 729 | 9310 730 | 9322 731 | 9345 732 | 9352 733 | 9376 734 | 9377 735 | 9382 736 | 9392 737 | 9401 738 | 9405 739 | 9441 740 | 9449 741 | 9464 742 | 9475 743 | 9502 744 | 9505 745 | 9514 746 | 9515 747 | 9545 748 | 9567 749 | 9576 750 | 9608 751 | 9609 752 | 9624 753 | 9633 754 | 9639 755 | 9643 756 | 9656 757 | 9674 758 | 9740 759 | 9752 760 | 9760 761 | 9767 762 | 9778 763 | 9802 764 | 9820 765 | 9839 766 | 9879 767 | 9924 768 | 9956 769 | 9961 770 | 9963 771 | 9970 772 | 9997 773 | 10010 774 | 10031 775 | 10040 776 | 10052 777 | 10073 778 | 10075 779 | 10078 780 | 10094 781 | 10097 782 | 10109 783 | 10118 784 | 10121 785 | 10124 786 | 10158 787 | 10226 788 | 10276 789 | 10304 790 | 10307 791 | 10314 792 | 10315 793 | 10332 794 | 10337 795 | 10338 796 | 10413 797 | 10423 798 | 10451 799 | 10463 800 | 10465 801 | 10487 802 | 10519 803 | 10522 804 | 10523 805 | 10532 806 | 10534 807 | 10535 808 | 10551 809 | 10559 810 | 10574 811 | 10583 812 | 10586 813 | 10589 814 | 10612 815 | 10626 816 | 10635 817 | 10638 818 | 10677 819 | 10683 820 | 10726 821 | 10776 822 | 10782 823 | 10783 824 | 10807 825 | 10837 826 | 10840 827 | 10848 828 | 10859 829 | 10871 830 | 10881 831 | 10884 832 | 10908 833 | 10914 834 | 10921 835 | 10936 836 | 10947 837 | 10951 838 | 10952 839 | 10957 840 | 10999 841 | 11003 842 | 11018 843 | 11023 844 | 11025 845 | 11027 846 | 11045 847 | 11055 848 | 11095 849 | 11110 850 | 11137 851 | 5564 852 | 11168 853 | 11186 854 | 11221 855 | 11223 856 | 11242 857 | 11255 858 | 11259 859 | 11279 860 | 11306 861 | 11311 862 | 11331 863 | 11367 864 | 11377 865 | 11389 866 | 11392 867 | 11401 868 | 11407 869 | 11437 870 | 11449 871 | 11466 872 | 11469 873 | 11473 874 | 11478 875 | 11483 876 | 11484 877 | 11507 878 | 11536 879 | 11558 880 | 11566 881 | 11575 882 | 11584 883 | 11594 884 | 11611 885 | 11612 886 | 11619 887 | 11621 888 | 11640 889 | 11643 890 | 11664 891 | 11674 892 | 11689 893 | 11709 894 | 11710 895 | 11716 896 | 11721 897 | 11726 898 | 11729 899 | 11743 900 | 11760 901 | 11771 902 | 11837 903 | 11839 904 | 11856 905 | 11876 906 | 11878 907 | 11884 908 | 11889 909 | 11896 910 | 11917 911 | 11923 912 | 11930 913 | 11944 914 | 11952 915 | 11980 916 | 11984 917 | 12214 918 | 12229 919 | 12239 920 | 12241 921 | 12242 922 | 12247 923 | 12283 924 | 12349 925 | 12369 926 | 12373 927 | 12422 928 | 12560 929 | 12566 930 | 12575 931 | 12688 932 | 12755 933 | 12768 934 | 12778 935 | 12780 936 | 12812 937 | 12832 938 | 12835 939 | 12836 940 | 12843 941 | 12847 942 | 12849 943 | 12850 944 | 12856 945 | 12858 946 | 12873 947 | 12938 948 | 12971 949 | 13017 950 | 13038 951 | 13046 952 | 13059 953 | 13085 954 | 13086 955 | 13088 956 | 13094 957 | 13134 958 | 13182 959 | 13230 960 | 13406 961 | 13444 962 | 13614 963 | 13690 964 | 13698 965 | 13709 966 | 13749 967 | 13804 968 | 13982 969 | 14051 970 | 14059 971 | 14219 972 | 14246 973 | 14256 974 | 14264 975 | 14294 976 | 14324 977 | 14367 978 | 14389 979 | 14394 980 | 14438 981 | 14442 982 | 14965 983 | 15732 984 | 16744 985 | 18037 986 | 18205 987 | 18535 988 | 18792 989 | 19102 990 | 20019 991 | 20462 992 | 21026 993 | 21045 994 | 21163 995 | 21171 996 | 21181 997 | 21196 998 | 21200 999 | 21369 1000 | 21817 -------------------------------------------------------------------------------- /BOAT-Swin/data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /BOAT-Swin/data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import zipfile 10 | import io 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageFile 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def is_zip_path(img_or_path): 19 | """judge if this is a zip path""" 20 | return '.zip@' in img_or_path 21 | 22 | 23 | class ZipReader(object): 24 | """A class to read zipped files""" 25 | zip_bank = dict() 26 | 27 | def __init__(self): 28 | super(ZipReader, self).__init__() 29 | 30 | @staticmethod 31 | def get_zipfile(path): 32 | zip_bank = ZipReader.zip_bank 33 | if path not in zip_bank: 34 | zfile = zipfile.ZipFile(path, 'r') 35 | zip_bank[path] = zfile 36 | return zip_bank[path] 37 | 38 | @staticmethod 39 | def split_zip_style_path(path): 40 | pos_at = path.index('@') 41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 42 | 43 | zip_path = path[0: pos_at] 44 | folder_path = path[pos_at + 1:] 45 | folder_path = str.strip(folder_path, '/') 46 | return zip_path, folder_path 47 | 48 | @staticmethod 49 | def list_folder(path): 50 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 51 | 52 | zfile = ZipReader.get_zipfile(zip_path) 53 | folder_list = [] 54 | for file_foler_name in zfile.namelist(): 55 | file_foler_name = str.strip(file_foler_name, '/') 56 | if file_foler_name.startswith(folder_path) and \ 57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 58 | file_foler_name != folder_path: 59 | if len(folder_path) == 0: 60 | folder_list.append(file_foler_name) 61 | else: 62 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 63 | 64 | return folder_list 65 | 66 | @staticmethod 67 | def list_files(path, extension=None): 68 | if extension is None: 69 | extension = ['.*'] 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and \ 77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 78 | if len(folder_path) == 0: 79 | file_lists.append(file_foler_name) 80 | else: 81 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 82 | 83 | return file_lists 84 | 85 | @staticmethod 86 | def read(path): 87 | zip_path, path_img = ZipReader.split_zip_style_path(path) 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | data = zfile.read(path_img) 90 | return data 91 | 92 | @staticmethod 93 | def imread(path): 94 | zip_path, path_img = ZipReader.split_zip_style_path(path) 95 | zfile = ZipReader.get_zipfile(zip_path) 96 | data = zfile.read(path_img) 97 | try: 98 | im = Image.open(io.BytesIO(data)) 99 | except: 100 | print("ERROR IMG LOADED: ", path_img) 101 | random_img = np.random.rand(224, 224, 3) * 255 102 | im = Image.fromarray(np.uint8(random_img)) 103 | return im 104 | -------------------------------------------------------------------------------- /BOAT-Swin/figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahaoyuHKU/pytorch-boat/b94a66c13c16a673e31c5c1bd246fe6f7dc99d71/BOAT-Swin/figures/teaser.png -------------------------------------------------------------------------------- /BOAT-Swin/get_started.md: -------------------------------------------------------------------------------- 1 | # Swin Transformer for Image Classification 2 | 3 | This folder contains the implementation of the Swin Transformer for image classification. 4 | 5 | ## Model Zoo 6 | 7 | ### Regular ImageNet-1K trained models 8 | 9 | | name | resolution |acc@1 | acc@5 | #params | FLOPs | model | 10 | |:---:|:---:|:---:|:---:| :---:| :---:|:---:| 11 | | Swin-T | 224x224 | 81.2 | 95.5 | 28M | 4.5G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) | 12 | | Swin-S | 224x224 | 83.2 | 96.2 | 50M | 8.7G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) | 13 | | Swin-B | 224x224 | 83.5 | 96.5 | 88M | 15.4G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) | 14 | | Swin-B | 384x384 | 84.5 | 97.0 | 88M | 47.1G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) | 15 | 16 | ### ImageNet-22K pre-trained models 17 | 18 | | name | resolution |acc@1 | acc@5 | #params | FLOPs | 22K model | 1K model | 19 | |:---: |:---: |:---:|:---:|:---:|:---:|:---:|:---:| 20 | | Swin-B | 224x224 | 85.2 | 97.5 | 88M | 15.4G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) | 21 | | Swin-B | 384x384 | 86.4 | 98.0 | 88M | 47.1G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) | 22 | | Swin-L | 224x224 | 86.3 | 97.9 | 197M | 34.5G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) | 23 | | Swin-L | 384x384 | 87.3 | 98.2 | 197M | 103.9G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) | 24 | 25 | Note: access code for `baidu` is `swin`. 26 | 27 | ## Usage 28 | 29 | ### Install 30 | 31 | - Clone this repo: 32 | 33 | ```bash 34 | git clone https://github.com/microsoft/Swin-Transformer.git 35 | cd Swin-Transformer 36 | ``` 37 | 38 | - Create a conda virtual environment and activate it: 39 | 40 | ```bash 41 | conda create -n swin python=3.7 -y 42 | conda activate swin 43 | ``` 44 | 45 | - Install `CUDA==10.1` with `cudnn7` following 46 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) 47 | - Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`: 48 | 49 | ```bash 50 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch 51 | ``` 52 | 53 | - Install `timm==0.3.2`: 54 | 55 | ```bash 56 | pip install timm==0.3.2 57 | ``` 58 | 59 | - Install `Apex`: 60 | 61 | ```bash 62 | git clone https://github.com/NVIDIA/apex 63 | cd apex 64 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 65 | ``` 66 | 67 | - Install other requirements: 68 | 69 | ```bash 70 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 71 | ``` 72 | 73 | ### Data preparation 74 | 75 | We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to 76 | load data: 77 | 78 | - For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like: 79 | ```bash 80 | $ tree data 81 | imagenet 82 | ├── train 83 | │ ├── class1 84 | │ │ ├── img1.jpeg 85 | │ │ ├── img2.jpeg 86 | │ │ └── ... 87 | │ ├── class2 88 | │ │ ├── img3.jpeg 89 | │ │ └── ... 90 | │ └── ... 91 | └── val 92 | ├── class1 93 | │ ├── img4.jpeg 94 | │ ├── img5.jpeg 95 | │ └── ... 96 | ├── class2 97 | │ ├── img6.jpeg 98 | │ └── ... 99 | └── ... 100 | 101 | ``` 102 | - To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes 103 | four files: 104 | - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits. 105 | - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth 106 | label. Make sure the data folder looks like this: 107 | 108 | ```bash 109 | $ tree data 110 | data 111 | └── ImageNet-Zip 112 | ├── train_map.txt 113 | ├── train.zip 114 | ├── val_map.txt 115 | └── val.zip 116 | 117 | $ head -n 5 data/ImageNet-Zip/val_map.txt 118 | ILSVRC2012_val_00000001.JPEG 65 119 | ILSVRC2012_val_00000002.JPEG 970 120 | ILSVRC2012_val_00000003.JPEG 230 121 | ILSVRC2012_val_00000004.JPEG 809 122 | ILSVRC2012_val_00000005.JPEG 516 123 | 124 | $ head -n 5 data/ImageNet-Zip/train_map.txt 125 | n01440764/n01440764_10026.JPEG 0 126 | n01440764/n01440764_10027.JPEG 0 127 | n01440764/n01440764_10029.JPEG 0 128 | n01440764/n01440764_10040.JPEG 0 129 | n01440764/n01440764_10042.JPEG 0 130 | ``` 131 | 132 | ### Evaluation 133 | 134 | To evaluate a pre-trained `Swin Transformer` on ImageNet val, run: 135 | 136 | ```bash 137 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \ 138 | --cfg --resume --data-path 139 | ``` 140 | 141 | For example, to evaluate the `Swin-B` with a single GPU: 142 | 143 | ```bash 144 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \ 145 | --cfg configs/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path 146 | ``` 147 | 148 | ### Training from scratch 149 | 150 | To train a `Swin Transformer` on ImageNet from scratch, run: 151 | 152 | ```bash 153 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py \ 154 | --cfg --data-path [--batch-size --output --tag ] 155 | ``` 156 | 157 | **Notes**: 158 | 159 | - To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters. 160 | - To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will 161 | shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU. 162 | - When GPU memory is not enough, you can try the following suggestions: 163 | - Use gradient accumulation by adding `--accumulation-steps `, set appropriate `` according to your need. 164 | - Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`. 165 | Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details. 166 | - We recommend using multi-node with more GPUs for training very large models, a tutorial can be found 167 | in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html). 168 | - To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g., 169 | `--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5. 170 | - For additional options, see [config](config.py) and run `python main.py --help` to get detailed message. 171 | 172 | For example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run: 173 | 174 | `Swin-T`: 175 | 176 | ```bash 177 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 178 | --cfg configs/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 128 179 | ``` 180 | 181 | `Swin-S`: 182 | 183 | ```bash 184 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 185 | --cfg configs/swin_small_patch4_window7_224.yaml --data-path --batch-size 128 186 | ``` 187 | 188 | `Swin-B`: 189 | 190 | ```bash 191 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 192 | --cfg configs/swin_base_patch4_window7_224.yaml --data-path --batch-size 64 \ 193 | --accumulation-steps 2 [--use-checkpoint] 194 | ``` 195 | 196 | ### Fine-tuning on higher resolution 197 | 198 | For example, to fine-tune a `Swin-B` model pre-trained on 224x224 resolution to 384x384 resolution: 199 | 200 | ```bashs 201 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 202 | --cfg configs/swin_base_patch4_window12_384_finetune.yaml --pretrained swin_base_patch4_window7_224.pth \ 203 | --data-path --batch-size 64 --accumulation-steps 2 [--use-checkpoint] 204 | ``` 205 | 206 | ### Fine-tuning from a ImageNet-22K(21K) pre-trained model 207 | 208 | For example, to fine-tune a `Swin-B` model pre-trained on ImageNet-22K(21K): 209 | 210 | ```bashs 211 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 212 | --cfg configs/swin_base_patch4_window7_224_22kto1k_finetune.yaml --pretrained swin_base_patch4_window7_224_22k.pth \ 213 | --data-path --batch-size 64 --accumulation-steps 2 [--use-checkpoint] 214 | ``` 215 | 216 | ### Throughput 217 | 218 | To measure the throughput, run: 219 | 220 | ```bash 221 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \ 222 | --cfg --data-path --batch-size 64 --throughput --amp-opt-level O0 223 | ``` 224 | -------------------------------------------------------------------------------- /BOAT-Swin/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /BOAT-Swin/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from timm.scheduler.cosine_lr import CosineLRScheduler 10 | from timm.scheduler.step_lr import StepLRScheduler 11 | from timm.scheduler.scheduler import Scheduler 12 | 13 | 14 | def build_scheduler(config, optimizer, n_iter_per_epoch): 15 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 16 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 17 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 18 | 19 | lr_scheduler = None 20 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 21 | lr_scheduler = CosineLRScheduler( 22 | optimizer, 23 | t_initial=num_steps, 24 | t_mul=1., 25 | lr_min=config.TRAIN.MIN_LR, 26 | warmup_lr_init=config.TRAIN.WARMUP_LR, 27 | warmup_t=warmup_steps, 28 | cycle_limit=1, 29 | t_in_epochs=False, 30 | ) 31 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 32 | lr_scheduler = LinearLRScheduler( 33 | optimizer, 34 | t_initial=num_steps, 35 | lr_min_rate=0.01, 36 | warmup_lr_init=config.TRAIN.WARMUP_LR, 37 | warmup_t=warmup_steps, 38 | t_in_epochs=False, 39 | ) 40 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 41 | lr_scheduler = StepLRScheduler( 42 | optimizer, 43 | decay_t=decay_steps, 44 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 45 | warmup_lr_init=config.TRAIN.WARMUP_LR, 46 | warmup_t=warmup_steps, 47 | t_in_epochs=False, 48 | ) 49 | 50 | return lr_scheduler 51 | 52 | 53 | class LinearLRScheduler(Scheduler): 54 | def __init__(self, 55 | optimizer: torch.optim.Optimizer, 56 | t_initial: int, 57 | lr_min_rate: float, 58 | warmup_t=0, 59 | warmup_lr_init=0., 60 | t_in_epochs=True, 61 | noise_range_t=None, 62 | noise_pct=0.67, 63 | noise_std=1.0, 64 | noise_seed=42, 65 | initialize=True, 66 | ) -> None: 67 | super().__init__( 68 | optimizer, param_group_field="lr", 69 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 70 | initialize=initialize) 71 | 72 | self.t_initial = t_initial 73 | self.lr_min_rate = lr_min_rate 74 | self.warmup_t = warmup_t 75 | self.warmup_lr_init = warmup_lr_init 76 | self.t_in_epochs = t_in_epochs 77 | if self.warmup_t: 78 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 79 | super().update_groups(self.warmup_lr_init) 80 | else: 81 | self.warmup_steps = [1 for _ in self.base_values] 82 | 83 | def _get_lr(self, t): 84 | if t < self.warmup_t: 85 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 86 | else: 87 | t = t - self.warmup_t 88 | total_t = self.t_initial - self.warmup_t 89 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 90 | return lrs 91 | 92 | def get_epoch_values(self, epoch: int): 93 | if self.t_in_epochs: 94 | return self._get_lr(epoch) 95 | else: 96 | return None 97 | 98 | def get_update_values(self, num_updates: int): 99 | if not self.t_in_epochs: 100 | return self._get_lr(num_updates) 101 | else: 102 | return None 103 | -------------------------------------------------------------------------------- /BOAT-Swin/main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import random 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | 19 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 20 | from timm.utils import accuracy, AverageMeter 21 | 22 | from config import get_config 23 | from models import build_model 24 | from data import build_loader 25 | from lr_scheduler import build_scheduler 26 | from optimizer import build_optimizer 27 | from logger import create_logger 28 | from utils import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 29 | 30 | try: 31 | # noinspection PyUnresolvedReferences 32 | from apex import amp 33 | except ImportError: 34 | amp = None 35 | 36 | 37 | def parse_option(): 38 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 39 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 40 | parser.add_argument( 41 | "--opts", 42 | help="Modify config options by adding 'KEY VALUE' pairs. ", 43 | default=None, 44 | nargs='+', 45 | ) 46 | 47 | # easy config modification 48 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 49 | parser.add_argument('--data-path', type=str, help='path to dataset') 50 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 51 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 52 | help='no: no cache, ' 53 | 'full: cache all data, ' 54 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 55 | parser.add_argument('--pretrained', 56 | help='pretrained weight from checkpoint, could be imagenet22k pretrained weight') 57 | parser.add_argument('--resume', help='resume from checkpoint') 58 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 59 | parser.add_argument('--use-checkpoint', action='store_true', 60 | help="whether to use gradient checkpointing to save memory") 61 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 62 | help='mixed precision opt level, if O0, no amp is used') 63 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 64 | help='root of output folder, the full path is // (default: output)') 65 | parser.add_argument('--tag', help='tag of experiment') 66 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 67 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 68 | 69 | # distributed training 70 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 71 | 72 | args, unparsed = parser.parse_known_args() 73 | 74 | config = get_config(args) 75 | 76 | return args, config 77 | 78 | 79 | def main(config): 80 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 81 | 82 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 83 | model = build_model(config) 84 | model.cuda() 85 | logger.info(str(model)) 86 | 87 | optimizer = build_optimizer(config, model) 88 | if config.AMP_OPT_LEVEL != "O0": 89 | model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) 90 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 91 | model_without_ddp = model.module 92 | 93 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 94 | logger.info(f"number of params: {n_parameters}") 95 | if hasattr(model_without_ddp, 'flops'): 96 | flops = model_without_ddp.flops() 97 | logger.info(f"number of GFLOPs: {flops / 1e9}") 98 | 99 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 100 | 101 | if config.AUG.MIXUP > 0.: 102 | # smoothing is handled with mixup label transform 103 | criterion = SoftTargetCrossEntropy() 104 | elif config.MODEL.LABEL_SMOOTHING > 0.: 105 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 106 | else: 107 | criterion = torch.nn.CrossEntropyLoss() 108 | 109 | max_accuracy = 0.0 110 | 111 | if config.TRAIN.AUTO_RESUME: 112 | resume_file = auto_resume_helper(config.OUTPUT) 113 | if resume_file: 114 | if config.MODEL.RESUME: 115 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 116 | config.defrost() 117 | config.MODEL.RESUME = resume_file 118 | config.freeze() 119 | logger.info(f'auto resuming from {resume_file}') 120 | else: 121 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 122 | 123 | if config.MODEL.RESUME: 124 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 125 | acc1, acc5, loss = validate(config, data_loader_val, model) 126 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 127 | if config.EVAL_MODE: 128 | return 129 | 130 | if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): 131 | load_pretrained(config, model_without_ddp, logger) 132 | acc1, acc5, loss = validate(config, data_loader_val, model) 133 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 134 | 135 | if config.THROUGHPUT_MODE: 136 | throughput(data_loader_val, model, logger) 137 | return 138 | 139 | logger.info("Start training") 140 | start_time = time.time() 141 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 142 | data_loader_train.sampler.set_epoch(epoch) 143 | 144 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 145 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 146 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 147 | 148 | acc1, acc5, loss = validate(config, data_loader_val, model) 149 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 150 | max_accuracy = max(max_accuracy, acc1) 151 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 152 | 153 | total_time = time.time() - start_time 154 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 155 | logger.info('Training time {}'.format(total_time_str)) 156 | 157 | 158 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 159 | model.train() 160 | optimizer.zero_grad() 161 | 162 | num_steps = len(data_loader) 163 | batch_time = AverageMeter() 164 | loss_meter = AverageMeter() 165 | norm_meter = AverageMeter() 166 | 167 | start = time.time() 168 | end = time.time() 169 | for idx, (samples, targets) in enumerate(data_loader): 170 | samples = samples.cuda(non_blocking=True) 171 | targets = targets.cuda(non_blocking=True) 172 | 173 | if mixup_fn is not None: 174 | samples, targets = mixup_fn(samples, targets) 175 | 176 | outputs = model(samples) 177 | 178 | if config.TRAIN.ACCUMULATION_STEPS > 1: 179 | loss = criterion(outputs, targets) 180 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 181 | if config.AMP_OPT_LEVEL != "O0": 182 | with amp.scale_loss(loss, optimizer) as scaled_loss: 183 | scaled_loss.backward() 184 | if config.TRAIN.CLIP_GRAD: 185 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 186 | else: 187 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 188 | else: 189 | loss.backward() 190 | if config.TRAIN.CLIP_GRAD: 191 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 192 | else: 193 | grad_norm = get_grad_norm(model.parameters()) 194 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 195 | optimizer.step() 196 | optimizer.zero_grad() 197 | lr_scheduler.step_update(epoch * num_steps + idx) 198 | else: 199 | loss = criterion(outputs, targets) 200 | optimizer.zero_grad() 201 | if config.AMP_OPT_LEVEL != "O0": 202 | with amp.scale_loss(loss, optimizer) as scaled_loss: 203 | scaled_loss.backward() 204 | if config.TRAIN.CLIP_GRAD: 205 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 206 | else: 207 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 208 | else: 209 | loss.backward() 210 | if config.TRAIN.CLIP_GRAD: 211 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 212 | else: 213 | grad_norm = get_grad_norm(model.parameters()) 214 | optimizer.step() 215 | lr_scheduler.step_update(epoch * num_steps + idx) 216 | 217 | torch.cuda.synchronize() 218 | 219 | loss_meter.update(loss.item(), targets.size(0)) 220 | norm_meter.update(grad_norm) 221 | batch_time.update(time.time() - end) 222 | end = time.time() 223 | 224 | if idx % config.PRINT_FREQ == 0: 225 | lr = optimizer.param_groups[0]['lr'] 226 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 227 | etas = batch_time.avg * (num_steps - idx) 228 | logger.info( 229 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 230 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 231 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 232 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 233 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 234 | f'mem {memory_used:.0f}MB') 235 | epoch_time = time.time() - start 236 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 237 | 238 | 239 | @torch.no_grad() 240 | def validate(config, data_loader, model): 241 | criterion = torch.nn.CrossEntropyLoss() 242 | model.eval() 243 | 244 | batch_time = AverageMeter() 245 | loss_meter = AverageMeter() 246 | acc1_meter = AverageMeter() 247 | acc5_meter = AverageMeter() 248 | 249 | end = time.time() 250 | for idx, (images, target) in enumerate(data_loader): 251 | images = images.cuda(non_blocking=True) 252 | target = target.cuda(non_blocking=True) 253 | 254 | # compute output 255 | output = model(images) 256 | 257 | # measure accuracy and record loss 258 | loss = criterion(output, target) 259 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 260 | 261 | acc1 = reduce_tensor(acc1) 262 | acc5 = reduce_tensor(acc5) 263 | loss = reduce_tensor(loss) 264 | 265 | loss_meter.update(loss.item(), target.size(0)) 266 | acc1_meter.update(acc1.item(), target.size(0)) 267 | acc5_meter.update(acc5.item(), target.size(0)) 268 | 269 | # measure elapsed time 270 | batch_time.update(time.time() - end) 271 | end = time.time() 272 | 273 | if idx % config.PRINT_FREQ == 0: 274 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 275 | logger.info( 276 | f'Test: [{idx}/{len(data_loader)}]\t' 277 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 278 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 279 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 280 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 281 | f'Mem {memory_used:.0f}MB') 282 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 283 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 284 | 285 | 286 | @torch.no_grad() 287 | def throughput(data_loader, model, logger): 288 | model.eval() 289 | 290 | for idx, (images, _) in enumerate(data_loader): 291 | images = images.cuda(non_blocking=True) 292 | batch_size = images.shape[0] 293 | for i in range(50): 294 | model(images) 295 | torch.cuda.synchronize() 296 | logger.info(f"throughput averaged with 30 times") 297 | tic1 = time.time() 298 | for i in range(30): 299 | model(images) 300 | torch.cuda.synchronize() 301 | tic2 = time.time() 302 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 303 | return 304 | 305 | 306 | if __name__ == '__main__': 307 | _, config = parse_option() 308 | 309 | if config.AMP_OPT_LEVEL != "O0": 310 | assert amp is not None, "amp not installed!" 311 | 312 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 313 | rank = int(os.environ["RANK"]) 314 | world_size = int(os.environ['WORLD_SIZE']) 315 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 316 | else: 317 | rank = -1 318 | world_size = -1 319 | torch.cuda.set_device(config.LOCAL_RANK) 320 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 321 | torch.distributed.barrier() 322 | 323 | seed = config.SEED + dist.get_rank() 324 | torch.manual_seed(seed) 325 | torch.cuda.manual_seed(seed) 326 | np.random.seed(seed) 327 | random.seed(seed) 328 | cudnn.benchmark = True 329 | 330 | # linear scale the learning rate according to total batch size, may not be optimal 331 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 332 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 333 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 334 | # gradient accumulation also need to scale the learning rate 335 | if config.TRAIN.ACCUMULATION_STEPS > 1: 336 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 337 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 338 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 339 | config.defrost() 340 | config.TRAIN.BASE_LR = linear_scaled_lr 341 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 342 | config.TRAIN.MIN_LR = linear_scaled_min_lr 343 | config.freeze() 344 | 345 | os.makedirs(config.OUTPUT, exist_ok=True) 346 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 347 | 348 | if dist.get_rank() == 0: 349 | path = os.path.join(config.OUTPUT, "config.json") 350 | with open(path, "w") as f: 351 | f.write(config.dump()) 352 | logger.info(f"Full config saved to {path}") 353 | 354 | # print config 355 | logger.info(config.dump()) 356 | 357 | main(config) 358 | -------------------------------------------------------------------------------- /BOAT-Swin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /BOAT-Swin/models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .boat_swin_transformer import SwinTransformer 9 | from .swin_mlp import SwinMLP 10 | 11 | 12 | def build_model(config): 13 | model_type = config.MODEL.TYPE 14 | if model_type == 'swin': 15 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 16 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 17 | in_chans=config.MODEL.SWIN.IN_CHANS, 18 | num_classes=config.MODEL.NUM_CLASSES, 19 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 20 | depths=config.MODEL.SWIN.DEPTHS, 21 | num_heads=config.MODEL.SWIN.NUM_HEADS, 22 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 23 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 24 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 25 | qk_scale=config.MODEL.SWIN.QK_SCALE, 26 | drop_rate=config.MODEL.DROP_RATE, 27 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 28 | ape=config.MODEL.SWIN.APE, 29 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 30 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 31 | elif model_type == 'swin_mlp': 32 | model = SwinMLP(img_size=config.DATA.IMG_SIZE, 33 | patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE, 34 | in_chans=config.MODEL.SWIN_MLP.IN_CHANS, 35 | num_classes=config.MODEL.NUM_CLASSES, 36 | embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM, 37 | depths=config.MODEL.SWIN_MLP.DEPTHS, 38 | num_heads=config.MODEL.SWIN_MLP.NUM_HEADS, 39 | window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE, 40 | mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO, 41 | drop_rate=config.MODEL.DROP_RATE, 42 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 43 | ape=config.MODEL.SWIN_MLP.APE, 44 | patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM, 45 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 46 | else: 47 | raise NotImplementedError(f"Unkown model: {model_type}") 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /BOAT-Swin/models/swin_mlp.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | def window_partition(x, window_size): 35 | """ 36 | Args: 37 | x: (B, H, W, C) 38 | window_size (int): window size 39 | 40 | Returns: 41 | windows: (num_windows*B, window_size, window_size, C) 42 | """ 43 | B, H, W, C = x.shape 44 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 45 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 46 | return windows 47 | 48 | 49 | def window_reverse(windows, window_size, H, W): 50 | """ 51 | Args: 52 | windows: (num_windows*B, window_size, window_size, C) 53 | window_size (int): Window size 54 | H (int): Height of image 55 | W (int): Width of image 56 | 57 | Returns: 58 | x: (B, H, W, C) 59 | """ 60 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 61 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 62 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 63 | return x 64 | 65 | 66 | class SwinMLPBlock(nn.Module): 67 | r""" Swin MLP Block. 68 | 69 | Args: 70 | dim (int): Number of input channels. 71 | input_resolution (tuple[int]): Input resulotion. 72 | num_heads (int): Number of attention heads. 73 | window_size (int): Window size. 74 | shift_size (int): Shift size for SW-MSA. 75 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 76 | drop (float, optional): Dropout rate. Default: 0.0 77 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 78 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 79 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 80 | """ 81 | 82 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 83 | mlp_ratio=4., drop=0., drop_path=0., 84 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 85 | super().__init__() 86 | self.dim = dim 87 | self.input_resolution = input_resolution 88 | self.num_heads = num_heads 89 | self.window_size = window_size 90 | self.shift_size = shift_size 91 | self.mlp_ratio = mlp_ratio 92 | if min(self.input_resolution) <= self.window_size: 93 | # if window size is larger than input resolution, we don't partition windows 94 | self.shift_size = 0 95 | self.window_size = min(self.input_resolution) 96 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 97 | 98 | self.padding = [self.window_size - self.shift_size, self.shift_size, 99 | self.window_size - self.shift_size, self.shift_size] # P_l,P_r,P_t,P_b 100 | 101 | self.norm1 = norm_layer(dim) 102 | # use group convolution to implement multi-head MLP 103 | self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2, 104 | self.num_heads * self.window_size ** 2, 105 | kernel_size=1, 106 | groups=self.num_heads) 107 | 108 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 109 | self.norm2 = norm_layer(dim) 110 | mlp_hidden_dim = int(dim * mlp_ratio) 111 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 112 | 113 | def forward(self, x): 114 | H, W = self.input_resolution 115 | B, L, C = x.shape 116 | assert L == H * W, "input feature has wrong size" 117 | 118 | shortcut = x 119 | x = self.norm1(x) 120 | x = x.view(B, H, W, C) 121 | 122 | # shift 123 | if self.shift_size > 0: 124 | P_l, P_r, P_t, P_b = self.padding 125 | shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], "constant", 0) 126 | else: 127 | shifted_x = x 128 | _, _H, _W, _ = shifted_x.shape 129 | 130 | # partition windows 131 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 132 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 133 | 134 | # Window/Shifted-Window Spatial MLP 135 | x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads) 136 | x_windows_heads = x_windows_heads.transpose(1, 2) # nW*B, nH, window_size*window_size, C//nH 137 | x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size, 138 | C // self.num_heads) 139 | spatial_mlp_windows = self.spatial_mlp(x_windows_heads) # nW*B, nH*window_size*window_size, C//nH 140 | spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size, 141 | C // self.num_heads).transpose(1, 2) 142 | spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C) 143 | 144 | # merge windows 145 | spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C) 146 | shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W) # B H' W' C 147 | 148 | # reverse shift 149 | if self.shift_size > 0: 150 | P_l, P_r, P_t, P_b = self.padding 151 | x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous() 152 | else: 153 | x = shifted_x 154 | x = x.view(B, H * W, C) 155 | 156 | # FFN 157 | x = shortcut + self.drop_path(x) 158 | x = x + self.drop_path(self.mlp(self.norm2(x))) 159 | 160 | return x 161 | 162 | def extra_repr(self) -> str: 163 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 164 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 165 | 166 | def flops(self): 167 | flops = 0 168 | H, W = self.input_resolution 169 | # norm1 170 | flops += self.dim * H * W 171 | 172 | # Window/Shifted-Window Spatial MLP 173 | if self.shift_size > 0: 174 | nW = (H / self.window_size + 1) * (W / self.window_size + 1) 175 | else: 176 | nW = H * W / self.window_size / self.window_size 177 | flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size) 178 | # mlp 179 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 180 | # norm2 181 | flops += self.dim * H * W 182 | return flops 183 | 184 | 185 | class PatchMerging(nn.Module): 186 | r""" Patch Merging Layer. 187 | 188 | Args: 189 | input_resolution (tuple[int]): Resolution of input feature. 190 | dim (int): Number of input channels. 191 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 192 | """ 193 | 194 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 195 | super().__init__() 196 | self.input_resolution = input_resolution 197 | self.dim = dim 198 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 199 | self.norm = norm_layer(4 * dim) 200 | 201 | def forward(self, x): 202 | """ 203 | x: B, H*W, C 204 | """ 205 | H, W = self.input_resolution 206 | B, L, C = x.shape 207 | assert L == H * W, "input feature has wrong size" 208 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 209 | 210 | x = x.view(B, H, W, C) 211 | 212 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 213 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 214 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 215 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 216 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 217 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 218 | 219 | x = self.norm(x) 220 | x = self.reduction(x) 221 | 222 | return x 223 | 224 | def extra_repr(self) -> str: 225 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 226 | 227 | def flops(self): 228 | H, W = self.input_resolution 229 | flops = H * W * self.dim 230 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 231 | return flops 232 | 233 | 234 | class BasicLayer(nn.Module): 235 | """ A basic Swin MLP layer for one stage. 236 | 237 | Args: 238 | dim (int): Number of input channels. 239 | input_resolution (tuple[int]): Input resolution. 240 | depth (int): Number of blocks. 241 | num_heads (int): Number of attention heads. 242 | window_size (int): Local window size. 243 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 244 | drop (float, optional): Dropout rate. Default: 0.0 245 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 246 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 247 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 248 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 249 | """ 250 | 251 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 252 | mlp_ratio=4., drop=0., drop_path=0., 253 | norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 254 | 255 | super().__init__() 256 | self.dim = dim 257 | self.input_resolution = input_resolution 258 | self.depth = depth 259 | self.use_checkpoint = use_checkpoint 260 | 261 | # build blocks 262 | self.blocks = nn.ModuleList([ 263 | SwinMLPBlock(dim=dim, input_resolution=input_resolution, 264 | num_heads=num_heads, window_size=window_size, 265 | shift_size=0 if (i % 2 == 0) else window_size // 2, 266 | mlp_ratio=mlp_ratio, 267 | drop=drop, 268 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 269 | norm_layer=norm_layer) 270 | for i in range(depth)]) 271 | 272 | # patch merging layer 273 | if downsample is not None: 274 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 275 | else: 276 | self.downsample = None 277 | 278 | def forward(self, x): 279 | for blk in self.blocks: 280 | if self.use_checkpoint: 281 | x = checkpoint.checkpoint(blk, x) 282 | else: 283 | x = blk(x) 284 | if self.downsample is not None: 285 | x = self.downsample(x) 286 | return x 287 | 288 | def extra_repr(self) -> str: 289 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 290 | 291 | def flops(self): 292 | flops = 0 293 | for blk in self.blocks: 294 | flops += blk.flops() 295 | if self.downsample is not None: 296 | flops += self.downsample.flops() 297 | return flops 298 | 299 | 300 | class PatchEmbed(nn.Module): 301 | r""" Image to Patch Embedding 302 | 303 | Args: 304 | img_size (int): Image size. Default: 224. 305 | patch_size (int): Patch token size. Default: 4. 306 | in_chans (int): Number of input image channels. Default: 3. 307 | embed_dim (int): Number of linear projection output channels. Default: 96. 308 | norm_layer (nn.Module, optional): Normalization layer. Default: None 309 | """ 310 | 311 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 312 | super().__init__() 313 | img_size = to_2tuple(img_size) 314 | patch_size = to_2tuple(patch_size) 315 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 316 | self.img_size = img_size 317 | self.patch_size = patch_size 318 | self.patches_resolution = patches_resolution 319 | self.num_patches = patches_resolution[0] * patches_resolution[1] 320 | 321 | self.in_chans = in_chans 322 | self.embed_dim = embed_dim 323 | 324 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 325 | if norm_layer is not None: 326 | self.norm = norm_layer(embed_dim) 327 | else: 328 | self.norm = None 329 | 330 | def forward(self, x): 331 | B, C, H, W = x.shape 332 | # FIXME look at relaxing size constraints 333 | assert H == self.img_size[0] and W == self.img_size[1], \ 334 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 335 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 336 | if self.norm is not None: 337 | x = self.norm(x) 338 | return x 339 | 340 | def flops(self): 341 | Ho, Wo = self.patches_resolution 342 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 343 | if self.norm is not None: 344 | flops += Ho * Wo * self.embed_dim 345 | return flops 346 | 347 | 348 | class SwinMLP(nn.Module): 349 | r""" Swin MLP 350 | 351 | Args: 352 | img_size (int | tuple(int)): Input image size. Default 224 353 | patch_size (int | tuple(int)): Patch size. Default: 4 354 | in_chans (int): Number of input image channels. Default: 3 355 | num_classes (int): Number of classes for classification head. Default: 1000 356 | embed_dim (int): Patch embedding dimension. Default: 96 357 | depths (tuple(int)): Depth of each Swin MLP layer. 358 | num_heads (tuple(int)): Number of attention heads in different layers. 359 | window_size (int): Window size. Default: 7 360 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 361 | drop_rate (float): Dropout rate. Default: 0 362 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 363 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 364 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 365 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 366 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 367 | """ 368 | 369 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 370 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 371 | window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1, 372 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 373 | use_checkpoint=False, **kwargs): 374 | super().__init__() 375 | 376 | self.num_classes = num_classes 377 | self.num_layers = len(depths) 378 | self.embed_dim = embed_dim 379 | self.ape = ape 380 | self.patch_norm = patch_norm 381 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 382 | self.mlp_ratio = mlp_ratio 383 | 384 | # split image into non-overlapping patches 385 | self.patch_embed = PatchEmbed( 386 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 387 | norm_layer=norm_layer if self.patch_norm else None) 388 | num_patches = self.patch_embed.num_patches 389 | patches_resolution = self.patch_embed.patches_resolution 390 | self.patches_resolution = patches_resolution 391 | 392 | # absolute position embedding 393 | if self.ape: 394 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 395 | trunc_normal_(self.absolute_pos_embed, std=.02) 396 | 397 | self.pos_drop = nn.Dropout(p=drop_rate) 398 | 399 | # stochastic depth 400 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 401 | 402 | # build layers 403 | self.layers = nn.ModuleList() 404 | for i_layer in range(self.num_layers): 405 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 406 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 407 | patches_resolution[1] // (2 ** i_layer)), 408 | depth=depths[i_layer], 409 | num_heads=num_heads[i_layer], 410 | window_size=window_size, 411 | mlp_ratio=self.mlp_ratio, 412 | drop=drop_rate, 413 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 414 | norm_layer=norm_layer, 415 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 416 | use_checkpoint=use_checkpoint) 417 | self.layers.append(layer) 418 | 419 | self.norm = norm_layer(self.num_features) 420 | self.avgpool = nn.AdaptiveAvgPool1d(1) 421 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 422 | 423 | self.apply(self._init_weights) 424 | 425 | def _init_weights(self, m): 426 | if isinstance(m, (nn.Linear, nn.Conv1d)): 427 | trunc_normal_(m.weight, std=.02) 428 | if m.bias is not None: 429 | nn.init.constant_(m.bias, 0) 430 | elif isinstance(m, nn.LayerNorm): 431 | nn.init.constant_(m.bias, 0) 432 | nn.init.constant_(m.weight, 1.0) 433 | 434 | @torch.jit.ignore 435 | def no_weight_decay(self): 436 | return {'absolute_pos_embed'} 437 | 438 | @torch.jit.ignore 439 | def no_weight_decay_keywords(self): 440 | return {'relative_position_bias_table'} 441 | 442 | def forward_features(self, x): 443 | x = self.patch_embed(x) 444 | if self.ape: 445 | x = x + self.absolute_pos_embed 446 | x = self.pos_drop(x) 447 | 448 | for layer in self.layers: 449 | x = layer(x) 450 | 451 | x = self.norm(x) # B L C 452 | x = self.avgpool(x.transpose(1, 2)) # B C 1 453 | x = torch.flatten(x, 1) 454 | return x 455 | 456 | def forward(self, x): 457 | x = self.forward_features(x) 458 | x = self.head(x) 459 | return x 460 | 461 | def flops(self): 462 | flops = 0 463 | flops += self.patch_embed.flops() 464 | for i, layer in enumerate(self.layers): 465 | flops += layer.flops() 466 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 467 | flops += self.num_features * self.num_classes 468 | return flops 469 | -------------------------------------------------------------------------------- /BOAT-Swin/optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from torch import optim as optim 9 | 10 | 11 | def build_optimizer(config, model): 12 | """ 13 | Build optimizer, set weight decay of normalization to 0 by default. 14 | """ 15 | skip = {} 16 | skip_keywords = {} 17 | if hasattr(model, 'no_weight_decay'): 18 | skip = model.no_weight_decay() 19 | if hasattr(model, 'no_weight_decay_keywords'): 20 | skip_keywords = model.no_weight_decay_keywords() 21 | parameters = set_weight_decay(model, skip, skip_keywords) 22 | 23 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 24 | optimizer = None 25 | if opt_lower == 'sgd': 26 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 27 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 28 | elif opt_lower == 'adamw': 29 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 30 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 31 | 32 | return optimizer 33 | 34 | 35 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 36 | has_decay = [] 37 | no_decay = [] 38 | 39 | for name, param in model.named_parameters(): 40 | if not param.requires_grad: 41 | continue # frozen weights 42 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 43 | check_keywords_in_name(name, skip_keywords): 44 | no_decay.append(param) 45 | # print(f"{name} has no weight decay") 46 | else: 47 | has_decay.append(param) 48 | return [{'params': has_decay}, 49 | {'params': no_decay, 'weight_decay': 0.}] 50 | 51 | 52 | def check_keywords_in_name(name, keywords=()): 53 | isin = False 54 | for keyword in keywords: 55 | if keyword in name: 56 | isin = True 57 | return isin 58 | -------------------------------------------------------------------------------- /BOAT-Swin/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.distributed as dist 11 | 12 | try: 13 | # noinspection PyUnresolvedReferences 14 | from apex import amp 15 | except ImportError: 16 | amp = None 17 | 18 | 19 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 20 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 21 | if config.MODEL.RESUME.startswith('https'): 22 | checkpoint = torch.hub.load_state_dict_from_url( 23 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 24 | else: 25 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 26 | msg = model.load_state_dict(checkpoint['model'], strict=False) 27 | logger.info(msg) 28 | max_accuracy = 0.0 29 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 30 | optimizer.load_state_dict(checkpoint['optimizer']) 31 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 32 | config.defrost() 33 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 34 | config.freeze() 35 | if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": 36 | amp.load_state_dict(checkpoint['amp']) 37 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 38 | if 'max_accuracy' in checkpoint: 39 | max_accuracy = checkpoint['max_accuracy'] 40 | 41 | del checkpoint 42 | torch.cuda.empty_cache() 43 | return max_accuracy 44 | 45 | 46 | def load_pretrained(config, model, logger): 47 | logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") 48 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') 49 | state_dict = checkpoint['model'] 50 | 51 | # delete relative_position_index since we always re-init it 52 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] 53 | for k in relative_position_index_keys: 54 | del state_dict[k] 55 | 56 | # delete relative_coords_table since we always re-init it 57 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] 58 | for k in relative_position_index_keys: 59 | del state_dict[k] 60 | 61 | # delete attn_mask since we always re-init it 62 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 63 | for k in attn_mask_keys: 64 | del state_dict[k] 65 | 66 | # bicubic interpolate relative_position_bias_table if not match 67 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 68 | for k in relative_position_bias_table_keys: 69 | relative_position_bias_table_pretrained = state_dict[k] 70 | relative_position_bias_table_current = model.state_dict()[k] 71 | L1, nH1 = relative_position_bias_table_pretrained.size() 72 | L2, nH2 = relative_position_bias_table_current.size() 73 | if nH1 != nH2: 74 | logger.warning(f"Error in loading {k}, passing......") 75 | else: 76 | if L1 != L2: 77 | # bicubic interpolate relative_position_bias_table if not match 78 | S1 = int(L1 ** 0.5) 79 | S2 = int(L2 ** 0.5) 80 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( 81 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), 82 | mode='bicubic') 83 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) 84 | 85 | # bicubic interpolate absolute_pos_embed if not match 86 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] 87 | for k in absolute_pos_embed_keys: 88 | # dpe 89 | absolute_pos_embed_pretrained = state_dict[k] 90 | absolute_pos_embed_current = model.state_dict()[k] 91 | _, L1, C1 = absolute_pos_embed_pretrained.size() 92 | _, L2, C2 = absolute_pos_embed_current.size() 93 | if C1 != C1: 94 | logger.warning(f"Error in loading {k}, passing......") 95 | else: 96 | if L1 != L2: 97 | S1 = int(L1 ** 0.5) 98 | S2 = int(L2 ** 0.5) 99 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) 100 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) 101 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( 102 | absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') 103 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) 104 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) 105 | state_dict[k] = absolute_pos_embed_pretrained_resized 106 | 107 | # check classifier, if not match, then re-init classifier to zero 108 | head_bias_pretrained = state_dict['head.bias'] 109 | Nc1 = head_bias_pretrained.shape[0] 110 | Nc2 = model.head.bias.shape[0] 111 | if (Nc1 != Nc2): 112 | if Nc1 == 21841 and Nc2 == 1000: 113 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......") 114 | map22kto1k_path = f'data/map22kto1k.txt' 115 | with open(map22kto1k_path) as f: 116 | map22kto1k = f.readlines() 117 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] 118 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] 119 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] 120 | else: 121 | torch.nn.init.constant_(model.head.bias, 0.) 122 | torch.nn.init.constant_(model.head.weight, 0.) 123 | del state_dict['head.weight'] 124 | del state_dict['head.bias'] 125 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0") 126 | 127 | msg = model.load_state_dict(state_dict, strict=False) 128 | logger.warning(msg) 129 | 130 | logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'") 131 | 132 | del checkpoint 133 | torch.cuda.empty_cache() 134 | 135 | 136 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger): 137 | save_state = {'model': model.state_dict(), 138 | 'optimizer': optimizer.state_dict(), 139 | 'lr_scheduler': lr_scheduler.state_dict(), 140 | 'max_accuracy': max_accuracy, 141 | 'epoch': epoch, 142 | 'config': config} 143 | if config.AMP_OPT_LEVEL != "O0": 144 | save_state['amp'] = amp.state_dict() 145 | 146 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 147 | logger.info(f"{save_path} saving......") 148 | torch.save(save_state, save_path) 149 | logger.info(f"{save_path} saved !!!") 150 | 151 | 152 | def get_grad_norm(parameters, norm_type=2): 153 | if isinstance(parameters, torch.Tensor): 154 | parameters = [parameters] 155 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 156 | norm_type = float(norm_type) 157 | total_norm = 0 158 | for p in parameters: 159 | param_norm = p.grad.data.norm(norm_type) 160 | total_norm += param_norm.item() ** norm_type 161 | total_norm = total_norm ** (1. / norm_type) 162 | return total_norm 163 | 164 | 165 | def auto_resume_helper(output_dir): 166 | checkpoints = os.listdir(output_dir) 167 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 168 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 169 | if len(checkpoints) > 0: 170 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 171 | print(f"The latest checkpoint founded: {latest_checkpoint}") 172 | resume_file = latest_checkpoint 173 | else: 174 | resume_file = None 175 | return resume_file 176 | 177 | 178 | def reduce_tensor(tensor): 179 | rt = tensor.clone() 180 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 181 | rt /= dist.get_world_size() 182 | return rt 183 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BOAT: Bilateral Local Attention Vision Transformer 2 | 3 | 4 | This is an unofficial implementation of the paper BOAT: Bilateral Local Attention Vision Transformer. 5 | 6 | The [Swin variant](https://github.com/mahaoyuHKU/pytorch-boat/tree/main/BOAT-Swin) is based on [Swin Transformer](https://github.com/microsoft/Swin-Transformer) 7 | 8 | The [CSwin variant](https://github.com/mahaoyuHKU/pytorch-boat/tree/main/BOAT-CSwin) is based on [CSwin Tranformer](https://github.com/microsoft/CSWin-Transformer) 9 | 10 | Please check corresponding folders for more installation, training and evaluation instructions. 11 | 12 | # Pre-trained models 13 | 14 | [BOAT-Swin-Tiny](https://www.dropbox.com/s/xa94uewsrvjglnn/tiny.pth?dl=0) 15 | 16 | [BOAT-Swin-Small](https://www.dropbox.com/s/7ih1zvii3bvdcgd/small.pth?dl=0) 17 | 18 | [BOAT-Swin-Base](https://www.dropbox.com/s/70hr7h0smcr0gr9/base.pth?dl=0) 19 | 20 | [BOAT-CSwin-Tiny](https://www.dropbox.com/s/rsmtu6r0v2lt0y5/cswin_tiny.pth.tar?dl=0) 21 | 22 | [BOAT-CSwin-Small](https://www.dropbox.com/s/cnl00d1faxxoi19/cswin_small.pth.tar?dl=0) 23 | 24 | [BOAT-CSwin-Base](https://www.dropbox.com/s/92sr8r8zhng1mqg/cswin_base.pth.tar?dl=0) 25 | 26 | ## Acknowledgement 27 | This is developped based on CSWin Transformer and Swin-transformer 28 | 29 | 30 | # If you use this code for your research, please consider citing: 31 | 32 | ```bash 33 | @article{BOAT, 34 | author = {Tan Yu and Gangming Zhao and Ping Li and Yizhou Yu}, 35 | title = {{BOAT:} Bilateral Local Attention Vision Transformer}, 36 | journal = {CoRR}, 37 | volume = {abs/2201.13027}, 38 | year = {2022}, 39 | url = {https://arxiv.org/abs/2201.13027}, 40 | } 41 | ``` 42 | --------------------------------------------------------------------------------