├── .gitignore ├── LICENSE ├── README.md ├── Retrieval.py ├── configs ├── PS_cuhk_pedes.yaml ├── PS_icfg_pedes.yaml ├── PS_rstp_reid.yaml └── config_bert.json ├── data_process.py ├── dataset ├── __init__.py ├── ps_dataset.py └── utils.py ├── images └── architecture.jpg ├── models ├── __init__.py ├── model_person_search.py ├── tokenization_bert.py ├── vit.py └── xbert.py ├── optim ├── __init__.py ├── adafactor.py ├── adahessian.py ├── adamp.py ├── adamw.py ├── lookahead.py ├── nadam.py ├── novograd.py ├── nvnovograd.py ├── optim_factory.py ├── radam.py ├── rmsprop_tf.py └── sgdp.py ├── scheduler ├── __init__.py ├── cosine_lr.py ├── plateau_lr.py ├── scheduler.py ├── scheduler_factory.py ├── step_lr.py └── tanh_lr.py ├── shell ├── cuhk-eval.sh ├── cuhk-train.sh ├── data_process.sh ├── icfg-eval.sh ├── icfg-train.sh ├── rstp-eval.sh └── rstp-train.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | # IDEA 142 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Flame-Chasers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RaSa: Relation and Sensitivity Aware Representation Learning for Text-based Person Search 2 | [![GitHub](https://img.shields.io/badge/license-MIT-green)](https://github.com/Flame-Chasers/RaSa/blob/main/LICENSE) 3 | 4 | This is the official PyTorch implementation of the paper [RaSa: Relation and Sensitivity Aware Representation Learning for Text-based Person Search (IJCAI 2023)](https://arxiv.org/abs/2305.13653). 5 | This repository supports training and evaluation on three text-based person search benchmarks: CUHK-PEDES, ICFG-PEDES and RSTPReid. 6 | 7 | ![](images/architecture.jpg) 8 | 9 | ## Usage 10 | ### Requirements 11 | - pytorch 1.9.1 12 | - torchvision 0.10.1 13 | - transformers 4.8.1 14 | - timm 0.4.9 15 | 16 | ### Prepare Datasets 17 | 1. Download the CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description), ICFG-PEDES dataset from [here](https://github.com/zifyloo/SSAN) and RSTPReid dataset form [here](https://github.com/NjtechCVLab/RSTPReid-Dataset) 18 | 2. Organize them in `your dataset root dir` folder as follows: 19 | ``` 20 | |-- your dataset root dir/ 21 | | |-- / 22 | | |-- imgs 23 | | |-- cam_a 24 | | |-- cam_b 25 | | |-- ... 26 | | |-- reid_raw.json 27 | | 28 | | |-- / 29 | | |-- imgs 30 | | |-- test 31 | | |-- train 32 | | |-- ICFG-PEDES.json 33 | | 34 | | |-- / 35 | | |-- imgs 36 | | |-- data_captions.json 37 | ``` 38 | 3. Split the raw annotations into train.json, val.json and test.json for training, validation and testing. 39 | ```shell 40 | # 1. CUHK-PEDES 41 | bash shell/data_process.sh 42 | # or 43 | python data_process.py --dataset_name "CUHK-PEDES" --dataset_root_dir [CUHK-PEDES DATASET DIRECTORY] 44 | 45 | # 2. ICFG-PEDES 46 | bash shell/data_process.sh 47 | # or 48 | python data_process.py --dataset_name "ICFG-PEDES" --dataset_root_dir [ICFG-PEDES DATASET DIRECTORY] 49 | 50 | # 3. RSTPReid 51 | bash shell/data_process.sh 52 | # or 53 | python data_process.py --dataset_name "RSTPReid" --dataset_root_dir [RSTPReid DATASET DIRECTORY] 54 | ``` 55 | 4. Organize the datasets as follows: 56 | ``` 57 | |-- your dataset root dir/ 58 | | |-- / 59 | | |-- imgs 60 | | |-- cam_a 61 | | |-- cam_b 62 | | |-- ... 63 | | |-- processed_data 64 | | |-- train.json 65 | | |-- val.json 66 | | |-- test.json 67 | | |-- reid_raw.json 68 | | 69 | | |-- / 70 | | |-- imgs 71 | | |-- test 72 | | |-- train 73 | | |-- processed_data 74 | | |-- train.json 75 | | |-- val.json 76 | | |-- test.json 77 | | |-- ICFG-PEDES.json 78 | | 79 | | |-- / 80 | | |-- imgs 81 | | |-- processed_data 82 | | |-- train.json 83 | | |-- val.json 84 | | |-- test.json 85 | | |-- data_captions.json 86 | ``` 87 | 88 | ### Pretrained Checkpoint 89 | - Please download the [pretrained ALBEF Checkpoint](https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth). 90 | 91 | ### Training 92 | ```shell 93 | # Usage: 94 | # 1. Training on CUHK-PEDES 95 | bash shell/cuhk-train.sh 96 | # or 97 | python -m torch.distributed.run --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29501 \ 98 | Retrieval.py \ 99 | --config configs/PS_cuhk_pedes.yaml \ 100 | --output_dir output/cuhk-pedes/train \ 101 | --checkpoint [PRETRAINED ALBEF CHECKPOINT PATH] \ 102 | --eval_mAP 103 | 104 | # 2. Training on ICFG-PEDES 105 | bash shell/icfg-train.sh 106 | # or 107 | python -m torch.distributed.run --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29501 \ 108 | Retrieval.py \ 109 | --config configs/PS_icfg_pedes.yaml \ 110 | --output_dir output/icfg-pedes/train \ 111 | --checkpoint [PRETRAINED ALBEF CHECKPOINT PATH] \ 112 | --eval_mAP 113 | 114 | # 3. Training on RSTPReid 115 | bash shell/rstp-train.sh 116 | # or 117 | python -m torch.distributed.run --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29501 \ 118 | Retrieval.py \ 119 | --config configs/PS_rstp_reid.yaml \ 120 | --output_dir output/rstp-reid/train \ 121 | --checkpoint [PRETRAINED ALBEF CHECKPOINT FILE PATH] \ 122 | --eval_mAP 123 | ``` 124 | 125 | ### Testing 126 | ```shell 127 | # Usage: 128 | # 1. Testing on CUHK-PEDES 129 | bash shell/cuhk-eval.sh 130 | # or 131 | python -m torch.distributed.launch --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29501 \ 132 | Retrieval.py \ 133 | --config configs/PS_cuhk_pedes.yaml \ 134 | --output_dir output/cuhk-pedes/evaluation \ 135 | --checkpoint [CHECKPOINT FILE PATH] \ 136 | --eval_mAP \ 137 | --evaluate 138 | 139 | # 2. Testing on ICFG-PEDES 140 | bash shell/icfg-eval.sh 141 | # or 142 | python -m torch.distributed.launch --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29501 \ 143 | Retrieval.py \ 144 | --config configs/PS_icfg_pedes.yaml \ 145 | --output_dir output/icfg-pedes/evaluation \ 146 | --checkpoint [CHECKPOINT FILE PATH] \ 147 | --eval_mAP \ 148 | --evaluate 149 | 150 | # 3. Testing on RSTPReid 151 | bash shell/rstp-eval.sh 152 | # or 153 | python -m torch.distributed.run --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29501 \ 154 | Retrieval.py \ 155 | --config configs/PS_rstp_reid.yaml \ 156 | --output_dir output/rstp-reid/evaluation/ \ 157 | --checkpoint [CHECKPOINT FILE PATH] \ 158 | --eval_mAP \ 159 | --evaluate 160 | ``` 161 | 162 | ## RaSa Performance on Three Text-based Person Search Benchmarks 163 | ### CUHK-PEDES dataset 164 | 165 | | Method | Rank-1 | Rank-5 | Rank-10 | mAP | 166 | |:---------------:|:---------:|:---------:|:---------:|:---------:| 167 | | CMPM/C | 49.37 | 71.69 | 79.27 | - | 168 | | ViTAA | 55.97 | 75.84 | 83.52 | - | 169 | | DSSL | 59.98 | 80.41 | 87.56 | - | 170 | | SAF | 64.13 | 82.62 | 88.40 | 58.61 | 171 | | LGUR | 65.25 | 83.12 | 89.00 | - | 172 | | IVT | 65.59 | 83.11 | 89.21 | - | 173 | | CFine | 69.57 | 85.93 | 91.15 | - | 174 | | **ALBEF** | 60.28 | 79.52 | 86.34 | 56.67 | 175 | | **RaSa (ours)** | **76.51** | **90.29** | **94.25** | **69.38** | 176 | 177 | [Model for CUHK-PEDES](https://drive.google.com/file/d/1BC1L-5JuIXHt6NR_l2ENHG3NZhncU91s/view?usp=sharing) 178 | 179 | ### ICFG-PEDES dataset 180 | 181 | | Method | Rank-1 | Rank-5 | Rank-10 | mAP | 182 | |:---------------:| :-------: | :-------: | :-------: | :-------: | 183 | | CMPM/C | 43.51 | 65.44 | 74.26 | - | 184 | | SSAN | 54.23 | 72.63 | 79.53 | - | 185 | | SAF | 54.86 | 72.13 | 79.13 | 32.76 | 186 | | IVT | 56.04 | 73.60 | 80.22 | - | 187 | | CFine | 60.83 | 76.55 | 82.42 | - | 188 | | **ALBEF** | 34.46 | 52.32 | 60.40 | 19.62 | 189 | | **RaSa (ours)** | **65.28** | **80.40** | **85.12** | **41.29** | 190 | 191 | [Model for ICFG-PEDES](https://drive.google.com/file/d/1lLB332ANq87v2jV2bLdsV7rQ7OBAnP9X/view?usp=sharing) 192 | 193 | ### RSTPReid dataset 194 | 195 | | Method | Rank-1 | Rank-5 | Rank-10 | mAP | 196 | |:---------------:| :-------: | :-------: | :-------: | :-------: | 197 | | DSSL | 32.43 | 55.08 | 63.19 | - | 198 | | SSAN | 43.50 | 67.80 | 77.15 | - | 199 | | SAF | 44.05 | 67.30 | 76.25 | 36.81 | 200 | | IVT | 46.70 | 70.00 | 78.80 | - | 201 | | CFine | 50.55 | 72.50 | 81.60 | - | 202 | | **ALBEF** | 50.10 | 73.70 | 82.10 | 41.73 | 203 | | **RaSa (ours)** | **66.90** | **86.50** | **91.35** | **52.31** | 204 | 205 | [Model for RSTPReid](https://drive.google.com/file/d/1e5KPmfoij22J2zZOZxhSodU4SNz76BjX/view?usp=sharing) 206 | 207 | 208 | ## Acknowledgments 209 | The implementation of RaSa relies on resources from [ALBEF](https://github.com/salesforce/ALBEF), [Huggingface Transformers](https://github.com/huggingface/transformers), and [timm](https://github.com/rwightman/pytorch-image-models/tree/master/timm). We sincerely appreciate the original authors for their open-sourcing. 210 | 211 | 212 | ## Citation 213 | If you find this code useful for your research, please cite our paper. 214 | 215 | ```tex 216 | @article{bai2023rasa, 217 | title={RaSa: Relation and Sensitivity Aware Representation Learning for Text-based Person Search}, 218 | author={Bai, Yang and Cao, Min and Gao, Daming and Cao, Ziqiang and Chen, Chen and Fan, Zhenfeng and Nie, Liqiang and Zhang, Min}, 219 | journal={arXiv preprint arXiv:2305.13653}, 220 | year={2023} 221 | } 222 | ``` 223 | -------------------------------------------------------------------------------- /Retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import random 6 | import time 7 | import numpy as np 8 | import ruamel_yaml as yaml 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.nn.functional as F 13 | from pathlib import Path 14 | 15 | import utils 16 | from dataset import create_dataset, create_sampler, create_loader 17 | from models.model_person_search import ALBEF 18 | from models.tokenization_bert import BertTokenizer 19 | from models.vit import interpolate_pos_embed 20 | from optim import create_optimizer 21 | from scheduler import create_scheduler 22 | 23 | def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config): 24 | # train 25 | model.train() 26 | metric_logger = utils.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | metric_logger.add_meter('loss_cl', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 29 | metric_logger.add_meter('loss_pitm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 30 | metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 31 | metric_logger.add_meter('loss_prd', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 32 | metric_logger.add_meter('loss_mrtd', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 33 | header = 'Train Epoch: [{}]'.format(epoch) 34 | print_freq = 50 35 | step_size = 100 36 | warmup_iterations = warmup_steps * step_size 37 | for i, (image1, image2, text1, text2, idx, replace) in enumerate( 38 | metric_logger.log_every(data_loader, print_freq, header)): 39 | image1 = image1.to(device, non_blocking=True) 40 | image2 = image2.to(device, non_blocking=True) 41 | idx = idx.to(device, non_blocking=True) 42 | replace = replace.to(device, non_blocking=True) 43 | text_input1 = tokenizer(text1, padding='longest', max_length=config['max_words'], return_tensors="pt").to(device) 44 | text_input2 = tokenizer(text2, padding='longest', max_length=config['max_words'], return_tensors="pt").to(device) 45 | if epoch > 0 or not config['warm_up']: 46 | alpha = config['alpha'] 47 | else: 48 | alpha = config['alpha'] * min(1.0, i / len(data_loader)) 49 | loss_cl, loss_pitm, loss_mlm, loss_prd, loss_mrtd = model(image1, image2, text_input1, text_input2, 50 | alpha=alpha, idx=idx, replace=replace) 51 | loss = 0. 52 | for j, los in enumerate((loss_cl, loss_pitm, loss_mlm, loss_prd, loss_mrtd)): 53 | loss += config['weights'][j] * los 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | metric_logger.update(loss_cl=loss_cl.item()) 58 | metric_logger.update(loss_pitm=loss_pitm.item()) 59 | metric_logger.update(loss_mlm=loss_mlm.item()) 60 | metric_logger.update(loss_prd=loss_prd.item()) 61 | metric_logger.update(loss_mrtd=loss_mrtd.item()) 62 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 63 | if epoch == 0 and i % step_size == 0 and i <= warmup_iterations: 64 | scheduler.step(i // step_size) 65 | # gather the stats from all processes 66 | metric_logger.synchronize_between_processes() 67 | print("Averaged stats:", metric_logger.global_avg()) 68 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 69 | 70 | @torch.no_grad() 71 | def evaluation(model, data_loader, tokenizer, device, config): 72 | # evaluate 73 | model.eval() 74 | metric_logger = utils.MetricLogger(delimiter=" ") 75 | header = 'Evaluation:' 76 | print('Computing features for evaluation...') 77 | start_time = time.time() 78 | # extract text features 79 | texts = data_loader.dataset.text 80 | num_text = len(texts) 81 | text_bs = 256 82 | text_feats = [] 83 | text_embeds = [] 84 | text_atts = [] 85 | for i in range(0, num_text, text_bs): 86 | text = texts[i: min(num_text, i + text_bs)] 87 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_words'], return_tensors="pt").to(device) 88 | text_output = model.text_encoder.bert(text_input.input_ids, attention_mask=text_input.attention_mask, mode='text') 89 | text_feat = text_output.last_hidden_state 90 | text_embed = F.normalize(model.text_proj(text_feat[:, 0, :])) 91 | text_embeds.append(text_embed) 92 | text_feats.append(text_feat) 93 | text_atts.append(text_input.attention_mask) 94 | text_embeds = torch.cat(text_embeds, dim=0) 95 | text_feats = torch.cat(text_feats, dim=0) 96 | text_atts = torch.cat(text_atts, dim=0) 97 | # extract image features 98 | image_feats = [] 99 | image_embeds = [] 100 | for image, img_id in data_loader: 101 | image = image.to(device) 102 | image_feat = model.visual_encoder(image) 103 | image_embed = model.vision_proj(image_feat[:, 0, :]) 104 | image_embed = F.normalize(image_embed, dim=-1) 105 | image_feats.append(image_feat.cpu()) 106 | image_embeds.append(image_embed) 107 | image_feats = torch.cat(image_feats, dim=0) 108 | image_embeds = torch.cat(image_embeds, dim=0) 109 | # compute the feature similarity score for all image-text pairs 110 | sims_matrix = text_embeds @ image_embeds.t() 111 | score_matrix_t2i = torch.full((len(texts), len(data_loader.dataset.image)), -100.0).to(device) 112 | # take the top-k candidates and calculate their ITM score sitm for ranking 113 | num_tasks = utils.get_world_size() 114 | rank = utils.get_rank() 115 | step = sims_matrix.size(0) // num_tasks + 1 116 | start = rank * step 117 | end = min(sims_matrix.size(0), start + step) 118 | for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 119 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 120 | encoder_output = image_feats[topk_idx] 121 | encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device) 122 | output = model.text_encoder.bert(encoder_embeds=text_feats[start + i].repeat(config['k_test'], 1, 1), 123 | attention_mask=text_atts[start + i].repeat(config['k_test'], 1), 124 | encoder_hidden_states=encoder_output.to(device), 125 | encoder_attention_mask=encoder_att, 126 | return_dict=True, 127 | mode='fusion' 128 | ) 129 | score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1] 130 | score_matrix_t2i[start + i, topk_idx] = score 131 | if args.distributed: 132 | dist.barrier() 133 | torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM) 134 | total_time = time.time() - start_time 135 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 136 | print('Evaluation time {}'.format(total_time_str)) 137 | return score_matrix_t2i.cpu() 138 | 139 | @torch.no_grad() 140 | def itm_eval(scores_t2i, img2person, txt2person, eval_mAP): 141 | img2person = torch.tensor(img2person) 142 | txt2person = torch.tensor(txt2person) 143 | index = torch.argsort(scores_t2i, dim=-1, descending=True) 144 | pred_person = img2person[index] 145 | matches = (txt2person.view(-1, 1).eq(pred_person)).long() 146 | 147 | def acc_k(matches, k=1): 148 | matches_k = matches[:, :k].sum(dim=-1) 149 | matches_k = torch.sum((matches_k > 0)) 150 | return 100.0 * matches_k / matches.size(0) 151 | 152 | # Compute metrics 153 | ir1 = acc_k(matches, k=1).item() 154 | ir5 = acc_k(matches, k=5).item() 155 | ir10 = acc_k(matches, k=10).item() 156 | ir_mean = (ir1 + ir5 + ir10) / 3 157 | 158 | if eval_mAP: 159 | real_num = matches.sum(dim=-1) 160 | tmp_cmc = matches.cumsum(dim=-1).float() 161 | order = torch.arange(start=1, end=matches.size(1) + 1, dtype=torch.long) 162 | tmp_cmc /= order 163 | tmp_cmc *= matches 164 | AP = tmp_cmc.sum(dim=-1) / real_num 165 | mAP = AP.mean() * 100.0 166 | eval_result = {'r1': ir1, 167 | 'r5': ir5, 168 | 'r10': ir10, 169 | 'r_mean': ir_mean, 170 | 'mAP': mAP.item() 171 | } 172 | else: 173 | eval_result = {'r1': ir1, 174 | 'r5': ir5, 175 | 'r10': ir10, 176 | 'r_mean': ir_mean, 177 | } 178 | return eval_result 179 | 180 | def main(args, config): 181 | utils.init_distributed_mode(args) 182 | device = torch.device(args.device) 183 | print(args) 184 | print(config) 185 | # fix the seed for reproducibility 186 | seed = args.seed + utils.get_rank() 187 | torch.manual_seed(seed) 188 | torch.cuda.manual_seed(seed) 189 | np.random.seed(seed) 190 | random.seed(seed) 191 | cudnn.deterministic = True 192 | cudnn.benchmark = True 193 | # Dataset 194 | print("Creating retrieval dataset") 195 | train_dataset, val_dataset, test_dataset = create_dataset('ps', config) 196 | if args.distributed: 197 | num_tasks = utils.get_world_size() 198 | global_rank = utils.get_rank() 199 | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None] 200 | else: 201 | samplers = [None, None, None] 202 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset], samplers, 203 | batch_size=[config['batch_size_train']] + [ 204 | config['batch_size_test']] * 2, 205 | num_workers=[4, 4, 4], 206 | is_trains=[True, False, False], 207 | collate_fns=[None, None, None]) 208 | tokenizer = BertTokenizer.from_pretrained(args.text_encoder) 209 | 210 | start_epoch = 0 211 | max_epoch = config['schedular']['epochs'] 212 | warmup_steps = config['schedular']['warmup_epochs'] 213 | best = 0 214 | best_epoch = 0 215 | best_log = '' 216 | 217 | # Model 218 | print("Creating model") 219 | model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer) 220 | model = model.to(device) 221 | # Optimizer and learning rate scheduler 222 | arg_opt = utils.AttrDict(config['optimizer']) 223 | optimizer = create_optimizer(arg_opt, model) 224 | arg_sche = utils.AttrDict(config['schedular']) 225 | lr_scheduler, _ = create_scheduler(arg_sche, optimizer) 226 | 227 | if args.checkpoint: 228 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 229 | state_dict = checkpoint['model'] 230 | if args.resume: 231 | optimizer.load_state_dict(checkpoint['optimizer']) 232 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 233 | start_epoch = checkpoint['epoch'] + 1 234 | best = checkpoint['best'] 235 | best_epoch = checkpoint['best_epoch'] 236 | else: 237 | # reshape positional embedding to accomodate for image resolution change 238 | pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'], model.visual_encoder) 239 | state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped 240 | m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 241 | model.visual_encoder_m) 242 | state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 243 | msg = model.load_state_dict(state_dict, strict=False) 244 | print('load checkpoint from %s' % args.checkpoint) 245 | print(msg) 246 | 247 | model_without_ddp = model 248 | if args.distributed: 249 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 250 | model_without_ddp = model.module 251 | 252 | print("Start training") 253 | start_time = time.time() 254 | for epoch in range(start_epoch, max_epoch): 255 | if not args.evaluate: 256 | if epoch > 0: 257 | lr_scheduler.step(epoch + warmup_steps) 258 | if args.distributed: 259 | train_loader.sampler.set_epoch(epoch) 260 | train_stats = train(model, train_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, 261 | config) 262 | if epoch >= config['eval_epoch'] or args.evaluate: 263 | score_test_t2i = evaluation(model_without_ddp, test_loader, tokenizer, device, config) 264 | if utils.is_main_process(): 265 | test_result = itm_eval(score_test_t2i, test_dataset.img2person, test_dataset.txt2person, args.eval_mAP) 266 | print('Test:', test_result, '\n') 267 | if args.evaluate: 268 | log_stats = {'epoch': epoch, 269 | **{f'test_{k}': v for k, v in test_result.items()} 270 | } 271 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 272 | f.write(json.dumps(log_stats) + "\n") 273 | else: 274 | log_stats = {'epoch': epoch, 275 | **{f'train_{k}': v for k, v in train_stats.items()}, 276 | **{f'test_{k}': v for k, v in test_result.items()}, 277 | } 278 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 279 | f.write(json.dumps(log_stats) + "\n") 280 | save_obj = { 281 | 'model': model_without_ddp.state_dict(), 282 | 'optimizer': optimizer.state_dict(), 283 | 'lr_scheduler': lr_scheduler.state_dict(), 284 | 'config': config, 285 | 'epoch': epoch, 286 | 'best': best, 287 | 'best_epoch': best_epoch 288 | } 289 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_epoch%02d.pth' % epoch)) 290 | if test_result['r1'] > best: 291 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 292 | best = test_result['r1'] 293 | best_epoch = epoch 294 | best_log = log_stats 295 | if args.evaluate: 296 | break 297 | dist.barrier() 298 | torch.cuda.empty_cache() 299 | total_time = time.time() - start_time 300 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 301 | print('Training time {}'.format(total_time_str)) 302 | if utils.is_main_process(): 303 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 304 | f.write(f"best epoch: {best_epoch} / {max_epoch}\n") 305 | f.write(f"{best_log}\n\n") 306 | 307 | if __name__ == '__main__': 308 | parser = argparse.ArgumentParser() 309 | parser.add_argument('--config', default='./configs/PS_cuhk_pedes.yaml') 310 | parser.add_argument('--output_dir', default='output/cuhk-pedes') 311 | parser.add_argument('--checkpoint', default='') 312 | parser.add_argument('--resume', action='store_true') 313 | parser.add_argument('--eval_mAP', action='store_true', help='whether to evaluate mAP') 314 | parser.add_argument('--text_encoder', default='bert-base-uncased') 315 | parser.add_argument('--evaluate', action='store_true') 316 | parser.add_argument('--device', default='cuda') 317 | parser.add_argument('--seed', default=42, type=int) 318 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 319 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 320 | parser.add_argument('--distributed', default=True, type=bool) 321 | args = parser.parse_args() 322 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 323 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 324 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 325 | main(args, config) 326 | -------------------------------------------------------------------------------- /configs/PS_cuhk_pedes.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['../dataset/CUHK-PEDES/processed_data/train.json'] 2 | val_file: '../dataset/CUHK-PEDES/processed_data/val.json' 3 | test_file: '../dataset/CUHK-PEDES/processed_data/test.json' 4 | train_image_root: '../dataset/CUHK-PEDES/imgs' 5 | val_image_root: '../dataset/CUHK-PEDES/imgs' 6 | test_image_root: '../dataset/CUHK-PEDES/imgs' 7 | 8 | bert_config: 'configs/config_bert.json' 9 | 10 | max_words: 50 11 | image_res: 384 12 | batch_size_train: 13 13 | batch_size_test: 64 14 | 15 | mlm_probability: 0.15 16 | weak_pos_pair_probability: 0.1 17 | mrtd_mask_probability: 0.3 18 | queue_size: 65536 19 | momentum: 0.995 20 | vision_width: 768 21 | embed_dim: 256 22 | temp: 0.07 23 | k_test: 128 24 | 25 | alpha: 0.4 26 | warm_up: True 27 | 28 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02, lr_custm: 1e-4} 29 | schedular: {sched: cosine, lr: 1e-5, epochs: 30, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 30 | 31 | eval_epoch: 29 32 | 33 | weights: 34 | - 0.5 35 | - 1 36 | - 1 37 | - 0.5 38 | - 0.5 39 | -------------------------------------------------------------------------------- /configs/PS_icfg_pedes.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['../dataset/ICFG-PEDES/processed_data/train.json'] 2 | val_file: '../dataset/ICFG-PEDES/processed_data/val.json' 3 | test_file: '../dataset/ICFG-PEDES/processed_data/test.json' 4 | train_image_root: '../dataset/ICFG-PEDES/imgs' 5 | val_image_root: '../dataset/ICFG-PEDES/imgs' 6 | test_image_root: '../dataset/ICFG-PEDES/imgs' 7 | 8 | bert_config: 'configs/config_bert.json' 9 | 10 | max_words: 50 11 | image_res: 384 12 | batch_size_train: 13 13 | batch_size_test: 64 14 | 15 | mlm_probability: 0.15 16 | weak_pos_pair_probability: 0.1 17 | mrtd_mask_probability: 0.3 18 | 19 | queue_size: 65536 20 | momentum: 0.995 21 | vision_width: 768 22 | embed_dim: 256 23 | temp: 0.07 24 | k_test: 128 25 | 26 | alpha: 0.4 27 | warm_up: True 28 | 29 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02, lr_custm: 1e-4} 30 | schedular: {sched: cosine, lr: 1e-5, epochs: 30, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 31 | 32 | eval_epoch: 10 33 | 34 | weights: 35 | - 0.5 36 | - 1 37 | - 1 38 | - 0.5 39 | - 0.5 -------------------------------------------------------------------------------- /configs/PS_rstp_reid.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['../dataset/RSTPReid/processed_data/train.json'] 2 | val_file: '../dataset/RSTPReid/processed_data/val.json' 3 | test_file: '../dataset/RSTPReid/processed_data/test.json' 4 | train_image_root: '../dataset/RSTPReid/imgs' 5 | val_image_root: '../dataset/RSTPReid/imgs' 6 | test_image_root: '../dataset/RSTPReid/imgs' 7 | 8 | bert_config: 'configs/config_bert.json' 9 | 10 | max_words: 50 11 | image_res: 384 12 | batch_size_train: 13 13 | batch_size_test: 64 14 | 15 | mlm_probability: 0.15 16 | weak_pos_pair_probability: 0.1 17 | mrtd_mask_probability: 0.3 18 | 19 | queue_size: 65536 20 | momentum: 0.995 21 | vision_width: 768 22 | embed_dim: 256 23 | temp: 0.07 24 | k_test: 128 25 | 26 | alpha: 0.4 27 | warm_up: True 28 | 29 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02, lr_custm: 1e-4} 30 | schedular: {sched: cosine, lr: 1e-5, epochs: 30, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 31 | 32 | eval_epoch: 0 33 | 34 | weights: 35 | - 0.5 36 | - 1 37 | - 1 38 | - 0.5 39 | - 0.5 40 | -------------------------------------------------------------------------------- /configs/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 6, 20 | "encoder_width": 768 21 | } 22 | -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import errno 4 | import argparse 5 | 6 | def mkdir_if_missing(directory): 7 | if not os.path.exists(directory): 8 | try: 9 | os.makedirs(directory) 10 | except OSError as e: 11 | if e.errno != errno.EEXIST: 12 | raise 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--dataset_name', default='CUHK-PEDES', type=str) 17 | parser.add_argument('--dataset_root_dir', default='./CUHK-PEDES', type=str) 18 | args = parser.parse_args() 19 | raw_annotation_file_name = "" 20 | if args.dataset_name == "CUHK-PEDES": 21 | raw_annotation_file_name = "reid_raw.json" 22 | elif args.dataset_name == "ICFG-PEDES": 23 | raw_annotation_file_name = "ICFG-PEDES.json" 24 | elif args.dataset_name == "RSTPReid": 25 | raw_annotation_file_name = "data_captions.json" 26 | raw_annotation_file_path = os.path.join(args.dataset_root_dir, raw_annotation_file_name) 27 | # split raw annotations into training, validation and test dataset 28 | anns = json.load(open(raw_annotation_file_path, "r")) 29 | train = [] 30 | val = [] 31 | test = [] 32 | for ann in anns: 33 | if args.dataset_name == "RSTPReid": 34 | ann['file_path'] = ann.pop('img_path') 35 | eval(ann['split']).append(ann) 36 | output_dir = os.path.join(args.dataset_root_dir, "processed_data") 37 | mkdir_if_missing(output_dir) 38 | json.dump(train, open(os.path.join(output_dir, "train.json"), 'w')) 39 | json.dump(val, open(os.path.join(output_dir, "val.json"), 'w')) 40 | json.dump(test, open(os.path.join(output_dir, "test.json"), 'w')) 41 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms import InterpolationMode 5 | from dataset.ps_dataset import ps_train_dataset, ps_eval_dataset 6 | 7 | def create_dataset(dataset, config): 8 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 9 | train_transform = transforms.Compose([ 10 | transforms.Resize((config['image_res'], config['image_res']), interpolation=InterpolationMode.BICUBIC), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | normalize, 14 | ]) 15 | test_transform = transforms.Compose([ 16 | transforms.Resize((config['image_res'], config['image_res']), interpolation=InterpolationMode.BICUBIC), 17 | transforms.ToTensor(), 18 | normalize, 19 | ]) 20 | if dataset == 'ps': 21 | train_dataset = ps_train_dataset(config['train_file'], train_transform, config['train_image_root'], 22 | config['max_words'], config['weak_pos_pair_probability']) 23 | val_dataset = ps_eval_dataset(config['val_file'], test_transform, config['val_image_root'], config['max_words']) 24 | test_dataset = ps_eval_dataset(config['test_file'], test_transform, config['test_image_root'], config['max_words']) 25 | return train_dataset, val_dataset, test_dataset 26 | 27 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 28 | samplers = [] 29 | for dataset, shuffle in zip(datasets, shuffles): 30 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, 31 | shuffle=shuffle) 32 | samplers.append(sampler) 33 | return samplers 34 | 35 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 36 | loaders = [] 37 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(datasets, samplers, batch_size, num_workers, 38 | is_trains, collate_fns): 39 | if is_train: 40 | shuffle = (sampler is None) 41 | drop_last = True 42 | else: 43 | shuffle = False 44 | drop_last = False 45 | loader = DataLoader( 46 | dataset, 47 | batch_size=bs, 48 | num_workers=n_worker, 49 | pin_memory=True, 50 | sampler=sampler, 51 | shuffle=shuffle, 52 | collate_fn=collate_fn, 53 | drop_last=drop_last, 54 | ) 55 | loaders.append(loader) 56 | return loaders 57 | -------------------------------------------------------------------------------- /dataset/ps_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from PIL import ImageFile 6 | from torch.utils.data import Dataset 7 | from collections import defaultdict 8 | from dataset.utils import pre_caption 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | Image.MAX_IMAGE_PIXELS = None 12 | 13 | class ps_train_dataset(Dataset): 14 | def __init__(self, ann_file, transform, image_root, max_words=30, weak_pos_pair_probability=0.1): 15 | anns = [] 16 | for f in ann_file: 17 | anns += json.load(open(f, 'r')) 18 | self.transform = transform 19 | self.image_root = image_root 20 | self.max_words = max_words 21 | self.weak_pos_pair_probability = weak_pos_pair_probability # 待修改 22 | self.person2image = defaultdict(list) 23 | self.person2text = defaultdict(list) 24 | person_id2idx = {} 25 | n = 0 26 | self.pairs = [] 27 | for ann in anns: 28 | person_id = ann['id'] 29 | if person_id not in person_id2idx.keys(): 30 | person_id2idx[person_id] = n 31 | n += 1 32 | person_idx = person_id2idx[person_id] 33 | self.person2image[person_idx].append(ann['file_path']) 34 | for cap in ann['captions']: 35 | self.pairs.append((ann['file_path'], cap, person_idx)) 36 | self.person2text[person_idx].append(cap) 37 | 38 | def __len__(self): 39 | return len(self.pairs) 40 | 41 | def augment(self, caption, person): 42 | caption_aug = caption 43 | if np.random.random() < self.weak_pos_pair_probability: 44 | caption_aug = np.random.choice(self.person2text[person], 1).item() 45 | if caption_aug == caption: 46 | replace = 0 47 | else: 48 | replace = 1 49 | return caption_aug, replace 50 | 51 | def __getitem__(self, index): 52 | image_path, caption, person = self.pairs[index] 53 | caption_aug, replace = self.augment(caption, person) 54 | image_path = os.path.join(self.image_root, image_path) 55 | image = Image.open(image_path).convert('RGB') 56 | image1 = self.transform(image) 57 | image2 = self.transform(image) 58 | caption1 = pre_caption(caption, self.max_words) 59 | caption2 = pre_caption(caption_aug, self.max_words) 60 | return image1, image2, caption1, caption2, person, replace 61 | 62 | class ps_eval_dataset(Dataset): 63 | def __init__(self, ann_file, transform, image_root, max_words=30): 64 | self.ann = json.load(open(ann_file, 'r')) 65 | self.transform = transform 66 | self.image_root = image_root 67 | self.max_words = max_words 68 | self.text = [] 69 | self.image = [] 70 | self.txt2person = [] 71 | self.img2person = [] 72 | person2img = defaultdict(list) 73 | person2txt = defaultdict(list) 74 | txt_id = 0 75 | for img_id, ann in enumerate(self.ann): 76 | self.image.append(ann['file_path']) 77 | person_id = ann['id'] 78 | person2img[person_id].append(img_id) 79 | self.img2person.append(person_id) 80 | for caption in ann['captions']: 81 | self.text.append(pre_caption(caption, self.max_words)) 82 | person2txt[person_id].append(txt_id) 83 | self.txt2person.append(person_id) 84 | txt_id += 1 85 | 86 | def __len__(self): 87 | return len(self.image) 88 | 89 | def __getitem__(self, index): 90 | image_path = os.path.join(self.image_root, self.ann[index]['file_path']) 91 | image = Image.open(image_path).convert('RGB') 92 | image = self.transform(image) 93 | return image, index 94 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def pre_caption(caption, max_words): 4 | caption = re.sub( 5 | r"([,.'!?\"()*#:;~])", 6 | '', 7 | caption.lower(), 8 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 9 | caption = re.sub( 10 | r"\s{2,}", 11 | ' ', 12 | caption, 13 | ) 14 | caption = caption.rstrip('\n') 15 | caption = caption.strip(' ') 16 | # truncate caption 17 | caption_words = caption.split(' ') 18 | if len(caption_words)>max_words: 19 | caption = ' '.join(caption_words[:max_words]) 20 | return caption 21 | -------------------------------------------------------------------------------- /images/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Flame-Chasers/RaSa/bd16aa1a15f149548a90d196fcf10a27d7ab7c66/images/architecture.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/model_person_search.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from models.vit import VisionTransformer 6 | from models.xbert import BertConfig, BertForMaskedLM 7 | 8 | class ALBEF(nn.Module): 9 | def __init__(self, 10 | text_encoder=None, 11 | tokenizer=None, 12 | config=None, 13 | ): 14 | super().__init__() 15 | 16 | self.tokenizer = tokenizer 17 | embed_dim = config['embed_dim'] 18 | vision_width = config['vision_width'] 19 | self.visual_encoder = VisionTransformer( 20 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 21 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), ) 22 | bert_config = BertConfig.from_json_file(config['bert_config']) 23 | self.text_encoder = BertForMaskedLM.from_pretrained(text_encoder, config=bert_config) 24 | self.text_width = self.text_encoder.config.hidden_size 25 | self.vision_proj = nn.Linear(vision_width, embed_dim) 26 | self.text_proj = nn.Linear(self.text_width, embed_dim) 27 | self.temp = nn.Parameter(torch.ones([]) * config['temp']) 28 | self.mlm_probability = config['mlm_probability'] 29 | self.mrtd_mask_probability = config['mrtd_mask_probability'] 30 | self.queue_size = config['queue_size'] 31 | self.momentum = config['momentum'] 32 | self.itm_head = nn.Linear(self.text_width, 2) 33 | self.prd_head = nn.Linear(self.text_width, 2) 34 | self.mrtd_head = nn.Linear(self.text_width, 2) 35 | # create momentum models 36 | self.visual_encoder_m = VisionTransformer( 37 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 38 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), ) 39 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 40 | self.text_encoder_m = BertForMaskedLM.from_pretrained(text_encoder, config=bert_config) 41 | self.text_proj_m = nn.Linear(self.text_width, embed_dim) 42 | self.model_pairs = [[self.visual_encoder, self.visual_encoder_m], 43 | [self.vision_proj, self.vision_proj_m], 44 | [self.text_encoder, self.text_encoder_m], 45 | [self.text_proj, self.text_proj_m], 46 | ] 47 | self.copy_params() 48 | # create the queue 49 | self.register_buffer("image_queue", torch.randn(embed_dim, self.queue_size)) 50 | self.register_buffer("text_queue", torch.randn(embed_dim, self.queue_size)) 51 | self.register_buffer("idx_queue", torch.full((1, self.queue_size), -100)) 52 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 53 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 54 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 55 | 56 | def forward(self, image1, image2, text1, text2, alpha, idx, replace): 57 | # extract image features 58 | image_embeds = self.visual_encoder(image1) 59 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image1.device) 60 | image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) 61 | # extract text features 62 | text_output = self.text_encoder.bert(text2.input_ids, attention_mask=text2.attention_mask, 63 | return_dict=True, mode='text') 64 | text_embeds = text_output.last_hidden_state 65 | text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) 66 | # Contrastive loss 67 | idx = idx.view(-1, 1) 68 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1) 69 | pos_idx = torch.eq(idx, idx_all).float() 70 | sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) 71 | with torch.no_grad(): 72 | self._momentum_update() 73 | image_embeds_m = self.visual_encoder_m(image2) 74 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1) 75 | image_feat_all = torch.cat([image_feat_m.t(), self.image_queue.clone().detach()], dim=1) 76 | 77 | text_output_m = self.text_encoder_m.bert(text2.input_ids, attention_mask=text2.attention_mask, 78 | return_dict=True, mode='text') 79 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:, 0, :]), dim=-1) 80 | text_feat_all = torch.cat([text_feat_m.t(), self.text_queue.clone().detach()], dim=1) 81 | 82 | sim_i2t_m = image_feat_m @ text_feat_all / self.temp 83 | sim_t2i_m = text_feat_m @ image_feat_all / self.temp 84 | sim_i2i_m = image_feat_m @ image_feat_all / self.temp 85 | sim_t2t_m = text_feat_m @ text_feat_all / self.temp 86 | 87 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 88 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 89 | sim_i2i_targets = alpha * F.softmax(sim_i2i_m, dim=1) + (1 - alpha) * sim_targets 90 | sim_t2t_targets = alpha * F.softmax(sim_t2t_m, dim=1) + (1 - alpha) * sim_targets 91 | 92 | sim_i2t = image_feat @ text_feat_all / self.temp 93 | sim_t2i = text_feat @ image_feat_all / self.temp 94 | sim_i2i = image_feat @ image_feat_all / self.temp 95 | sim_t2t = text_feat @ text_feat_all / self.temp 96 | 97 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean() 98 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean() 99 | loss_i2i = -torch.sum(F.log_softmax(sim_i2i, dim=1) * sim_i2i_targets, dim=1).mean() 100 | loss_t2t = -torch.sum(F.log_softmax(sim_t2t, dim=1) * sim_t2t_targets, dim=1).mean() 101 | loss_cl = (loss_i2t + loss_t2i + loss_i2i + loss_t2t) / 4 102 | 103 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) 104 | 105 | # Relation-aware Learning: Probabilistic Image-Text Matching + Positive Relation Detection 106 | # Probabilistic Image-Text Matching 107 | # forward the positve image-text pairs 108 | output_pos = self.text_encoder.bert(encoder_embeds=text_embeds, 109 | attention_mask=text2.attention_mask, 110 | encoder_hidden_states=image_embeds, 111 | encoder_attention_mask=image_atts, 112 | return_dict=True, 113 | mode='fusion', 114 | ) 115 | with torch.no_grad(): 116 | bs = image1.size(0) 117 | weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) 118 | weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) 119 | mask = torch.eq(idx, idx.T) 120 | weights_i2t.masked_fill_(mask, 0) 121 | weights_t2i.masked_fill_(mask, 0) 122 | # select a negative image for each text 123 | image_neg_idx = torch.multinomial(weights_t2i, 1).flatten() 124 | image_embeds_neg = image_embeds[image_neg_idx] 125 | # select a negative text for each image 126 | text_neg_idx = torch.multinomial(weights_i2t, 1).flatten() 127 | text_embeds_neg = text_embeds[text_neg_idx] 128 | text_atts_neg = text2.attention_mask[text_neg_idx] 129 | # forward the negative image-text pairs 130 | text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) 131 | text_atts_all = torch.cat([text2.attention_mask, text_atts_neg], dim=0) 132 | image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) 133 | image_atts_all = torch.cat([image_atts, image_atts], dim=0) 134 | output_neg_cross = self.text_encoder.bert(encoder_embeds=text_embeds_all, 135 | attention_mask=text_atts_all, 136 | encoder_hidden_states=image_embeds_all, 137 | encoder_attention_mask=image_atts_all, 138 | return_dict=True, 139 | mode='fusion', 140 | ) 141 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:, 0, :], output_neg_cross.last_hidden_state[:, 0, :]], 142 | dim=0) 143 | vl_output = self.itm_head(vl_embeddings) 144 | itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], 145 | dim=0).to(image1.device) 146 | loss_pitm = F.cross_entropy(vl_output, itm_labels) 147 | # Positive Relation Detection 148 | prd_output = self.prd_head(output_pos.last_hidden_state[:, 0, :]) 149 | loss_prd = F.cross_entropy(prd_output, replace) 150 | 151 | # Sensitivity-aware Learning: Masked Language Modeling + Momentum-based Replaced Token Detection 152 | input_ids = text1.input_ids.clone() 153 | labels = input_ids.clone() 154 | mrtd_input_ids = input_ids.clone() 155 | # Masked Language Modeling 156 | probability_matrix = torch.full(labels.shape, self.mlm_probability) 157 | input_ids, labels = self.mask(input_ids, self.text_encoder.config.vocab_size, targets=labels, probability_matrix=probability_matrix) 158 | with torch.no_grad(): 159 | logits_m = self.text_encoder_m(input_ids, 160 | attention_mask=text1.attention_mask, 161 | encoder_hidden_states=image_embeds_m, 162 | encoder_attention_mask=image_atts, 163 | return_dict=True, 164 | return_logits=True, 165 | ) 166 | prediction = F.softmax(logits_m, dim=-1) 167 | mlm_output = self.text_encoder(input_ids, 168 | attention_mask=text1.attention_mask, 169 | encoder_hidden_states=image_embeds, 170 | encoder_attention_mask=image_atts, 171 | return_dict=True, 172 | labels=labels, 173 | soft_labels=prediction, 174 | alpha=alpha 175 | ) 176 | loss_mlm = mlm_output.loss 177 | # Momentum-based Replaced Token Detection 178 | with torch.no_grad(): 179 | probability_matrix = torch.full(labels.shape, self.mrtd_mask_probability) 180 | mrtd_input_ids = self.mask(mrtd_input_ids, self.text_encoder.config.vocab_size, probability_matrix=probability_matrix) 181 | # momentum module is used as generator 182 | mrtd_logits_m = self.text_encoder_m(mrtd_input_ids, 183 | attention_mask=text1.attention_mask, 184 | encoder_hidden_states=image_embeds_m, 185 | encoder_attention_mask=image_atts, 186 | return_dict=True, 187 | return_logits=True, 188 | ) 189 | weights = F.softmax(mrtd_logits_m, dim=-1) 190 | mrtd_input_ids, mrtd_labels = self.mrtd_mask_modeling(mrtd_input_ids, text1.input_ids, text1.attention_mask, weights) 191 | output_mrtd = self.text_encoder.bert(mrtd_input_ids, 192 | attention_mask=text1.attention_mask, 193 | encoder_hidden_states=image_embeds, 194 | encoder_attention_mask=image_atts, 195 | return_dict=True, 196 | ) 197 | mrtd_output = self.mrtd_head(output_mrtd.last_hidden_state.view(-1, self.text_width)) 198 | loss_mrtd = F.cross_entropy(mrtd_output, mrtd_labels.view(-1)) 199 | 200 | return loss_cl, loss_pitm, loss_mlm, loss_prd, loss_mrtd 201 | 202 | @torch.no_grad() 203 | def copy_params(self): 204 | for model_pair in self.model_pairs: 205 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 206 | param_m.data.copy_(param.data) # initialize 207 | param_m.requires_grad = False # not update by gradient 208 | 209 | @torch.no_grad() 210 | def _momentum_update(self): 211 | for model_pair in self.model_pairs: 212 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 213 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 214 | 215 | @torch.no_grad() 216 | def _dequeue_and_enqueue(self, image_feat, text_feat, idx): 217 | # gather keys before updating queue 218 | image_feats = concat_all_gather(image_feat) 219 | text_feats = concat_all_gather(text_feat) 220 | idxs = concat_all_gather(idx) 221 | batch_size = image_feats.shape[0] 222 | ptr = int(self.queue_ptr) 223 | # replace the keys at ptr (dequeue and enqueue) 224 | empty = self.image_queue.size(1) - ptr 225 | if batch_size <= empty: 226 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 227 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 228 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T 229 | else: 230 | self.image_queue[:, ptr:] = image_feats[:empty].T 231 | self.text_queue[:, ptr:] = text_feats[:empty].T 232 | self.idx_queue[:, ptr:] = idxs[:empty].T 233 | self.image_queue[:, :batch_size - empty] = image_feats[empty:].T 234 | self.text_queue[:, :batch_size - empty] = text_feats[empty:].T 235 | self.idx_queue[:, :batch_size - empty] = idxs[empty:].T 236 | ptr = (ptr + batch_size) % self.queue_size # move pointer 237 | self.queue_ptr[0] = ptr 238 | 239 | def mask(self, input_ids, vocab_size, targets=None, masked_indices=None, probability_matrix=None): 240 | if masked_indices is None: 241 | masked_indices = torch.bernoulli(probability_matrix).bool() 242 | masked_indices[input_ids == self.tokenizer.pad_token_id] = False 243 | masked_indices[input_ids == self.tokenizer.cls_token_id] = False 244 | if targets is not None: 245 | targets[~masked_indices] = -100 # We only compute loss on masked tokens 246 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 247 | indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices 248 | input_ids[indices_replaced] = self.tokenizer.mask_token_id 249 | # 10% of the time, we replace masked input tokens with random word 250 | indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced 251 | random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(input_ids.device) 252 | input_ids[indices_random] = random_words[indices_random] 253 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 254 | if targets is not None: 255 | return input_ids, targets 256 | else: 257 | return input_ids 258 | 259 | def mrtd_mask_modeling(self, mrtd_input_ids, ori_input_ids, attention_mask, weights): 260 | bs = mrtd_input_ids.size(0) 261 | weights = weights.view(-1, weights.size(-1)) 262 | pred = torch.multinomial(weights, 1).view(bs, -1) 263 | pred[:, 0] = self.tokenizer.cls_token_id 264 | # pad_token_id is 0 265 | mrtd_input_ids = pred * attention_mask 266 | mrtd_labels = (pred != ori_input_ids) * attention_mask 267 | mrtd_labels[mrtd_input_ids == self.tokenizer.pad_token_id] = -100 268 | mrtd_labels[mrtd_input_ids == self.tokenizer.cls_token_id] = -100 269 | return mrtd_input_ids, mrtd_labels 270 | 271 | @torch.no_grad() 272 | def concat_all_gather(tensor): 273 | """ 274 | Performs all_gather operation on the provided tensors. 275 | *** Warning ***: torch.distributed.all_gather has no gradient. 276 | """ 277 | tensors_gather = [torch.ones_like(tensor) 278 | for _ in range(torch.distributed.get_world_size())] 279 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 280 | 281 | output = torch.cat(tensors_gather, dim=0) 282 | return output 283 | -------------------------------------------------------------------------------- /models/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for Bert.""" 16 | 17 | import collections 18 | import os 19 | import unicodedata 20 | from typing import List, Optional, Tuple 21 | 22 | from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace 23 | from transformers.utils import logging 24 | 25 | logger = logging.get_logger(__name__) 26 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 27 | 28 | PRETRAINED_VOCAB_FILES_MAP = { 29 | "vocab_file": { 30 | "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", 31 | "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", 32 | "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt", 33 | "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", 34 | "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt", 35 | "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt", 36 | "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt", 37 | "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt", 38 | "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt", 39 | "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt", 40 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", 41 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", 42 | "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt", 43 | "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt", 44 | "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt", 45 | "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt", 46 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt", 47 | "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt", 48 | } 49 | } 50 | 51 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 52 | "bert-base-uncased": 512, 53 | "bert-large-uncased": 512, 54 | "bert-base-cased": 512, 55 | "bert-large-cased": 512, 56 | "bert-base-multilingual-uncased": 512, 57 | "bert-base-multilingual-cased": 512, 58 | "bert-base-chinese": 512, 59 | "bert-base-german-cased": 512, 60 | "bert-large-uncased-whole-word-masking": 512, 61 | "bert-large-cased-whole-word-masking": 512, 62 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512, 63 | "bert-large-cased-whole-word-masking-finetuned-squad": 512, 64 | "bert-base-cased-finetuned-mrpc": 512, 65 | "bert-base-german-dbmdz-cased": 512, 66 | "bert-base-german-dbmdz-uncased": 512, 67 | "TurkuNLP/bert-base-finnish-cased-v1": 512, 68 | "TurkuNLP/bert-base-finnish-uncased-v1": 512, 69 | "wietsedv/bert-base-dutch-cased": 512, 70 | } 71 | 72 | PRETRAINED_INIT_CONFIGURATION = { 73 | "bert-base-uncased": {"do_lower_case": True}, 74 | "bert-large-uncased": {"do_lower_case": True}, 75 | "bert-base-cased": {"do_lower_case": False}, 76 | "bert-large-cased": {"do_lower_case": False}, 77 | "bert-base-multilingual-uncased": {"do_lower_case": True}, 78 | "bert-base-multilingual-cased": {"do_lower_case": False}, 79 | "bert-base-chinese": {"do_lower_case": False}, 80 | "bert-base-german-cased": {"do_lower_case": False}, 81 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, 82 | "bert-large-cased-whole-word-masking": {"do_lower_case": False}, 83 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, 84 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, 85 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, 86 | "bert-base-german-dbmdz-cased": {"do_lower_case": False}, 87 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, 88 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, 89 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, 90 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, 91 | } 92 | 93 | def load_vocab(vocab_file): 94 | """Loads a vocabulary file into a dictionary.""" 95 | vocab = collections.OrderedDict() 96 | with open(vocab_file, "r", encoding="utf-8") as reader: 97 | tokens = reader.readlines() 98 | for index, token in enumerate(tokens): 99 | token = token.rstrip("\n") 100 | vocab[token] = index 101 | return vocab 102 | 103 | def whitespace_tokenize(text): 104 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 105 | text = text.strip() 106 | if not text: 107 | return [] 108 | tokens = text.split() 109 | return tokens 110 | 111 | class BertTokenizer(PreTrainedTokenizer): 112 | r""" 113 | Construct a BERT tokenizer. Based on WordPiece. 114 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 115 | Users should refer to this superclass for more information regarding those methods. 116 | Args: 117 | vocab_file (:obj:`str`): 118 | File containing the vocabulary. 119 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 120 | Whether or not to lowercase the input when tokenizing. 121 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 122 | Whether or not to do basic tokenization before WordPiece. 123 | never_split (:obj:`Iterable`, `optional`): 124 | Collection of tokens which will never be split during tokenization. Only has an effect when 125 | :obj:`do_basic_tokenize=True` 126 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 127 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 128 | token instead. 129 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 130 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 131 | sequence classification or for a text and a question for question answering. It is also used as the last 132 | token of a sequence built with special tokens. 133 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 134 | The token used for padding, for example when batching sequences of different lengths. 135 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 136 | The classifier token which is used when doing sequence classification (classification of the whole sequence 137 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 138 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 139 | The token used for masking values. This is the token used when training this model with masked language 140 | modeling. This is the token which the model will try to predict. 141 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 142 | Whether or not to tokenize Chinese characters. 143 | This should likely be deactivated for Japanese (see this `issue 144 | `__). 145 | strip_accents: (:obj:`bool`, `optional`): 146 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 147 | value for :obj:`lowercase` (as in the original BERT). 148 | """ 149 | 150 | vocab_files_names = VOCAB_FILES_NAMES 151 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 152 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 153 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 154 | 155 | def __init__( 156 | self, 157 | vocab_file, 158 | do_lower_case=True, 159 | do_basic_tokenize=True, 160 | never_split=None, 161 | unk_token="[UNK]", 162 | sep_token="[SEP]", 163 | pad_token="[PAD]", 164 | cls_token="[CLS]", 165 | mask_token="[MASK]", 166 | tokenize_chinese_chars=True, 167 | strip_accents=None, 168 | **kwargs 169 | ): 170 | super().__init__( 171 | do_lower_case=do_lower_case, 172 | do_basic_tokenize=do_basic_tokenize, 173 | never_split=never_split, 174 | unk_token=unk_token, 175 | sep_token=sep_token, 176 | pad_token=pad_token, 177 | cls_token=cls_token, 178 | mask_token=mask_token, 179 | tokenize_chinese_chars=tokenize_chinese_chars, 180 | strip_accents=strip_accents, 181 | **kwargs, 182 | ) 183 | 184 | if not os.path.isfile(vocab_file): 185 | raise ValueError( 186 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 187 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) 188 | ) 189 | self.vocab = load_vocab(vocab_file) 190 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 191 | self.do_basic_tokenize = do_basic_tokenize 192 | if do_basic_tokenize: 193 | self.basic_tokenizer = BasicTokenizer( 194 | do_lower_case=do_lower_case, 195 | never_split=never_split, 196 | tokenize_chinese_chars=tokenize_chinese_chars, 197 | strip_accents=strip_accents, 198 | ) 199 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 200 | 201 | @property 202 | def do_lower_case(self): 203 | return self.basic_tokenizer.do_lower_case 204 | 205 | @property 206 | def vocab_size(self): 207 | return len(self.vocab) 208 | 209 | def get_vocab(self): 210 | return dict(self.vocab, **self.added_tokens_encoder) 211 | 212 | def _tokenize(self, text): 213 | split_tokens = [] 214 | if self.do_basic_tokenize: 215 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 216 | 217 | # If the token is part of the never_split set 218 | if token in self.basic_tokenizer.never_split: 219 | split_tokens.append(token) 220 | else: 221 | split_tokens += self.wordpiece_tokenizer.tokenize(token) 222 | else: 223 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 224 | return split_tokens 225 | 226 | def _convert_token_to_id(self, token): 227 | """ Converts a token (str) in an id using the vocab. """ 228 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 229 | 230 | def _convert_id_to_token(self, index): 231 | """Converts an index (integer) in a token (str) using the vocab.""" 232 | return self.ids_to_tokens.get(index, self.unk_token) 233 | 234 | def convert_tokens_to_string(self, tokens): 235 | """ Converts a sequence of tokens (string) in a single string. """ 236 | out_string = " ".join(tokens).replace(" ##", "").strip() 237 | return out_string 238 | 239 | def convert_input_ids_to_text(self, input_ids): 240 | """ Converts a list of index to text """ 241 | assert isinstance(input_ids, list) 242 | return list(map(self._convert_id_to_token, input_ids)) 243 | 244 | def convert_bs_input_ids_to_text(self, bs_input_ids): 245 | assert isinstance(bs_input_ids, list) and isinstance(bs_input_ids[0], list) 246 | return list(map(self.convert_input_ids_to_text, bs_input_ids)) 247 | 248 | def build_inputs_with_special_tokens( 249 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 250 | ) -> List[int]: 251 | """ 252 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 253 | adding special tokens. A BERT sequence has the following format: 254 | - single sequence: ``[CLS] X `` 255 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 256 | Args: 257 | token_ids_0 (:obj:`List[int]`): 258 | List of IDs to which the special tokens will be added. 259 | token_ids_1 (:obj:`List[int]`, `optional`): 260 | Optional second list of IDs for sequence pairs. 261 | Returns: 262 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 263 | """ 264 | if token_ids_1 is None: 265 | return [self.cls_token_id] + token_ids_0 266 | cls = [self.cls_token_id] 267 | sep = [self.sep_token_id] 268 | return cls + token_ids_0 + sep + token_ids_1 + sep 269 | 270 | def get_special_tokens_mask( 271 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 272 | ) -> List[int]: 273 | """ 274 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 275 | special tokens using the tokenizer ``prepare_for_model`` method. 276 | Args: 277 | token_ids_0 (:obj:`List[int]`): 278 | List of IDs. 279 | token_ids_1 (:obj:`List[int]`, `optional`): 280 | Optional second list of IDs for sequence pairs. 281 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 282 | Whether or not the token list is already formatted with special tokens for the model. 283 | Returns: 284 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 285 | """ 286 | 287 | if already_has_special_tokens: 288 | if token_ids_1 is not None: 289 | raise ValueError( 290 | "You should not supply a second sequence if the provided sequence of " 291 | "ids is already formatted with special tokens for the model." 292 | ) 293 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 294 | 295 | if token_ids_1 is not None: 296 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 297 | return [1] + ([0] * len(token_ids_0)) + [1] 298 | 299 | def create_token_type_ids_from_sequences( 300 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 301 | ) -> List[int]: 302 | """ 303 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence 304 | pair mask has the following format: 305 | :: 306 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 307 | | first sequence | second sequence | 308 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). 309 | Args: 310 | token_ids_0 (:obj:`List[int]`): 311 | List of IDs. 312 | token_ids_1 (:obj:`List[int]`, `optional`): 313 | Optional second list of IDs for sequence pairs. 314 | Returns: 315 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 316 | sequence(s). 317 | """ 318 | sep = [self.sep_token_id] 319 | cls = [self.cls_token_id] 320 | if token_ids_1 is None: 321 | return len(cls + token_ids_0 + sep) * [0] 322 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 323 | 324 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 325 | index = 0 326 | if os.path.isdir(save_directory): 327 | vocab_file = os.path.join( 328 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 329 | ) 330 | else: 331 | vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory 332 | with open(vocab_file, "w", encoding="utf-8") as writer: 333 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 334 | if index != token_index: 335 | logger.warning( 336 | "Saving vocabulary to {}: vocabulary indices are not consecutive." 337 | " Please check that the vocabulary is not corrupted!".format(vocab_file) 338 | ) 339 | index = token_index 340 | writer.write(token + "\n") 341 | index += 1 342 | return (vocab_file,) 343 | 344 | class BasicTokenizer(object): 345 | """ 346 | Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). 347 | Args: 348 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 349 | Whether or not to lowercase the input when tokenizing. 350 | never_split (:obj:`Iterable`, `optional`): 351 | Collection of tokens which will never be split during tokenization. Only has an effect when 352 | :obj:`do_basic_tokenize=True` 353 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 354 | Whether or not to tokenize Chinese characters. 355 | This should likely be deactivated for Japanese (see this `issue 356 | `__). 357 | strip_accents: (:obj:`bool`, `optional`): 358 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 359 | value for :obj:`lowercase` (as in the original BERT). 360 | """ 361 | 362 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): 363 | if never_split is None: 364 | never_split = [] 365 | self.do_lower_case = do_lower_case 366 | self.never_split = set(never_split) 367 | self.tokenize_chinese_chars = tokenize_chinese_chars 368 | self.strip_accents = strip_accents 369 | 370 | def tokenize(self, text, never_split=None): 371 | """ 372 | Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see 373 | WordPieceTokenizer. 374 | Args: 375 | **never_split**: (`optional`) list of str 376 | Kept for backward compatibility purposes. Now implemented directly at the base class level (see 377 | :func:`PreTrainedTokenizer.tokenize`) List of token not to split. 378 | """ 379 | # union() returns a new set by concatenating the two sets. 380 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split 381 | text = self._clean_text(text) 382 | 383 | # This was added on November 1st, 2018 for the multilingual and Chinese 384 | # models. This is also applied to the English models now, but it doesn't 385 | # matter since the English models were not trained on any Chinese data 386 | # and generally don't have any Chinese data in them (there are Chinese 387 | # characters in the vocabulary because Wikipedia does have some Chinese 388 | # words in the English Wikipedia.). 389 | if self.tokenize_chinese_chars: 390 | text = self._tokenize_chinese_chars(text) 391 | orig_tokens = whitespace_tokenize(text) 392 | split_tokens = [] 393 | for token in orig_tokens: 394 | if token not in never_split: 395 | if self.do_lower_case: 396 | token = token.lower() 397 | if self.strip_accents is not False: 398 | token = self._run_strip_accents(token) 399 | elif self.strip_accents: 400 | token = self._run_strip_accents(token) 401 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 402 | 403 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 404 | return output_tokens 405 | 406 | def _run_strip_accents(self, text): 407 | """Strips accents from a piece of text.""" 408 | text = unicodedata.normalize("NFD", text) 409 | output = [] 410 | for char in text: 411 | cat = unicodedata.category(char) 412 | if cat == "Mn": 413 | continue 414 | output.append(char) 415 | return "".join(output) 416 | 417 | def _run_split_on_punc(self, text, never_split=None): 418 | """Splits punctuation on a piece of text.""" 419 | if never_split is not None and text in never_split: 420 | return [text] 421 | chars = list(text) 422 | i = 0 423 | start_new_word = True 424 | output = [] 425 | while i < len(chars): 426 | char = chars[i] 427 | if _is_punctuation(char): 428 | output.append([char]) 429 | start_new_word = True 430 | else: 431 | if start_new_word: 432 | output.append([]) 433 | start_new_word = False 434 | output[-1].append(char) 435 | i += 1 436 | 437 | return ["".join(x) for x in output] 438 | 439 | def _tokenize_chinese_chars(self, text): 440 | """Adds whitespace around any CJK character.""" 441 | output = [] 442 | for char in text: 443 | cp = ord(char) 444 | if self._is_chinese_char(cp): 445 | output.append(" ") 446 | output.append(char) 447 | output.append(" ") 448 | else: 449 | output.append(char) 450 | return "".join(output) 451 | 452 | def _is_chinese_char(self, cp): 453 | """Checks whether CP is the codepoint of a CJK character.""" 454 | # This defines a "chinese character" as anything in the CJK Unicode block: 455 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 456 | # 457 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 458 | # despite its name. The modern Korean Hangul alphabet is a different block, 459 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 460 | # space-separated words, so they are not treated specially and handled 461 | # like the all of the other languages. 462 | if ( 463 | (cp >= 0x4E00 and cp <= 0x9FFF) 464 | or (cp >= 0x3400 and cp <= 0x4DBF) # 465 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 466 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 467 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 468 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 469 | or (cp >= 0xF900 and cp <= 0xFAFF) 470 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 471 | ): # 472 | return True 473 | 474 | return False 475 | 476 | def _clean_text(self, text): 477 | """Performs invalid character removal and whitespace cleanup on text.""" 478 | output = [] 479 | for char in text: 480 | cp = ord(char) 481 | if cp == 0 or cp == 0xFFFD or _is_control(char): 482 | continue 483 | if _is_whitespace(char): 484 | output.append(" ") 485 | else: 486 | output.append(char) 487 | return "".join(output) 488 | 489 | class WordpieceTokenizer(object): 490 | """Runs WordPiece tokenization.""" 491 | 492 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 493 | self.vocab = vocab 494 | self.unk_token = unk_token 495 | self.max_input_chars_per_word = max_input_chars_per_word 496 | 497 | def tokenize(self, text): 498 | """ 499 | Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform 500 | tokenization using the given vocabulary. 501 | For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. 502 | Args: 503 | text: A single token or whitespace separated tokens. This should have 504 | already been passed through `BasicTokenizer`. 505 | Returns: 506 | A list of wordpiece tokens. 507 | """ 508 | 509 | output_tokens = [] 510 | for token in whitespace_tokenize(text): 511 | chars = list(token) 512 | if len(chars) > self.max_input_chars_per_word: 513 | output_tokens.append(self.unk_token) 514 | continue 515 | 516 | is_bad = False 517 | start = 0 518 | sub_tokens = [] 519 | while start < len(chars): 520 | end = len(chars) 521 | cur_substr = None 522 | while start < end: 523 | substr = "".join(chars[start:end]) 524 | if start > 0: 525 | substr = "##" + substr 526 | if substr in self.vocab: 527 | cur_substr = substr 528 | break 529 | end -= 1 530 | if cur_substr is None: 531 | is_bad = True 532 | break 533 | sub_tokens.append(cur_substr) 534 | start = end 535 | 536 | if is_bad: 537 | output_tokens.append(self.unk_token) 538 | else: 539 | output_tokens.extend(sub_tokens) 540 | return output_tokens 541 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.vision_transformer import _cfg, PatchEmbed 7 | from timm.models.registry import register_model 8 | from timm.models.layers import trunc_normal_, DropPath 9 | 10 | class Mlp(nn.Module): 11 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 12 | """ 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | self.fc1 = nn.Linear(in_features, hidden_features) 18 | self.act = act_layer() 19 | self.fc2 = nn.Linear(hidden_features, out_features) 20 | self.drop = nn.Dropout(drop) 21 | 22 | def forward(self, x): 23 | x = self.fc1(x) 24 | x = self.act(x) 25 | x = self.drop(x) 26 | x = self.fc2(x) 27 | x = self.drop(x) 28 | return x 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 32 | super().__init__() 33 | self.num_heads = num_heads 34 | head_dim = dim // num_heads 35 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 36 | self.scale = qk_scale or head_dim ** -0.5 37 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 38 | self.attn_drop = nn.Dropout(attn_drop) 39 | self.proj = nn.Linear(dim, dim) 40 | self.proj_drop = nn.Dropout(proj_drop) 41 | self.attn_gradients = None 42 | self.attention_map = None 43 | 44 | def save_attn_gradients(self, attn_gradients): 45 | self.attn_gradients = attn_gradients 46 | 47 | def get_attn_gradients(self): 48 | return self.attn_gradients 49 | 50 | def save_attention_map(self, attention_map): 51 | self.attention_map = attention_map 52 | 53 | def get_attention_map(self): 54 | return self.attention_map 55 | 56 | def forward(self, x, register_hook=False): 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 60 | attn = (q @ k.transpose(-2, -1)) * self.scale 61 | attn = attn.softmax(dim=-1) 62 | attn = self.attn_drop(attn) 63 | if register_hook: 64 | self.save_attention_map(attn) 65 | attn.register_hook(self.save_attn_gradients) 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | class Block(nn.Module): 72 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 73 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 74 | super().__init__() 75 | self.norm1 = norm_layer(dim) 76 | self.attn = Attention( 77 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 78 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 79 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 80 | self.norm2 = norm_layer(dim) 81 | mlp_hidden_dim = int(dim * mlp_ratio) 82 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 83 | 84 | def forward(self, x, register_hook=False): 85 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 86 | x = x + self.drop_path(self.mlp(self.norm2(x))) 87 | return x 88 | 89 | class VisionTransformer(nn.Module): 90 | """ Vision Transformer 91 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 92 | https://arxiv.org/abs/2010.11929 93 | """ 94 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 95 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 96 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None): 97 | """ 98 | Args: 99 | img_size (int, tuple): input image size 100 | patch_size (int, tuple): patch size 101 | in_chans (int): number of input channels 102 | num_classes (int): number of classes for classification head 103 | embed_dim (int): embedding dimension 104 | depth (int): depth of transformer 105 | num_heads (int): number of attention heads 106 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 107 | qkv_bias (bool): enable bias for qkv if True 108 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 109 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 110 | drop_rate (float): dropout rate 111 | attn_drop_rate (float): attention dropout rate 112 | drop_path_rate (float): stochastic depth rate 113 | norm_layer: (nn.Module): normalization layer 114 | """ 115 | super().__init__() 116 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 117 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 118 | self.patch_embed = PatchEmbed( 119 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 120 | num_patches = self.patch_embed.num_patches 121 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 122 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 123 | self.pos_drop = nn.Dropout(p=drop_rate) 124 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 125 | self.blocks = nn.ModuleList([ 126 | Block( 127 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 128 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 129 | for i in range(depth)]) 130 | self.norm = norm_layer(embed_dim) 131 | trunc_normal_(self.pos_embed, std=.02) 132 | trunc_normal_(self.cls_token, std=.02) 133 | self.apply(self._init_weights) 134 | 135 | def _init_weights(self, m): 136 | if isinstance(m, nn.Linear): 137 | trunc_normal_(m.weight, std=.02) 138 | if isinstance(m, nn.Linear) and m.bias is not None: 139 | nn.init.constant_(m.bias, 0) 140 | elif isinstance(m, nn.LayerNorm): 141 | nn.init.constant_(m.bias, 0) 142 | nn.init.constant_(m.weight, 1.0) 143 | 144 | @torch.jit.ignore 145 | def no_weight_decay(self): 146 | return {'pos_embed', 'cls_token'} 147 | 148 | def forward(self, x, register_blk=-1): 149 | B = x.shape[0] 150 | x = self.patch_embed(x) 151 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 152 | x = torch.cat((cls_tokens, x), dim=1) 153 | x = x + self.pos_embed[:,:x.size(1),:] 154 | x = self.pos_drop(x) 155 | for i,blk in enumerate(self.blocks): 156 | x = blk(x, register_blk==i) 157 | x = self.norm(x) 158 | return x 159 | 160 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 161 | # interpolate position embedding 162 | embedding_size = pos_embed_checkpoint.shape[-1] 163 | num_patches = visual_encoder.patch_embed.num_patches 164 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 165 | # height (== width) for the checkpoint position embedding 166 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 167 | # height (== width) for the new position embedding 168 | new_size = int(num_patches ** 0.5) 169 | if orig_size!=new_size: 170 | # class_token and dist_token are kept unchanged 171 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 172 | # only the position tokens are interpolated 173 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 174 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 175 | pos_tokens = torch.nn.functional.interpolate( 176 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 177 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 178 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 179 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 180 | return new_pos_embed 181 | else: 182 | return pos_embed_checkpoint 183 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adamp import AdamP 2 | from .adamw import AdamW 3 | from .adafactor import Adafactor 4 | from .adahessian import Adahessian 5 | from .lookahead import Lookahead 6 | from .nadam import Nadam 7 | from .novograd import NovoGrad 8 | from .nvnovograd import NvNovoGrad 9 | from .radam import RAdam 10 | from .rmsprop_tf import RMSpropTF 11 | from .sgdp import SGDP 12 | 13 | from .optim_factory import create_optimizer 14 | -------------------------------------------------------------------------------- /optim/adafactor.py: -------------------------------------------------------------------------------- 1 | """ Adafactor Optimizer 2 | 3 | Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py 4 | 5 | Original header/copyright below. 6 | 7 | """ 8 | # Copyright (c) Facebook, Inc. and its affiliates. 9 | # 10 | # This source code is licensed under the MIT license found in the 11 | # LICENSE file in the root directory of this source tree. 12 | import torch 13 | import math 14 | 15 | 16 | class Adafactor(torch.optim.Optimizer): 17 | """Implements Adafactor algorithm. 18 | This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` 19 | (see https://arxiv.org/abs/1804.04235) 20 | 21 | Note that this optimizer internally adjusts the learning rate depending on the 22 | *scale_parameter*, *relative_step* and *warmup_init* options. 23 | 24 | To use a manual (external) learning rate schedule you should set `scale_parameter=False` and 25 | `relative_step=False`. 26 | 27 | Arguments: 28 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups 29 | lr (float, optional): external learning rate (default: None) 30 | eps (tuple[float, float]): regularization constants for square gradient 31 | and parameter scale respectively (default: (1e-30, 1e-3)) 32 | clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) 33 | decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) 34 | beta1 (float): coefficient used for computing running averages of gradient (default: None) 35 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 36 | scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) 37 | relative_step (bool): if True, time-dependent learning rate is computed 38 | instead of external learning rate (default: True) 39 | warmup_init (bool): time-dependent learning rate computation depends on 40 | whether warm-up initialization is being used (default: False) 41 | """ 42 | 43 | def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, 44 | decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): 45 | relative_step = lr is None 46 | if warmup_init and not relative_step: 47 | raise ValueError('warmup_init requires relative_step=True') 48 | 49 | beta1 = None if betas is None else betas[0] # make it compat with standard betas arg 50 | defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, 51 | beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, 52 | relative_step=relative_step, warmup_init=warmup_init) 53 | super(Adafactor, self).__init__(params, defaults) 54 | 55 | @staticmethod 56 | def _get_lr(param_group, param_state): 57 | if param_group['relative_step']: 58 | min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 59 | lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) 60 | param_scale = 1.0 61 | if param_group['scale_parameter']: 62 | param_scale = max(param_group['eps_scale'], param_state['RMS']) 63 | param_group['lr'] = lr_t * param_scale 64 | return param_group['lr'] 65 | 66 | @staticmethod 67 | def _get_options(param_group, param_shape): 68 | factored = len(param_shape) >= 2 69 | use_first_moment = param_group['beta1'] is not None 70 | return factored, use_first_moment 71 | 72 | @staticmethod 73 | def _rms(tensor): 74 | return tensor.norm(2) / (tensor.numel() ** 0.5) 75 | 76 | def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): 77 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) 78 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() 79 | return torch.mul(r_factor, c_factor) 80 | 81 | def step(self, closure=None): 82 | """Performs a single optimization step. 83 | Arguments: 84 | closure (callable, optional): A closure that reevaluates the model and returns the loss. 85 | """ 86 | loss = None 87 | if closure is not None: 88 | loss = closure() 89 | 90 | for group in self.param_groups: 91 | for p in group['params']: 92 | if p.grad is None: 93 | continue 94 | grad = p.grad.data 95 | if grad.dtype in {torch.float16, torch.bfloat16}: 96 | grad = grad.float() 97 | if grad.is_sparse: 98 | raise RuntimeError('Adafactor does not support sparse gradients.') 99 | 100 | state = self.state[p] 101 | grad_shape = grad.shape 102 | 103 | factored, use_first_moment = self._get_options(group, grad_shape) 104 | # State Initialization 105 | if len(state) == 0: 106 | state['step'] = 0 107 | 108 | if use_first_moment: 109 | # Exponential moving average of gradient values 110 | state['exp_avg'] = torch.zeros_like(grad) 111 | if factored: 112 | state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) 113 | state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) 114 | else: 115 | state['exp_avg_sq'] = torch.zeros_like(grad) 116 | 117 | state['RMS'] = 0 118 | else: 119 | if use_first_moment: 120 | state['exp_avg'] = state['exp_avg'].to(grad) 121 | if factored: 122 | state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) 123 | state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) 124 | else: 125 | state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) 126 | 127 | p_data_fp32 = p.data 128 | if p.data.dtype in {torch.float16, torch.bfloat16}: 129 | p_data_fp32 = p_data_fp32.float() 130 | 131 | state['step'] += 1 132 | state['RMS'] = self._rms(p_data_fp32) 133 | lr_t = self._get_lr(group, state) 134 | 135 | beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) 136 | update = grad ** 2 + group['eps'] 137 | if factored: 138 | exp_avg_sq_row = state['exp_avg_sq_row'] 139 | exp_avg_sq_col = state['exp_avg_sq_col'] 140 | 141 | exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1)) 142 | exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2)) 143 | #exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+ 144 | #exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) 145 | 146 | # Approximation of exponential moving average of square of gradient 147 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 148 | update.mul_(grad) 149 | else: 150 | exp_avg_sq = state['exp_avg_sq'] 151 | 152 | exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update) 153 | #exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+ 154 | update = exp_avg_sq.rsqrt().mul_(grad) 155 | 156 | update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) 157 | update.mul_(lr_t) 158 | 159 | if use_first_moment: 160 | exp_avg = state['exp_avg'] 161 | exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update) 162 | #exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+ 163 | update = exp_avg 164 | 165 | if group['weight_decay'] != 0: 166 | p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32) 167 | #p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+ 168 | 169 | p_data_fp32.add_(-update) 170 | 171 | if p.data.dtype in {torch.float16, torch.bfloat16}: 172 | p.data.copy_(p_data_fp32) 173 | 174 | return loss -------------------------------------------------------------------------------- /optim/adahessian.py: -------------------------------------------------------------------------------- 1 | """ AdaHessian Optimizer 2 | 3 | Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py 4 | Originally licensed MIT, Copyright 2020, David Samuel 5 | """ 6 | import torch 7 | 8 | 9 | class Adahessian(torch.optim.Optimizer): 10 | """ 11 | Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" 12 | 13 | Arguments: 14 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups 15 | lr (float, optional): learning rate (default: 0.1) 16 | betas ((float, float), optional): coefficients used for computing running averages of gradient and the 17 | squared hessian trace (default: (0.9, 0.999)) 18 | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) 20 | hessian_power (float, optional): exponent of the hessian trace (default: 1.0) 21 | update_each (int, optional): compute the hessian trace approximation only after *this* number of steps 22 | (to save time) (default: 1) 23 | n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) 24 | """ 25 | 26 | def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, 27 | hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): 28 | if not 0.0 <= lr: 29 | raise ValueError(f"Invalid learning rate: {lr}") 30 | if not 0.0 <= eps: 31 | raise ValueError(f"Invalid epsilon value: {eps}") 32 | if not 0.0 <= betas[0] < 1.0: 33 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 34 | if not 0.0 <= betas[1] < 1.0: 35 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 36 | if not 0.0 <= hessian_power <= 1.0: 37 | raise ValueError(f"Invalid Hessian power value: {hessian_power}") 38 | 39 | self.n_samples = n_samples 40 | self.update_each = update_each 41 | self.avg_conv_kernel = avg_conv_kernel 42 | 43 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training 44 | self.seed = 2147483647 45 | self.generator = torch.Generator().manual_seed(self.seed) 46 | 47 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) 48 | super(Adahessian, self).__init__(params, defaults) 49 | 50 | for p in self.get_params(): 51 | p.hess = 0.0 52 | self.state[p]["hessian step"] = 0 53 | 54 | @property 55 | def is_second_order(self): 56 | return True 57 | 58 | def get_params(self): 59 | """ 60 | Gets all parameters in all param_groups with gradients 61 | """ 62 | 63 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad) 64 | 65 | def zero_hessian(self): 66 | """ 67 | Zeros out the accumalated hessian traces. 68 | """ 69 | 70 | for p in self.get_params(): 71 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: 72 | p.hess.zero_() 73 | 74 | @torch.no_grad() 75 | def set_hessian(self): 76 | """ 77 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. 78 | """ 79 | 80 | params = [] 81 | for p in filter(lambda p: p.grad is not None, self.get_params()): 82 | if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step 83 | params.append(p) 84 | self.state[p]["hessian step"] += 1 85 | 86 | if len(params) == 0: 87 | return 88 | 89 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device 90 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed) 91 | 92 | grads = [p.grad for p in params] 93 | 94 | for i in range(self.n_samples): 95 | # Rademacher distribution {-1.0, 1.0} 96 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] 97 | h_zs = torch.autograd.grad( 98 | grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) 99 | for h_z, z, p in zip(h_zs, zs, params): 100 | p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) 101 | 102 | @torch.no_grad() 103 | def step(self, closure=None): 104 | """ 105 | Performs a single optimization step. 106 | Arguments: 107 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) 108 | """ 109 | 110 | loss = None 111 | if closure is not None: 112 | loss = closure() 113 | 114 | self.zero_hessian() 115 | self.set_hessian() 116 | 117 | for group in self.param_groups: 118 | for p in group['params']: 119 | if p.grad is None or p.hess is None: 120 | continue 121 | 122 | if self.avg_conv_kernel and p.dim() == 4: 123 | p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() 124 | 125 | # Perform correct stepweight decay as in AdamW 126 | p.mul_(1 - group['lr'] * group['weight_decay']) 127 | 128 | state = self.state[p] 129 | 130 | # State initialization 131 | if len(state) == 1: 132 | state['step'] = 0 133 | # Exponential moving average of gradient values 134 | state['exp_avg'] = torch.zeros_like(p) 135 | # Exponential moving average of Hessian diagonal square values 136 | state['exp_hessian_diag_sq'] = torch.zeros_like(p) 137 | 138 | exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] 139 | beta1, beta2 = group['betas'] 140 | state['step'] += 1 141 | 142 | # Decay the first and second moment running average coefficient 143 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) 144 | exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) 145 | 146 | bias_correction1 = 1 - beta1 ** state['step'] 147 | bias_correction2 = 1 - beta2 ** state['step'] 148 | 149 | k = group['hessian_power'] 150 | denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) 151 | 152 | # make update 153 | step_size = group['lr'] / bias_correction1 154 | p.addcdiv_(exp_avg, denom, value=-step_size) 155 | 156 | return loss 157 | -------------------------------------------------------------------------------- /optim/adamp.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class AdamP(Optimizer): 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): 19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) 21 | super(AdamP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | 65 | grad = p.grad.data 66 | beta1, beta2 = group['betas'] 67 | nesterov = group['nesterov'] 68 | 69 | state = self.state[p] 70 | 71 | # State initialization 72 | if len(state) == 0: 73 | state['step'] = 0 74 | state['exp_avg'] = torch.zeros_like(p.data) 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | 77 | # Adam 78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 79 | 80 | state['step'] += 1 81 | bias_correction1 = 1 - beta1 ** state['step'] 82 | bias_correction2 = 1 - beta2 ** state['step'] 83 | 84 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 86 | 87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 88 | step_size = group['lr'] / bias_correction1 89 | 90 | if nesterov: 91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 92 | else: 93 | perturb = exp_avg / denom 94 | 95 | # Projection 96 | wd_ratio = 1 97 | if len(p.shape) > 1: 98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) 99 | 100 | # Weight decay 101 | if group['weight_decay'] > 0: 102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) 103 | 104 | # Step 105 | p.data.add_(-step_size, perturb) 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ AdamW Optimizer 2 | Impl copied from PyTorch master 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdamW(Optimizer): 10 | r"""Implements AdamW algorithm. 11 | 12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | 28 | .. _Adam\: A Method for Stochastic Optimization: 29 | https://arxiv.org/abs/1412.6980 30 | .. _Decoupled Weight Decay Regularization: 31 | https://arxiv.org/abs/1711.05101 32 | .. _On the Convergence of Adam and Beyond: 33 | https://openreview.net/forum?id=ryQu7f-RZ 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 37 | weight_decay=1e-2, amsgrad=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | defaults = dict(lr=lr, betas=betas, eps=eps, 47 | weight_decay=weight_decay, amsgrad=amsgrad) 48 | super(AdamW, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super(AdamW, self).__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault('amsgrad', False) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | 71 | # Perform stepweight decay 72 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 73 | 74 | # Perform optimization step 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | state['step'] += 1 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | 102 | # Decay the first and second moment running average coefficient 103 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 105 | if amsgrad: 106 | # Maintains the maximum of all 2nd moment running avg. till now 107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 108 | # Use the max. for normalizing running avg. of gradient 109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 110 | else: 111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 112 | 113 | step_size = group['lr'] / bias_correction1 114 | 115 | p.data.addcdiv_(-step_size, exp_avg, denom) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | if not 0.0 <= alpha <= 1.0: 15 | raise ValueError(f'Invalid slow update rate: {alpha}') 16 | if not 1 <= k: 17 | raise ValueError(f'Invalid lookahead steps: {k}') 18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 19 | self.base_optimizer = base_optimizer 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults = base_optimizer.defaults 22 | self.defaults.update(defaults) 23 | self.state = defaultdict(dict) 24 | # manually add our defaults to the param groups 25 | for name, default in defaults.items(): 26 | for group in self.param_groups: 27 | group.setdefault(name, default) 28 | 29 | def update_slow(self, group): 30 | for fast_p in group["params"]: 31 | if fast_p.grad is None: 32 | continue 33 | param_state = self.state[fast_p] 34 | if 'slow_buffer' not in param_state: 35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 36 | param_state['slow_buffer'].copy_(fast_p.data) 37 | slow = param_state['slow_buffer'] 38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 39 | fast_p.data.copy_(slow) 40 | 41 | def sync_lookahead(self): 42 | for group in self.param_groups: 43 | self.update_slow(group) 44 | 45 | def step(self, closure=None): 46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 47 | loss = self.base_optimizer.step(closure) 48 | for group in self.param_groups: 49 | group['lookahead_step'] += 1 50 | if group['lookahead_step'] % group['lookahead_k'] == 0: 51 | self.update_slow(group) 52 | return loss 53 | 54 | def state_dict(self): 55 | fast_state_dict = self.base_optimizer.state_dict() 56 | slow_state = { 57 | (id(k) if isinstance(k, torch.Tensor) else k): v 58 | for k, v in self.state.items() 59 | } 60 | fast_state = fast_state_dict['state'] 61 | param_groups = fast_state_dict['param_groups'] 62 | return { 63 | 'state': fast_state, 64 | 'slow_state': slow_state, 65 | 'param_groups': param_groups, 66 | } 67 | 68 | def load_state_dict(self, state_dict): 69 | fast_state_dict = { 70 | 'state': state_dict['state'], 71 | 'param_groups': state_dict['param_groups'], 72 | } 73 | self.base_optimizer.load_state_dict(fast_state_dict) 74 | 75 | # We want to restore the slow state, but share param_groups reference 76 | # with base_optimizer. This is a bit redundant but least code 77 | slow_state_new = False 78 | if 'slow_state' not in state_dict: 79 | print('Loading state_dict from optimizer without Lookahead applied.') 80 | state_dict['slow_state'] = defaultdict(dict) 81 | slow_state_new = True 82 | slow_state_dict = { 83 | 'state': state_dict['slow_state'], 84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 85 | } 86 | super(Lookahead, self).load_state_dict(slow_state_dict) 87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 88 | if slow_state_new: 89 | # reapply defaults to catch missing lookahead specific ones 90 | for name, default in self.defaults.items(): 91 | for group in self.param_groups: 92 | group.setdefault(name, default) 93 | -------------------------------------------------------------------------------- /optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | """ 27 | 28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 29 | weight_decay=0, schedule_decay=4e-3): 30 | defaults = dict(lr=lr, betas=betas, eps=eps, 31 | weight_decay=weight_decay, schedule_decay=schedule_decay) 32 | super(Nadam, self).__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['m_schedule'] = 1. 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 58 | 59 | # Warming momentum schedule 60 | m_schedule = state['m_schedule'] 61 | schedule_decay = group['schedule_decay'] 62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 63 | beta1, beta2 = group['betas'] 64 | eps = group['eps'] 65 | state['step'] += 1 66 | t = state['step'] 67 | 68 | if group['weight_decay'] != 0: 69 | grad = grad.add(group['weight_decay'], p.data) 70 | 71 | momentum_cache_t = beta1 * \ 72 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 73 | momentum_cache_t_1 = beta1 * \ 74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 75 | m_schedule_new = m_schedule * momentum_cache_t 76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 77 | state['m_schedule'] = m_schedule_new 78 | 79 | # Decay the first and second moment running average coefficient 80 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 83 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 84 | 85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /optim/novograd.py: -------------------------------------------------------------------------------- 1 | """NovoGrad Optimizer. 2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd 3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 4 | - https://arxiv.org/abs/1905.11286 5 | """ 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | class NovoGrad(Optimizer): 13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | super(NovoGrad, self).__init__(params, defaults) 16 | self._lr = lr 17 | self._beta1 = betas[0] 18 | self._beta2 = betas[1] 19 | self._eps = eps 20 | self._wd = weight_decay 21 | self._grad_averaging = grad_averaging 22 | 23 | self._momentum_initialized = False 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | if not self._momentum_initialized: 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | state = self.state[p] 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError('NovoGrad does not support sparse gradients') 39 | 40 | v = torch.norm(grad)**2 41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data 42 | state['step'] = 0 43 | state['v'] = v 44 | state['m'] = m 45 | state['grad_ema'] = None 46 | self._momentum_initialized = True 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | state = self.state[p] 53 | state['step'] += 1 54 | 55 | step, v, m = state['step'], state['v'], state['m'] 56 | grad_ema = state['grad_ema'] 57 | 58 | grad = p.grad.data 59 | g2 = torch.norm(grad)**2 60 | grad_ema = g2 if grad_ema is None else grad_ema * \ 61 | self._beta2 + g2 * (1. - self._beta2) 62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) 63 | 64 | if self._grad_averaging: 65 | grad *= (1. - self._beta1) 66 | 67 | g2 = torch.norm(grad)**2 68 | v = self._beta2*v + (1. - self._beta2)*g2 69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) 70 | bias_correction1 = 1 - self._beta1 ** step 71 | bias_correction2 = 1 - self._beta2 ** step 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | 74 | state['v'], state['m'] = v, m 75 | state['grad_ema'] = grad_ema 76 | p.data.add_(-step_size, m) 77 | return loss 78 | -------------------------------------------------------------------------------- /optim/nvnovograd.py: -------------------------------------------------------------------------------- 1 | """ Nvidia NovoGrad Optimizer. 2 | Original impl by Nvidia from Jasper example: 3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper 4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 5 | - https://arxiv.org/abs/1905.11286 6 | """ 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer 10 | import math 11 | 12 | 13 | class NvNovoGrad(Optimizer): 14 | """ 15 | Implements Novograd algorithm. 16 | 17 | Args: 18 | params (iterable): iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr (float, optional): learning rate (default: 1e-3) 21 | betas (Tuple[float, float], optional): coefficients used for computing 22 | running averages of gradient and its square (default: (0.95, 0.98)) 23 | eps (float, optional): term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 26 | grad_averaging: gradient averaging 27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 28 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 29 | (default: False) 30 | """ 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, 33 | weight_decay=0, grad_averaging=False, amsgrad=False): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {}".format(eps)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, 44 | grad_averaging=grad_averaging, 45 | amsgrad=amsgrad) 46 | 47 | super(NvNovoGrad, self).__init__(params, defaults) 48 | 49 | def __setstate__(self, state): 50 | super(NvNovoGrad, self).__setstate__(state) 51 | for group in self.param_groups: 52 | group.setdefault('amsgrad', False) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | if grad.is_sparse: 71 | raise RuntimeError('Sparse gradients are not supported.') 72 | amsgrad = group['amsgrad'] 73 | 74 | state = self.state[p] 75 | 76 | # State initialization 77 | if len(state) == 0: 78 | state['step'] = 0 79 | # Exponential moving average of gradient values 80 | state['exp_avg'] = torch.zeros_like(p.data) 81 | # Exponential moving average of squared gradient values 82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 83 | if amsgrad: 84 | # Maintains max of all exp. moving avg. of sq. grad. values 85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | if amsgrad: 89 | max_exp_avg_sq = state['max_exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | norm = torch.sum(torch.pow(grad, 2)) 95 | 96 | if exp_avg_sq == 0: 97 | exp_avg_sq.copy_(norm) 98 | else: 99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) 100 | 101 | if amsgrad: 102 | # Maintains the maximum of all 2nd moment running avg. till now 103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 104 | # Use the max. for normalizing running avg. of gradient 105 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 106 | else: 107 | denom = exp_avg_sq.sqrt().add_(group['eps']) 108 | 109 | grad.div_(denom) 110 | if group['weight_decay'] != 0: 111 | grad.add_(group['weight_decay'], p.data) 112 | if group['grad_averaging']: 113 | grad.mul_(1 - beta1) 114 | exp_avg.mul_(beta1).add_(grad) 115 | 116 | p.data.add_(-group['lr'], exp_avg) 117 | 118 | return loss 119 | -------------------------------------------------------------------------------- /optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | """ Optimizer Factory w/ Custom Weight Decay 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch 5 | from torch import optim as optim 6 | 7 | from .adafactor import Adafactor 8 | from .adahessian import Adahessian 9 | from .adamp import AdamP 10 | from .lookahead import Lookahead 11 | from .nadam import Nadam 12 | from .novograd import NovoGrad 13 | from .nvnovograd import NvNovoGrad 14 | from .radam import RAdam 15 | from .rmsprop_tf import RMSpropTF 16 | from .sgdp import SGDP 17 | 18 | try: 19 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 20 | has_apex = True 21 | except ImportError: 22 | has_apex = False 23 | 24 | def add_weight_decay(model, lr_custm=1e-4, weight_decay=1e-5, skip_list=()): 25 | decay = [] 26 | no_decay = [] 27 | decay_lr = [] 28 | no_decay_lr = [] 29 | for name, param in model.named_parameters(): 30 | if not param.requires_grad: 31 | continue # frozen weights 32 | if 'prd' in name or 'mrtd' in name: 33 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 34 | no_decay_lr.append(param) 35 | else: 36 | decay_lr.append(param) 37 | else: 38 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 39 | no_decay.append(param) 40 | else: 41 | decay.append(param) 42 | return [ 43 | {'params': no_decay, 'weight_decay': 0.}, 44 | {'params': decay, 'weight_decay': weight_decay}, 45 | {'params': no_decay_lr, 'weight_decay': 0., 'lr': lr_custm}, 46 | {'params': decay_lr, 'weight_decay': weight_decay, 'lr': lr_custm}] 47 | 48 | def create_optimizer(args, model, filter_bias_and_bn=True): 49 | opt_lower = args.opt.lower() 50 | weight_decay = args.weight_decay 51 | if weight_decay and filter_bias_and_bn: 52 | skip = {} 53 | if hasattr(model, 'no_weight_decay'): 54 | skip = model.no_weight_decay() 55 | parameters = add_weight_decay(model, args.lr_custm, weight_decay, skip) 56 | weight_decay = 0. 57 | else: 58 | parameters = model.parameters() 59 | if 'fused' in opt_lower: 60 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 61 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 62 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 63 | opt_args['eps'] = args.opt_eps 64 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 65 | opt_args['betas'] = args.opt_betas 66 | if hasattr(args, 'opt_args') and args.opt_args is not None: 67 | opt_args.update(args.opt_args) 68 | opt_split = opt_lower.split('_') 69 | opt_lower = opt_split[-1] 70 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 71 | opt_args.pop('eps', None) 72 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 73 | elif opt_lower == 'momentum': 74 | opt_args.pop('eps', None) 75 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 76 | elif opt_lower == 'adam': 77 | optimizer = optim.Adam(parameters, **opt_args) 78 | elif opt_lower == 'adamw': 79 | optimizer = optim.AdamW(parameters, **opt_args) 80 | elif opt_lower == 'nadam': 81 | optimizer = Nadam(parameters, **opt_args) 82 | elif opt_lower == 'radam': 83 | optimizer = RAdam(parameters, **opt_args) 84 | elif opt_lower == 'adamp': 85 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 86 | elif opt_lower == 'sgdp': 87 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 88 | elif opt_lower == 'adadelta': 89 | optimizer = optim.Adadelta(parameters, **opt_args) 90 | elif opt_lower == 'adafactor': 91 | if not args.lr: 92 | opt_args['lr'] = None 93 | optimizer = Adafactor(parameters, **opt_args) 94 | elif opt_lower == 'adahessian': 95 | optimizer = Adahessian(parameters, **opt_args) 96 | elif opt_lower == 'rmsprop': 97 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 98 | elif opt_lower == 'rmsproptf': 99 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 100 | elif opt_lower == 'novograd': 101 | optimizer = NovoGrad(parameters, **opt_args) 102 | elif opt_lower == 'nvnovograd': 103 | optimizer = NvNovoGrad(parameters, **opt_args) 104 | elif opt_lower == 'fusedsgd': 105 | opt_args.pop('eps', None) 106 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 107 | elif opt_lower == 'fusedmomentum': 108 | opt_args.pop('eps', None) 109 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 110 | elif opt_lower == 'fusedadam': 111 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 112 | elif opt_lower == 'fusedadamw': 113 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 114 | elif opt_lower == 'fusedlamb': 115 | optimizer = FusedLAMB(parameters, **opt_args) 116 | elif opt_lower == 'fusednovograd': 117 | opt_args.setdefault('betas', (0.95, 0.98)) 118 | optimizer = FusedNovoGrad(parameters, **opt_args) 119 | else: 120 | assert False and "Invalid optimizer" 121 | raise ValueError 122 | if len(opt_split) > 1: 123 | if opt_split[0] == 'lookahead': 124 | optimizer = Lookahead(optimizer) 125 | return optimizer 126 | -------------------------------------------------------------------------------- /optim/radam.py: -------------------------------------------------------------------------------- 1 | """RAdam Optimizer. 2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam 3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer, required 8 | 9 | 10 | class RAdam(Optimizer): 11 | 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 14 | self.buffer = [[None, None, None] for ind in range(10)] 15 | super(RAdam, self).__init__(params, defaults) 16 | 17 | def __setstate__(self, state): 18 | super(RAdam, self).__setstate__(state) 19 | 20 | def step(self, closure=None): 21 | 22 | loss = None 23 | if closure is not None: 24 | loss = closure() 25 | 26 | for group in self.param_groups: 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad.data.float() 32 | if grad.is_sparse: 33 | raise RuntimeError('RAdam does not support sparse gradients') 34 | 35 | p_data_fp32 = p.data.float() 36 | 37 | state = self.state[p] 38 | 39 | if len(state) == 0: 40 | state['step'] = 0 41 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 43 | else: 44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 46 | 47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 48 | beta1, beta2 = group['betas'] 49 | 50 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 51 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 52 | 53 | state['step'] += 1 54 | buffered = self.buffer[int(state['step'] % 10)] 55 | if state['step'] == buffered[0]: 56 | N_sma, step_size = buffered[1], buffered[2] 57 | else: 58 | buffered[0] = state['step'] 59 | beta2_t = beta2 ** state['step'] 60 | N_sma_max = 2 / (1 - beta2) - 1 61 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 62 | buffered[1] = N_sma 63 | 64 | # more conservative since it's an approximated value 65 | if N_sma >= 5: 66 | step_size = group['lr'] * math.sqrt( 67 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 68 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 69 | else: 70 | step_size = group['lr'] / (1 - beta1 ** state['step']) 71 | buffered[2] = step_size 72 | 73 | if group['weight_decay'] != 0: 74 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 75 | 76 | # more conservative since it's an approximated value 77 | if N_sma >= 5: 78 | denom = exp_avg_sq.sqrt().add_(group['eps']) 79 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 80 | else: 81 | p_data_fp32.add_(-step_size, exp_avg) 82 | 83 | p.data.copy_(p_data_fp32) 84 | 85 | return loss 86 | 87 | 88 | class PlainRAdam(Optimizer): 89 | 90 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 91 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 92 | 93 | super(PlainRAdam, self).__init__(params, defaults) 94 | 95 | def __setstate__(self, state): 96 | super(PlainRAdam, self).__setstate__(state) 97 | 98 | def step(self, closure=None): 99 | 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group in self.param_groups: 105 | 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data.float() 110 | if grad.is_sparse: 111 | raise RuntimeError('RAdam does not support sparse gradients') 112 | 113 | p_data_fp32 = p.data.float() 114 | 115 | state = self.state[p] 116 | 117 | if len(state) == 0: 118 | state['step'] = 0 119 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 121 | else: 122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 124 | 125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 126 | beta1, beta2 = group['betas'] 127 | 128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 129 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 130 | 131 | state['step'] += 1 132 | beta2_t = beta2 ** state['step'] 133 | N_sma_max = 2 / (1 - beta2) - 1 134 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 135 | 136 | if group['weight_decay'] != 0: 137 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 138 | 139 | # more conservative since it's an approximated value 140 | if N_sma >= 5: 141 | step_size = group['lr'] * math.sqrt( 142 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 143 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 144 | denom = exp_avg_sq.sqrt().add_(group['eps']) 145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 146 | else: 147 | step_size = group['lr'] / (1 - beta1 ** state['step']) 148 | p_data_fp32.add_(-step_size, exp_avg) 149 | 150 | p.data.copy_(p_data_fp32) 151 | 152 | return loss 153 | -------------------------------------------------------------------------------- /optim/rmsprop_tf.py: -------------------------------------------------------------------------------- 1 | """ RMSProp modified to behave like Tensorflow impl 2 | 3 | Originally cut & paste from PyTorch RMSProp 4 | https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py 5 | Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE 6 | 7 | Modifications Copyright 2020 Ross Wightman 8 | """ 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | 14 | class RMSpropTF(Optimizer): 15 | """Implements RMSprop algorithm (TensorFlow style epsilon) 16 | 17 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt 18 | and a few other modifications to closer match Tensorflow for matching hyper-params. 19 | 20 | Noteworthy changes include: 21 | 1. Epsilon applied inside square-root 22 | 2. square_avg initialized to ones 23 | 3. LR scaling of update accumulated in momentum buffer 24 | 25 | Proposed by G. Hinton in his 26 | `course `_. 27 | 28 | The centered version first appears in `Generating Sequences 29 | With Recurrent Neural Networks `_. 30 | 31 | Arguments: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-2) 35 | momentum (float, optional): momentum factor (default: 0) 36 | alpha (float, optional): smoothing (decay) constant (default: 0.9) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-10) 39 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 40 | the gradient is normalized by an estimation of its variance 41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 42 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 43 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer 44 | update as per defaults in Tensorflow 45 | 46 | """ 47 | 48 | def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, 49 | decoupled_decay=False, lr_in_momentum=True): 50 | if not 0.0 <= lr: 51 | raise ValueError("Invalid learning rate: {}".format(lr)) 52 | if not 0.0 <= eps: 53 | raise ValueError("Invalid epsilon value: {}".format(eps)) 54 | if not 0.0 <= momentum: 55 | raise ValueError("Invalid momentum value: {}".format(momentum)) 56 | if not 0.0 <= weight_decay: 57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 58 | if not 0.0 <= alpha: 59 | raise ValueError("Invalid alpha value: {}".format(alpha)) 60 | 61 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, 62 | decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) 63 | super(RMSpropTF, self).__init__(params, defaults) 64 | 65 | def __setstate__(self, state): 66 | super(RMSpropTF, self).__setstate__(state) 67 | for group in self.param_groups: 68 | group.setdefault('momentum', 0) 69 | group.setdefault('centered', False) 70 | 71 | def step(self, closure=None): 72 | """Performs a single optimization step. 73 | 74 | Arguments: 75 | closure (callable, optional): A closure that reevaluates the model 76 | and returns the loss. 77 | """ 78 | loss = None 79 | if closure is not None: 80 | loss = closure() 81 | 82 | for group in self.param_groups: 83 | for p in group['params']: 84 | if p.grad is None: 85 | continue 86 | grad = p.grad.data 87 | if grad.is_sparse: 88 | raise RuntimeError('RMSprop does not support sparse gradients') 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero 95 | if group['momentum'] > 0: 96 | state['momentum_buffer'] = torch.zeros_like(p.data) 97 | if group['centered']: 98 | state['grad_avg'] = torch.zeros_like(p.data) 99 | 100 | square_avg = state['square_avg'] 101 | one_minus_alpha = 1. - group['alpha'] 102 | 103 | state['step'] += 1 104 | 105 | if group['weight_decay'] != 0: 106 | if 'decoupled_decay' in group and group['decoupled_decay']: 107 | p.data.add_(-group['weight_decay'], p.data) 108 | else: 109 | grad = grad.add(group['weight_decay'], p.data) 110 | 111 | # Tensorflow order of ops for updating squared avg 112 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) 113 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original 114 | 115 | if group['centered']: 116 | grad_avg = state['grad_avg'] 117 | grad_avg.add_(one_minus_alpha, grad - grad_avg) 118 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original 119 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt 120 | else: 121 | avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt 122 | 123 | if group['momentum'] > 0: 124 | buf = state['momentum_buffer'] 125 | # Tensorflow accumulates the LR scaling in the momentum buffer 126 | if 'lr_in_momentum' in group and group['lr_in_momentum']: 127 | buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) 128 | p.data.add_(-buf) 129 | else: 130 | # PyTorch scales the param update by LR 131 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 132 | p.data.add_(-group['lr'], buf) 133 | else: 134 | p.data.addcdiv_(-group['lr'], grad, avg) 135 | 136 | return loss 137 | -------------------------------------------------------------------------------- /optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class SGDP(Optimizer): 17 | def __init__(self, params, lr=required, momentum=0, dampening=0, 18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 21 | super(SGDP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | weight_decay = group['weight_decay'] 62 | momentum = group['momentum'] 63 | dampening = group['dampening'] 64 | nesterov = group['nesterov'] 65 | 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['momentum'] = torch.zeros_like(p.data) 75 | 76 | # SGD 77 | buf = state['momentum'] 78 | buf.mul_(momentum).add_(1 - dampening, grad) 79 | if nesterov: 80 | d_p = grad + momentum * buf 81 | else: 82 | d_p = buf 83 | 84 | # Projection 85 | wd_ratio = 1 86 | if len(p.shape) > 1: 87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 88 | 89 | # Weight decay 90 | if weight_decay != 0: 91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 92 | 93 | # Step 94 | p.data.add_(-group['lr'], d_p) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .plateau_lr import PlateauLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .tanh_lr import TanhLRScheduler 5 | from .scheduler_factory import create_scheduler 6 | -------------------------------------------------------------------------------- /scheduler/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | from pdb import set_trace as breakpoint 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class CosineLRScheduler(Scheduler): 20 | """ 21 | Cosine decay with restarts. 22 | This is described in the paper https://arxiv.org/abs/1608.03983. 23 | 24 | Inspiration from 25 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 26 | """ 27 | 28 | def __init__(self, 29 | optimizer: torch.optim.Optimizer, 30 | t_initial: int, 31 | t_mul: float = 1., 32 | lr_min: float = 0., 33 | decay_rate: float = 1., 34 | warmup_t=0, 35 | warmup_lr_init=0, 36 | warmup_prefix=True, 37 | cycle_limit=0, 38 | t_in_epochs=True, 39 | noise_range_t=None, 40 | noise_pct=0.67, 41 | noise_std=1.0, 42 | noise_seed=42, 43 | initialize=True) -> None: 44 | super().__init__( 45 | optimizer, param_group_field="lr", 46 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 47 | initialize=initialize) 48 | 49 | assert t_initial > 0 50 | assert lr_min >= 0 51 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 52 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 53 | "rate since t_initial = t_mul = eta_mul = 1.") 54 | self.t_initial = t_initial 55 | self.t_mul = t_mul 56 | self.lr_min = lr_min 57 | self.decay_rate = decay_rate 58 | self.cycle_limit = cycle_limit 59 | self.warmup_t = warmup_t 60 | self.warmup_lr_init = warmup_lr_init 61 | self.warmup_prefix = warmup_prefix 62 | self.t_in_epochs = t_in_epochs 63 | if self.warmup_t: 64 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 65 | super().update_groups(self.warmup_lr_init) 66 | else: 67 | self.warmup_steps = [1 for _ in self.base_values] 68 | 69 | def _get_lr(self, t): 70 | if t < self.warmup_t: 71 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 72 | else: 73 | if self.warmup_prefix: 74 | t = t - self.warmup_t 75 | 76 | if self.t_mul != 1: 77 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 78 | t_i = self.t_mul ** i * self.t_initial 79 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 80 | else: 81 | i = t // self.t_initial 82 | t_i = self.t_initial 83 | t_curr = t - (self.t_initial * i) 84 | 85 | gamma = self.decay_rate ** i 86 | lr_min = self.lr_min * gamma 87 | lr_max_values = [v * gamma for v in self.base_values] 88 | 89 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 90 | lrs = [ 91 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 92 | ] 93 | else: 94 | lrs = [self.lr_min for _ in self.base_values] 95 | 96 | return lrs 97 | 98 | def get_epoch_values(self, epoch: int): 99 | if self.t_in_epochs: 100 | return self._get_lr(epoch) 101 | else: 102 | return None 103 | 104 | def get_update_values(self, num_updates: int): 105 | if not self.t_in_epochs: 106 | return self._get_lr(num_updates) 107 | else: 108 | return None 109 | 110 | def get_cycle_length(self, cycles=0): 111 | if not cycles: 112 | cycles = self.cycle_limit 113 | cycles = max(1, cycles) 114 | if self.t_mul == 1.0: 115 | return self.t_initial * cycles 116 | else: 117 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 118 | -------------------------------------------------------------------------------- /scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | """ Plateau Scheduler 2 | 3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | 9 | from .scheduler import Scheduler 10 | 11 | 12 | class PlateauLRScheduler(Scheduler): 13 | """Decay the LR by a factor every time the validation loss plateaus.""" 14 | 15 | def __init__(self, 16 | optimizer, 17 | decay_rate=0.1, 18 | patience_t=10, 19 | verbose=True, 20 | threshold=1e-4, 21 | cooldown_t=0, 22 | warmup_t=0, 23 | warmup_lr_init=0, 24 | lr_min=0, 25 | mode='max', 26 | noise_range_t=None, 27 | noise_type='normal', 28 | noise_pct=0.67, 29 | noise_std=1.0, 30 | noise_seed=None, 31 | initialize=True, 32 | ): 33 | super().__init__(optimizer, 'lr', initialize=initialize) 34 | 35 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 36 | self.optimizer, 37 | patience=patience_t, 38 | factor=decay_rate, 39 | verbose=verbose, 40 | threshold=threshold, 41 | cooldown=cooldown_t, 42 | mode=mode, 43 | min_lr=lr_min 44 | ) 45 | 46 | self.noise_range = noise_range_t 47 | self.noise_pct = noise_pct 48 | self.noise_type = noise_type 49 | self.noise_std = noise_std 50 | self.noise_seed = noise_seed if noise_seed is not None else 42 51 | self.warmup_t = warmup_t 52 | self.warmup_lr_init = warmup_lr_init 53 | if self.warmup_t: 54 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 55 | super().update_groups(self.warmup_lr_init) 56 | else: 57 | self.warmup_steps = [1 for _ in self.base_values] 58 | self.restore_lr = None 59 | 60 | def state_dict(self): 61 | return { 62 | 'best': self.lr_scheduler.best, 63 | 'last_epoch': self.lr_scheduler.last_epoch, 64 | } 65 | 66 | def load_state_dict(self, state_dict): 67 | self.lr_scheduler.best = state_dict['best'] 68 | if 'last_epoch' in state_dict: 69 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 70 | 71 | # override the base class step fn completely 72 | def step(self, epoch, metric=None): 73 | if epoch <= self.warmup_t: 74 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 75 | super().update_groups(lrs) 76 | else: 77 | if self.restore_lr is not None: 78 | # restore actual LR from before our last noise perturbation before stepping base 79 | for i, param_group in enumerate(self.optimizer.param_groups): 80 | param_group['lr'] = self.restore_lr[i] 81 | self.restore_lr = None 82 | 83 | self.lr_scheduler.step(metric, epoch) # step the base scheduler 84 | 85 | if self.noise_range is not None: 86 | if isinstance(self.noise_range, (list, tuple)): 87 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] 88 | else: 89 | apply_noise = epoch >= self.noise_range 90 | if apply_noise: 91 | self._apply_noise(epoch) 92 | 93 | def _apply_noise(self, epoch): 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + epoch) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | 105 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 106 | # stepping of base scheduler 107 | restore_lr = [] 108 | for i, param_group in enumerate(self.optimizer.param_groups): 109 | old_lr = float(param_group['lr']) 110 | restore_lr.append(old_lr) 111 | new_lr = old_lr + old_lr * noise 112 | param_group['lr'] = new_lr 113 | self.restore_lr = restore_lr 114 | -------------------------------------------------------------------------------- /scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /scheduler/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | from .tanh_lr import TanhLRScheduler 6 | from .step_lr import StepLRScheduler 7 | from .plateau_lr import PlateauLRScheduler 8 | 9 | 10 | def create_scheduler(args, optimizer): 11 | num_epochs = args.epochs 12 | if getattr(args, 'lr_noise', None) is not None: 13 | lr_noise = getattr(args, 'lr_noise') 14 | if isinstance(lr_noise, (list, tuple)): 15 | noise_range = [n * num_epochs for n in lr_noise] 16 | if len(noise_range) == 1: 17 | noise_range = noise_range[0] 18 | else: 19 | noise_range = lr_noise * num_epochs 20 | else: 21 | noise_range = None 22 | 23 | lr_scheduler = None 24 | if args.sched == 'cosine': 25 | lr_scheduler = CosineLRScheduler( 26 | optimizer, 27 | t_initial=num_epochs, 28 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 29 | lr_min=args.min_lr, 30 | decay_rate=args.decay_rate, 31 | warmup_lr_init=args.warmup_lr, 32 | warmup_t=args.warmup_epochs, 33 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 34 | t_in_epochs=True, 35 | noise_range_t=noise_range, 36 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 37 | noise_std=getattr(args, 'lr_noise_std', 1.), 38 | noise_seed=getattr(args, 'seed', 42), 39 | ) 40 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 41 | elif args.sched == 'tanh': 42 | lr_scheduler = TanhLRScheduler( 43 | optimizer, 44 | t_initial=num_epochs, 45 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 46 | lr_min=args.min_lr, 47 | warmup_lr_init=args.warmup_lr, 48 | warmup_t=args.warmup_epochs, 49 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 50 | t_in_epochs=True, 51 | noise_range_t=noise_range, 52 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 53 | noise_std=getattr(args, 'lr_noise_std', 1.), 54 | noise_seed=getattr(args, 'seed', 42), 55 | ) 56 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 57 | elif args.sched == 'step': 58 | lr_scheduler = StepLRScheduler( 59 | optimizer, 60 | decay_t=args.decay_epochs, 61 | decay_rate=args.decay_rate, 62 | warmup_lr_init=args.warmup_lr, 63 | warmup_t=args.warmup_epochs, 64 | noise_range_t=noise_range, 65 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 66 | noise_std=getattr(args, 'lr_noise_std', 1.), 67 | noise_seed=getattr(args, 'seed', 42), 68 | ) 69 | elif args.sched == 'plateau': 70 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' 71 | lr_scheduler = PlateauLRScheduler( 72 | optimizer, 73 | decay_rate=args.decay_rate, 74 | patience_t=args.patience_epochs, 75 | lr_min=args.min_lr, 76 | mode=mode, 77 | warmup_lr_init=args.warmup_lr, 78 | warmup_t=args.warmup_epochs, 79 | cooldown_t=0, 80 | noise_range_t=noise_range, 81 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 82 | noise_std=getattr(args, 'lr_noise_std', 1.), 83 | noise_seed=getattr(args, 'seed', 42), 84 | ) 85 | return lr_scheduler, num_epochs 86 | -------------------------------------------------------------------------------- /scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | """ TanH Scheduler 2 | 3 | TanH schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class TanhLRScheduler(Scheduler): 19 | """ 20 | Hyberbolic-Tangent decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1806.01593 22 | """ 23 | 24 | def __init__(self, 25 | optimizer: torch.optim.Optimizer, 26 | t_initial: int, 27 | lb: float = -6., 28 | ub: float = 4., 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | assert lb < ub 50 | assert cycle_limit >= 0 51 | assert warmup_t >= 0 52 | assert warmup_lr_init >= 0 53 | self.lb = lb 54 | self.ub = ub 55 | self.t_initial = t_initial 56 | self.t_mul = t_mul 57 | self.lr_min = lr_min 58 | self.decay_rate = decay_rate 59 | self.cycle_limit = cycle_limit 60 | self.warmup_t = warmup_t 61 | self.warmup_lr_init = warmup_lr_init 62 | self.warmup_prefix = warmup_prefix 63 | self.t_in_epochs = t_in_epochs 64 | if self.warmup_t: 65 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 66 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 67 | super().update_groups(self.warmup_lr_init) 68 | else: 69 | self.warmup_steps = [1 for _ in self.base_values] 70 | 71 | def _get_lr(self, t): 72 | if t < self.warmup_t: 73 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 74 | else: 75 | if self.warmup_prefix: 76 | t = t - self.warmup_t 77 | 78 | if self.t_mul != 1: 79 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 80 | t_i = self.t_mul ** i * self.t_initial 81 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 82 | else: 83 | i = t // self.t_initial 84 | t_i = self.t_initial 85 | t_curr = t - (self.t_initial * i) 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | gamma = self.decay_rate ** i 89 | lr_min = self.lr_min * gamma 90 | lr_max_values = [v * gamma for v in self.base_values] 91 | 92 | tr = t_curr / t_i 93 | lrs = [ 94 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 95 | for lr_max in lr_max_values 96 | ] 97 | else: 98 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] 99 | return lrs 100 | 101 | def get_epoch_values(self, epoch: int): 102 | if self.t_in_epochs: 103 | return self._get_lr(epoch) 104 | else: 105 | return None 106 | 107 | def get_update_values(self, num_updates: int): 108 | if not self.t_in_epochs: 109 | return self._get_lr(num_updates) 110 | else: 111 | return None 112 | 113 | def get_cycle_length(self, cycles=0): 114 | if not cycles: 115 | cycles = self.cycle_limit 116 | cycles = max(1, cycles) 117 | if self.t_mul == 1.0: 118 | return self.t_initial * cycles 119 | else: 120 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 121 | -------------------------------------------------------------------------------- /shell/cuhk-eval.sh: -------------------------------------------------------------------------------- 1 | #export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | 4 | python -m torch.distributed.launch --nproc_per_node=4 --use_env --rdzv_endpoint=127.0.0.1:29501 \ 5 | Retrieval.py \ 6 | --config configs/PS_cuhk_pedes.yaml \ 7 | --output_dir output/cuhk-pedes/evaluation \ 8 | --checkpoint ../rasa_checkpoint/rasa_cuhk_checkpoint.pth \ 9 | --eval_mAP \ 10 | --evaluate 11 | -------------------------------------------------------------------------------- /shell/cuhk-train.sh: -------------------------------------------------------------------------------- 1 | #export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | 4 | python -m torch.distributed.run --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29501 \ 5 | Retrieval.py \ 6 | --config configs/PS_cuhk_pedes.yaml \ 7 | --output_dir output/cuhk-pedes/train \ 8 | --checkpoint /data/ALBEF/ALBEF.pth \ 9 | --eval_mAP 10 | -------------------------------------------------------------------------------- /shell/data_process.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python data_process.py --dataset_name "CUHK-PEDES" --dataset_root_dir "../dataset/CUHK-PEDES" 4 | #python data_process.py --dataset_name "ICFG-PEDES" --dataset_root_dir "../dataset/ICFG-PEDES" 5 | #python data_process.py --dataset_name "RSTPReid" --dataset_root_dir "../dataset/RSTPReid" 6 | -------------------------------------------------------------------------------- /shell/icfg-eval.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | #export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | 4 | python -m torch.distributed.launch --nproc_per_node=4 --use_env --rdzv_endpoint=127.0.0.1:29501 \ 5 | Retrieval.py \ 6 | --config configs/PS_icfg_pedes.yaml \ 7 | --output_dir output/icfg-pedes/evaluation \ 8 | --checkpoint ../rasa_checkpoint/rasa_icfg_checkpoint.pth \ 9 | --eval_mAP \ 10 | --evaluate 11 | -------------------------------------------------------------------------------- /shell/icfg-train.sh: -------------------------------------------------------------------------------- 1 | #export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | 4 | python -m torch.distributed.run --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29502 \ 5 | Retrieval.py \ 6 | --config configs/PS_icfg_pedes.yaml \ 7 | --output_dir output/icfg-pedes/train \ 8 | --checkpoint /data/ALBEF/ALBEF.pth \ 9 | --eval_mAP 10 | -------------------------------------------------------------------------------- /shell/rstp-eval.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | #export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | 4 | python -m torch.distributed.run --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29501 \ 5 | Retrieval.py \ 6 | --config configs/PS_rstp_reid.yaml \ 7 | --output_dir output/rstp-reid/evaluation/ \ 8 | --checkpoint ../rasa_checkpoint/rasa_rstp_checkpoint.pth \ 9 | --eval_mAP \ 10 | --evaluate 11 | -------------------------------------------------------------------------------- /shell/rstp-train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | #export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | 4 | python -m torch.distributed.run --nproc_per_node=4 --rdzv_endpoint=127.0.0.1:29502 \ 5 | Retrieval.py \ 6 | --config configs/PS_rstp_reid.yaml \ 7 | --output_dir output/rstp-reid/train \ 8 | --checkpoint /data1/byyoung/data/ALBEF/ALBEF.pth \ 9 | --eval_mAP 10 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | from collections import defaultdict, deque 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | class SmoothedValue(object): 10 | """Track a series of values and provide access to smoothed values over a 11 | window or the global series average. 12 | """ 13 | def __init__(self, window_size=20, fmt=None): 14 | if fmt is None: 15 | fmt = "{median:.4f} ({global_avg:.4f})" 16 | self.deque = deque(maxlen=window_size) 17 | self.total = 0.0 18 | self.count = 0 19 | self.fmt = fmt 20 | 21 | def update(self, value, n=1): 22 | self.deque.append(value) 23 | self.count += n 24 | self.total += value * n 25 | 26 | def synchronize_between_processes(self): 27 | """ 28 | Warning: does not synchronize the deque! 29 | """ 30 | if not is_dist_avail_and_initialized(): 31 | return 32 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 33 | dist.barrier() 34 | dist.all_reduce(t) 35 | t = t.tolist() 36 | self.count = int(t[0]) 37 | self.total = t[1] 38 | 39 | @property 40 | def median(self): 41 | d = torch.tensor(list(self.deque)) 42 | return d.median().item() 43 | 44 | @property 45 | def avg(self): 46 | d = torch.tensor(list(self.deque), dtype=torch.float32) 47 | return d.mean().item() 48 | 49 | @property 50 | def global_avg(self): 51 | return self.total / self.count 52 | 53 | @property 54 | def max(self): 55 | return max(self.deque) 56 | 57 | @property 58 | def value(self): 59 | return self.deque[-1] 60 | 61 | def __str__(self): 62 | return self.fmt.format( 63 | median=self.median, 64 | avg=self.avg, 65 | global_avg=self.global_avg, 66 | max=self.max, 67 | value=self.value) 68 | 69 | class MetricLogger(object): 70 | def __init__(self, delimiter="\t"): 71 | self.meters = defaultdict(SmoothedValue) 72 | self.delimiter = delimiter 73 | 74 | def update(self, **kwargs): 75 | for k, v in kwargs.items(): 76 | if isinstance(v, torch.Tensor): 77 | v = v.item() 78 | assert isinstance(v, (float, int)) 79 | self.meters[k].update(v) 80 | 81 | def __getattr__(self, attr): 82 | if attr in self.meters: 83 | return self.meters[attr] 84 | if attr in self.__dict__: 85 | return self.__dict__[attr] 86 | raise AttributeError("'{}' object has no attribute '{}'".format( 87 | type(self).__name__, attr)) 88 | 89 | def __str__(self): 90 | loss_str = [] 91 | for name, meter in self.meters.items(): 92 | loss_str.append( 93 | "{}: {}".format(name, str(meter)) 94 | ) 95 | return self.delimiter.join(loss_str) 96 | 97 | def global_avg(self): 98 | loss_str = [] 99 | for name, meter in self.meters.items(): 100 | loss_str.append( 101 | "{}: {:.4f}".format(name, meter.global_avg) 102 | ) 103 | return self.delimiter.join(loss_str) 104 | 105 | def synchronize_between_processes(self): 106 | for meter in self.meters.values(): 107 | meter.synchronize_between_processes() 108 | 109 | def add_meter(self, name, meter): 110 | self.meters[name] = meter 111 | 112 | def log_every(self, iterable, print_freq, header=None): 113 | i = 0 114 | if not header: 115 | header = '' 116 | start_time = time.time() 117 | end = time.time() 118 | iter_time = SmoothedValue(fmt='{avg:.4f}') 119 | data_time = SmoothedValue(fmt='{avg:.4f}') 120 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 121 | log_msg = [ 122 | header, 123 | '[{0' + space_fmt + '}/{1}]', 124 | 'eta: {eta}', 125 | '{meters}', 126 | 'time: {time}', 127 | 'data: {data}' 128 | ] 129 | if torch.cuda.is_available(): 130 | log_msg.append('max mem: {memory:.0f}') 131 | log_msg = self.delimiter.join(log_msg) 132 | MB = 1024.0 * 1024.0 133 | for obj in iterable: 134 | data_time.update(time.time() - end) 135 | yield obj 136 | iter_time.update(time.time() - end) 137 | if i % print_freq == 0 or i == len(iterable) - 1: 138 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 139 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 140 | if torch.cuda.is_available(): 141 | print(log_msg.format( 142 | i, len(iterable), eta=eta_string, 143 | meters=str(self), 144 | time=str(iter_time), data=str(data_time), 145 | memory=torch.cuda.max_memory_allocated() / MB)) 146 | else: 147 | print(log_msg.format( 148 | i, len(iterable), eta=eta_string, 149 | meters=str(self), 150 | time=str(iter_time), data=str(data_time))) 151 | i += 1 152 | end = time.time() 153 | total_time = time.time() - start_time 154 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 155 | print('{} Total time: {} ({:.4f} s / it)'.format( 156 | header, total_time_str, total_time / len(iterable))) 157 | 158 | class AttrDict(dict): 159 | def __init__(self, *args, **kwargs): 160 | super(AttrDict, self).__init__(*args, **kwargs) 161 | self.__dict__ = self 162 | 163 | def setup_for_distributed(is_master): 164 | """ 165 | This function disables printing when not in master process 166 | """ 167 | import builtins as __builtin__ 168 | builtin_print = __builtin__.print 169 | 170 | def print(*args, **kwargs): 171 | force = kwargs.pop('force', False) 172 | if is_master or force: 173 | builtin_print(*args, **kwargs) 174 | __builtin__.print = print 175 | 176 | def is_dist_avail_and_initialized(): 177 | if not dist.is_available(): 178 | return False 179 | if not dist.is_initialized(): 180 | return False 181 | return True 182 | 183 | def get_world_size(): 184 | if not is_dist_avail_and_initialized(): 185 | return 1 186 | return dist.get_world_size() 187 | 188 | def get_rank(): 189 | if not is_dist_avail_and_initialized(): 190 | return 0 191 | return dist.get_rank() 192 | 193 | def is_main_process(): 194 | return get_rank() == 0 195 | 196 | def init_distributed_mode(args): 197 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 198 | args.rank = int(os.environ["RANK"]) 199 | args.world_size = int(os.environ['WORLD_SIZE']) 200 | args.gpu = int(os.environ['LOCAL_RANK']) 201 | elif 'SLURM_PROCID' in os.environ: 202 | args.rank = int(os.environ['SLURM_PROCID']) 203 | args.gpu = args.rank % torch.cuda.device_count() 204 | else: 205 | print('Not using distributed mode') 206 | args.distributed = False 207 | return 208 | args.distributed = True 209 | torch.cuda.set_device(args.gpu) 210 | args.dist_backend = 'nccl' 211 | print('| distributed init (rank {}): {}'.format( 212 | args.rank, args.dist_url), flush=True) 213 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 214 | world_size=args.world_size, rank=args.rank) 215 | torch.distributed.barrier() 216 | setup_for_distributed(args.rank == 0) 217 | --------------------------------------------------------------------------------