├── LICENSE ├── README.md ├── Retrieval.py ├── assets ├── badsamples.jpg ├── chart1.jpg ├── cuhk_freq50.jpg ├── examples.jpg ├── framework.jpg ├── mals_fre50.jpg ├── mals_rst_30.jpg ├── readme.txt └── result.jpg ├── configs ├── Retrieval_cuhk.yaml ├── Retrieval_gene.yaml ├── Retrieval_icfg.yaml ├── Retrieval_pa100k.yaml ├── Retrieval_rstp.yaml ├── config_bert.json └── config_swinB_384.json ├── dataset ├── __init__.py ├── eda.py ├── randaugment.py ├── random_erasing.py ├── re_dataset.py └── utils.py ├── models ├── __init__.py ├── aptm.py ├── bert.py ├── model_retrieval.py ├── swin_transformer.py └── tokenization_bert.py ├── optim.py ├── reTools.py ├── requirements.txt ├── run.py ├── scheduler.py ├── train_pa100ks.py ├── train_tools.py ├── trains.py └── utils └── __init__.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shuyu Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # APTM 2 | 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/towards-unified-text-based-person-retrieval-a/nlp-based-person-retrival-on-cuhk-pedes)](https://paperswithcode.com/sota/nlp-based-person-retrival-on-cuhk-pedes?p=towards-unified-text-based-person-retrieval-a) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/towards-unified-text-based-person-retrieval-a/text-based-person-retrieval-on-icfg-pedes)](https://paperswithcode.com/sota/text-based-person-retrieval-on-icfg-pedes?p=towards-unified-text-based-person-retrieval-a) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/towards-unified-text-based-person-retrieval-a/text-based-person-retrieval-on-rstpreid-1)](https://paperswithcode.com/sota/text-based-person-retrieval-on-rstpreid-1?p=towards-unified-text-based-person-retrieval-a) 7 | 8 | **APTM (ACM MM 2023)** is a new joint **A**ttribute **P**rompt Learning and **T**ext **M**atching Learning framework, considering the shared knowledge between attribute and text. As the name implies, APTM contains an attribute prompt learning stream and a text matching learning stream. 9 | 10 | We also present a large Multi-Attribute and Language Search dataset for text-based person retrieval, called **MALS**, and explore the feasibility of performing pre-training on both attribute recognition and image-text matching tasks in one stone. In particular, MALS contains 1, 510, 330 image-text pairs, which is about 37.5× larger than prevailing CUHK-PEDES, and all images are annotated with 27 attributes. 11 | 12 | Extensive experiments validate the effectiveness of the pre-training on MALS, achieving the state-of-the-art retrieval performance via APTM on three challenging real-world benchmarks. In particular, APTM achieves a consistent improvement of +6.60%, +7.39%, and +15.90% Recall@1 accuracy on CUHK-PEDES, ICFG-PEDES, and RSTPReid datasets by a clear margin, respectively. More details can be found at our paper: [Towards Unified Text-based Person Retrieval: A Large-scale Multi-Attribute and Language Search Benchmark](https://arxiv.org/abs/2306.02898) 13 |
14 | 15 | ## News 16 | * The **OneDrive** link of **MALS** dataset are released! 17 | * The **APTM** and the **MALS** dataset are released. Welcome to communicate! 18 | 19 | ## MALS 20 | MALS leverages generative models to generate a large-scale dataset including 1.5𝑀 image-text pairs. Each image-text pair in MALS is annotated with one corresponding description and several appropriate attribute labels, indicating that MALS is not only effective for text-image matching and attribute prompt learning, but also explores the feasibility of pre-training for both attribute recognition and image-text matching in one stone. **The dataset is released at [Baidu Yun](https://pan.baidu.com/s/1HMvNIIFlquI2w0R6f0G7Dg) [4kq0] and [OneDrive](https://1drv.ms/f/s!Ak2z-VJ5LcCvgdZGSTJbaHOMMFZi9A?e=gCBnv0) [mals].** 21 | 22 | **Note that MALS can only be used for research, any commercial usage is forbidden.** 23 | 24 | This is the comparison between MALS and other text based person retrieval datasets. 25 |
26 | These are examples of our MALS dataset and CUHK-PEDES. 27 |
28 | Annotation format: 29 | 30 | ``` 31 | [{"image": "gene_crop/c_g_a_0/0.jpg", 32 | "caption": "a young boy wearing a black hoodie leaning against a wall with his hands on his hips and his hands on his hips wearing jeans and a baseball cap", 33 | "image_id": "c_g_a_0_0", 34 | "label": [1, 0, ..., 1, 1]}, 35 | ... 36 | {"image": "gene_crop/c_g_a_0/20217.jpg", 37 | "caption": "a woman in a white top and black pants posing for a picture in front of a brick wall with a pink carpet in front of her", 38 | "image_id": "c_g_a_0_20217", 39 | "label": [0, 1, ..., -1, -1]}] 40 | ``` 41 | 42 | ## Models and Weights 43 | 44 | The checkpoints have been released at [Baidu Yun](https://pan.baidu.com/s/1oAkenOKaVEYWpNh2hznkGA) [b2l8] and [Google Drive](https://drive.google.com/drive/folders/1N1Lumvb4epP0awHLcJ3RzQmv5zwrAFBh?usp=sharing) 45 | 46 | 47 | ## Usage 48 | 49 | ### Install Requirements 50 | 51 | we use 4 A100 80G GPU for training and evaluation. 52 | 53 | Create conda environment. 54 | 55 | ``` 56 | conda create -n aptm python=3.8 57 | conda activate aptm 58 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 59 | pip3 install -r requirements.txt 60 | ``` 61 | 62 | ### Datasets Prepare 63 | 64 | Download the CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description) , the PA-100K dataset from [here](https://github.com/xh-liu/HydraPlus-Net), the RSTPReid dataset from [here](https://github.com/NjtechCVLab/RSTPReid-Dataset), and ICFG-PEDES dataset from [here](https://github.com/zifyloo/SSAN). Download the processed json files of the aboves four datasets from [here](https://pan.baidu.com/s/1oAkenOKaVEYWpNh2hznkGA) [b2l8] 65 | 66 | Download pre-trained models for parameter initialization: 67 | 68 | image encoder: [swin-transformer-base](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) 69 | 70 | text encoder: [bert-base](https://huggingface.co/bert-base-uncased/tree/main) 71 | 72 | Organize `data` folder as follows: 73 | 74 | ``` 75 | |-- data/ 76 | | |-- bert-base-uncased/ 77 | | |-- finetune/ 78 | | |-- gene_attrs/ 79 | | |-- g_4x_attrs.json 80 | | |-- g_c_g_a_0_attrs.json 81 | | |-- ... 82 | | |-- cuhk_train.json 83 | | |-- ... 84 | | |-- icfg_train.json 85 | | |-- ... 86 | | |-- rstp_train.json 87 | | |-- ... 88 | | |-- PA100K_train.json 89 | | |-- ... 90 | | |-- swin_base_patch4_window7_224_22k.pth 91 | ``` 92 | 93 | And organize those datasets in `images` folder as follows: 94 | 95 | ``` 96 | |-- images/ 97 | | |-- / 98 | | |-- imgs/ 99 | | |-- cam_a/ 100 | | |-- cam_b/ 101 | | |-- ... 102 | | |-- train_query/ 103 | | |-- gene_crop/ 104 | | |-- 4x/ 105 | | |-- c_g_a/ 106 | | |-- ... 107 | | |-- i_g_a_43/ 108 | | 109 | | |-- / 110 | | |-- test/ 111 | | |-- train/ 112 | | 113 | | |-- / 114 | | |-- release_data/ 115 | | 116 | | |-- / 117 | ``` 118 | 119 | ### Pretraining 120 | We pretrain our APTM using MALS as follows: 121 | 122 | ``` 123 | python3 run.py --task "itr_gene" --dist "f4" --output_dir "output/pretrained" 124 | ``` 125 | 126 | ### Fine-tuning 127 | We fine-tune our APTM using existing text-based Person Reid datasets. Performance can be improved by replacing the backbone with our pre-trained model. Taking CUHK-PEDES as example: 128 | 129 | ``` 130 | python3 run.py --task "itr_cuhk" --dist "f4" --output_dir "output/ft_cuhk" --checkpoint "output/pretrained/checkpoint_31.pth" 131 | ``` 132 | 133 | ### Evaluation 134 | 135 | ``` 136 | python3 run.py --task "itr_cuhk" --evaluate --dist "f4" --output_dir "output/ft_cuhk/test" --checkpoint "output/ft_cuhk/checkpoint_best.pth" 137 | ``` 138 | 139 | ## Reference 140 | If you use APTM in your research, please cite it by the following BibTeX entry: 141 | 142 | ```bibtex 143 | @inproceedings{yang2023towards, 144 | title={Towards Unified Text-based Person Retrieval: A Large-scale Multi-Attribute and Language Search Benchmark}, 145 | author={Yang, Shuyu and Zhou, Yinan and Wang, Yaxiong and Wu, Yujiao and Zhu, Li and Zheng, Zhedong}, 146 | booktitle = {Proceedings of the 2023 {ACM} on Multimedia Conference}, 147 | year={2023} 148 | } 149 | 150 | ``` 151 | -------------------------------------------------------------------------------- /Retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | 5 | import ruamel.yaml as yaml 6 | import numpy as np 7 | import random 8 | import time 9 | import datetime 10 | import json 11 | from pathlib import Path 12 | from prettytable import PrettyTable 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | 18 | from models.model_retrieval import APTM_Retrieval 19 | from models.tokenization_bert import BertTokenizer 20 | 21 | import utils 22 | from dataset import create_dataset, create_sampler, create_loader 23 | from dataset.re_dataset import TextMaskingGenerator 24 | from scheduler import create_scheduler 25 | from optim import create_optimizer 26 | 27 | from trains import train, train_attr 28 | from train_pa100ks import train_pa100k, train_pa100k_only_img_classifier 29 | 30 | from reTools import evaluation, mAP 31 | from reTools import evaluation_attr, itm_eval_attr 32 | from reTools import evaluation_attr_only_img_classifier, itm_eval_attr_only_img_classifier 33 | 34 | 35 | def main(args, config): 36 | utils.init_distributed_mode(args) 37 | device = torch.device(args.device) 38 | world_size = utils.get_world_size() 39 | 40 | if args.bs > 0: 41 | config['batch_size_train'] = args.bs // world_size 42 | if args.epo > 0: 43 | config['schedular']['epochs'] = args.epo 44 | 45 | seed = args.seed + utils.get_rank() 46 | torch.manual_seed(seed) 47 | np.random.seed(seed) 48 | random.seed(seed) 49 | cudnn.benchmark = True 50 | 51 | print("Creating model", flush=True) 52 | tokenizer = BertTokenizer.from_pretrained(config['text_encoder']) 53 | model = APTM_Retrieval(config=config) 54 | if config['load_pretrained']: 55 | model.load_pretrained(args.checkpoint, config, is_eval=args.evaluate) 56 | model = model.to(device) 57 | 58 | print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad), flush=True) 59 | 60 | model_without_ddp = model 61 | if args.distributed: 62 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 63 | model_without_ddp = model.module 64 | 65 | print("Creating retrieval dataset", flush=True) 66 | if args.task == "itr_icfg": 67 | train_dataset, test_dataset = create_dataset('re_icfg', config, args.evaluate) 68 | elif args.task == "itr_rstp": 69 | train_dataset, val_dataset, test_dataset = create_dataset('re_rstp', config, args.evaluate) 70 | elif args.task == "itr_cuhk": 71 | train_dataset, val_dataset, test_dataset = create_dataset('re_cuhk', config, args.evaluate) 72 | elif args.task == "itr_pa100k": 73 | train_dataset, val_dataset, test_dataset = create_dataset('re_pa100k', config, args.evaluate) 74 | else: 75 | train_dataset, val_dataset, test_dataset = create_dataset('re_gene', config, args.evaluate) 76 | 77 | start_time = time.time() 78 | print("### output_dir, ", args.output_dir, flush=True) 79 | 80 | if args.evaluate: 81 | print("Start evaluating", flush=True) 82 | # if args.task not in ["itr_icfg", "itr_pa100k"]: 83 | # print("val_dataset", flush=True) 84 | # val_loader = create_loader([val_dataset], [None], 85 | # batch_size=[config['batch_size_test']], 86 | # num_workers=[4], 87 | # is_trains=[False], 88 | # collate_fns=[None])[0] 89 | # score_val_t2i = evaluation(model_without_ddp, val_loader, 90 | # tokenizer, device, config, args) 91 | 92 | print("test_dataset", flush=True) 93 | test_loader = create_loader([test_dataset], [None], 94 | batch_size=[config['batch_size_test']], 95 | num_workers=[4], 96 | is_trains=[False], 97 | collate_fns=[None])[0] 98 | if args.task == "itr_pa100k": 99 | if model_without_ddp.pa100k_only_img_classifier: 100 | score_test_i2t_attr = evaluation_attr_only_img_classifier(model_without_ddp, test_loader, 101 | tokenizer, device, config, args) 102 | else: 103 | score_test_i2t_attr = evaluation_attr(model_without_ddp, test_loader, 104 | tokenizer, device, config, args) 105 | else: 106 | score_test_t2i = evaluation(model_without_ddp, test_loader, 107 | tokenizer, device, config, args) 108 | 109 | if utils.is_main_process(): 110 | # if args.task not in ["itr_icfg", "itr_pa100k"]: 111 | # print('val_result:', flush=True) 112 | # mAP(score_val_t2i, val_loader.dataset.g_pids, val_loader.dataset.q_pids) 113 | if args.task == "itr_pa100k": 114 | if model_without_ddp.pa100k_only_img_classifier: 115 | test_result_attr = itm_eval_attr_only_img_classifier(score_test_i2t_attr, test_loader.dataset) 116 | else: 117 | test_result_attr = itm_eval_attr(score_test_i2t_attr, test_loader.dataset) 118 | print('test_result_attr:', flush=True) 119 | print(test_result_attr, flush=True) 120 | else: 121 | print('test_result:', flush=True) 122 | mAP(score_test_t2i, test_loader.dataset.g_pids, test_loader.dataset.q_pids) 123 | 124 | dist.barrier() 125 | 126 | else: 127 | print("Start training", flush=True) 128 | train_dataset_size = len(train_dataset) 129 | if utils.is_main_process(): 130 | print(f"### data {train_dataset_size}, batch size, {config['batch_size_train']} x {world_size}") 131 | if args.task == "itr_pa100k": 132 | table = PrettyTable(["epoch", "label_mA", "ins_acc", "ins_prec", "ins_rec", "ins_f1"]) 133 | else: 134 | table = PrettyTable(["epoch", "R1", "R5", "R10", "mAP", "mINP"]) 135 | table.custom_format["R1"] = lambda f, v: f"{v:.3f}" 136 | table.custom_format["R5"] = lambda f, v: f"{v:.3f}" 137 | table.custom_format["R10"] = lambda f, v: f"{v:.3f}" 138 | table.custom_format["mAP"] = lambda f, v: f"{v:.3f}" 139 | table.custom_format["mINP"] = lambda f, v: f"{v:.3f}" 140 | if args.distributed: 141 | num_tasks = utils.get_world_size() 142 | global_rank = utils.get_rank() 143 | if args.task == "itr_icfg": 144 | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None] 145 | else: 146 | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None] 147 | else: 148 | if args.task == "itr_icfg": 149 | samplers = [None, None] 150 | else: 151 | samplers = [None, None, None] 152 | 153 | if args.task == "itr_icfg": 154 | train_loader, test_loader = create_loader([train_dataset, test_dataset], samplers, 155 | batch_size=[config['batch_size_train']] + [ 156 | config['batch_size_test']], 157 | num_workers=[4, 4], is_trains=[True, False], 158 | collate_fns=[None, None]) 159 | else: 160 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset], samplers, 161 | batch_size=[config['batch_size_train']] + [ 162 | config['batch_size_test']] * 2, 163 | num_workers=[4, 4, 4], is_trains=[True, False, False], 164 | collate_fns=[None, None, None]) 165 | 166 | arg_opt = utils.AttrDict(config['optimizer']) 167 | optimizer = create_optimizer(arg_opt, model_without_ddp) 168 | arg_sche = utils.AttrDict(config['schedular']) 169 | arg_sche['step_per_epoch'] = math.ceil(train_dataset_size / (config['batch_size_train'] * world_size)) 170 | lr_scheduler = create_scheduler(arg_sche, optimizer) 171 | 172 | max_epoch = config['schedular']['epochs'] 173 | best = 0 174 | best_epoch = 0 175 | 176 | if config['mlm']: 177 | mask_generator = TextMaskingGenerator(tokenizer, config['mask_prob'], config['max_masks'], 178 | config['skipgram_prb'], config['skipgram_size'], 179 | config['mask_whole_word']) 180 | else: 181 | mask_generator = None 182 | 183 | for epoch in range(0, max_epoch): 184 | if args.distributed: 185 | train_loader.sampler.set_epoch(epoch) 186 | 187 | if args.task == "itr_pa100k": 188 | if model_without_ddp.pa100k_only_img_classifier: 189 | train_stats = train_pa100k_only_img_classifier(model, train_loader, optimizer, tokenizer, epoch, 190 | device, lr_scheduler, config, mask_generator) 191 | else: 192 | train_stats = train_pa100k(model, train_loader, optimizer, tokenizer, epoch, 193 | device, lr_scheduler, config, mask_generator) 194 | else: 195 | if ('attr' in config.keys()) and config['attr']: 196 | train_stats = train_attr(model, train_loader, optimizer, tokenizer, epoch, 197 | device, lr_scheduler, config, mask_generator) 198 | else: 199 | train_stats = train(model, train_loader, optimizer, tokenizer, epoch, 200 | device, lr_scheduler, config, mask_generator) 201 | 202 | if (epoch + 1) % 1 == 0: 203 | # if args.task not in ["itr_icfg", "itr_pa100k"]: 204 | # score_val_t2i = evaluation(model_without_ddp, val_loader, tokenizer, 205 | # device, config, args) 206 | if args.task == "itr_pa100k": 207 | if model_without_ddp.pa100k_only_img_classifier: 208 | score_test_i2t = evaluation_attr_only_img_classifier(model_without_ddp, test_loader, 209 | tokenizer, device, config, args) 210 | else: 211 | score_test_i2t = evaluation_attr(model_without_ddp, test_loader, 212 | tokenizer, device, config, args) 213 | else: 214 | score_test_t2i = evaluation(model_without_ddp, test_loader, 215 | tokenizer, device, config, args) 216 | 217 | if utils.is_main_process(): 218 | # if args.task not in ["itr_icfg", "itr_pa100k"]: 219 | # val_result = mAP(score_val_t2i, val_loader.dataset.g_pids, val_loader.dataset.q_pids, table) 220 | if args.task == "itr_pa100k": 221 | if model_without_ddp.pa100k_only_img_classifier: 222 | test_result = itm_eval_attr_only_img_classifier(score_test_i2t, test_loader.dataset) 223 | else: 224 | test_result = itm_eval_attr(score_test_i2t, test_loader.dataset) 225 | table.add_row([epoch, test_result['label_mA'] * 100, test_result['ins_acc'] * 100, 226 | test_result['ins_prec'] * 100, test_result['ins_rec'] * 100, 227 | test_result['ins_f1'] * 100]) 228 | test_result_log = test_result 229 | else: 230 | test_result = mAP(score_test_t2i, test_loader.dataset.g_pids, test_loader.dataset.q_pids, table) 231 | table.add_row([epoch, test_result['R1'], test_result['R5'], test_result['R10'], 232 | test_result['mAP'], test_result['mINP']]) 233 | test_result_log = {} 234 | for k, v in test_result.items(): 235 | test_result_log[k] = str(np.around(v, 3)) 236 | print(table, flush=True) 237 | 238 | log_stats = {'e': epoch, 239 | **{k: v for k, v in test_result_log.items()}, 240 | **{k: v for k, v in train_stats.items()}, 241 | # **{f'val_{k}': v for k, v in val_result.items()}, 242 | } 243 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 244 | f.write(json.dumps(log_stats) + "\n") 245 | 246 | if args.task == "itr_pa100k": 247 | result = test_result['label_mA'] 248 | else: 249 | result = test_result['R1'] 250 | 251 | if result > best: 252 | save_obj = {'model': model_without_ddp.state_dict(), 'config': config, } 253 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 254 | best = result 255 | best_epoch = epoch 256 | elif epoch >= max_epoch - 1: 257 | save_obj = { 258 | 'model': model_without_ddp.state_dict(), 259 | # 'optimizer': optimizer.state_dict(), 260 | # 'lr_scheduler': lr_scheduler.state_dict(), 261 | 'config': config, 262 | # 'epoch': epoch, 263 | } 264 | torch.save(save_obj, os.path.join(args.output_dir, f'checkpoint_{epoch}.pth')) 265 | dist.barrier() 266 | torch.cuda.empty_cache() 267 | 268 | if utils.is_main_process(): 269 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 270 | f.write("best epoch: %d" % best_epoch) 271 | 272 | os.system(f"cat {args.output_dir}/log.txt") 273 | 274 | total_time = time.time() - start_time 275 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 276 | print(' ### Time {}'.format(total_time_str)) 277 | 278 | 279 | if __name__ == '__main__': 280 | parser = argparse.ArgumentParser() 281 | parser.add_argument('--checkpoint', type=str) 282 | parser.add_argument('--config', type=str, required=True) 283 | parser.add_argument('--task', type=str, required=True) 284 | parser.add_argument('--output_dir', type=str, required=True) 285 | parser.add_argument('--device', default='cuda') 286 | parser.add_argument('--seed', default=42, type=int) 287 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 288 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 289 | parser.add_argument('--distributed', action='store_false') 290 | parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus") 291 | parser.add_argument('--epo', default=-1, type=int, help="epoch") 292 | parser.add_argument('--evaluate', action='store_true') 293 | 294 | args = parser.parse_args() 295 | 296 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 297 | 298 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 299 | 300 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 301 | 302 | main(args, config) 303 | -------------------------------------------------------------------------------- /assets/badsamples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/badsamples.jpg -------------------------------------------------------------------------------- /assets/chart1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/chart1.jpg -------------------------------------------------------------------------------- /assets/cuhk_freq50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/cuhk_freq50.jpg -------------------------------------------------------------------------------- /assets/examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/examples.jpg -------------------------------------------------------------------------------- /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/framework.jpg -------------------------------------------------------------------------------- /assets/mals_fre50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/mals_fre50.jpg -------------------------------------------------------------------------------- /assets/mals_rst_30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/mals_rst_30.jpg -------------------------------------------------------------------------------- /assets/readme.txt: -------------------------------------------------------------------------------- 1 | assets 2 | -------------------------------------------------------------------------------- /assets/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/result.jpg -------------------------------------------------------------------------------- /configs/Retrieval_cuhk.yaml: -------------------------------------------------------------------------------- 1 | image_root: 'images/CUHK-PEDES/' 2 | test_file: 'data/finetune/cuhk_test.json' 3 | val_file: 'data/finetune/cuhk_val.json' 4 | train_file: ['data/finetune/cuhk_train.json'] 5 | 6 | 7 | ## Vision Encoder 8 | vision_config: 'configs/config_swinB_384.json' 9 | image_res: 384 10 | patch_size: 32 11 | h: 384 12 | w: 128 13 | 14 | 15 | ## Text Encoder 16 | text_config: 'configs/config_bert.json' 17 | text_encoder: 'data/bert-base-uncased' 18 | 19 | 20 | ## Training 21 | batch_size_train: 120 22 | batch_size_test: 150 23 | batch_size_test_text: 750 24 | 25 | max_tokens: 56 26 | max_words: 56 27 | 28 | embed_dim: 256 29 | temp: 0.07 30 | k_test: 128 31 | 32 | 33 | ## mlm loss 34 | mlm: True 35 | mask_prob: 0.25 36 | max_masks: 10 37 | skipgram_prb: 0.2 38 | skipgram_size: 3 39 | mask_whole_word: True 40 | 41 | 42 | ## Other Settings 43 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2} 44 | schedular: {sched: step, lr: 1e-4, epochs: 30, num_warmup_steps: 0.1} 45 | 46 | pa100k: False 47 | icfg_rstp: False 48 | 49 | lr_2: True 50 | load_params: False 51 | load_pretrained: True 52 | 53 | eda: True 54 | eda_p: 1 55 | erasing_p: 0.6 56 | LabelSmooth: 0 -------------------------------------------------------------------------------- /configs/Retrieval_gene.yaml: -------------------------------------------------------------------------------- 1 | image_root: 'images/CUHK-PEDES/' 2 | test_file: 'data/finetune/cuhk_test.json' 3 | val_file: 'data/finetune/cuhk_val.json' 4 | train_file: ['data/finetune/gene_attrs/g_4x_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_attrs.json', 5 | 'data/finetune/gene_attrs/g_c_g_a_0_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_1_attrs.json', 6 | 'data/finetune/gene_attrs/g_c_g_a_2_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_3_attrs.json', 7 | 'data/finetune/gene_attrs/g_c_g_a_4_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_5_attrs.json', 8 | 'data/finetune/gene_attrs/g_c_g_a_6_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_7_attrs.json', 9 | 'data/finetune/gene_attrs/g_c_g_a_8_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_9_attrs.json', 10 | 'data/finetune/gene_attrs/g_c_g_a_10_attrs.json','data/finetune/gene_attrs/g_c_g_a_11_attrs.json', 11 | 'data/finetune/gene_attrs/g_c_g_a_12_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_13_attrs.json', 12 | 'data/finetune/gene_attrs/g_c_g_a_14_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_15_attrs.json', 13 | 'data/finetune/gene_attrs/g_c_g_a_16_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_17_attrs.json', 14 | 'data/finetune/gene_attrs/g_c_g_a_18_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_19_attrs.json', 15 | 'data/finetune/gene_attrs/g_c_g_a_20_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_21_attrs.json', 16 | 'data/finetune/gene_attrs/g_c_g_a_22_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_23_attrs.json', 17 | 'data/finetune/gene_attrs/g_c_g_a_24_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_25_attrs.json', 18 | 'data/finetune/gene_attrs/g_c_g_a_26_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_27_attrs.json', 19 | 'data/finetune/gene_attrs/g_c_g_a_28_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_29_attrs.json', 20 | 'data/finetune/gene_attrs/g_c_g_a_30_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_31_attrs.json', 21 | 'data/finetune/gene_attrs/g_c_g_a_32_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_33_attrs.json', 22 | 'data/finetune/gene_attrs/g_c_g_a_34_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_35_attrs.json', 23 | 'data/finetune/gene_attrs/g_c_g_a_36_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_37_attrs.json', 24 | 'data/finetune/gene_attrs/g_c_g_a_38_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_39_attrs.json', 25 | 'data/finetune/gene_attrs/g_c_g_a_40_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_41_attrs.json', 26 | 'data/finetune/gene_attrs/g_c_g_a_42_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_43_attrs.json', 27 | 'data/finetune/gene_attrs/g_i_g_a_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_0_attrs.json', 28 | 'data/finetune/gene_attrs/g_i_g_a_1_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_2_attrs.json', 29 | 'data/finetune/gene_attrs/g_i_g_a_3_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_4_attrs.json', 30 | 'data/finetune/gene_attrs/g_i_g_a_5_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_6_attrs.json', 31 | 'data/finetune/gene_attrs/g_i_g_a_7_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_8_attrs.json', 32 | 'data/finetune/gene_attrs/g_i_g_a_9_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_10_attrs.json', 33 | 'data/finetune/gene_attrs/g_i_g_a_11_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_12_attrs.json', 34 | 'data/finetune/gene_attrs/g_i_g_a_13_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_14_attrs.json', 35 | 'data/finetune/gene_attrs/g_i_g_a_15_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_16_attrs.json', 36 | 'data/finetune/gene_attrs/g_i_g_a_17_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_18_attrs.json', 37 | 'data/finetune/gene_attrs/g_i_g_a_19_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_20_attrs.json', 38 | 'data/finetune/gene_attrs/g_i_g_a_21_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_22_attrs.json', 39 | 'data/finetune/gene_attrs/g_i_g_a_23_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_24_attrs.json', 40 | 'data/finetune/gene_attrs/g_i_g_a_25_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_26_attrs.json', 41 | 'data/finetune/gene_attrs/g_i_g_a_27_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_28_attrs.json', 42 | 'data/finetune/gene_attrs/g_i_g_a_29_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_30_attrs.json', 43 | 'data/finetune/gene_attrs/g_i_g_a_31_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_32_attrs.json', 44 | 'data/finetune/gene_attrs/g_i_g_a_33_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_34_attrs.json', 45 | 'data/finetune/gene_attrs/g_i_g_a_35_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_36_attrs.json', 46 | 'data/finetune/gene_attrs/g_i_g_a_37_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_38_attrs.json', 47 | 'data/finetune/gene_attrs/g_i_g_a_39_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_40_attrs.json', 48 | 'data/finetune/gene_attrs/g_i_g_a_41_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_42_attrs.json', 49 | 'data/finetune/gene_attrs/g_i_g_a_43_attrs.json'] 50 | 51 | 52 | ## Vision Encoder 53 | vision_config: 'configs/config_swinB_384.json' 54 | image_res: 384 55 | patch_size: 32 56 | h: 384 57 | w: 128 58 | 59 | ## Text Encoder 60 | text_config: 'configs/config_bert.json' 61 | text_encoder: 'data/bert-base-uncased' 62 | 63 | 64 | ## Training 65 | batch_size_train: 150 66 | batch_size_test: 150 67 | batch_size_test_text: 750 68 | 69 | max_tokens: 56 70 | max_words: 56 71 | 72 | embed_dim: 256 73 | temp: 0.07 74 | k_test: 128 75 | 76 | 77 | ## mlm loss 78 | mlm: True 79 | mask_prob: 0.25 80 | max_masks: 10 81 | skipgram_prb: 0.2 82 | skipgram_size: 3 83 | mask_whole_word: True 84 | 85 | 86 | ## Other Settings 87 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2} 88 | schedular: {sched: linear, lr: 1e-4, epochs: 32, num_warmup_steps: 2600} 89 | 90 | pa100k: False 91 | icfg_rstp: False 92 | 93 | lr_2: False 94 | init_cross: True 95 | load_params: True 96 | load_pretrained: False 97 | 98 | erasing_p: 0.6 99 | eda: False 100 | eda_p: 1 101 | 102 | attr: True 103 | LabelSmooth: 0.4 104 | t: 0.8 105 | -------------------------------------------------------------------------------- /configs/Retrieval_icfg.yaml: -------------------------------------------------------------------------------- 1 | image_root: 'images/ICFG-PEDES/' 2 | train_file: ['data/finetune/icfg_train.json'] 3 | test_file: 'data/finetune/icfg_test.json' 4 | 5 | 6 | ## Vision Encoder 7 | vision_config: 'configs/config_swinB_384.json' 8 | image_res: 384 9 | patch_size: 32 10 | h: 384 11 | w: 128 12 | 13 | 14 | ## Text Encoder 15 | text_config: 'configs/config_bert.json' 16 | text_encoder: 'data/bert-base-uncased' 17 | 18 | 19 | ## Training 20 | batch_size_train: 120 21 | batch_size_test: 150 22 | batch_size_test_text: 750 23 | 24 | max_tokens: 56 25 | max_words: 56 26 | 27 | embed_dim: 256 28 | temp: 0.07 29 | k_test: 128 30 | 31 | 32 | ## mlm loss 33 | mlm: True 34 | mask_prob: 0.25 35 | max_masks: 10 36 | skipgram_prb: 0.2 37 | skipgram_size: 3 38 | mask_whole_word: True 39 | 40 | 41 | ## Other Settings 42 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2} 43 | schedular: {sched: step, lr: 1e-4, epochs: 30, num_warmup_steps: 0.1} 44 | 45 | pa100k: False 46 | icfg_rstp: True 47 | 48 | lr_2: True 49 | load_params: False 50 | load_pretrained: True 51 | 52 | erasing_p: 0.6 53 | eda: True 54 | eda_p: 1 55 | LabelSmooth: 0 -------------------------------------------------------------------------------- /configs/Retrieval_pa100k.yaml: -------------------------------------------------------------------------------- 1 | image_root: 'images/pa100k/release_data' 2 | test_file: 'data/finetune/PA100K_test.json' 3 | val_file: 'data/finetune/PA100K_val.json' 4 | train_file: ['data/finetune/PA100K_train.json'] 5 | 6 | 7 | ## Vision Encoder 8 | vision_config: 'configs/config_swinB_384.json' 9 | image_res: 384 10 | patch_size: 32 11 | h: 384 12 | w: 128 13 | 14 | ## Text Encoder 15 | text_config: 'configs/config_bert.json' 16 | text_encoder: 'data/bert-base-uncased' 17 | 18 | 19 | ## Training 20 | batch_size_train: 200 21 | batch_size_test: 200 22 | batch_size_test_text: 1000 23 | 24 | max_tokens: 15 25 | max_words: 56 26 | 27 | embed_dim: 256 28 | temp: 0.07 29 | k_test: 128 30 | 31 | 32 | ## mlm loss 33 | mlm: True 34 | mask_prob: 0.25 35 | max_masks: 10 36 | skipgram_prb: 0.2 37 | skipgram_size: 3 38 | mask_whole_word: True 39 | 40 | 41 | ## Other Settings 42 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2} 43 | schedular: {sched: step, lr: 1e-4, epochs: 30, num_warmup_steps: 0.1} 44 | 45 | lr_2: True 46 | load_params: False 47 | load_pretrained: True 48 | 49 | pa100k: True 50 | #pa100k_only_img_classifier: True 51 | #dop: 0.1 52 | 53 | erasing_p: 0.6 54 | LabelSmooth: 0 # 0.4 -------------------------------------------------------------------------------- /configs/Retrieval_rstp.yaml: -------------------------------------------------------------------------------- 1 | image_root: 'images/RSTPReid/' 2 | train_file: ['data/finetune/rstp_train.json'] 3 | val_file: 'data/finetune/rstp_val.json' 4 | test_file: 'data/finetune/rstp_test.json' 5 | 6 | 7 | ## Vision Encoder 8 | vision_config: 'configs/config_swinB_384.json' 9 | image_res: 384 10 | patch_size: 32 11 | h: 384 12 | w: 128 13 | 14 | 15 | ## Text Encoder 16 | text_config: 'configs/config_bert.json' 17 | text_encoder: 'data/bert-base-uncased' 18 | 19 | 20 | ## Training 21 | batch_size_train: 120 22 | batch_size_test: 150 23 | batch_size_test_text: 750 24 | 25 | max_tokens: 56 26 | max_words: 56 27 | 28 | embed_dim: 256 29 | temp: 0.07 30 | k_test: 128 31 | 32 | 33 | ## mlm loss 34 | mlm: True 35 | mask_prob: 0.25 36 | max_masks: 10 37 | skipgram_prb: 0.2 38 | skipgram_size: 3 39 | mask_whole_word: True 40 | 41 | 42 | ## Other Settings 43 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2} 44 | schedular: {sched: step, lr: 1e-4, epochs: 30, num_warmup_steps: 0.1} 45 | 46 | pa100k: False 47 | icfg_rstp: True 48 | 49 | lr_2: True 50 | load_params: False 51 | load_pretrained: True 52 | 53 | erasing_p: 0.6 54 | eda: True 55 | eda_p: 1 56 | LabelSmooth: 0 -------------------------------------------------------------------------------- /configs/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 6, 20 | "encoder_width": 1024 21 | } 22 | -------------------------------------------------------------------------------- /configs/config_swinB_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "data/swin_base_patch4_window7_224_22k.pth", 3 | "vision_width": 1024, 4 | "image_res": 384, 5 | "h": 384, 6 | "w": 128, 7 | "window_size": 8, 8 | "embed_dim": 128, 9 | "depths": [ 2, 2, 18, 2 ], 10 | "num_heads": [ 4, 8, 16, 32 ] 11 | } 12 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import transforms 5 | from torchvision.transforms import InterpolationMode 6 | from PIL import Image 7 | 8 | from dataset.re_dataset import re_train_dataset, re_test_dataset, re_test_dataset_icfg, re_train_dataset_attr, \ 9 | re_test_dataset_attr 10 | from dataset.randaugment import RandomAugment 11 | from dataset.random_erasing import RandomErasing 12 | 13 | 14 | def create_dataset(dataset, config, evaluate=False): 15 | # gene 16 | gene_norm = transforms.Normalize((0.4416847, 0.41812873, 0.4237452), (0.3088255, 0.29743394, 0.301009)) 17 | # cuhk 18 | cuhk_norm = transforms.Normalize((0.38901278, 0.3651612, 0.34836376), (0.24344306, 0.23738699, 0.23368555)) 19 | # icfg 20 | icfg_norm = transforms.Normalize((0.30941582, 0.28956893, 0.30347288), (0.25849792, 0.24547698, 0.2366199)) 21 | # rstp 22 | rstp_norm = transforms.Normalize((0.27722597, 0.26065794, 0.3036557), (0.2609547, 0.2508087, 0.25293276)) 23 | # pa100k 24 | pa100k_norm = transforms.Normalize((0.46485138, 0.45038012, 0.4632019), (0.25088054, 0.24609283, 0.24240193)) 25 | 26 | if dataset == 're_cuhk': 27 | train_norm = cuhk_norm 28 | test_norm = cuhk_norm 29 | elif dataset == 're_icfg': 30 | train_norm = icfg_norm 31 | test_norm = icfg_norm 32 | elif dataset == 're_rstp': 33 | train_norm = rstp_norm 34 | test_norm = rstp_norm 35 | elif dataset == 're_gene': 36 | train_norm = gene_norm 37 | test_norm = cuhk_norm 38 | elif dataset == 're_pa100k': 39 | train_norm = pa100k_norm 40 | test_norm = pa100k_norm 41 | 42 | train_transform = transforms.Compose([ 43 | # transforms.RandomResizedCrop((config['h'], config['h']), 44 | # scale=(0.5, 1.0), interpolation=InterpolationMode.BICUBIC), 45 | transforms.Resize((config['h'], config['w']), interpolation=InterpolationMode.BICUBIC), 46 | transforms.RandomHorizontalFlip(), 47 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 48 | 'Brightness', 'Sharpness', 'ShearX', 49 | 'ShearY', 'TranslateX', 'TranslateY', 50 | 'Rotate']), 51 | transforms.ToTensor(), 52 | train_norm, 53 | RandomErasing(probability=config['erasing_p'], mean=[0.0, 0.0, 0.0]) 54 | ]) 55 | 56 | pre_transform = transforms.Compose([ 57 | transforms.RandomResizedCrop((config['h'], config['h']), 58 | scale=(0.5, 1.0), interpolation=InterpolationMode.BICUBIC), 59 | transforms.Resize((config['h'], config['w']), interpolation=InterpolationMode.BICUBIC), 60 | transforms.RandomHorizontalFlip(), 61 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 62 | 'Brightness', 'Sharpness', 'ShearX', 63 | 'ShearY', 'TranslateX', 'TranslateY', 64 | 'Rotate']), 65 | transforms.ToTensor(), 66 | train_norm, 67 | RandomErasing(probability=config['erasing_p'], mean=[0.0, 0.0, 0.0]) 68 | ]) 69 | 70 | test_transform = transforms.Compose([ 71 | transforms.Resize((config['h'], config['w']), interpolation=InterpolationMode.BICUBIC), 72 | transforms.ToTensor(), 73 | test_norm, 74 | ]) 75 | 76 | if dataset == 're_icfg': 77 | test_dataset = re_test_dataset_icfg(config, test_transform) 78 | if evaluate: 79 | return None, test_dataset 80 | train_dataset = re_train_dataset(config, train_transform, pre_transform) 81 | return train_dataset, test_dataset 82 | elif dataset == 're_pa100k': 83 | test_dataset = re_test_dataset_attr(config['test_file'], config, test_transform) 84 | val_dataset = re_test_dataset_attr(config['val_file'], config, test_transform) 85 | if evaluate: 86 | return None, val_dataset, test_dataset 87 | train_dataset = re_train_dataset_attr(config, train_transform) 88 | return train_dataset, val_dataset, test_dataset 89 | else: 90 | test_dataset = re_test_dataset(config['test_file'], config, test_transform) 91 | val_dataset = re_test_dataset(config['val_file'], config, test_transform) 92 | if evaluate: 93 | return None, val_dataset, test_dataset 94 | train_dataset = re_train_dataset(config, train_transform, pre_transform) 95 | return train_dataset, val_dataset, test_dataset 96 | 97 | 98 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 99 | samplers = [] 100 | for dataset, shuffle in zip(datasets, shuffles): 101 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, 102 | shuffle=shuffle) 103 | samplers.append(sampler) 104 | return samplers 105 | 106 | 107 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 108 | loaders = [] 109 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(datasets, samplers, batch_size, num_workers, 110 | is_trains, collate_fns): 111 | if is_train: 112 | shuffle = (sampler is None) 113 | drop_last = True 114 | else: 115 | shuffle = False 116 | drop_last = False 117 | loader = DataLoader( 118 | dataset, 119 | batch_size=bs, 120 | num_workers=n_worker, 121 | pin_memory=True, 122 | sampler=sampler, 123 | shuffle=shuffle, 124 | collate_fn=collate_fn, 125 | drop_last=drop_last, 126 | ) 127 | loaders.append(loader) 128 | 129 | if len(loaders) <= 1: 130 | print(f"### be careful: func create_loader returns a list length of {len(loaders)}") 131 | 132 | return loaders 133 | -------------------------------------------------------------------------------- /dataset/eda.py: -------------------------------------------------------------------------------- 1 | # Easy data augmentation techniques for text classification 2 | # Jason Wei and Kai Zou 3 | 4 | import random 5 | from random import shuffle 6 | 7 | random.seed(1) 8 | 9 | # stop words list 10 | stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 11 | 'ours', 'ourselves', 'you', 'your', 'yours', 12 | 'yourself', 'yourselves', 'he', 'him', 'his', 13 | 'himself', 'she', 'her', 'hers', 'herself', 14 | 'it', 'its', 'itself', 'they', 'them', 'their', 15 | 'theirs', 'themselves', 'what', 'which', 'who', 16 | 'whom', 'this', 'that', 'these', 'those', 'am', 17 | 'is', 'are', 'was', 'were', 'be', 'been', 'being', 18 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 19 | 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 20 | 'because', 'as', 'until', 'while', 'of', 'at', 21 | 'by', 'for', 'with', 'about', 'against', 'between', 22 | 'into', 'through', 'during', 'before', 'after', 23 | 'above', 'below', 'to', 'from', 'up', 'down', 'in', 24 | 'out', 'on', 'off', 'over', 'under', 'again', 25 | 'further', 'then', 'once', 'here', 'there', 'when', 26 | 'where', 'why', 'how', 'all', 'any', 'both', 'each', 27 | 'few', 'more', 'most', 'other', 'some', 'such', 'no', 28 | 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 29 | 'very', 's', 't', 'can', 'will', 'just', 'don', 30 | 'should', 'now', ''] 31 | 32 | # cleaning up text 33 | import re 34 | 35 | 36 | def get_only_chars(line): 37 | clean_line = "" 38 | 39 | line = line.replace("’", "") 40 | line = line.replace("'", "") 41 | line = line.replace("-", " ") # replace hyphens with spaces 42 | line = line.replace("\t", " ") 43 | line = line.replace("\n", " ") 44 | line = line.lower() 45 | 46 | for char in line: 47 | if char in 'qwertyuiopasdfghjklzxcvbnm ': 48 | clean_line += char 49 | else: 50 | clean_line += ' ' 51 | 52 | clean_line = re.sub(' +', ' ', clean_line) # delete extra spaces 53 | if clean_line[0] == ' ': 54 | clean_line = clean_line[1:] 55 | return clean_line 56 | 57 | 58 | ######################################################################## 59 | # Synonym replacement 60 | # Replace n words in the sentence with synonyms from wordnet 61 | ######################################################################## 62 | 63 | # for the first time you use wordnet 64 | # import nltk 65 | # nltk.download('wordnet') 66 | from nltk.corpus import wordnet 67 | 68 | 69 | def synonym_replacement(words, n): 70 | new_words = words.copy() 71 | random_word_list = list(set([word for word in words if word not in stop_words])) 72 | random.shuffle(random_word_list) 73 | num_replaced = 0 74 | for random_word in random_word_list: 75 | synonyms = get_synonyms(random_word) 76 | if len(synonyms) >= 1: 77 | synonym = random.choice(list(synonyms)) 78 | new_words = [synonym if word == random_word else word for word in new_words] 79 | # print("replaced", random_word, "with", synonym) 80 | num_replaced += 1 81 | if num_replaced >= n: # only replace up to n words 82 | break 83 | 84 | # this is stupid but we need it, trust me 85 | sentence = ' '.join(new_words) 86 | new_words = sentence.split(' ') 87 | 88 | return new_words 89 | 90 | 91 | def get_synonyms(word): 92 | synonyms = set() 93 | for syn in wordnet.synsets(word): 94 | for l in syn.lemmas(): 95 | synonym = l.name().replace("_", " ").replace("-", " ").lower() 96 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm']) 97 | synonyms.add(synonym) 98 | if word in synonyms: 99 | synonyms.remove(word) 100 | return list(synonyms) 101 | 102 | 103 | ######################################################################## 104 | # Random deletion 105 | # Randomly delete words from the sentence with probability p 106 | ######################################################################## 107 | 108 | def random_deletion(words, p): 109 | # obviously, if there's only one word, don't delete it 110 | if len(words) == 1: 111 | return words 112 | 113 | # randomly delete words with probability p 114 | new_words = [] 115 | for word in words: 116 | r = random.uniform(0, 1) 117 | if r > p: 118 | new_words.append(word) 119 | 120 | # if you end up deleting all words, just return a random word 121 | if len(new_words) == 0: 122 | rand_int = random.randint(0, len(words) - 1) 123 | return [words[rand_int]] 124 | 125 | return new_words 126 | 127 | 128 | ######################################################################## 129 | # Random swap 130 | # Randomly swap two words in the sentence n times 131 | ######################################################################## 132 | 133 | def random_swap(words, n): 134 | new_words = words.copy() 135 | for _ in range(n): 136 | new_words = swap_word(new_words) 137 | return new_words 138 | 139 | 140 | def swap_word(new_words): 141 | random_idx_1 = random.randint(0, len(new_words) - 1) 142 | random_idx_2 = random_idx_1 143 | counter = 0 144 | while random_idx_2 == random_idx_1: 145 | random_idx_2 = random.randint(0, len(new_words) - 1) 146 | counter += 1 147 | if counter > 3: 148 | return new_words 149 | new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1] 150 | return new_words 151 | 152 | 153 | ######################################################################## 154 | # Random insertion 155 | # Randomly insert n words into the sentence 156 | ######################################################################## 157 | 158 | def random_insertion(words, n): 159 | new_words = words.copy() 160 | for _ in range(n): 161 | add_word(new_words) 162 | return new_words 163 | 164 | 165 | def add_word(new_words): 166 | synonyms = [] 167 | counter = 0 168 | while len(synonyms) < 1: 169 | random_word = new_words[random.randint(0, len(new_words) - 1)] 170 | synonyms = get_synonyms(random_word) 171 | counter += 1 172 | if counter >= 10: 173 | return 174 | random_synonym = synonyms[0] 175 | random_idx = random.randint(0, len(new_words) - 1) 176 | new_words.insert(random_idx, random_synonym) 177 | 178 | 179 | ######################################################################## 180 | # main data augmentation function 181 | ######################################################################## 182 | 183 | def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9): 184 | sentence = get_only_chars(sentence) 185 | words = sentence.split(' ') 186 | words = [word for word in words if word != ''] 187 | num_words = len(words) 188 | 189 | augmented_sentences = [] 190 | num_new_per_technique = int(num_aug / 4) + 1 191 | 192 | # sr 193 | if (alpha_sr > 0): 194 | n_sr = max(1, int(alpha_sr * num_words)) 195 | for _ in range(num_new_per_technique): 196 | a_words = synonym_replacement(words, n_sr) 197 | augmented_sentences.append(' '.join(a_words)) 198 | 199 | # ri 200 | if (alpha_ri > 0): 201 | n_ri = max(1, int(alpha_ri * num_words)) 202 | for _ in range(num_new_per_technique): 203 | a_words = random_insertion(words, n_ri) 204 | augmented_sentences.append(' '.join(a_words)) 205 | 206 | # rs 207 | if (alpha_rs > 0): 208 | n_rs = max(1, int(alpha_rs * num_words)) 209 | for _ in range(num_new_per_technique): 210 | a_words = random_swap(words, n_rs) 211 | augmented_sentences.append(' '.join(a_words)) 212 | 213 | # rd 214 | if (p_rd > 0): 215 | for _ in range(num_new_per_technique): 216 | a_words = random_deletion(words, p_rd) 217 | augmented_sentences.append(' '.join(a_words)) 218 | 219 | augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences] 220 | shuffle(augmented_sentences) 221 | 222 | # trim so that we have the desired number of augmented sentences 223 | if num_aug >= 1: 224 | augmented_sentences = augmented_sentences[:num_aug] 225 | else: 226 | keep_prob = num_aug / len(augmented_sentences) 227 | augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob] 228 | 229 | # append the original sentence 230 | augmented_sentences.append(sentence) 231 | 232 | return augmented_sentences -------------------------------------------------------------------------------- /dataset/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /dataset/random_erasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | #from torchvision.transforms import * 4 | 5 | #from PIL import Image 6 | import random 7 | import math 8 | #import numpy as np 9 | 10 | class RandomErasing(object): 11 | """ Randomly selects a rectangle region in an image and erases its pixels. 12 | 'Random Erasing Data Augmentation' by Zhong et al. 13 | See https://arxiv.org/pdf/1708.04896.pdf 14 | Args: 15 | probability: The probability that the Random Erasing operation will be performed. 16 | sl: Minimum proportion of erased area against input image. 17 | sh: Maximum proportion of erased area against input image. 18 | r1: Minimum aspect ratio of erased area. 19 | mean: Erasing value. 20 | """ 21 | 22 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 23 | self.probability = probability 24 | self.mean = mean 25 | self.sl = sl 26 | self.sh = sh 27 | self.r1 = r1 28 | 29 | def __call__(self, img): 30 | 31 | if random.uniform(0, 1) > self.probability: 32 | return img 33 | 34 | for attempt in range(100): 35 | area = img.size()[1] * img.size()[2] 36 | 37 | target_area = random.uniform(self.sl, self.sh) * area 38 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 39 | 40 | h = int(round(math.sqrt(target_area * aspect_ratio))) 41 | w = int(round(math.sqrt(target_area / aspect_ratio))) 42 | 43 | if w < img.size()[2] and h < img.size()[1]: 44 | x1 = random.randint(0, img.size()[1] - h) 45 | y1 = random.randint(0, img.size()[2] - w) 46 | if img.size()[0] == 3: 47 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 48 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 49 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 50 | else: 51 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 52 | return img 53 | 54 | return img 55 | 56 | 57 | class RandomGrayscaleErasing(object): 58 | """ Randomly selects a rectangle region in an image and use grayscale image 59 | instead of its pixels. 60 | 'Local Grayscale Transfomation' by Yunpeng Gong. 61 | See https://arxiv.org/pdf/2101.08533.pdf 62 | Args: 63 | probability: The probability that the Random Grayscale Erasing operation will be performed. 64 | sl: Minimum proportion of erased area against input image. 65 | sh: Maximum proportion of erased area against input image. 66 | r1: Minimum aspect ratio of erased area. 67 | """ 68 | 69 | def __init__(self, probability: float = 0.2, sl: float = 0.02, sh: float = 0.4, r1: float = 0.3): 70 | self.probability = probability 71 | self.sl = sl 72 | self.sh = sh 73 | self.r1 = r1 74 | 75 | def __call__(self, img): 76 | """ 77 | Args: 78 | img: after ToTensor() and Normalize([...]), img's type is Tensor 79 | """ 80 | if random.uniform(0, 1) > self.probability: 81 | return img 82 | 83 | height, width = img.size()[-2], img.size()[-1] 84 | area = height * width 85 | 86 | for _ in range(100): 87 | 88 | target_area = random.uniform(self.sl, self.sh) * area 89 | aspect_ratio = random.uniform(self.r1, 1/self.r1) # height / width 90 | 91 | h = int(round(math.sqrt(target_area * aspect_ratio))) 92 | w = int(round(math.sqrt(target_area / aspect_ratio))) 93 | 94 | if w < width and h < height: 95 | # tl 96 | x = random.randint(0, height - h) 97 | y = random.randint(0, width - w) 98 | # unbind channel dim 99 | r, g, b = img.unbind(dim=-3) 100 | # Weighted average method -> grayscale patch 101 | l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype) 102 | l_img = l_img.unsqueeze(dim=-3) # rebind channel 103 | # erasing 104 | img[0, y:y + h, x:x + w] = l_img[0, y:y + h, x:x + w] 105 | img[1, y:y + h, x:x + w] = l_img[0, y:y + h, x:x + w] 106 | img[2, y:y + h, x:x + w] = l_img[0, y:y + h, x:x + w] 107 | 108 | return img 109 | 110 | return img -------------------------------------------------------------------------------- /dataset/re_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | from random import randint, shuffle 7 | from random import random as rand 8 | from PIL import Image 9 | from PIL import ImageFile 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | 14 | from dataset.utils import pre_caption 15 | 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | Image.MAX_IMAGE_PIXELS = None 18 | 19 | 20 | class TextMaskingGenerator: 21 | def __init__(self, tokenizer, mask_prob, mask_max, skipgram_prb=0.2, skipgram_size=3, mask_whole_word=True, 22 | use_roberta=False): 23 | self.id2token = {i: w for w, i in tokenizer.get_vocab().items()} 24 | self.use_roberta = use_roberta 25 | for i in range(len(self.id2token)): 26 | assert i in self.id2token.keys() # check 27 | self.cls_token_id = tokenizer.cls_token_id 28 | self.mask_token_id = tokenizer.mask_token_id 29 | self.mask_max = mask_max 30 | self.mask_prob = mask_prob 31 | self.skipgram_prb = skipgram_prb 32 | self.skipgram_size = skipgram_size 33 | self.mask_whole_word = mask_whole_word 34 | 35 | print("len(tokenizer.id2token): ", len(self.id2token), " ---- cls_token_id: ", self.cls_token_id, 36 | " ---- mask_token_id: ", self.mask_token_id, flush=True) 37 | 38 | def get_random_word(self): 39 | i = randint(0, len(self.id2token) - 1) 40 | return i # self.id2token[i] 41 | 42 | def __call__(self, text_ids): # tokens: [CLS] + ... 43 | n_pred = min(self.mask_max, max(1, int(round(len(text_ids) * self.mask_prob)))) 44 | 45 | # candidate positions of masked tokens 46 | assert text_ids[0] == self.cls_token_id 47 | special_pos = set([0]) # will not be masked 48 | cand_pos = list(range(1, len(text_ids))) 49 | 50 | shuffle(cand_pos) 51 | masked_pos = set() 52 | max_cand_pos = max(cand_pos) 53 | for pos in cand_pos: 54 | if len(masked_pos) >= n_pred: 55 | break 56 | if pos in masked_pos: 57 | continue 58 | 59 | def _expand_whole_word(st, end): 60 | new_st, new_end = st, end 61 | 62 | if self.use_roberta: 63 | while (new_st > 1) and (self.id2token[text_ids[new_st].item()][0] != 'Ġ'): 64 | new_st -= 1 65 | while (new_end < len(text_ids)) and (self.id2token[text_ids[new_end].item()][0] != 'Ġ'): 66 | new_end += 1 67 | else: 68 | # bert, WordPiece 69 | while (new_st >= 0) and self.id2token[text_ids[new_st].item()].startswith('##'): 70 | new_st -= 1 71 | while (new_end < len(text_ids)) and self.id2token[text_ids[new_end].item()].startswith('##'): 72 | new_end += 1 73 | 74 | return new_st, new_end 75 | 76 | if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 77 | # ngram 78 | cur_skipgram_size = randint(2, self.skipgram_size) 79 | if self.mask_whole_word: 80 | st_pos, end_pos = _expand_whole_word( 81 | pos, pos + cur_skipgram_size) 82 | else: 83 | st_pos, end_pos = pos, pos + cur_skipgram_size 84 | else: 85 | if self.mask_whole_word: 86 | st_pos, end_pos = _expand_whole_word(pos, pos + 1) 87 | else: 88 | st_pos, end_pos = pos, pos + 1 89 | 90 | for mp in range(st_pos, end_pos): 91 | if (0 < mp <= max_cand_pos) and (mp not in special_pos): 92 | masked_pos.add(mp) 93 | else: 94 | break 95 | 96 | masked_pos = list(masked_pos) 97 | n_real_pred = len(masked_pos) 98 | if n_real_pred > n_pred: 99 | shuffle(masked_pos) 100 | masked_pos = masked_pos[:n_pred] 101 | 102 | for pos in masked_pos: 103 | if rand() < 0.8: # 80% 104 | text_ids[pos] = self.mask_token_id 105 | elif rand() < 0.5: # 10% 106 | text_ids[pos] = self.get_random_word() 107 | 108 | return text_ids, masked_pos 109 | 110 | 111 | class re_train_dataset(Dataset): 112 | def __init__(self, config, transform, pre_transform): 113 | self.image_root = config['image_root'] 114 | self.max_words = config['max_words'] 115 | self.icfg_rstp = config['icfg_rstp'] 116 | self.eda = config['eda'] 117 | self.eda_p = config['eda_p'] 118 | ann_file = config['train_file'] 119 | 120 | if ('attr' in config.keys()) and config['attr']: 121 | self.attr = True 122 | else: 123 | self.attr = False 124 | 125 | self.transform = transform 126 | self.pre_transform = pre_transform 127 | self.ann = [] 128 | for f in ann_file: 129 | self.ann += json.load(open(f, 'r')) 130 | 131 | self.img_ids = {} 132 | 133 | n = 1 134 | for ann in self.ann: 135 | img_id = ann['image_id'] 136 | if img_id not in self.img_ids.keys(): 137 | self.img_ids[img_id] = n 138 | n += 1 139 | 140 | def __len__(self): 141 | return len(self.ann) 142 | 143 | def __getitem__(self, index): 144 | 145 | ann = self.ann[index] 146 | try: 147 | image_path = os.path.join(self.image_root, ann['image']) 148 | except: 149 | print("self.image_root", self.image_root) 150 | print("ann['image']", ann['image']) 151 | image = Image.open(image_path).convert('RGB') 152 | image1 = self.transform(image) 153 | 154 | caption = pre_caption(ann['caption'], self.max_words) 155 | if self.eda: 156 | caption1 = pre_caption(ann['caption'], self.max_words, self.icfg_rstp, True, self.eda_p) 157 | return image1, caption, caption1, self.img_ids[ann['image_id']] 158 | elif self.attr: 159 | label = torch.tensor(ann['label']) 160 | return image1, caption, self.img_ids[ann['image_id']], label 161 | else: 162 | return image1, caption, self.img_ids[ann['image_id']] 163 | 164 | 165 | class re_test_dataset(Dataset): 166 | def __init__(self, ann_file, config, transform): 167 | self.ann = json.load(open(ann_file, 'r')) 168 | self.transform = transform 169 | self.image_root = config['image_root'] 170 | self.max_words = config['max_words'] 171 | self.icfg_rstp = config['icfg_rstp'] 172 | 173 | self.text = [] 174 | self.image = [] 175 | self.txt2img = {} 176 | self.img2txt = {} 177 | 178 | self.g_pids = [] 179 | self.q_pids = [] 180 | 181 | txt_id = 0 182 | for img_id, ann in enumerate(self.ann): 183 | self.g_pids.append(ann['image_id']) 184 | self.image.append(ann['image']) 185 | self.img2txt[img_id] = [] 186 | 187 | t = 0 188 | for i, caption in enumerate(ann['caption']): 189 | self.q_pids.append(ann['image_id']) 190 | self.text.append(pre_caption(caption, self.max_words, icfg_rstp=self.icfg_rstp)) 191 | self.img2txt[img_id].append(txt_id) 192 | self.txt2img[txt_id] = [] 193 | self.txt2img[txt_id].append(img_id) 194 | txt_id += 1 195 | t += 1 196 | 197 | txt_id1 = 0 198 | for img_id1, ann1 in enumerate(self.ann): 199 | for i1, caption1 in enumerate(ann1['caption']): 200 | if ann['image_id'] == ann1['image_id'] and img_id != img_id1: 201 | self.img2txt[img_id].append(txt_id1) 202 | txt_id1 += 1 203 | if ann['image_id'] == ann1['image_id'] and img_id != img_id1: 204 | for temp in range(t): 205 | self.txt2img[txt_id - 1 - temp].append(img_id1) 206 | 207 | def __len__(self): 208 | return len(self.image) 209 | 210 | def __getitem__(self, index): 211 | image_path = os.path.join(self.image_root, self.ann[index]['image']) 212 | image = Image.open(image_path).convert('RGB') 213 | image = self.transform(image) 214 | return image, index 215 | 216 | 217 | class re_test_dataset_icfg(Dataset): 218 | def __init__(self, config, transform): 219 | ann_file = config['test_file'] 220 | self.ann = json.load(open(ann_file, 'r')) 221 | self.transform = transform 222 | self.image_root = config['image_root'] 223 | self.max_words = config['max_words'] 224 | 225 | self.text = [] 226 | self.image = [] 227 | self.txt2img = {} 228 | self.img2txt = {} 229 | 230 | self.g_pids = [] 231 | self.q_pids = [] 232 | 233 | for img_id, ann in enumerate(self.ann): 234 | self.image.append(ann['image']) 235 | self.g_pids.append(ann['image_id']) 236 | self.img2txt[img_id] = [] 237 | self.img2txt[img_id].append(img_id) 238 | 239 | self.text.append(pre_caption(ann['caption'][0], self.max_words, icfg_rstp=True)) 240 | self.q_pids.append(ann['image_id']) 241 | 242 | self.txt2img[img_id] = [] 243 | self.txt2img[img_id].append(img_id) 244 | 245 | for img_id1, ann1 in enumerate(self.ann): 246 | if ann['image_id'] == ann1['image_id'] and img_id != img_id1: 247 | self.txt2img[img_id].append(img_id1) 248 | self.img2txt[img_id].append(img_id1) 249 | 250 | def __len__(self): 251 | return len(self.image) 252 | 253 | def __getitem__(self, index): 254 | image_path = os.path.join(self.image_root, self.ann[index]['image']) 255 | image = Image.open(image_path).convert('RGB') 256 | image = self.transform(image) 257 | return image, index 258 | 259 | 260 | class re_train_dataset_attr(Dataset): 261 | def __init__(self, config, transform): 262 | ann_file = config['train_file'] 263 | self.ann = [] 264 | for f in ann_file: 265 | self.ann += json.load(open(f, 'r')) 266 | self.transform = transform 267 | self.image_root = config['image_root'] 268 | self.max_words = config['max_words'] 269 | 270 | def __len__(self): 271 | return len(self.ann) 272 | 273 | def __getitem__(self, index): 274 | ann = self.ann[index] 275 | image_path = os.path.join(self.image_root, ann['image']) 276 | image = Image.open(image_path).convert('RGB') 277 | image = self.transform(image) 278 | label = torch.tensor(ann['label']) 279 | return image, label 280 | 281 | 282 | class re_test_dataset_attr(Dataset): 283 | def __init__(self, ann_file, config, transform): 284 | self.ann = json.load(open(ann_file, 'r')) 285 | self.transform = transform 286 | self.image_root = config['image_root'] 287 | self.max_words = config['max_words'] 288 | 289 | self.image = [] 290 | self.label = [] 291 | for img_id, ann in enumerate(self.ann): 292 | self.image.append(ann['image']) 293 | self.label.append(ann['label']) 294 | self.label = np.array(self.label) 295 | 296 | def __len__(self): 297 | return len(self.image) 298 | 299 | def __getitem__(self, index): 300 | image_path = os.path.join(self.image_root, self.ann[index]['image']) 301 | image = Image.open(image_path).convert('RGB') 302 | image = self.transform(image) 303 | return image, index 304 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn.functional as F 9 | 10 | import utils 11 | from tqdm import tqdm 12 | from dataset.eda import * 13 | 14 | def pre_caption(caption, max_words, icfg_rstp=False, is_eda=False, eda_p=0.5): 15 | if icfg_rstp: 16 | try: 17 | caption = re.sub( 18 | r'[^0-9a-z]+', 19 | ' ', 20 | caption.lower(), 21 | ) 22 | except: 23 | print(caption) 24 | caption_words = caption.split() 25 | caption = ' '.join(caption_words) 26 | 27 | # eda 28 | if is_eda and random.random() < eda_p: 29 | caption = eda(caption, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=1)[0] 30 | 31 | # truncate caption 32 | caption_words = caption.split() 33 | if len(caption_words) > max_words: 34 | caption = ' '.join(caption_words[:max_words]) 35 | 36 | if not len(caption): 37 | raise ValueError("pre_caption yields invalid text") 38 | 39 | return caption 40 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.aptm import APTM 2 | from models.aptm import load_pretrained 3 | from models.aptm import AllGather -------------------------------------------------------------------------------- /models/aptm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.distributed as dist 9 | from torch.nn import init 10 | from timm.models.layers import trunc_normal_ 11 | from functools import partial 12 | 13 | from models.swin_transformer import SwinTransformer, interpolate_relative_pos_embed 14 | from models.bert import BertConfig, BertForMaskedLM, BertModel 15 | from utils import read_json 16 | 17 | 18 | class CrossEntropyLabelSmooth(nn.Module): 19 | """Cross entropy loss with label smoothing regularizer. 20 | Reference: 21 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 22 | Equation: y = (1 - epsilon) * y + epsilon / K. 23 | Args: 24 | epsilon (float): weight. 25 | """ 26 | 27 | def __init__(self, epsilon=0.1, use_gpu=True): 28 | super(CrossEntropyLabelSmooth, self).__init__() 29 | self.epsilon = epsilon 30 | self.use_gpu = use_gpu 31 | self.logsoftmax = nn.LogSoftmax(dim=1) 32 | 33 | def forward(self, inputs, targets): 34 | """ 35 | Args: 36 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 37 | targets: ground truth labels with shape (num_classes) 38 | """ 39 | _, num_classes = inputs.shape 40 | log_probs = self.logsoftmax(inputs) 41 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 42 | if self.use_gpu: targets = targets.cuda() 43 | targets = (1 - self.epsilon) * targets + self.epsilon / num_classes 44 | loss = (- targets * log_probs).mean(0).sum() 45 | return loss 46 | 47 | 48 | class AllGather(torch.autograd.Function): 49 | """An autograd function that performs allgather on a tensor.""" 50 | 51 | @staticmethod 52 | def forward(ctx, tensor, rank, world_size): 53 | output = [torch.empty_like(tensor) for _ in range(world_size)] 54 | dist.all_gather(output, tensor) 55 | ctx.rank = rank 56 | ctx.batch_size = tensor.shape[0] 57 | return torch.cat(output, 0) 58 | 59 | @staticmethod 60 | def backward(ctx, grad_output): 61 | return ( 62 | grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], 63 | None, 64 | None 65 | ) 66 | 67 | 68 | allgather = AllGather.apply 69 | 70 | 71 | def build_vision_encoder(config, load_params=False): 72 | """ 73 | Args: load_params: False when building fine-tuning models 74 | """ 75 | 76 | print('use_swin') 77 | vision_config = read_json(config['vision_config']) 78 | assert config['image_res'] == vision_config['image_res'] 79 | assert config['patch_size'] == 32 80 | vision_width = vision_config['vision_width'] 81 | 82 | vision_encoder = SwinTransformer(img_size=vision_config['image_res'], 83 | h=vision_config['h'], 84 | w=vision_config['w'], 85 | patch_size=4, 86 | in_chans=3, 87 | embed_dim=vision_config['embed_dim'], 88 | depths=vision_config['depths'], 89 | num_heads=vision_config['num_heads'], 90 | window_size=vision_config['window_size'], 91 | mlp_ratio=4., 92 | qkv_bias=True, 93 | drop_rate=0.0, 94 | drop_path_rate=0.1, 95 | ape=False, 96 | patch_norm=True, 97 | use_checkpoint=False) 98 | 99 | if load_params: 100 | # download from https://github.com/microsoft/Swin-Transformer 101 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 102 | window_size = vision_config['window_size'] 103 | 104 | for k in list(state_dict.keys()): 105 | if 'relative_position_bias_table' in k: 106 | if 'layers.3' in k: 107 | window_size = 4 108 | dst_num_pos = (2 * window_size - 1) ** 2 109 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 110 | elif ('relative_position_index' in k) or ('attn_mask' in k): 111 | del state_dict[k] 112 | print("### build_vision_encoder: ", flush=True) 113 | msg = vision_encoder.load_state_dict(state_dict, strict=False) 114 | print("missing_keys: ", msg.missing_keys) 115 | print("unexpected_keys: ", msg.unexpected_keys) 116 | 117 | return vision_encoder, vision_width 118 | 119 | 120 | def build_text_encoder(config, vision_width, load_text_params=False, use_mlm_loss=False, config_text=None): 121 | init_params = [] # train from scratch with larger lr 122 | 123 | if config_text is None: 124 | config_text = BertConfig.from_json_file(config['text_config']) 125 | 126 | config_text.encoder_width = vision_width 127 | 128 | if use_mlm_loss: 129 | text_encoder, msg = BertForMaskedLM.from_pretrained(config['text_encoder'], config=config_text, 130 | output_loading_info=True) 131 | if ('init_cross' in config.keys()) and config['init_cross']: 132 | init_params.extend(['text_encoder.' + n for n in msg['missing_keys']]) # of cross attention 133 | print("### init_params.extend --> cross attention ###") 134 | 135 | if load_text_params: 136 | print("### build_text_encoder --> Load BERT: ") 137 | for k, v in msg.items(): 138 | print(f"{k}: {sorted(v)}") 139 | return text_encoder, init_params 140 | 141 | 142 | def build_mlp(input_dim, output_dim): 143 | return nn.Sequential( 144 | nn.Linear(input_dim, input_dim * 2), 145 | nn.LayerNorm(input_dim * 2), 146 | nn.GELU(), 147 | nn.Linear(input_dim * 2, output_dim) 148 | ) 149 | 150 | 151 | def attr_mlp(input_dim, inter_dim, output_dim, after_cross, dropout_p): 152 | if after_cross: 153 | new_mlp = nn.Sequential( 154 | nn.Flatten(), 155 | nn.Linear(input_dim, inter_dim), 156 | nn.LayerNorm(inter_dim), 157 | nn.Dropout(p=dropout_p), 158 | nn.Linear(inter_dim, output_dim) 159 | ) 160 | else: 161 | new_mlp = nn.Sequential( 162 | nn.Flatten(), 163 | nn.Linear(input_dim, inter_dim), 164 | nn.BatchNorm1d(inter_dim), 165 | nn.Dropout(p=dropout_p), 166 | nn.Linear(inter_dim, output_dim) 167 | ) 168 | init.normal_(new_mlp[1].weight.data, std=0.00001) 169 | init.constant_(new_mlp[1].bias.data, 0.0) 170 | init.normal_(new_mlp[4].weight.data, std=0.00001) 171 | init.constant_(new_mlp[4].bias.data, 0.0) 172 | return new_mlp 173 | 174 | 175 | def load_pretrained(ckpt_rpath, config, is_eval=False, load_text=False): 176 | checkpoint = torch.load(ckpt_rpath, map_location='cpu') 177 | state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint 178 | if is_eval: 179 | return state_dict 180 | 181 | print("### Loading pretrained vision encoder", flush=True) 182 | 183 | if load_text: 184 | print("### Loading pretrained text encoder", flush=True) 185 | for key in list(state_dict.keys()): 186 | if 'text_encoder.' in key: 187 | if not config['mlm']: 188 | if 'bert.' in key: 189 | encoder_key = key.replace('bert.', '') 190 | state_dict[encoder_key] = state_dict[key] 191 | del state_dict[key] 192 | else: 193 | if 'bert.' not in key and 'cls' not in key: 194 | encoder_key = key.replace('text_encoder.', 'text_encoder.bert.') 195 | state_dict[encoder_key] = state_dict[key] 196 | del state_dict[key] 197 | 198 | return state_dict 199 | 200 | 201 | class APTM(nn.Module): 202 | def __init__(self, config=None, load_vision_params=False, load_text_params=False, 203 | use_contrastive_loss=False, use_matching_loss=False, 204 | use_mlm_loss=False, config_text=None): 205 | super().__init__() 206 | self.init_params = [] # train from scratch with larger lr 207 | 208 | self.vision_encoder, vision_width = build_vision_encoder(config, load_params=load_vision_params) 209 | self.vision_width = vision_width 210 | 211 | if ('pa100k_only_img_classifier' in config.keys()) and config['pa100k_only_img_classifier']: 212 | self.pa100k_only_img_classifier = config['pa100k_only_img_classifier'] 213 | self.img_cls = attr_mlp(self.vision_width, config['embed_dim'], 26, False, config['dop']) 214 | self.criterion = nn.BCEWithLogitsLoss() 215 | self.criterion = self.criterion.cuda() 216 | 217 | else: 218 | self.pa100k_only_img_classifier = False 219 | # text & cross-modal 220 | self.text_encoder, init_params = build_text_encoder(config, vision_width=vision_width, 221 | load_text_params=load_text_params, 222 | use_mlm_loss=use_mlm_loss, config_text=config_text) 223 | self.text_width = self.text_encoder.config.hidden_size # i.e. cross_width 224 | self.init_params.extend(init_params) 225 | if 0 < config['LabelSmooth'] < 1: 226 | self.new_cross_entropy = CrossEntropyLabelSmooth(epsilon=config['LabelSmooth']) 227 | self.add_label_smooth = True 228 | else: 229 | self.add_label_smooth = False 230 | 231 | # lr * x 232 | if config['lr_2']: 233 | # vision encoder 234 | for i in range(2, 4): 235 | for name, param in self.vision_encoder.layers[i].named_parameters(): 236 | # param.requires_grad = False 237 | self.init_params.extend(['vision_encoder.layers.' + str(i) + '.' + name]) 238 | # text encoder 239 | if config['mlm']: 240 | self.init_params.extend( 241 | ['text_encoder.cls.' + n for n, _ in self.text_encoder.cls.named_parameters()]) 242 | temp_name = 'text_encoder.bert.encoder.layer.' 243 | temp_encoder = self.text_encoder.bert 244 | else: 245 | temp_name = 'text_encoder.encoder.layer.' 246 | temp_encoder = self.text_encoder 247 | temp_list = [4, 5, 10, 11] 248 | for i in temp_list: 249 | for name, param in temp_encoder.encoder.layer[i].named_parameters(): 250 | self.init_params.extend([temp_name + str(i) + '.' + name]) 251 | 252 | if use_contrastive_loss: 253 | self.embed_dim = config['embed_dim'] 254 | self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) 255 | self.text_proj = nn.Linear(self.text_width, self.embed_dim) 256 | self.temp = nn.Parameter(torch.ones([]) * config['temp']) 257 | if config['lr_2']: 258 | self.init_params.extend(['vision_proj.' + n for n, _ in self.vision_proj.named_parameters()]) 259 | self.init_params.extend(['text_proj.' + n for n, _ in self.text_proj.named_parameters()]) 260 | 261 | if use_matching_loss: 262 | self.itm_head = build_mlp(input_dim=self.text_width, output_dim=2) 263 | if config['lr_2']: 264 | self.init_params.extend(['itm_head.' + n for n, _ in self.itm_head.named_parameters()]) 265 | 266 | def load_pretrained(self, ckpt_rpath, config, is_eval=False): 267 | state_dict = load_pretrained(ckpt_rpath, config, is_eval=is_eval, load_text=True) 268 | msg = self.load_state_dict(state_dict, strict=False) 269 | print('load checkpoint from %s' % ckpt_rpath) 270 | print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p]) 271 | print("unexpected_keys: ", msg.unexpected_keys) 272 | 273 | def get_vision_embeds(self, image): 274 | image_embeds = self.vision_encoder(image) 275 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) 276 | return image_embeds, image_atts 277 | 278 | def get_text_embeds(self, text_ids, text_atts): 279 | encoder = self.text_encoder.bert if hasattr(self.text_encoder, 'bert') else self.text_encoder 280 | return encoder(text_ids, attention_mask=text_atts, return_dict=True, mode='text').last_hidden_state 281 | 282 | def get_cross_embeds(self, image_embeds, image_atts, text_ids=None, text_embeds=None, text_atts=None): 283 | assert text_atts is not None 284 | encoder = self.text_encoder.bert if hasattr(self.text_encoder, 'bert') else self.text_encoder 285 | if text_embeds is not None: 286 | return encoder(encoder_embeds=text_embeds, 287 | attention_mask=text_atts, 288 | encoder_hidden_states=image_embeds, 289 | encoder_attention_mask=image_atts, 290 | return_dict=True, 291 | mode='fusion', 292 | ).last_hidden_state 293 | elif text_ids is not None: 294 | return encoder(text_ids, 295 | attention_mask=text_atts, 296 | encoder_hidden_states=image_embeds, 297 | encoder_attention_mask=image_atts, 298 | return_dict=True, 299 | ).last_hidden_state 300 | else: 301 | raise ValueError 302 | 303 | def get_features(self, image_embeds=None, text_embeds=None): 304 | if image_embeds is None: 305 | text_feat = self.text_proj(text_embeds[:, 0, :]) 306 | return text_feat 307 | elif text_embeds is None: 308 | image_feat = self.vision_proj(image_embeds[:, 0, :]) 309 | return image_feat 310 | else: 311 | image_feat = self.vision_proj(image_embeds[:, 0, :]) 312 | text_feat = self.text_proj(text_embeds[:, 0, :]) 313 | return image_feat, text_feat 314 | 315 | def get_contrastive_loss(self, image_feat, text_feat, idx=None): 316 | assert image_feat.size(-1) == self.embed_dim 317 | assert text_feat.size(-1) == self.embed_dim 318 | image_feat = F.normalize(image_feat, dim=-1) 319 | text_feat = F.normalize(text_feat, dim=-1) 320 | 321 | image_feat_all = allgather(image_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) 322 | text_feat_all = allgather(text_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) 323 | logits = image_feat_all @ text_feat_all.t() / self.temp 324 | bsz = image_feat_all.shape[0] 325 | 326 | if idx is None: 327 | labels = torch.arange(bsz, device=image_feat.device) 328 | loss_i2t = F.cross_entropy(logits, labels) 329 | loss_t2i = F.cross_entropy(logits.t(), labels) 330 | return (loss_i2t + loss_t2i) / 2 331 | else: 332 | idx = idx.view(-1, 1) 333 | assert idx.size(0) == image_feat.size(0) 334 | idx_all = allgather(idx, torch.distributed.get_rank(), torch.distributed.get_world_size()) 335 | pos_idx = torch.eq(idx_all, idx_all.t()).float() 336 | labels = pos_idx / pos_idx.sum(1, keepdim=True) 337 | 338 | loss_i2t = -torch.sum(F.log_softmax(logits, dim=1) * labels, dim=1).mean() 339 | loss_t2i = -torch.sum(F.log_softmax(logits.t(), dim=1) * labels, dim=1).mean() 340 | return (loss_i2t + loss_t2i) / 2 341 | 342 | def get_matching_loss(self, image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat, idx=None): 343 | """ 344 | Matching Loss with hard negatives 345 | """ 346 | bs = image_embeds.size(0) 347 | 348 | image_feat = F.normalize(image_feat, dim=-1) 349 | text_feat = F.normalize(text_feat, dim=-1) 350 | 351 | with torch.no_grad(): 352 | sim_i2t = image_feat @ text_feat.t() / self.temp 353 | sim_t2i = text_feat @ image_feat.t() / self.temp 354 | 355 | weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-5 356 | weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-5 357 | 358 | if idx is None: 359 | weights_i2t.fill_diagonal_(0) 360 | weights_t2i.fill_diagonal_(0) 361 | else: 362 | idx = idx.view(-1, 1) 363 | assert idx.size(0) == bs 364 | mask = torch.eq(idx, idx.t()) 365 | weights_i2t.masked_fill_(mask, 0) 366 | weights_t2i.masked_fill_(mask, 0) 367 | 368 | image_embeds_neg = [] 369 | image_atts_neg = [] 370 | for b in range(bs): 371 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 372 | image_embeds_neg.append(image_embeds[neg_idx]) 373 | image_atts_neg.append(image_atts[neg_idx]) 374 | image_embeds_neg = torch.stack(image_embeds_neg, dim=0) 375 | image_atts_neg = torch.stack(image_atts_neg, dim=0) 376 | 377 | text_embeds_neg = [] 378 | text_atts_neg = [] 379 | for b in range(bs): 380 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 381 | text_embeds_neg.append(text_embeds[neg_idx]) 382 | text_atts_neg.append(text_atts[neg_idx]) 383 | text_embeds_neg = torch.stack(text_embeds_neg, dim=0) 384 | text_atts_neg = torch.stack(text_atts_neg, dim=0) 385 | 386 | text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) 387 | text_atts_all = torch.cat([text_atts, text_atts_neg], dim=0) 388 | image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) 389 | image_atts_all = torch.cat([image_atts_neg, image_atts], dim=0) 390 | 391 | cross_pos = self.get_cross_embeds(image_embeds, image_atts, text_embeds=text_embeds, 392 | text_atts=text_atts)[:, 0, :] 393 | cross_neg = self.get_cross_embeds(image_embeds_all, image_atts_all, text_embeds=text_embeds_all, 394 | text_atts=text_atts_all)[:, 0, :] 395 | 396 | output = self.itm_head(torch.cat([cross_pos, cross_neg], dim=0)) 397 | itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), 398 | torch.zeros(2 * bs, dtype=torch.long)], dim=0).to(image_embeds.device) 399 | itm_loss = F.cross_entropy(output, itm_labels) 400 | 401 | return itm_loss 402 | 403 | def get_mlm_loss(self, text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids): 404 | return self.text_encoder(text_ids_masked, 405 | attention_mask=text_atts, 406 | encoder_hidden_states=image_embeds, 407 | encoder_attention_mask=image_atts, 408 | return_dict=True, 409 | labels=masked_ids, 410 | masked_pos=masked_pos).loss 411 | 412 | def label_smooth_loss(self, inputs, targets): 413 | bs = inputs.size(0) 414 | inputs_neg = [] 415 | targets_neg = [] 416 | for b in range(bs): 417 | if targets[b] != -1: 418 | inputs_neg.append(inputs[b]) 419 | targets_neg.append(targets[b]) 420 | if not inputs_neg: 421 | return 0 422 | inputs = torch.stack(inputs_neg, dim=0) 423 | targets = torch.stack(targets_neg, dim=0) 424 | return self.new_cross_entropy(inputs, targets) 425 | 426 | def get_contrastive_loss_attr(self, image_feat, text_feat, label): 427 | image_feat = F.normalize(image_feat, dim=-1) 428 | text_feat = F.normalize(text_feat, dim=-1) 429 | logits = image_feat @ text_feat.t() / self.temp 430 | l = 0 431 | for i in range(label.size(1)): 432 | left = 2 * i 433 | right = 2 * i + 2 434 | if self.add_label_smooth: 435 | l = l + self.label_smooth_loss(logits[:, left:right], label[:, i]) 436 | else: 437 | l = l + F.cross_entropy(logits[:, left:right], label[:, i], ignore_index=-1) 438 | 439 | return l / label.size(1) 440 | 441 | def get_matching_loss_attr(self, image_embeds, image_atts, text_embeds, text_atts, label): 442 | bs = image_embeds.size(0) 443 | 444 | labels = [] 445 | for i in range(label.size(1)): 446 | l = 1 - label[:, i] 447 | l = torch.where(l == 2, -1, l) 448 | labels.append(l) 449 | labels.append(label[:, i]) 450 | labels = torch.stack(labels, dim=1) 451 | 452 | r = random.sample(range(0, text_embeds.size(0)), 5) 453 | ll = 0 454 | for t in r: 455 | text_embeds_0 = text_embeds[t].repeat(bs, 1, 1) 456 | text_atts_0 = text_atts[t].repeat(bs, 1, 1) 457 | cross_0 = self.get_cross_embeds(image_embeds, image_atts, text_embeds=text_embeds_0, 458 | text_atts=text_atts_0)[:, 0, :] 459 | output_0 = self.itm_head(cross_0) 460 | if self.add_label_smooth: 461 | ll = ll + self.label_smooth_loss(output_0, labels[:, t]) 462 | else: 463 | ll = ll + F.cross_entropy(output_0, labels[:, t], ignore_index=-1) 464 | return ll / 5 465 | 466 | def get_mlm_loss_attr(self, text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids, label): 467 | 468 | labels = [] 469 | for i in range(label.size(1)): 470 | l = 1 - label[:, i] 471 | l = torch.where(l == 2, -1, l) 472 | labels.append(l) 473 | labels.append(label[:, i]) 474 | labels = torch.stack(labels, dim=1) 475 | 476 | image_embeds_pos = [] 477 | image_atts_pos = [] 478 | text_ids_masked_pos = [] 479 | text_atts_pos = [] 480 | masked_pos_pos = [] 481 | masked_ids_pos = [] 482 | for b in range(text_atts.size(0)): 483 | temp_label = labels[:, b] 484 | temp_label = torch.where(temp_label == -1, 0, temp_label) 485 | if torch.count_nonzero(temp_label).item() > 0: 486 | text_ids_masked_pos.append(text_ids_masked[b]) 487 | text_atts_pos.append(text_atts[b]) 488 | masked_pos_pos.append(masked_pos[b]) 489 | masked_ids_pos.append(masked_ids[b]) 490 | idx = torch.multinomial(temp_label.float(), 1).item() 491 | image_embeds_pos.append(image_embeds[idx]) 492 | image_atts_pos.append(image_atts[idx]) 493 | 494 | image_embeds_pos = torch.stack(image_embeds_pos, dim=0) 495 | image_atts_pos = torch.stack(image_atts_pos, dim=0) 496 | text_ids_masked_pos = torch.stack(text_ids_masked_pos, dim=0) 497 | text_atts_pos = torch.stack(text_atts_pos, dim=0) 498 | masked_pos_pos = torch.stack(masked_pos_pos, dim=0) 499 | masked_ids_pos = torch.stack(masked_ids_pos, dim=0) 500 | 501 | loss = self.text_encoder(text_ids_masked_pos, 502 | attention_mask=text_atts_pos, 503 | encoder_hidden_states=image_embeds_pos, 504 | encoder_attention_mask=image_atts_pos, 505 | return_dict=True, 506 | labels=masked_ids_pos, 507 | masked_pos=masked_pos_pos).loss 508 | return loss 509 | -------------------------------------------------------------------------------- /models/model_retrieval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import APTM, load_pretrained, AllGather 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class APTM_Retrieval(APTM): 8 | def __init__(self, config): 9 | super().__init__(config, load_vision_params=config['load_params'], load_text_params=config['load_params'], 10 | use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=config['mlm']) 11 | 12 | if not self.pa100k_only_img_classifier: 13 | self.mlm = config['mlm'] 14 | self.pa100k = config['pa100k'] 15 | if not self.pa100k: 16 | self.eda = config['eda'] 17 | if ('attr' in config.keys()) and config['attr']: 18 | self.attr = True 19 | else: 20 | self.attr = False 21 | 22 | def load_pretrained(self, ckpt_rpath, config, is_eval=False): 23 | state_dict = load_pretrained(ckpt_rpath, config, is_eval=is_eval, load_text=True) 24 | msg = self.load_state_dict(state_dict, strict=False) 25 | print('load checkpoint from %s' % ckpt_rpath) 26 | print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p]) 27 | print("vision_encoder missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' in p]) 28 | print("unexpected_keys: ", msg.unexpected_keys) 29 | 30 | def forward(self, image, text_ids, text_atts, text_ids_masked=None, masked_pos=None, masked_ids=None, 31 | idx=None, attr_text_ids=None, attr_text_atts=None, attr_text_ids_masked=None, 32 | attr_masked_pos=None, attr_masked_ids=None, label=None, text_ids_eda=None, text_atts_eda=None): 33 | 34 | if self.pa100k_only_img_classifier: 35 | image_embeds = self.vision_encoder(image) 36 | outputs = self.img_cls(image_embeds[:, 0, :]) 37 | loss = self.criterion(outputs, label.float()) 38 | return loss 39 | 40 | if self.pa100k: 41 | image_embeds, image_atts = self.get_vision_embeds(image) 42 | text_embeds = self.get_text_embeds(text_ids, text_atts) 43 | image_feat, text_feat = self.get_features(image_embeds, text_embeds) 44 | loss_itc = self.get_contrastive_loss_attr(image_feat, text_feat, label) 45 | loss_itm = self.get_matching_loss_attr(image_embeds, image_atts, text_embeds, text_atts, label) 46 | if self.mlm: 47 | loss_mlm = self.get_mlm_loss_attr(text_ids_masked, text_atts, image_embeds, image_atts, 48 | masked_pos, masked_ids, label) 49 | return loss_itc, loss_itm, loss_mlm 50 | else: 51 | return loss_itc, loss_itm 52 | 53 | if self.attr: 54 | image_embeds, image_atts = self.get_vision_embeds(image) 55 | text_embeds = self.get_text_embeds(text_ids, text_atts) 56 | image_feat, text_feat = self.get_features(image_embeds, text_embeds) 57 | 58 | attr_text_embeds = self.get_text_embeds(attr_text_ids, attr_text_atts) 59 | attr_text_feat = self.get_features(text_embeds=attr_text_embeds) 60 | 61 | attr_loss_itc = self.get_contrastive_loss_attr(image_feat, attr_text_feat, label) 62 | attr_loss_itm = self.get_matching_loss_attr(image_embeds, image_atts, attr_text_embeds, attr_text_atts, 63 | label) 64 | 65 | loss_itc = self.get_contrastive_loss(image_feat, text_feat, idx=idx) 66 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, 67 | text_embeds, text_atts, text_feat, idx=idx) 68 | 69 | if self.mlm: 70 | attr_loss_mlm = self.get_mlm_loss_attr(attr_text_ids_masked, attr_text_atts, image_embeds, image_atts, 71 | attr_masked_pos, attr_masked_ids, label) 72 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, 73 | masked_ids) 74 | loss_attr = (attr_loss_itc + attr_loss_itm + attr_loss_mlm) / 3 75 | return loss_itc, loss_itm, loss_mlm, loss_attr 76 | else: 77 | loss_attr = (attr_loss_itc + attr_loss_itm) / 2 78 | return loss_itc, loss_itm, loss_attr 79 | 80 | image_embeds, image_atts = self.get_vision_embeds(image) 81 | text_embeds = self.get_text_embeds(text_ids, text_atts) 82 | image_feat, text_feat = self.get_features(image_embeds, text_embeds) 83 | loss_itc = self.get_contrastive_loss(image_feat, text_feat, idx=idx) 84 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, 85 | text_embeds, text_atts, text_feat, idx=idx) 86 | 87 | # eda 88 | if self.eda: 89 | text_embeds_eda = self.get_text_embeds(text_ids_eda, text_atts_eda) 90 | text_feat_eda = self.get_features(text_embeds=text_embeds_eda) 91 | loss_itc_eda = self.get_contrastive_loss(image_feat, text_feat_eda, idx=idx) 92 | loss_itm_eda = self.get_matching_loss(image_embeds, image_atts, image_feat, 93 | text_embeds_eda, text_atts_eda, text_feat_eda, idx=idx) 94 | loss_itc = loss_itc + 0.8 * loss_itc_eda 95 | loss_itm = loss_itm + 0.8 * loss_itm_eda 96 | 97 | if self.mlm: 98 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, 99 | masked_ids) 100 | return loss_itc, loss_itm, loss_mlm 101 | else: 102 | return loss_itc, loss_itm 103 | -------------------------------------------------------------------------------- /models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | from scipy import interpolate 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.utils.checkpoint as checkpoint 14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 19 | super().__init__() 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = nn.Linear(in_features, hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Linear(hidden_features, out_features) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | 36 | def window_partition(x, window_size): 37 | """ 38 | Args: 39 | x: (B, H, W, C) 40 | window_size (int): window size 41 | 42 | Returns: 43 | windows: (num_windows*B, window_size, window_size, C) 44 | """ 45 | B, H, W, C = x.shape 46 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 47 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 48 | return windows 49 | 50 | 51 | def window_reverse(windows, window_size, H, W): 52 | """ 53 | Args: 54 | windows: (num_windows*B, window_size, window_size, C) 55 | window_size (int): Window size 56 | H (int): Height of image 57 | W (int): Width of image 58 | 59 | Returns: 60 | x: (B, H, W, C) 61 | """ 62 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 63 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 64 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 65 | return x 66 | 67 | 68 | class WindowAttention(nn.Module): 69 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 70 | It supports both of shifted and non-shifted window. 71 | 72 | Args: 73 | dim (int): Number of input channels. 74 | window_size (tuple[int]): The height and width of the window. 75 | num_heads (int): Number of attention heads. 76 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 77 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 78 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 79 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 80 | """ 81 | 82 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 83 | 84 | super().__init__() 85 | self.dim = dim 86 | self.window_size = window_size # Wh, Ww 87 | self.num_heads = num_heads 88 | head_dim = dim // num_heads 89 | self.scale = qk_scale or head_dim ** -0.5 90 | 91 | # define a parameter table of relative position bias 92 | self.relative_position_bias_table = nn.Parameter( 93 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 94 | 95 | # get pair-wise relative position index for each token inside the window 96 | coords_h = torch.arange(self.window_size[0]) 97 | coords_w = torch.arange(self.window_size[1]) 98 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 99 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 100 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 101 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 102 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 103 | relative_coords[:, :, 1] += self.window_size[1] - 1 104 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 105 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 106 | self.register_buffer("relative_position_index", relative_position_index) 107 | 108 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(dim, dim) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | trunc_normal_(self.relative_position_bias_table, std=.02) 114 | self.softmax = nn.Softmax(dim=-1) 115 | 116 | def forward(self, x, mask=None): 117 | """ 118 | Args: 119 | x: input features with shape of (num_windows*B, N, C) 120 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 121 | """ 122 | B_, N, C = x.shape 123 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 124 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 125 | 126 | q = q * self.scale 127 | attn = (q @ k.transpose(-2, -1)) 128 | 129 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 130 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 131 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 132 | attn = attn + relative_position_bias.unsqueeze(0) 133 | 134 | if mask is not None: 135 | nW = mask.shape[0] 136 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 137 | attn = attn.view(-1, self.num_heads, N, N) 138 | attn = self.softmax(attn) 139 | else: 140 | attn = self.softmax(attn) 141 | 142 | attn = self.attn_drop(attn) 143 | 144 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 145 | x = self.proj(x) 146 | x = self.proj_drop(x) 147 | return x 148 | 149 | def extra_repr(self) -> str: 150 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 151 | 152 | def flops(self, N): 153 | # calculate flops for 1 window with token length of N 154 | flops = 0 155 | # qkv = self.qkv(x) 156 | flops += N * self.dim * 3 * self.dim 157 | # attn = (q @ k.transpose(-2, -1)) 158 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 159 | # x = (attn @ v) 160 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 161 | # x = self.proj(x) 162 | flops += N * self.dim * self.dim 163 | return flops 164 | 165 | 166 | class SwinTransformerBlock(nn.Module): 167 | r""" Swin Transformer Block. 168 | 169 | Args: 170 | dim (int): Number of input channels. 171 | input_resolution (tuple[int]): Input resulotion. 172 | num_heads (int): Number of attention heads. 173 | window_size (int): Window size. 174 | shift_size (int): Shift size for SW-MSA. 175 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 176 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 177 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 178 | drop (float, optional): Dropout rate. Default: 0.0 179 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 180 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 181 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 182 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 183 | """ 184 | 185 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 186 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 187 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 188 | super().__init__() 189 | self.dim = dim 190 | self.input_resolution = input_resolution 191 | self.num_heads = num_heads 192 | self.window_size = window_size 193 | self.shift_size = shift_size 194 | self.mlp_ratio = mlp_ratio 195 | if min(self.input_resolution) <= self.window_size: 196 | # if window size is larger than input resolution, we don't partition windows 197 | self.shift_size = 0 198 | self.window_size = min(self.input_resolution) 199 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 200 | 201 | self.norm1 = norm_layer(dim) 202 | self.attn = WindowAttention( 203 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 204 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 205 | 206 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 207 | self.norm2 = norm_layer(dim) 208 | mlp_hidden_dim = int(dim * mlp_ratio) 209 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 210 | 211 | if self.shift_size > 0: 212 | # calculate attention mask for SW-MSA 213 | H, W = self.input_resolution 214 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 215 | h_slices = (slice(0, -self.window_size), 216 | slice(-self.window_size, -self.shift_size), 217 | slice(-self.shift_size, None)) 218 | w_slices = (slice(0, -self.window_size), 219 | slice(-self.window_size, -self.shift_size), 220 | slice(-self.shift_size, None)) 221 | cnt = 0 222 | for h in h_slices: 223 | for w in w_slices: 224 | img_mask[:, h, w, :] = cnt 225 | cnt += 1 226 | 227 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 228 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 229 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 230 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 231 | else: 232 | attn_mask = None 233 | 234 | self.register_buffer("attn_mask", attn_mask) 235 | 236 | def forward(self, x): 237 | H, W = self.input_resolution 238 | B, L, C = x.shape 239 | assert L == H * W, "input feature has wrong size" 240 | 241 | shortcut = x 242 | x = self.norm1(x) 243 | x = x.view(B, H, W, C) 244 | 245 | # cyclic shift 246 | if self.shift_size > 0: 247 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 248 | else: 249 | shifted_x = x 250 | 251 | # partition windows 252 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 253 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 254 | 255 | # W-MSA/SW-MSA 256 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 257 | 258 | # merge windows 259 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 260 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 261 | 262 | # reverse cyclic shift 263 | if self.shift_size > 0: 264 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 265 | else: 266 | x = shifted_x 267 | x = x.view(B, H * W, C) 268 | 269 | # FFN 270 | x = shortcut + self.drop_path(x) 271 | x = x + self.drop_path(self.mlp(self.norm2(x))) 272 | 273 | return x 274 | 275 | def extra_repr(self) -> str: 276 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 277 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 278 | 279 | def flops(self): 280 | flops = 0 281 | H, W = self.input_resolution 282 | # norm1 283 | flops += self.dim * H * W 284 | # W-MSA/SW-MSA 285 | nW = H * W / self.window_size / self.window_size 286 | flops += nW * self.attn.flops(self.window_size * self.window_size) 287 | # mlp 288 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 289 | # norm2 290 | flops += self.dim * H * W 291 | return flops 292 | 293 | 294 | class PatchMerging(nn.Module): 295 | r""" Patch Merging Layer. 296 | 297 | Args: 298 | input_resolution (tuple[int]): Resolution of input feature. 299 | dim (int): Number of input channels. 300 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 301 | """ 302 | 303 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 304 | super().__init__() 305 | self.input_resolution = input_resolution 306 | self.dim = dim 307 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 308 | self.norm = norm_layer(4 * dim) 309 | 310 | def forward(self, x): 311 | """ 312 | x: B, H*W, C 313 | """ 314 | H, W = self.input_resolution 315 | B, L, C = x.shape 316 | assert L == H * W, "input feature has wrong size" 317 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 318 | 319 | x = x.view(B, H, W, C) 320 | 321 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 322 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 323 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 324 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 325 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 326 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 327 | 328 | x = self.norm(x) 329 | x = self.reduction(x) 330 | 331 | return x 332 | 333 | def extra_repr(self) -> str: 334 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 335 | 336 | def flops(self): 337 | H, W = self.input_resolution 338 | flops = H * W * self.dim 339 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 340 | return flops 341 | 342 | 343 | class BasicLayer(nn.Module): 344 | """ A basic Swin Transformer layer for one stage. 345 | 346 | Args: 347 | dim (int): Number of input channels. 348 | input_resolution (tuple[int]): Input resolution. 349 | depth (int): Number of blocks. 350 | num_heads (int): Number of attention heads. 351 | window_size (int): Local window size. 352 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 353 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 354 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 355 | drop (float, optional): Dropout rate. Default: 0.0 356 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 357 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 358 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 359 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 360 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 361 | """ 362 | 363 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 364 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 365 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 366 | 367 | super().__init__() 368 | self.dim = dim 369 | self.input_resolution = input_resolution 370 | self.depth = depth 371 | self.use_checkpoint = use_checkpoint 372 | 373 | # build blocks 374 | self.blocks = nn.ModuleList([ 375 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 376 | num_heads=num_heads, window_size=window_size, 377 | shift_size=0 if (i % 2 == 0) else window_size // 2, 378 | mlp_ratio=mlp_ratio, 379 | qkv_bias=qkv_bias, qk_scale=qk_scale, 380 | drop=drop, attn_drop=attn_drop, 381 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 382 | norm_layer=norm_layer) 383 | for i in range(depth)]) 384 | 385 | # patch merging layer 386 | if downsample is not None: 387 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 388 | else: 389 | self.downsample = None 390 | 391 | def forward(self, x): 392 | for blk in self.blocks: 393 | if self.use_checkpoint: 394 | x = checkpoint.checkpoint(blk, x) 395 | else: 396 | x = blk(x) 397 | if self.downsample is not None: 398 | x = self.downsample(x) 399 | return x 400 | 401 | def extra_repr(self) -> str: 402 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 403 | 404 | def flops(self): 405 | flops = 0 406 | for blk in self.blocks: 407 | flops += blk.flops() 408 | if self.downsample is not None: 409 | flops += self.downsample.flops() 410 | return flops 411 | 412 | 413 | class PatchEmbed(nn.Module): 414 | r""" Image to Patch Embedding 415 | 416 | Args: 417 | img_size (int): Image size. Default: 224. 418 | patch_size (int): Patch token size. Default: 4. 419 | in_chans (int): Number of input image channels. Default: 3. 420 | embed_dim (int): Number of linear projection output channels. Default: 96. 421 | norm_layer (nn.Module, optional): Normalization layer. Default: None 422 | """ 423 | 424 | def __init__(self, img_size=224, h=224, w=224, 425 | patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 426 | super().__init__() 427 | img_size = (h, w) 428 | patch_size = to_2tuple(patch_size) 429 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 430 | self.img_size = img_size 431 | self.patch_size = patch_size 432 | self.patches_resolution = patches_resolution 433 | self.num_patches = patches_resolution[0] * patches_resolution[1] 434 | 435 | self.in_chans = in_chans 436 | self.embed_dim = embed_dim 437 | 438 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 439 | if norm_layer is not None: 440 | self.norm = norm_layer(embed_dim) 441 | else: 442 | self.norm = None 443 | 444 | def forward(self, x): 445 | B, C, H, W = x.shape 446 | # FIXME look at relaxing size constraints 447 | assert H == self.img_size[0] and W == self.img_size[1], \ 448 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 449 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 450 | if self.norm is not None: 451 | x = self.norm(x) 452 | return x 453 | 454 | def flops(self): 455 | Ho, Wo = self.patches_resolution 456 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 457 | if self.norm is not None: 458 | flops += Ho * Wo * self.embed_dim 459 | return flops 460 | 461 | 462 | class SwinTransformer(nn.Module): 463 | r""" Swin Transformer 464 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 465 | https://arxiv.org/pdf/2103.14030 466 | 467 | Args: 468 | img_size (int | tuple(int)): Input image size. Default 224 469 | patch_size (int | tuple(int)): Patch size. Default: 4 470 | in_chans (int): Number of input image channels. Default: 3 471 | num_classes (int): Number of classes for classification head. Default: 1000 472 | embed_dim (int): Patch embedding dimension. Default: 96 473 | depths (tuple(int)): Depth of each Swin Transformer layer. 474 | num_heads (tuple(int)): Number of attention heads in different layers. 475 | window_size (int): Window size. Default: 7 476 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 477 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 478 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 479 | drop_rate (float): Dropout rate. Default: 0 480 | attn_drop_rate (float): Attention dropout rate. Default: 0 481 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 482 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 483 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 484 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 485 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 486 | """ 487 | 488 | def __init__(self, img_size=224, h=224, w=224, 489 | patch_size=4, in_chans=3, num_classes=1000, 490 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 491 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 492 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 493 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 494 | use_checkpoint=False, **kwargs): 495 | super().__init__() 496 | 497 | self.num_classes = num_classes 498 | self.num_layers = len(depths) 499 | self.embed_dim = embed_dim 500 | self.ape = ape 501 | self.patch_norm = patch_norm 502 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 503 | self.mlp_ratio = mlp_ratio 504 | 505 | # split image into non-overlapping patches 506 | self.patch_embed = PatchEmbed( 507 | img_size=img_size, h=h, w=w, 508 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 509 | norm_layer=norm_layer if self.patch_norm else None) 510 | num_patches = self.patch_embed.num_patches 511 | patches_resolution = self.patch_embed.patches_resolution 512 | self.patches_resolution = patches_resolution 513 | 514 | # absolute position embedding 515 | if self.ape: 516 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 517 | trunc_normal_(self.absolute_pos_embed, std=.02) 518 | 519 | self.pos_drop = nn.Dropout(p=drop_rate) 520 | 521 | # stochastic depth 522 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 523 | 524 | # build layers 525 | self.layers = nn.ModuleList() 526 | for i_layer in range(self.num_layers): 527 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 528 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 529 | patches_resolution[1] // (2 ** i_layer)), 530 | depth=depths[i_layer], 531 | num_heads=num_heads[i_layer], 532 | window_size=window_size, 533 | mlp_ratio=self.mlp_ratio, 534 | qkv_bias=qkv_bias, qk_scale=qk_scale, 535 | drop=drop_rate, attn_drop=attn_drop_rate, 536 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 537 | norm_layer=norm_layer, 538 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 539 | use_checkpoint=use_checkpoint) 540 | self.layers.append(layer) 541 | 542 | self.norm = norm_layer(self.num_features) 543 | self.avgpool = nn.AdaptiveAvgPool1d(1) 544 | 545 | # shortcut block 1-->4 546 | # self.my_proj = nn.Conv2d(256, 1024, kernel_size=patch_size, stride=patch_size) 547 | 548 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 549 | 550 | self.apply(self._init_weights) 551 | 552 | def _init_weights(self, m): 553 | if isinstance(m, nn.Linear): 554 | trunc_normal_(m.weight, std=.02) 555 | if isinstance(m, nn.Linear) and m.bias is not None: 556 | nn.init.constant_(m.bias, 0) 557 | elif isinstance(m, nn.LayerNorm): 558 | nn.init.constant_(m.bias, 0) 559 | nn.init.constant_(m.weight, 1.0) 560 | 561 | @torch.jit.ignore 562 | def no_weight_decay(self): 563 | return {'absolute_pos_embed'} 564 | 565 | @torch.jit.ignore 566 | def no_weight_decay_keywords(self): 567 | return {'relative_position_bias_table'} 568 | 569 | def forward(self, x): 570 | x = self.patch_embed(x) 571 | if self.ape: 572 | x = x + self.absolute_pos_embed 573 | x = self.pos_drop(x) 574 | 575 | for i, layer in enumerate(self.layers): 576 | x = layer(x) 577 | 578 | x = self.norm(x) # B L C 579 | x_cls = self.avgpool(x.transpose(1, 2)) # B C 1 580 | x = torch.cat([x_cls.transpose(1, 2), x], dim=1) 581 | return x 582 | 583 | def flops(self): 584 | flops = 0 585 | flops += self.patch_embed.flops() 586 | for i, layer in enumerate(self.layers): 587 | flops += layer.flops() 588 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 589 | flops += self.num_features * self.num_classes 590 | return flops 591 | 592 | 593 | def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''): 594 | # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348 595 | 596 | # rel_pos_bias: relative_position_bias_table 597 | src_num_pos, num_attn_heads = rel_pos_bias.size() 598 | 599 | num_extra_tokens = 0 600 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 601 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 602 | if src_size != dst_size: 603 | print("Position interpolate %s from %dx%d to %dx%d" % (param_name, src_size, src_size, dst_size, dst_size)) 604 | 605 | # extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 606 | # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 607 | 608 | def geometric_progression(a, r, n): 609 | return a * (1.0 - r ** n) / (1.0 - r) 610 | 611 | left, right = 1.01, 1.5 612 | while right - left > 1e-6: 613 | q = (left + right) / 2.0 614 | gp = geometric_progression(1, q, src_size // 2) 615 | if gp > dst_size // 2: 616 | right = q 617 | else: 618 | left = q 619 | 620 | # if q > 1.090307: 621 | # q = 1.090307 622 | 623 | dis = [] 624 | cur = 1 625 | for i in range(src_size // 2): 626 | dis.append(cur) 627 | cur += q ** (i + 1) 628 | 629 | r_ids = [-_ for _ in reversed(dis)] 630 | 631 | x = r_ids + [0] + dis 632 | y = r_ids + [0] + dis 633 | 634 | t = dst_size // 2.0 635 | dx = np.arange(-t, t + 0.1, 1.0) 636 | dy = np.arange(-t, t + 0.1, 1.0) 637 | 638 | # print("Original positions = %s" % str(x)) 639 | # print("Target positions = %s" % str(dx)) 640 | 641 | all_rel_pos_bias = [] 642 | 643 | for i in range(num_attn_heads): 644 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 645 | f = interpolate.interp2d(x, y, z, kind='cubic') 646 | all_rel_pos_bias.append( 647 | torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) 648 | 649 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 650 | 651 | return rel_pos_bias 652 | -------------------------------------------------------------------------------- /models/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for Bert.""" 16 | 17 | 18 | import collections 19 | import os 20 | import unicodedata 21 | from typing import List, Optional, Tuple 22 | 23 | from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace 24 | from transformers.utils import logging 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | "vocab_file": { 33 | "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", 34 | "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", 35 | "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt", 36 | "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", 37 | "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt", 38 | "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt", 39 | "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt", 40 | "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt", 41 | "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt", 42 | "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt", 43 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", 44 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", 45 | "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt", 46 | "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt", 47 | "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt", 48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt", 49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt", 50 | "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt", 51 | } 52 | } 53 | 54 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 55 | "bert-base-uncased": 512, 56 | "bert-large-uncased": 512, 57 | "bert-base-cased": 512, 58 | "bert-large-cased": 512, 59 | "bert-base-multilingual-uncased": 512, 60 | "bert-base-multilingual-cased": 512, 61 | "bert-base-chinese": 512, 62 | "bert-base-german-cased": 512, 63 | "bert-large-uncased-whole-word-masking": 512, 64 | "bert-large-cased-whole-word-masking": 512, 65 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512, 66 | "bert-large-cased-whole-word-masking-finetuned-squad": 512, 67 | "bert-base-cased-finetuned-mrpc": 512, 68 | "bert-base-german-dbmdz-cased": 512, 69 | "bert-base-german-dbmdz-uncased": 512, 70 | "TurkuNLP/bert-base-finnish-cased-v1": 512, 71 | "TurkuNLP/bert-base-finnish-uncased-v1": 512, 72 | "wietsedv/bert-base-dutch-cased": 512, 73 | } 74 | 75 | PRETRAINED_INIT_CONFIGURATION = { 76 | "bert-base-uncased": {"do_lower_case": True}, 77 | "bert-large-uncased": {"do_lower_case": True}, 78 | "bert-base-cased": {"do_lower_case": False}, 79 | "bert-large-cased": {"do_lower_case": False}, 80 | "bert-base-multilingual-uncased": {"do_lower_case": True}, 81 | "bert-base-multilingual-cased": {"do_lower_case": False}, 82 | "bert-base-chinese": {"do_lower_case": False}, 83 | "bert-base-german-cased": {"do_lower_case": False}, 84 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, 85 | "bert-large-cased-whole-word-masking": {"do_lower_case": False}, 86 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, 87 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, 88 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, 89 | "bert-base-german-dbmdz-cased": {"do_lower_case": False}, 90 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, 91 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, 92 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, 93 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, 94 | } 95 | 96 | 97 | def load_vocab(vocab_file): 98 | """Loads a vocabulary file into a dictionary.""" 99 | vocab = collections.OrderedDict() 100 | with open(vocab_file, "r", encoding="utf-8") as reader: 101 | tokens = reader.readlines() 102 | for index, token in enumerate(tokens): 103 | token = token.rstrip("\n") 104 | vocab[token] = index 105 | return vocab 106 | 107 | 108 | def whitespace_tokenize(text): 109 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 110 | text = text.strip() 111 | if not text: 112 | return [] 113 | tokens = text.split() 114 | return tokens 115 | 116 | 117 | class BertTokenizer(PreTrainedTokenizer): 118 | r""" 119 | Construct a BERT tokenizer. Based on WordPiece. 120 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 121 | Users should refer to this superclass for more information regarding those methods. 122 | Args: 123 | vocab_file (:obj:`str`): 124 | File containing the vocabulary. 125 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 126 | Whether or not to lowercase the input when tokenizing. 127 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 128 | Whether or not to do basic tokenization before WordPiece. 129 | never_split (:obj:`Iterable`, `optional`): 130 | Collection of tokens which will never be split during tokenization. Only has an effect when 131 | :obj:`do_basic_tokenize=True` 132 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 133 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 134 | token instead. 135 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 136 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 137 | sequence classification or for a text and a question for question answering. It is also used as the last 138 | token of a sequence built with special tokens. 139 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 140 | The token used for padding, for example when batching sequences of different lengths. 141 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 142 | The classifier token which is used when doing sequence classification (classification of the whole sequence 143 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 144 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 145 | The token used for masking values. This is the token used when training this model with masked language 146 | modeling. This is the token which the model will try to predict. 147 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 148 | Whether or not to tokenize Chinese characters. 149 | This should likely be deactivated for Japanese (see this `issue 150 | `__). 151 | strip_accents: (:obj:`bool`, `optional`): 152 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 153 | value for :obj:`lowercase` (as in the original BERT). 154 | """ 155 | 156 | vocab_files_names = VOCAB_FILES_NAMES 157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 158 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 159 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 160 | 161 | def __init__( 162 | self, 163 | vocab_file, 164 | do_lower_case=True, 165 | do_basic_tokenize=True, 166 | never_split=None, 167 | unk_token="[UNK]", 168 | sep_token="[SEP]", 169 | pad_token="[PAD]", 170 | cls_token="[CLS]", 171 | mask_token="[MASK]", 172 | tokenize_chinese_chars=True, 173 | strip_accents=None, 174 | **kwargs 175 | ): 176 | super().__init__( 177 | do_lower_case=do_lower_case, 178 | do_basic_tokenize=do_basic_tokenize, 179 | never_split=never_split, 180 | unk_token=unk_token, 181 | sep_token=sep_token, 182 | pad_token=pad_token, 183 | cls_token=cls_token, 184 | mask_token=mask_token, 185 | tokenize_chinese_chars=tokenize_chinese_chars, 186 | strip_accents=strip_accents, 187 | **kwargs, 188 | ) 189 | 190 | if not os.path.isfile(vocab_file): 191 | raise ValueError( 192 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 193 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) 194 | ) 195 | self.vocab = load_vocab(vocab_file) 196 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 197 | self.do_basic_tokenize = do_basic_tokenize 198 | if do_basic_tokenize: 199 | self.basic_tokenizer = BasicTokenizer( 200 | do_lower_case=do_lower_case, 201 | never_split=never_split, 202 | tokenize_chinese_chars=tokenize_chinese_chars, 203 | strip_accents=strip_accents, 204 | ) 205 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 206 | 207 | @property 208 | def do_lower_case(self): 209 | return self.basic_tokenizer.do_lower_case 210 | 211 | @property 212 | def vocab_size(self): 213 | return len(self.vocab) 214 | 215 | def get_vocab(self): 216 | return dict(self.vocab, **self.added_tokens_encoder) 217 | 218 | def _tokenize(self, text): 219 | split_tokens = [] 220 | if self.do_basic_tokenize: 221 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 222 | 223 | # If the token is part of the never_split set 224 | if token in self.basic_tokenizer.never_split: 225 | split_tokens.append(token) 226 | else: 227 | split_tokens += self.wordpiece_tokenizer.tokenize(token) 228 | else: 229 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 230 | return split_tokens 231 | 232 | def _convert_token_to_id(self, token): 233 | """ Converts a token (str) in an id using the vocab. """ 234 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 235 | 236 | def _convert_id_to_token(self, index): 237 | """Converts an index (integer) in a token (str) using the vocab.""" 238 | return self.ids_to_tokens.get(index, self.unk_token) 239 | 240 | def convert_tokens_to_string(self, tokens): 241 | """ Converts a sequence of tokens (string) in a single string. """ 242 | out_string = " ".join(tokens).replace(" ##", "").strip() 243 | return out_string 244 | 245 | def build_inputs_with_special_tokens( 246 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 247 | ) -> List[int]: 248 | """ 249 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 250 | adding special tokens. A BERT sequence has the following format: 251 | - single sequence: ``[CLS] X `` 252 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 253 | Args: 254 | token_ids_0 (:obj:`List[int]`): 255 | List of IDs to which the special tokens will be added. 256 | token_ids_1 (:obj:`List[int]`, `optional`): 257 | Optional second list of IDs for sequence pairs. 258 | Returns: 259 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 260 | """ 261 | if token_ids_1 is None: 262 | return [self.cls_token_id] + token_ids_0 263 | cls = [self.cls_token_id] 264 | sep = [self.sep_token_id] 265 | return cls + token_ids_0 + sep + token_ids_1 + sep 266 | 267 | def get_special_tokens_mask( 268 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 269 | ) -> List[int]: 270 | """ 271 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 272 | special tokens using the tokenizer ``prepare_for_model`` method. 273 | Args: 274 | token_ids_0 (:obj:`List[int]`): 275 | List of IDs. 276 | token_ids_1 (:obj:`List[int]`, `optional`): 277 | Optional second list of IDs for sequence pairs. 278 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 279 | Whether or not the token list is already formatted with special tokens for the model. 280 | Returns: 281 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 282 | """ 283 | 284 | if already_has_special_tokens: 285 | if token_ids_1 is not None: 286 | raise ValueError( 287 | "You should not supply a second sequence if the provided sequence of " 288 | "ids is already formatted with special tokens for the model." 289 | ) 290 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 291 | 292 | if token_ids_1 is not None: 293 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 294 | return [1] + ([0] * len(token_ids_0)) + [1] 295 | 296 | def create_token_type_ids_from_sequences( 297 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 298 | ) -> List[int]: 299 | """ 300 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence 301 | pair mask has the following format: 302 | :: 303 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 304 | | first sequence | second sequence | 305 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). 306 | Args: 307 | token_ids_0 (:obj:`List[int]`): 308 | List of IDs. 309 | token_ids_1 (:obj:`List[int]`, `optional`): 310 | Optional second list of IDs for sequence pairs. 311 | Returns: 312 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 313 | sequence(s). 314 | """ 315 | sep = [self.sep_token_id] 316 | cls = [self.cls_token_id] 317 | if token_ids_1 is None: 318 | return len(cls + token_ids_0 + sep) * [0] 319 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 320 | 321 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 322 | index = 0 323 | if os.path.isdir(save_directory): 324 | vocab_file = os.path.join( 325 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 326 | ) 327 | else: 328 | vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory 329 | with open(vocab_file, "w", encoding="utf-8") as writer: 330 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 331 | if index != token_index: 332 | logger.warning( 333 | "Saving vocabulary to {}: vocabulary indices are not consecutive." 334 | " Please check that the vocabulary is not corrupted!".format(vocab_file) 335 | ) 336 | index = token_index 337 | writer.write(token + "\n") 338 | index += 1 339 | return (vocab_file,) 340 | 341 | 342 | class BasicTokenizer(object): 343 | """ 344 | Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). 345 | Args: 346 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 347 | Whether or not to lowercase the input when tokenizing. 348 | never_split (:obj:`Iterable`, `optional`): 349 | Collection of tokens which will never be split during tokenization. Only has an effect when 350 | :obj:`do_basic_tokenize=True` 351 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 352 | Whether or not to tokenize Chinese characters. 353 | This should likely be deactivated for Japanese (see this `issue 354 | `__). 355 | strip_accents: (:obj:`bool`, `optional`): 356 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 357 | value for :obj:`lowercase` (as in the original BERT). 358 | """ 359 | 360 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): 361 | if never_split is None: 362 | never_split = [] 363 | self.do_lower_case = do_lower_case 364 | self.never_split = set(never_split) 365 | self.tokenize_chinese_chars = tokenize_chinese_chars 366 | self.strip_accents = strip_accents 367 | 368 | def tokenize(self, text, never_split=None): 369 | """ 370 | Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see 371 | WordPieceTokenizer. 372 | Args: 373 | **never_split**: (`optional`) list of str 374 | Kept for backward compatibility purposes. Now implemented directly at the base class level (see 375 | :func:`PreTrainedTokenizer.tokenize`) List of token not to split. 376 | """ 377 | # union() returns a new set by concatenating the two sets. 378 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split 379 | text = self._clean_text(text) 380 | 381 | # This was added on November 1st, 2018 for the multilingual and Chinese 382 | # models. This is also applied to the English models now, but it doesn't 383 | # matter since the English models were not trained on any Chinese data 384 | # and generally don't have any Chinese data in them (there are Chinese 385 | # characters in the vocabulary because Wikipedia does have some Chinese 386 | # words in the English Wikipedia.). 387 | if self.tokenize_chinese_chars: 388 | text = self._tokenize_chinese_chars(text) 389 | orig_tokens = whitespace_tokenize(text) 390 | split_tokens = [] 391 | for token in orig_tokens: 392 | if token not in never_split: 393 | if self.do_lower_case: 394 | token = token.lower() 395 | if self.strip_accents is not False: 396 | token = self._run_strip_accents(token) 397 | elif self.strip_accents: 398 | token = self._run_strip_accents(token) 399 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 400 | 401 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 402 | return output_tokens 403 | 404 | def _run_strip_accents(self, text): 405 | """Strips accents from a piece of text.""" 406 | text = unicodedata.normalize("NFD", text) 407 | output = [] 408 | for char in text: 409 | cat = unicodedata.category(char) 410 | if cat == "Mn": 411 | continue 412 | output.append(char) 413 | return "".join(output) 414 | 415 | def _run_split_on_punc(self, text, never_split=None): 416 | """Splits punctuation on a piece of text.""" 417 | if never_split is not None and text in never_split: 418 | return [text] 419 | chars = list(text) 420 | i = 0 421 | start_new_word = True 422 | output = [] 423 | while i < len(chars): 424 | char = chars[i] 425 | if _is_punctuation(char): 426 | output.append([char]) 427 | start_new_word = True 428 | else: 429 | if start_new_word: 430 | output.append([]) 431 | start_new_word = False 432 | output[-1].append(char) 433 | i += 1 434 | 435 | return ["".join(x) for x in output] 436 | 437 | def _tokenize_chinese_chars(self, text): 438 | """Adds whitespace around any CJK character.""" 439 | output = [] 440 | for char in text: 441 | cp = ord(char) 442 | if self._is_chinese_char(cp): 443 | output.append(" ") 444 | output.append(char) 445 | output.append(" ") 446 | else: 447 | output.append(char) 448 | return "".join(output) 449 | 450 | def _is_chinese_char(self, cp): 451 | """Checks whether CP is the codepoint of a CJK character.""" 452 | # This defines a "chinese character" as anything in the CJK Unicode block: 453 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 454 | # 455 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 456 | # despite its name. The modern Korean Hangul alphabet is a different block, 457 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 458 | # space-separated words, so they are not treated specially and handled 459 | # like the all of the other languages. 460 | if ( 461 | (cp >= 0x4E00 and cp <= 0x9FFF) 462 | or (cp >= 0x3400 and cp <= 0x4DBF) # 463 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 464 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 465 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 466 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 467 | or (cp >= 0xF900 and cp <= 0xFAFF) 468 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 469 | ): # 470 | return True 471 | 472 | return False 473 | 474 | def _clean_text(self, text): 475 | """Performs invalid character removal and whitespace cleanup on text.""" 476 | output = [] 477 | for char in text: 478 | cp = ord(char) 479 | if cp == 0 or cp == 0xFFFD or _is_control(char): 480 | continue 481 | if _is_whitespace(char): 482 | output.append(" ") 483 | else: 484 | output.append(char) 485 | return "".join(output) 486 | 487 | 488 | class WordpieceTokenizer(object): 489 | """Runs WordPiece tokenization.""" 490 | 491 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 492 | self.vocab = vocab 493 | self.unk_token = unk_token 494 | self.max_input_chars_per_word = max_input_chars_per_word 495 | 496 | def tokenize(self, text): 497 | """ 498 | Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform 499 | tokenization using the given vocabulary. 500 | For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. 501 | Args: 502 | text: A single token or whitespace separated tokens. This should have 503 | already been passed through `BasicTokenizer`. 504 | Returns: 505 | A list of wordpiece tokens. 506 | """ 507 | 508 | output_tokens = [] 509 | for token in whitespace_tokenize(text): 510 | chars = list(token) 511 | if len(chars) > self.max_input_chars_per_word: 512 | output_tokens.append(self.unk_token) 513 | continue 514 | 515 | is_bad = False 516 | start = 0 517 | sub_tokens = [] 518 | while start < len(chars): 519 | end = len(chars) 520 | cur_substr = None 521 | while start < end: 522 | substr = "".join(chars[start:end]) 523 | if start > 0: 524 | substr = "##" + substr 525 | if substr in self.vocab: 526 | cur_substr = substr 527 | break 528 | end -= 1 529 | if cur_substr is None: 530 | is_bad = True 531 | break 532 | sub_tokens.append(cur_substr) 533 | start = end 534 | 535 | if is_bad: 536 | output_tokens.append(self.unk_token) 537 | else: 538 | output_tokens.extend(sub_tokens) 539 | return output_tokens -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW 2 | 3 | 4 | def create_optimizer(args, model): 5 | lr = args.lr 6 | wd = args.weight_decay 7 | lr_mult = getattr(args, 'lr_mult', 1) 8 | print("### lr: ", lr, " ### lr_mult: ", lr_mult, flush=True) 9 | 10 | optimizer_grouped_parameters = [ 11 | {"params": [], "weight_decay": wd, "lr": lr}, 12 | {"params": [], "weight_decay": 0.0, "lr": lr}, 13 | {"params": [], "weight_decay": wd, "lr": lr * lr_mult}, 14 | {"params": [], "weight_decay": 0.0, "lr": lr * lr_mult} 15 | ] 16 | 17 | no_decay = {"bias", 18 | "LayerNorm.bias", 19 | "LayerNorm.weight", 20 | "norm.bias", 21 | "norm.weight", 22 | "norm1.bias", 23 | "norm1.weight", 24 | "norm2.bias", 25 | "norm2.weight"} 26 | 27 | if hasattr(model, 'init_params'): 28 | large_lr = model.init_params 29 | print("### model has 'init_params', ", len(large_lr), flush=True) 30 | else: 31 | large_lr = {} 32 | 33 | for n, p in model.named_parameters(): 34 | if not p.requires_grad: 35 | continue # frozen weights 36 | 37 | if any(nd in n for nd in no_decay): 38 | if n in large_lr: 39 | optimizer_grouped_parameters[3]['params'].append(p) 40 | else: 41 | optimizer_grouped_parameters[1]['params'].append(p) 42 | else: # decay 43 | if n in large_lr: 44 | optimizer_grouped_parameters[2]['params'].append(p) 45 | else: 46 | optimizer_grouped_parameters[0]['params'].append(p) 47 | 48 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)) 49 | 50 | return optimizer 51 | -------------------------------------------------------------------------------- /reTools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import copy 4 | import numpy as np 5 | import time 6 | import datetime 7 | import json 8 | from pathlib import Path 9 | from matplotlib import pyplot as plt 10 | import seaborn as sns 11 | from PIL import Image, ImageFont, ImageDraw 12 | from sklearn import metrics 13 | from easydict import EasyDict 14 | from prettytable import PrettyTable 15 | 16 | import torch 17 | import torch.distributed as dist 18 | import torch.nn.functional as F 19 | 20 | import utils 21 | 22 | 23 | @torch.no_grad() 24 | def evaluation_attr(model, data_loader, tokenizer, device, config, args): 25 | model.eval() 26 | metric_logger = utils.MetricLogger(delimiter=" ") 27 | header = 'Evaluation:' 28 | print('Computing features for evaluation attr...') 29 | start_time = time.time() 30 | 31 | text = ['the person is a man', 'the person is a woman', 32 | 'the person is no more than 60 years old', 'the person is older than 60 years old', 33 | 'the person is a young or old one', 'the person is of mid age, between 18 and 60 years old', 34 | 'the person is older than 18', 'the person is a baby or a teenager, younger than 18', 35 | 36 | 'the picture is not the front of the person', 'the picture shows the front of the person', 37 | 'the picture is not the side of the person', 'the picture shows the side of the person', 38 | 'the picture is not the back of the person', 'the picture shows the back of the person', 39 | 'a person without a hat', 'a person with a hat', 40 | 41 | 'a person without a glasses', 'a person with a glasses', 42 | 'a person without a handbag', 'a person with a handbag', 43 | 'a person without a shoulder bag', 'a person with a shoulder bag', 44 | 'a person without a backpack', 'a person with a backpack', 45 | 46 | 'the person does not hold an object in front', 'the person hold an object in front', 47 | 'the person does not wear short sleeved upper clothes', 'the person wears short sleeved upper clothes', 48 | 'the person does not wear long sleeved upper clothes', 'the person wears long sleeved upper clothes', 49 | 'there is no stride on the upper clothes of the person', 50 | 'there is stride on the upper clothes of the person', 51 | 52 | 'there is no logo on the upper clothes of the person', 53 | 'there is logo on the upper clothes of the person', 54 | 'there is no plaid on the upper clothes of the person', 55 | 'there is plaid on the upper clothes of the person', 56 | 'there is no splice on the upper clothes of the person', 57 | 'there is splice on the upper clothes of the person', 58 | 'there is no stripe on the upper clothes of the person', 59 | 'there is stripe on the upper clothes of the person', 60 | 61 | 'there is no pattern on the lower part of the person', 62 | 'there is pattern on the lower part of the person', 63 | 'the person does not wear long coat', 'the person wears long coat', 64 | 'the person does not wear trousers', 'the person wears trousers', 65 | 'the person does not wear shorts', 'the person wears shorts', 66 | 67 | 'the person does not wear a skirt or a dress', 'the person wears a skirt or a dress', 68 | 'the person does not wear boots', 'the person wears boots', 69 | ] 70 | 71 | text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'], 72 | return_tensors="pt").to(device) 73 | text_embeds = model.get_text_embeds(text_input.input_ids, text_input.attention_mask) 74 | text_atts = text_input.attention_mask 75 | 76 | image_embeds = [] 77 | for image, img_id in data_loader: 78 | image = image.to(device) 79 | image_embed, _ = model.get_vision_embeds(image) 80 | image_embeds.append(image_embed) 81 | image_embeds = torch.cat(image_embeds, dim=0) 82 | 83 | score_matrix_i2t = torch.full((len(data_loader.dataset.image), len(text)), -1000.0).to(device) 84 | num_tasks = utils.get_world_size() 85 | rank = utils.get_rank() 86 | step = image_embeds.size(0) // num_tasks + 1 87 | start = rank * step 88 | end = min(image_embeds.size(0), start + step) 89 | 90 | for i, image_embed in enumerate(metric_logger.log_every(image_embeds[start:end], 50, header)): 91 | encoder_output = image_embed.repeat(len(text), 1, 1) 92 | encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device) 93 | output = model.get_cross_embeds(encoder_output, encoder_att, text_embeds=text_embeds, 94 | text_atts=text_atts)[:, 0, :] 95 | score = model.itm_head(output)[:, 1] 96 | score_matrix_i2t[start + i] = score 97 | if args.distributed: 98 | dist.barrier() 99 | torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM) 100 | total_time = time.time() - start_time 101 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 102 | print('Evaluation time {}'.format(total_time_str)) 103 | return score_matrix_i2t.cpu().numpy() 104 | 105 | 106 | @torch.no_grad() 107 | def evaluation_attr_only_img_classifier(model, data_loader, tokenizer, device, config, args): 108 | model.eval() 109 | metric_logger = utils.MetricLogger(delimiter=" ") 110 | header = 'Evaluation:' 111 | print('Computing features for evaluation attr...') 112 | start_time = time.time() 113 | 114 | image_embeds = [] 115 | outputs = [] 116 | for image, img_id in data_loader: 117 | image = image.to(device) 118 | image_embed = model.vision_encoder(image) 119 | output = model.img_cls(image_embed[:, 0, :]) 120 | output = torch.sigmoid(output) 121 | outputs.append(output) 122 | outputs = torch.cat(outputs, dim=0) 123 | orig_outputs = outputs.data.cpu().numpy() 124 | # transform raw outputs to attributes (binary codes) 125 | outputs = copy.deepcopy(orig_outputs) 126 | outputs[outputs < 0.5] = 0 127 | outputs[outputs >= 0.5] = 1 128 | return outputs 129 | 130 | 131 | @torch.no_grad() 132 | def accs(pred, y): 133 | print('Testing ... metrics') 134 | num_persons = pred.shape[0] 135 | print('num_persons', num_persons) 136 | ins_acc = 0 137 | ins_prec = 0 138 | ins_rec = 0 139 | mA_history = { 140 | 'correct_pos': 0, 141 | 'real_pos': 0, 142 | 'correct_neg': 0, 143 | 'real_neg': 0 144 | } 145 | 146 | # compute label-based metric 147 | outputs = pred 148 | attrs = y 149 | overlaps = outputs * attrs 150 | mA_history['correct_pos'] += overlaps.sum(0) 151 | mA_history['real_pos'] += attrs.sum(0) 152 | inv_overlaps = (1 - outputs) * (1 - attrs) 153 | mA_history['correct_neg'] += inv_overlaps.sum(0) 154 | mA_history['real_neg'] += (1 - attrs).sum(0) 155 | 156 | outputs = outputs.astype(bool) 157 | attrs = attrs.astype(bool) 158 | 159 | # compute instabce-based accuracy 160 | intersect = (outputs & attrs).astype(float) 161 | union = (outputs | attrs).astype(float) 162 | ins_acc += (intersect.sum(1) / union.sum(1)).sum() 163 | ins_prec += (intersect.sum(1) / outputs.astype(float).sum(1)).sum() 164 | ins_rec += (intersect.sum(1) / attrs.astype(float).sum(1)).sum() 165 | 166 | ins_acc /= num_persons 167 | ins_prec /= num_persons 168 | ins_rec /= num_persons 169 | ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec) 170 | 171 | term1 = mA_history['correct_pos'] / mA_history['real_pos'] 172 | term2 = mA_history['correct_neg'] / mA_history['real_neg'] 173 | label_mA_verbose = (term1 + term2) * 0.5 174 | label_mA = label_mA_verbose.mean() 175 | 176 | print('* Results *') 177 | print(' # test persons: {}'.format(num_persons)) 178 | print(' (label-based) mean accuracy: {:.2%}'.format(label_mA)) 179 | print(' (instance-based) accuracy: {:.2%}'.format(ins_acc)) 180 | print(' (instance-based) precition: {:.2%}'.format(ins_prec)) 181 | print(' (instance-based) recall: {:.2%}'.format(ins_rec)) 182 | print(' (instance-based) f1-score: {:.2%}'.format(ins_f1)) 183 | print(' mA for each attribute: {}'.format(label_mA_verbose)) 184 | return label_mA, ins_acc, ins_prec, ins_rec, ins_f1 185 | 186 | 187 | @torch.no_grad() 188 | def itm_eval_attr(scores_i2t, dataset): 189 | label = dataset.label 190 | pred = [] 191 | for i in range(label.shape[1]): 192 | a = np.argmax(scores_i2t[:, 2 * i: 2 * i + 2], axis=1) 193 | pred.append(a) 194 | 195 | label_mA, ins_acc, ins_prec, ins_rec, ins_f1 = accs(np.array(pred).T, label) 196 | print('############################################################\n') 197 | eval_result = {'label_mA': round(label_mA, 4), 198 | 'ins_acc': round(ins_acc, 4), 199 | 'ins_prec': round(ins_prec, 4), 200 | 'ins_rec': round(ins_rec, 4), 201 | 'ins_f1': round(ins_f1, 4), 202 | } 203 | return eval_result 204 | 205 | 206 | @torch.no_grad() 207 | def itm_eval_attr_only_img_classifier(scores_i2t, dataset): 208 | label = dataset.label 209 | pred = scores_i2t 210 | label_mA, ins_acc, ins_prec, ins_rec, ins_f1 = accs(pred, label) 211 | print('############################################################\n') 212 | eval_result = {'label_mA': round(label_mA, 4), 213 | 'ins_acc': round(ins_acc, 4), 214 | 'ins_prec': round(ins_prec, 4), 215 | 'ins_rec': round(ins_rec, 4), 216 | 'ins_f1': round(ins_f1, 4), 217 | } 218 | return eval_result 219 | 220 | 221 | @torch.no_grad() 222 | def evaluation(model, data_loader, tokenizer, device, config, args): 223 | model.eval() 224 | 225 | metric_logger = utils.MetricLogger(delimiter=" ") 226 | header = 'Evaluation:' 227 | 228 | print('Computing features for evaluation...') 229 | start_time = time.time() 230 | 231 | texts = data_loader.dataset.text 232 | num_text = len(texts) 233 | text_bs = config['batch_size_test_text'] # 256 234 | text_embeds = [] 235 | text_atts = [] 236 | text_feats = [] 237 | for i in range(0, num_text, text_bs): 238 | text = texts[i: min(num_text, i + text_bs)] 239 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'], 240 | return_tensors="pt").to(device) 241 | text_embed = model.get_text_embeds(text_input.input_ids, text_input.attention_mask) 242 | text_feat = model.text_proj(text_embed[:, 0, :]) 243 | text_feat = F.normalize(text_feat, dim=-1) 244 | 245 | text_embeds.append(text_embed) 246 | text_atts.append(text_input.attention_mask) 247 | text_feats.append(text_feat) 248 | text_embeds = torch.cat(text_embeds, dim=0) 249 | text_atts = torch.cat(text_atts, dim=0) 250 | text_feats = torch.cat(text_feats, dim=0) 251 | 252 | image_embeds = [] 253 | image_feats = [] 254 | for image, img_id in data_loader: 255 | image = image.to(device) 256 | image_embed, _ = model.get_vision_embeds(image) 257 | image_feat = model.vision_proj(image_embed[:, 0, :]) 258 | image_feat = F.normalize(image_feat, dim=-1) 259 | image_embeds.append(image_embed) 260 | image_feats.append(image_feat) 261 | image_embeds = torch.cat(image_embeds, dim=0) 262 | image_feats = torch.cat(image_feats, dim=0) 263 | sims_matrix = image_feats @ text_feats.t() 264 | sims_matrix = sims_matrix.t() 265 | score_matrix_t2i = torch.full((len(texts), len(data_loader.dataset.image)), 1000.0).to(device) 266 | score_sim_t2i = sims_matrix 267 | 268 | num_tasks = utils.get_world_size() 269 | rank = utils.get_rank() 270 | step = sims_matrix.size(0) // num_tasks + 1 271 | start = rank * step 272 | end = min(sims_matrix.size(0), start + step) 273 | 274 | for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 275 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 276 | encoder_output = image_embeds[topk_idx] 277 | encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device) 278 | 279 | output = model.get_cross_embeds(encoder_output, encoder_att, 280 | text_embeds=text_embeds[start + i].repeat(config['k_test'], 1, 1), 281 | text_atts=text_atts[start + i].repeat(config['k_test'], 1))[:, 0, :] 282 | score = model.itm_head(output)[:, 1] 283 | score_matrix_t2i[start + i, topk_idx] = score 284 | score_sim_t2i[start + i, topk_idx] = topk_sim 285 | 286 | min_values, _ = torch.min(score_matrix_t2i, dim=1) 287 | replacement_tensor = min_values.view(-1, 1).expand(-1, score_matrix_t2i.size(1)) 288 | score_matrix_t2i[score_matrix_t2i == 1000.0] = replacement_tensor[score_matrix_t2i == 1000.0] 289 | score_sim_t2i = (score_sim_t2i - score_sim_t2i.min()) / (score_sim_t2i.max() - score_sim_t2i.min()) 290 | score_matrix_t2i = (score_matrix_t2i - score_matrix_t2i.min()) / (score_matrix_t2i.max() - score_matrix_t2i.min()) 291 | score_matrix_t2i = score_matrix_t2i + 0.002*score_sim_t2i 292 | 293 | if args.distributed: 294 | dist.barrier() 295 | torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM) 296 | 297 | total_time = time.time() - start_time 298 | per_time = total_time / num_text 299 | print('total_time', total_time) 300 | print('per_time', per_time) 301 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 302 | print('Evaluation time {}'.format(total_time_str)) 303 | 304 | return score_matrix_t2i.cpu().numpy() 305 | 306 | 307 | def mAP(scores_t2i, g_pids, q_pids, table=None): 308 | similarity = torch.tensor(scores_t2i) 309 | indices = torch.argsort(similarity, dim=1, descending=True) 310 | g_pids = torch.tensor(g_pids) 311 | q_pids = torch.tensor(q_pids) 312 | pred_labels = g_pids[indices.cpu()] # q * k 313 | matches = pred_labels.eq(q_pids.view(-1, 1)) # q * k 314 | 315 | all_cmc = matches[:, :10].cumsum(1) # cumulative sum 316 | all_cmc[all_cmc > 1] = 1 317 | all_cmc = all_cmc.float().mean(0) * 100 318 | # all_cmc = all_cmc[topk - 1] 319 | 320 | num_rel = matches.sum(1) # q 321 | tmp_cmc = matches.cumsum(1) # q * k 322 | 323 | inp = [tmp_cmc[i][match_row.nonzero()[-1]] / (match_row.nonzero()[-1] + 1.) for i, match_row in enumerate(matches)] 324 | mINP = torch.cat(inp).mean() * 100 325 | 326 | tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])] 327 | tmp_cmc = torch.stack(tmp_cmc, 1) * matches 328 | AP = tmp_cmc.sum(1) / num_rel # q 329 | mAP = AP.mean() * 100 330 | 331 | t2i_cmc, t2i_mAP, t2i_mINP, _ = all_cmc, mAP, mINP, indices 332 | t2i_cmc, t2i_mAP, t2i_mINP = t2i_cmc.numpy(), t2i_mAP.numpy(), t2i_mINP.numpy() 333 | 334 | if not table: 335 | table = PrettyTable(["task", "R1", "R5", "R10", "mAP", "mINP"]) 336 | table.add_row(['t2i', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP, t2i_mINP]) 337 | table.custom_format["R1"] = lambda f, v: f"{v:.3f}" 338 | table.custom_format["R5"] = lambda f, v: f"{v:.3f}" 339 | table.custom_format["R10"] = lambda f, v: f"{v:.3f}" 340 | table.custom_format["mAP"] = lambda f, v: f"{v:.3f}" 341 | table.custom_format["mINP"] = lambda f, v: f"{v:.3f}" 342 | print(table) 343 | 344 | eval_result = {'R1': t2i_cmc[0], 345 | 'R5': t2i_cmc[4], 346 | 'R10': t2i_cmc[9], 347 | 'mAP': t2i_mAP, 348 | 'mINP': t2i_mINP, 349 | } 350 | return eval_result 351 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.9 2 | transformers==4.12.5 3 | ruamel_yaml 4 | opencv-python 5 | scikit-image 6 | matplotlib 7 | audtorch 8 | seaborn 9 | prettytable 10 | easydict 11 | nltk 12 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | # Set it correctly for distributed training across nodes 6 | NNODES = 1 # e.g. 1/2/3/4 7 | NPROC_PER_NODE = 4 # e.g. 4 gpus 8 | MASTER_ADDR = '127.0.0.1' 9 | MASTER_PORT = 3000 # 0~65536 10 | NODE_RANK = 0 # e.g. 0/1/2 11 | 12 | print("NNODES, ", NNODES) 13 | print("NPROC_PER_NODE, ", NPROC_PER_NODE) 14 | print("MASTER_ADDR, ", MASTER_ADDR) 15 | print("MASTER_PORT, ", MASTER_PORT) 16 | print("NODE_RANK, ", NODE_RANK) 17 | 18 | 19 | def get_dist_launch(args): # some examples 20 | if args.dist == 'f4': 21 | return "CUDA_VISIBLE_DEVICES=0,1,2,3 WORLD_SIZE=4 python3 -m torch.distributed.launch --nproc_per_node=4 " \ 22 | "--nnodes=1 --master_port={:}".format(MASTER_PORT) 23 | 24 | elif args.dist == 'f2': 25 | return "CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2 python3 -m torch.distributed.launch --nproc_per_node=2 " \ 26 | "--nnodes=1 --master_port={:}".format(MASTER_PORT) 27 | 28 | elif args.dist == 'l2': 29 | return "CUDA_VISIBLE_DEVICES=2,3 WORLD_SIZE=2 python3 -m torch.distributed.launch --nproc_per_node=2 " \ 30 | "--nnodes=1 --master_port={:}".format(MASTER_PORT) 31 | 32 | elif args.dist == 'f-0': 33 | return "CUDA_VISIBLE_DEVICES=1,2,3 WORLD_SIZE=3 python3 -m torch.distributed.launch --nproc_per_node=3 " \ 34 | "--nnodes=1 " 35 | 36 | elif args.dist == 'f-1': 37 | return "CUDA_VISIBLE_DEVICES=0,2,3 WORLD_SIZE=3 python3 -m torch.distributed.launch --nproc_per_node=3 " \ 38 | "--nnodes=1 " 39 | 40 | elif args.dist == 'f-2': 41 | return "CUDA_VISIBLE_DEVICES=0,1,3 WORLD_SIZE=3 python3 -m torch.distributed.launch --nproc_per_node=3 " \ 42 | "--nnodes=1 " 43 | 44 | elif args.dist == 'f-3': 45 | return "CUDA_VISIBLE_DEVICES=0,1,2 WORLD_SIZE=3 python3 -m torch.distributed.launch --nproc_per_node=3 " \ 46 | "--nnodes=1 " 47 | 48 | elif args.dist.startswith('gpu'): # use one gpu, --dist "gpu0" 49 | num = int(args.dist[3:]) 50 | assert 0 <= num <= 3 51 | return "CUDA_VISIBLE_DEVICES={:} WORLD_SIZE=1 python3 -m torch.distributed.launch --nproc_per_node=1 " \ 52 | "--nnodes=1 --master_port={:} ".format(num, MASTER_PORT) 53 | 54 | else: 55 | raise ValueError 56 | 57 | 58 | def run_retrieval(args): 59 | dist_launch = get_dist_launch(args) 60 | 61 | os.system(f"{dist_launch} " 62 | f"--use_env Retrieval.py --config {args.config} " 63 | f"--task {args.task} --output_dir {args.output_dir} --bs {args.bs} --epo {args.epo} --checkpoint {args.checkpoint} {'--evaluate' if args.evaluate else ''}") 64 | 65 | 66 | def run(args): 67 | if args.task not in ['itr_gene']: 68 | assert os.path.exists(args.checkpoint) 69 | 70 | if args.task == 'itr_cuhk': 71 | assert os.path.exists("images/CUHK-PEDES") 72 | args.config = 'configs/Retrieval_cuhk.yaml' 73 | run_retrieval(args) 74 | 75 | elif args.task == 'itr_icfg': 76 | assert os.path.exists("images/ICFG-PEDES") 77 | args.config = 'configs/Retrieval_icfg.yaml' 78 | run_retrieval(args) 79 | 80 | elif args.task == 'itr_rstp': 81 | assert os.path.exists("images/RSTPReid") 82 | args.config = 'configs/Retrieval_rstp.yaml' 83 | run_retrieval(args) 84 | 85 | elif args.task == 'itr_gene': 86 | assert os.path.exists("images/CUHK-PEDES") 87 | args.config = 'configs/Retrieval_gene.yaml' 88 | run_retrieval(args) 89 | 90 | elif args.task == 'itr_pa100k': 91 | assert os.path.exists("images/pa100k") 92 | args.config = 'configs/Retrieval_pa100k.yaml' 93 | run_retrieval(args) 94 | 95 | else: 96 | raise NotImplementedError(f"task == {args.task}") 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--task', type=str, required=True) 102 | parser.add_argument('--dist', type=str, required=True, help="see func get_dist_launch for details") 103 | parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus; ") 104 | parser.add_argument('--epo', default=-1, type=int, help="epoch") 105 | parser.add_argument('--seed', default=42, type=int) 106 | parser.add_argument('--checkpoint', default='output/pretrain/checkpoint_31.pth', type=str, help="for fine-tuning") 107 | parser.add_argument('--output_dir', type=str, required=True, help='local path; ') 108 | parser.add_argument('--evaluate', action='store_true', help="evaluation on downstream tasks") 109 | args = parser.parse_args() 110 | 111 | assert os.path.exists(os.path.dirname(args.output_dir)) 112 | if not os.path.exists(args.output_dir): 113 | os.mkdir(args.output_dir) 114 | 115 | run(args) 116 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | 3 | 4 | def create_scheduler(args, optimizer): 5 | if 'num_training_steps' not in args: 6 | args['num_training_steps'] = args['epochs'] * args['step_per_epoch'] 7 | print("### num_training_steps, ", args['num_training_steps'], flush=True) 8 | 9 | if isinstance(args['num_warmup_steps'], float): 10 | assert 0 <= args['num_warmup_steps'] < 1 11 | args['num_warmup_steps'] = int(args['num_training_steps'] * args['num_warmup_steps']) 12 | print("### num_warmup_steps, ", args['num_warmup_steps'], flush=True) 13 | 14 | print('sched:', args.sched, flush=True) 15 | 16 | if args.sched == 'linear': 17 | def lr_lambda(current_step: int): 18 | if current_step < args.num_warmup_steps: 19 | return float(current_step) / float(max(1, args.num_warmup_steps)) 20 | return max( 21 | 0.0, float(args.num_training_steps - current_step) / float( 22 | max(1, args.num_training_steps - args.num_warmup_steps)) 23 | ) 24 | 25 | lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) 26 | 27 | elif args.sched == 'step': 28 | def lr_lambda(current_step: int): 29 | if current_step < args.num_warmup_steps: 30 | return float(current_step) / float(max(1, args.num_warmup_steps)) 31 | elif current_step < args.num_warmup_steps * 4: 32 | tt = 1 33 | elif current_step < args.num_warmup_steps * 7: 34 | tt = 0.5 35 | else: 36 | tt = 0.2 37 | 38 | return tt * max( 39 | 0.0, float(args.num_training_steps - current_step) / float( 40 | max(1, args.num_training_steps - args.num_warmup_steps)) 41 | ) 42 | 43 | lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) 44 | 45 | else: 46 | raise NotImplementedError(f"args.sched == {args.sched}") 47 | 48 | return lr_scheduler 49 | -------------------------------------------------------------------------------- /train_pa100ks.py: -------------------------------------------------------------------------------- 1 | import utils 2 | from train_tools import mlm 3 | 4 | def train_pa100k(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config, mask_generator=None): 5 | model.train() 6 | 7 | metric_logger = utils.MetricLogger(delimiter=" ") 8 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 9 | metric_logger.add_meter('loss_itc', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 10 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 11 | if config['mlm']: 12 | metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 13 | header = 'Train Epoch: [{}]'.format(epoch) 14 | print_freq = 50 15 | 16 | for i, (image, label) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 17 | image = image.to(device, non_blocking=True) 18 | label = label.to(device, non_blocking=True) 19 | 20 | text = ['the person is a man', 'the person is a woman', 21 | 'the person is no more than 60 years old', 'the person is older than 60 years old', 22 | 'the person is a young or old one', 'the person is of mid age, between 18 and 60 years old', 23 | 'the person is older than 18', 'the person is a baby or a teenager, younger than 18', 24 | 25 | 'the picture is not the front of the person', 'the picture shows the front of the person', 26 | 'the picture is not the side of the person', 'the picture shows the side of the person', 27 | 'the picture is not the back of the person', 'the picture shows the back of the person', 28 | 'a person without a hat', 'a person with a hat', 29 | 30 | 'a person without a glasses', 'a person with a glasses', 31 | 'a person without a handbag', 'a person with a handbag', 32 | 'a person without a shoulder bag', 'a person with a shoulder bag', 33 | 'a person without a backpack', 'a person with a backpack', 34 | 35 | 'the person does not hold an object in front', 'the person hold an object in front', 36 | 'the person does not wear short sleeved upper clothes', 'the person wears short sleeved upper clothes', 37 | 'the person does not wear long sleeved upper clothes', 'the person wears long sleeved upper clothes', 38 | 'there is no stride on the upper clothes of the person', 39 | 'there is stride on the upper clothes of the person', 40 | 41 | 'there is no logo on the upper clothes of the person', 42 | 'there is logo on the upper clothes of the person', 43 | 'there is no plaid on the upper clothes of the person', 44 | 'there is plaid on the upper clothes of the person', 45 | 'there is no splice on the upper clothes of the person', 46 | 'there is splice on the upper clothes of the person', 47 | 'there is no stripe on the upper clothes of the person', 48 | 'there is stripe on the upper clothes of the person', 49 | 50 | 'there is no pattern on the lower part of the person', 51 | 'there is pattern on the lower part of the person', 52 | 'the person does not wear long coat', 'the person wears long coat', 53 | 'the person does not wear trousers', 'the person wears trousers', 54 | 'the person does not wear shorts', 'the person wears shorts', 55 | 56 | 'the person does not wear a skirt or a dress', 'the person wears a skirt or a dress', 57 | 'the person does not wear boots', 'the person wears boots', 58 | ] 59 | 60 | text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'], 61 | return_tensors="pt").to(device) 62 | 63 | # mlm loss 64 | if config['mlm']: 65 | text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator, config, True) 66 | loss_itc, loss_itm, loss_mlm = model(image, text_input.input_ids, text_input.attention_mask, 67 | text_ids_masked=text_ids_masked, masked_pos=masked_pos, 68 | masked_ids=masked_ids, label=label) 69 | loss = loss_itc + loss_itm + loss_mlm 70 | else: 71 | loss_itc, loss_itm = model(image, text_input.input_ids, text_input.attention_mask, label=label) 72 | loss = loss_itc + loss_itm 73 | 74 | optimizer.zero_grad() 75 | loss.backward() 76 | optimizer.step() 77 | scheduler.step() 78 | 79 | metric_logger.update(loss_itc=loss_itc.item()) 80 | metric_logger.update(loss_itm=loss_itm.item()) 81 | if config['mlm']: 82 | metric_logger.update(loss_mlm=loss_mlm.item()) 83 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 84 | 85 | # gather the stats from all processes 86 | metric_logger.synchronize_between_processes() 87 | print("Averaged stats:", metric_logger.global_avg()) 88 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 89 | 90 | 91 | def train_pa100k_only_img_classifier(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config, 92 | mask_generator=None): 93 | model.train() 94 | 95 | metric_logger = utils.MetricLogger(delimiter=" ") 96 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 97 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 98 | header = 'Train Epoch: [{}]'.format(epoch) 99 | print_freq = 50 100 | 101 | for i, (image, label) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 102 | image = image.to(device, non_blocking=True) 103 | label = label.to(device, non_blocking=True) 104 | 105 | loss = model(image, None, None, label=label) 106 | 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | scheduler.step() 111 | 112 | metric_logger.update(loss=loss.item()) 113 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 114 | 115 | # gather the stats from all processes 116 | metric_logger.synchronize_between_processes() 117 | print("Averaged stats:", metric_logger.global_avg()) 118 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /train_tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def mlm(text, text_input, tokenizer, device, mask_generator, config, pa100k=False): 6 | if pa100k: 7 | text_masked = tokenizer(text, padding='longest', max_length=config['max_tokens'], 8 | return_tensors="pt").to(device) 9 | else: 10 | text_masked = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'], 11 | return_tensors="pt").to(device) 12 | text_ids_masked = text_masked.input_ids 13 | masked_pos = torch.empty((text_ids_masked.shape[0], config['max_masks']), dtype=torch.int64, device=device) 14 | masked_ids = torch.empty((text_ids_masked.shape[0], config['max_masks']), dtype=torch.long, device=device) 15 | for index, text_id in enumerate(text_ids_masked): 16 | text_ids_masked_, masked_pos_ = mask_generator(text_id) 17 | masked_ids_ = [text_input.input_ids[index][p].item() for p in masked_pos_] 18 | n_pad = config['max_masks'] - len(masked_ids_) 19 | masked_pos_ = masked_pos_ + [0] * n_pad 20 | masked_pos_ = torch.tensor(masked_pos_, dtype=torch.int64).to(device) 21 | masked_ids_ = masked_ids_ + [-100] * n_pad 22 | masked_ids_ = torch.tensor(masked_ids_, dtype=torch.long).to(device) 23 | masked_pos[index] = masked_pos_ 24 | masked_ids[index] = masked_ids_ 25 | return text_ids_masked, masked_pos, masked_ids 26 | -------------------------------------------------------------------------------- /trains.py: -------------------------------------------------------------------------------- 1 | import utils 2 | from train_tools import mlm 3 | import numpy as np 4 | 5 | 6 | def train(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config, mask_generator=None): 7 | model.train() 8 | 9 | metric_logger = utils.MetricLogger(delimiter=" ") 10 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 11 | metric_logger.add_meter('loss_itc', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 12 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 13 | if config['mlm']: 14 | metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 15 | header = 'Train Epoch: [{}]'.format(epoch) 16 | print_freq = 50 17 | 18 | if config['eda']: 19 | for i, (image, text, text_eda, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 20 | image = image.to(device, non_blocking=True) 21 | idx = idx.to(device, non_blocking=True) 22 | # text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'], 23 | # return_tensors="pt").to(device) 24 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'], 25 | return_tensors="pt").to(device) 26 | text_input_eda = tokenizer(text_eda, padding='max_length', truncation=True, max_length=config['max_tokens'], 27 | return_tensors="pt").to(device) 28 | if config['mlm']: 29 | text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator, 30 | config) 31 | loss_itc, loss_itm, loss_mlm = model(image, text_input.input_ids, text_input.attention_mask, 32 | text_ids_masked=text_ids_masked, 33 | masked_pos=masked_pos, masked_ids=masked_ids, idx=idx, 34 | text_ids_eda=text_input_eda.input_ids, 35 | text_atts_eda=text_input_eda.attention_mask) 36 | loss = loss_itc + loss_itm + loss_mlm 37 | else: 38 | loss_itc, loss_itm = model(image, text_input.input_ids, text_input.attention_mask, idx=idx, 39 | text_ids_eda=text_input_eda.input_ids, 40 | text_atts_eda=text_input_eda.attention_mask) 41 | loss = loss_itc + loss_itm 42 | 43 | optimizer.zero_grad() 44 | loss.backward() 45 | optimizer.step() 46 | scheduler.step() 47 | 48 | metric_logger.update(loss_itc=loss_itc.item()) 49 | metric_logger.update(loss_itm=loss_itm.item()) 50 | if config['mlm']: 51 | metric_logger.update(loss_mlm=loss_mlm.item()) 52 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 53 | else: 54 | for i, (image, text, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 55 | image = image.to(device, non_blocking=True) 56 | idx = idx.to(device, non_blocking=True) 57 | # text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'], 58 | # return_tensors="pt").to(device) 59 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'], 60 | return_tensors="pt").to(device) 61 | # mlm loss 62 | if config['mlm']: 63 | text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator, 64 | config) 65 | loss_itc, loss_itm, loss_mlm = model(image, text_input.input_ids, 66 | text_input.attention_mask, 67 | text_ids_masked=text_ids_masked, 68 | masked_pos=masked_pos, masked_ids=masked_ids, 69 | idx=idx) 70 | loss = loss_itc + loss_itm + loss_mlm 71 | else: 72 | loss_itc, loss_itm = model(image, text_input.input_ids, text_input.attention_mask, idx=idx) 73 | loss = loss_itc + loss_itm 74 | 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | scheduler.step() 79 | 80 | metric_logger.update(loss_itc=loss_itc.item()) 81 | metric_logger.update(loss_itm=loss_itm.item()) 82 | if config['mlm']: 83 | metric_logger.update(loss_mlm=loss_mlm.item()) 84 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 85 | 86 | # gather the stats from all processes 87 | metric_logger.synchronize_between_processes() 88 | print("Averaged stats:", metric_logger.global_avg()) 89 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 90 | 91 | 92 | def train_attr(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config, mask_generator=None): 93 | model.train() 94 | 95 | metric_logger = utils.MetricLogger(delimiter=" ") 96 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 97 | metric_logger.add_meter('loss_itc', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 98 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 99 | if config['mlm']: 100 | metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 101 | metric_logger.add_meter('loss_attr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 102 | 103 | header = 'Train Epoch: [{}]'.format(epoch) 104 | print_freq = 50 105 | 106 | for i, (image, text, idx, label) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 107 | image = image.to(device, non_blocking=True) 108 | idx = idx.to(device, non_blocking=True) 109 | # text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'], 110 | # return_tensors="pt").to(device) 111 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'], 112 | return_tensors="pt").to(device) 113 | label = label.to(device, non_blocking=True) 114 | 115 | attr = ['the person is a woman', 'the person is a man', 116 | 'the person is younger than 18 years old', 'the person is older than 18 years old', 117 | 118 | 'the person with short hair', 'the person with long hair', 119 | 'the person with a hat', 'the person without a hat', 120 | 'the person with a backpack', 'the person without a backpack', 121 | 'the person with a handbag', 'the person without a handbag', 122 | 'the person with a bag', 'the person without a bag', 123 | 124 | 'the person wears long sleeved upper clothes', 'the person wears short sleeved upper clothes', 125 | 'the person wears long dress or long pants', 'the person wears short dress or short pants', 126 | 'the person wears dress or skirt', 'the person wears pants or shorts', 127 | 128 | 'the person wears black upper clothes', 'the person does not wear black upper clothes', 129 | 'the person wears white upper clothes', 'the person does not wear white upper clothes', 130 | 'the person wears red upper clothes', 'the person does not wear red upper clothes', 131 | 'the person wears purple upper clothes', 'the person does not wear purple upper clothes', 132 | 133 | 'the person wears yellow upper clothes', 'the person does not wear yellow upper clothes', 134 | 'the person wears blue upper clothes', 'the person does not wear blue upper clothes', 135 | 'the person wears green upper clothes', 'the person does not wear green upper clothes', 136 | 'the person wears gray upper clothes', 'the person does not wear gray upper clothes', 137 | 138 | 'the person wears black lower clothes', 'the person does not wear black lower clothes', 139 | 'the person wears white lower clothes', 'the person does not wear white lower clothes', 140 | 'the person wears purple lower clothes', 'the person does not wear purple lower clothes', 141 | 'the person wears yellow lower clothes', 'the person does not wear yellow lower clothes', 142 | 143 | 'the person wears blue lower clothes', 'the person does not wear blue lower clothes', 144 | 'the person wears green lower clothes', 'the person does not wear green lower clothes', 145 | 'the person wears pink lower clothes', 'the person does not wear pink lower clothes', 146 | 'the person wears gray lower clothes', 'the person does not wear gray lower clothes', 147 | 'the person wears brown lower clothes', 'the person does not wear brown lower clothes', 148 | 149 | ] 150 | attr_input = tokenizer(attr, padding='longest', max_length=config['max_tokens'], 151 | return_tensors="pt").to(device) 152 | 153 | # mlm loss 154 | if config['mlm']: 155 | text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator, 156 | config) 157 | attr_text_ids_masked, attr_masked_pos, attr_masked_ids = mlm(attr, attr_input, tokenizer, device, 158 | mask_generator, config, 159 | True) 160 | 161 | loss_itc, loss_itm, loss_mlm, loss_attr = model(image, text_input.input_ids, text_input.attention_mask, 162 | text_ids_masked=text_ids_masked, masked_pos=masked_pos, 163 | masked_ids=masked_ids, idx=idx, 164 | attr_text_ids=attr_input.input_ids, 165 | attr_text_atts=attr_input.attention_mask, 166 | attr_text_ids_masked=attr_text_ids_masked, 167 | attr_masked_pos=attr_masked_pos, 168 | attr_masked_ids=attr_masked_ids, label=label) 169 | loss = loss_itc + loss_itm + loss_mlm + config['t'] * loss_attr 170 | else: 171 | loss_itc, loss_itm, loss_attr = model(image, text_input.input_ids, text_input.attention_mask, idx=idx, 172 | attr_text_ids=attr_input.input_ids, 173 | attr_text_atts=attr_input.attention_mask, 174 | label=label) 175 | loss = loss_itc + loss_itm + config['t'] * loss_attr 176 | 177 | optimizer.zero_grad() 178 | loss.backward() 179 | optimizer.step() 180 | scheduler.step() 181 | 182 | metric_logger.update(loss_itc=loss_itc.item()) 183 | metric_logger.update(loss_itm=loss_itm.item()) 184 | if config['mlm']: 185 | metric_logger.update(loss_mlm=loss_mlm.item()) 186 | metric_logger.update(loss_attr=loss_attr.item()) 187 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 188 | 189 | # gather the stats from all processes 190 | metric_logger.synchronize_between_processes() 191 | print("Averaged stats:", metric_logger.global_avg()) 192 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from collections import defaultdict, deque, OrderedDict 5 | import datetime 6 | import numpy as np 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | class SmoothedValue(object): 13 | """Track a series of values and provide access to smoothed values over a 14 | window or the global series average. 15 | """ 16 | 17 | def __init__(self, window_size=20, fmt=None): 18 | if fmt is None: 19 | fmt = "{median:.4f} ({global_avg:.4f})" 20 | self.deque = deque(maxlen=window_size) 21 | self.total = 0.0 22 | self.count = 0 23 | self.fmt = fmt 24 | 25 | def update(self, value, n=1): 26 | self.deque.append(value) 27 | self.count += n 28 | self.total += value * n 29 | 30 | def synchronize_between_processes(self): 31 | """ 32 | Warning: does not synchronize the deque! 33 | """ 34 | if not is_dist_avail_and_initialized(): 35 | return 36 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 37 | dist.barrier() 38 | dist.all_reduce(t) 39 | t = t.tolist() 40 | self.count = int(t[0]) 41 | self.total = t[1] 42 | 43 | @property 44 | def median(self): 45 | d = torch.tensor(list(self.deque)) 46 | return d.median().item() 47 | 48 | @property 49 | def avg(self): 50 | d = torch.tensor(list(self.deque), dtype=torch.float32) 51 | return d.mean().item() 52 | 53 | @property 54 | def global_avg(self): 55 | return self.total / self.count 56 | 57 | @property 58 | def max(self): 59 | return max(self.deque) 60 | 61 | @property 62 | def value(self): 63 | return self.deque[-1] 64 | 65 | def __str__(self): 66 | return self.fmt.format( 67 | median=self.median, 68 | avg=self.avg, 69 | global_avg=self.global_avg, 70 | max=self.max, 71 | value=self.value) 72 | 73 | 74 | class MetricLogger(object): 75 | def __init__(self, delimiter="\t"): 76 | self.meters = defaultdict(SmoothedValue) 77 | self.delimiter = delimiter 78 | 79 | def update(self, **kwargs): 80 | for k, v in kwargs.items(): 81 | if isinstance(v, torch.Tensor): 82 | v = v.item() 83 | assert isinstance(v, (float, int)) 84 | self.meters[k].update(v) 85 | 86 | def __getattr__(self, attr): 87 | if attr in self.meters: 88 | return self.meters[attr] 89 | if attr in self.__dict__: 90 | return self.__dict__[attr] 91 | raise AttributeError("'{}' object has no attribute '{}'".format( 92 | type(self).__name__, attr)) 93 | 94 | def __str__(self): 95 | loss_str = [] 96 | for name, meter in self.meters.items(): 97 | loss_str.append( 98 | "{}: {}".format(name, str(meter)) 99 | ) 100 | return self.delimiter.join(loss_str) 101 | 102 | def global_avg(self): 103 | loss_str = [] 104 | for name, meter in self.meters.items(): 105 | loss_str.append( 106 | "{}: {:.4f}".format(name, meter.global_avg) 107 | ) 108 | return self.delimiter.join(loss_str) 109 | 110 | def synchronize_between_processes(self): 111 | for meter in self.meters.values(): 112 | meter.synchronize_between_processes() 113 | 114 | def add_meter(self, name, meter): 115 | self.meters[name] = meter 116 | 117 | def log_every(self, iterable, print_freq, header=None, dataset_len=None, epoch_info=None): 118 | if not header: 119 | header = '' 120 | if not dataset_len: 121 | dataset_len = len(iterable) 122 | start_time = time.time() 123 | end = time.time() 124 | iter_time = SmoothedValue(fmt='{avg:.4f}') 125 | data_time = SmoothedValue(fmt='{avg:.4f}') 126 | space_fmt = ':' + str(len(str(dataset_len))) + 'd' 127 | 128 | _msg = [ 129 | '[{0' + space_fmt + '}/{1}]', 130 | 'eta: {eta}', 131 | '{meters}', 132 | 'time: {time}', 133 | 'data: {data}' 134 | ] 135 | if torch.cuda.is_available(): 136 | _msg.append('max mem: {memory:.0f}') 137 | _msg = self.delimiter.join(_msg) 138 | MB = 1024.0 * 1024.0 139 | iterable = iter(iterable) 140 | train_steps = dataset_len 141 | if epoch_info: 142 | start_epoch, end_epoch = epoch_info 143 | train_steps = (end_epoch - start_epoch) * dataset_len 144 | for i in range(train_steps): 145 | obj = next(iterable) 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if epoch_info: 150 | header = int(i / dataset_len) + start_epoch 151 | header = 'Train step: [{}]'.format(header) 152 | log_msg = header + " " + _msg 153 | if (i % dataset_len) % print_freq == 0 or i == dataset_len - 1: 154 | eta_seconds = iter_time.global_avg * (dataset_len - i % dataset_len) 155 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 156 | if torch.cuda.is_available(): 157 | print(log_msg.format( 158 | i % dataset_len, dataset_len, eta=eta_string, 159 | meters=str(self), 160 | time=str(iter_time), data=str(data_time), 161 | memory=torch.cuda.max_memory_allocated() / MB)) 162 | else: 163 | print(log_msg.format( 164 | i % dataset_len, dataset_len, eta=eta_string, 165 | meters=str(self), 166 | time=str(iter_time), data=str(data_time))) 167 | 168 | end = time.time() 169 | total_time = time.time() - start_time 170 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 171 | print('{} Total time: {} ({:.4f} s / it)'.format( 172 | header, total_time_str, total_time / dataset_len)) 173 | 174 | 175 | class AttrDict(dict): 176 | def __init__(self, *args, **kwargs): 177 | super(AttrDict, self).__init__(*args, **kwargs) 178 | self.__dict__ = self 179 | 180 | 181 | def compute_acc(logits, label, reduction='mean'): 182 | ret = (torch.argmax(logits, dim=1) == label).float() 183 | if reduction == 'none': 184 | return ret.detach() 185 | elif reduction == 'mean': 186 | return ret.mean().item() 187 | 188 | 189 | def compute_n_params(model, return_str=True): 190 | tot = 0 191 | for p in model.parameters(): 192 | w = 1 193 | for x in p.shape: 194 | w *= x 195 | tot += w 196 | if return_str: 197 | if tot >= 1e6: 198 | return '{:.1f}M'.format(tot / 1e6) 199 | else: 200 | return '{:.1f}K'.format(tot / 1e3) 201 | else: 202 | return tot 203 | 204 | 205 | def setup_for_distributed(is_master): 206 | """ 207 | This function disables printing when not in master process 208 | """ 209 | import builtins as __builtin__ 210 | builtin_print = __builtin__.print 211 | 212 | def print(*args, **kwargs): 213 | force = kwargs.pop('force', False) 214 | if is_master or force: 215 | builtin_print(*args, **kwargs) 216 | 217 | __builtin__.print = print 218 | 219 | 220 | def is_dist_avail_and_initialized(): 221 | if not dist.is_available(): 222 | return False 223 | if not dist.is_initialized(): 224 | return False 225 | return True 226 | 227 | 228 | def get_world_size(): 229 | if not is_dist_avail_and_initialized(): 230 | return 1 231 | return dist.get_world_size() 232 | 233 | 234 | def get_rank(): 235 | if not is_dist_avail_and_initialized(): 236 | return 0 237 | return dist.get_rank() 238 | 239 | 240 | def is_main_process(): 241 | return get_rank() == 0 242 | 243 | 244 | def save_on_master(*args, **kwargs): 245 | if is_main_process(): 246 | torch.save(*args, **kwargs) 247 | 248 | 249 | def init_distributed_mode(args): 250 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 251 | args.rank = int(os.environ["RANK"]) 252 | args.world_size = int(os.environ['WORLD_SIZE']) 253 | args.gpu = int(os.environ['LOCAL_RANK']) 254 | elif 'SLURM_PROCID' in os.environ: 255 | args.rank = int(os.environ['SLURM_PROCID']) 256 | args.gpu = args.rank % torch.cuda.device_count() 257 | else: 258 | print('Not using distributed mode') 259 | args.distributed = False 260 | return 261 | 262 | args.distributed = True 263 | 264 | torch.cuda.set_device(args.gpu) 265 | args.dist_backend = 'nccl' 266 | print('| distributed init (rank {}): {}'.format( 267 | args.rank, args.dist_url), flush=True) 268 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 269 | world_size=args.world_size, rank=args.rank) 270 | torch.distributed.barrier() 271 | setup_for_distributed(args.rank == 0) 272 | 273 | 274 | def read_json(rpath): 275 | with open(rpath, 'r') as f: 276 | return json.load(f) --------------------------------------------------------------------------------