├── LICENSE
├── README.md
├── Retrieval.py
├── assets
├── badsamples.jpg
├── chart1.jpg
├── cuhk_freq50.jpg
├── examples.jpg
├── framework.jpg
├── mals_fre50.jpg
├── mals_rst_30.jpg
├── readme.txt
└── result.jpg
├── configs
├── Retrieval_cuhk.yaml
├── Retrieval_gene.yaml
├── Retrieval_icfg.yaml
├── Retrieval_pa100k.yaml
├── Retrieval_rstp.yaml
├── config_bert.json
└── config_swinB_384.json
├── dataset
├── __init__.py
├── eda.py
├── randaugment.py
├── random_erasing.py
├── re_dataset.py
└── utils.py
├── models
├── __init__.py
├── aptm.py
├── bert.py
├── model_retrieval.py
├── swin_transformer.py
└── tokenization_bert.py
├── optim.py
├── reTools.py
├── requirements.txt
├── run.py
├── scheduler.py
├── train_pa100ks.py
├── train_tools.py
├── trains.py
└── utils
└── __init__.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Shuyu Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # APTM
2 |
3 |
4 | [](https://paperswithcode.com/sota/nlp-based-person-retrival-on-cuhk-pedes?p=towards-unified-text-based-person-retrieval-a)
5 | [](https://paperswithcode.com/sota/text-based-person-retrieval-on-icfg-pedes?p=towards-unified-text-based-person-retrieval-a)
6 | [](https://paperswithcode.com/sota/text-based-person-retrieval-on-rstpreid-1?p=towards-unified-text-based-person-retrieval-a)
7 |
8 | **APTM (ACM MM 2023)** is a new joint **A**ttribute **P**rompt Learning and **T**ext **M**atching Learning framework, considering the shared knowledge between attribute and text. As the name implies, APTM contains an attribute prompt learning stream and a text matching learning stream.
9 |
10 | We also present a large Multi-Attribute and Language Search dataset for text-based person retrieval, called **MALS**, and explore the feasibility of performing pre-training on both attribute recognition and image-text matching tasks in one stone. In particular, MALS contains 1, 510, 330 image-text pairs, which is about 37.5× larger than prevailing CUHK-PEDES, and all images are annotated with 27 attributes.
11 |
12 | Extensive experiments validate the effectiveness of the pre-training on MALS, achieving the state-of-the-art retrieval performance via APTM on three challenging real-world benchmarks. In particular, APTM achieves a consistent improvement of +6.60%, +7.39%, and +15.90% Recall@1 accuracy on CUHK-PEDES, ICFG-PEDES, and RSTPReid datasets by a clear margin, respectively. More details can be found at our paper: [Towards Unified Text-based Person Retrieval: A Large-scale Multi-Attribute and Language Search Benchmark](https://arxiv.org/abs/2306.02898)
13 |

14 |
15 | ## News
16 | * The **OneDrive** link of **MALS** dataset are released!
17 | * The **APTM** and the **MALS** dataset are released. Welcome to communicate!
18 |
19 | ## MALS
20 | MALS leverages generative models to generate a large-scale dataset including 1.5𝑀 image-text pairs. Each image-text pair in MALS is annotated with one corresponding description and several appropriate attribute labels, indicating that MALS is not only effective for text-image matching and attribute prompt learning, but also explores the feasibility of pre-training for both attribute recognition and image-text matching in one stone. **The dataset is released at [Baidu Yun](https://pan.baidu.com/s/1HMvNIIFlquI2w0R6f0G7Dg) [4kq0] and [OneDrive](https://1drv.ms/f/s!Ak2z-VJ5LcCvgdZGSTJbaHOMMFZi9A?e=gCBnv0) [mals].**
21 |
22 | **Note that MALS can only be used for research, any commercial usage is forbidden.**
23 |
24 | This is the comparison between MALS and other text based person retrieval datasets.
25 | 
26 | These are examples of our MALS dataset and CUHK-PEDES.
27 | 
28 | Annotation format:
29 |
30 | ```
31 | [{"image": "gene_crop/c_g_a_0/0.jpg",
32 | "caption": "a young boy wearing a black hoodie leaning against a wall with his hands on his hips and his hands on his hips wearing jeans and a baseball cap",
33 | "image_id": "c_g_a_0_0",
34 | "label": [1, 0, ..., 1, 1]},
35 | ...
36 | {"image": "gene_crop/c_g_a_0/20217.jpg",
37 | "caption": "a woman in a white top and black pants posing for a picture in front of a brick wall with a pink carpet in front of her",
38 | "image_id": "c_g_a_0_20217",
39 | "label": [0, 1, ..., -1, -1]}]
40 | ```
41 |
42 | ## Models and Weights
43 |
44 | The checkpoints have been released at [Baidu Yun](https://pan.baidu.com/s/1oAkenOKaVEYWpNh2hznkGA) [b2l8] and [Google Drive](https://drive.google.com/drive/folders/1N1Lumvb4epP0awHLcJ3RzQmv5zwrAFBh?usp=sharing)
45 |
46 |
47 | ## Usage
48 |
49 | ### Install Requirements
50 |
51 | we use 4 A100 80G GPU for training and evaluation.
52 |
53 | Create conda environment.
54 |
55 | ```
56 | conda create -n aptm python=3.8
57 | conda activate aptm
58 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
59 | pip3 install -r requirements.txt
60 | ```
61 |
62 | ### Datasets Prepare
63 |
64 | Download the CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description) , the PA-100K dataset from [here](https://github.com/xh-liu/HydraPlus-Net), the RSTPReid dataset from [here](https://github.com/NjtechCVLab/RSTPReid-Dataset), and ICFG-PEDES dataset from [here](https://github.com/zifyloo/SSAN). Download the processed json files of the aboves four datasets from [here](https://pan.baidu.com/s/1oAkenOKaVEYWpNh2hznkGA) [b2l8]
65 |
66 | Download pre-trained models for parameter initialization:
67 |
68 | image encoder: [swin-transformer-base](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)
69 |
70 | text encoder: [bert-base](https://huggingface.co/bert-base-uncased/tree/main)
71 |
72 | Organize `data` folder as follows:
73 |
74 | ```
75 | |-- data/
76 | | |-- bert-base-uncased/
77 | | |-- finetune/
78 | | |-- gene_attrs/
79 | | |-- g_4x_attrs.json
80 | | |-- g_c_g_a_0_attrs.json
81 | | |-- ...
82 | | |-- cuhk_train.json
83 | | |-- ...
84 | | |-- icfg_train.json
85 | | |-- ...
86 | | |-- rstp_train.json
87 | | |-- ...
88 | | |-- PA100K_train.json
89 | | |-- ...
90 | | |-- swin_base_patch4_window7_224_22k.pth
91 | ```
92 |
93 | And organize those datasets in `images` folder as follows:
94 |
95 | ```
96 | |-- images/
97 | | |-- /
98 | | |-- imgs/
99 | | |-- cam_a/
100 | | |-- cam_b/
101 | | |-- ...
102 | | |-- train_query/
103 | | |-- gene_crop/
104 | | |-- 4x/
105 | | |-- c_g_a/
106 | | |-- ...
107 | | |-- i_g_a_43/
108 | |
109 | | |-- /
110 | | |-- test/
111 | | |-- train/
112 | |
113 | | |-- /
114 | | |-- release_data/
115 | |
116 | | |-- /
117 | ```
118 |
119 | ### Pretraining
120 | We pretrain our APTM using MALS as follows:
121 |
122 | ```
123 | python3 run.py --task "itr_gene" --dist "f4" --output_dir "output/pretrained"
124 | ```
125 |
126 | ### Fine-tuning
127 | We fine-tune our APTM using existing text-based Person Reid datasets. Performance can be improved by replacing the backbone with our pre-trained model. Taking CUHK-PEDES as example:
128 |
129 | ```
130 | python3 run.py --task "itr_cuhk" --dist "f4" --output_dir "output/ft_cuhk" --checkpoint "output/pretrained/checkpoint_31.pth"
131 | ```
132 |
133 | ### Evaluation
134 |
135 | ```
136 | python3 run.py --task "itr_cuhk" --evaluate --dist "f4" --output_dir "output/ft_cuhk/test" --checkpoint "output/ft_cuhk/checkpoint_best.pth"
137 | ```
138 |
139 | ## Reference
140 | If you use APTM in your research, please cite it by the following BibTeX entry:
141 |
142 | ```bibtex
143 | @inproceedings{yang2023towards,
144 | title={Towards Unified Text-based Person Retrieval: A Large-scale Multi-Attribute and Language Search Benchmark},
145 | author={Yang, Shuyu and Zhou, Yinan and Wang, Yaxiong and Wu, Yujiao and Zhu, Li and Zheng, Zhedong},
146 | booktitle = {Proceedings of the 2023 {ACM} on Multimedia Conference},
147 | year={2023}
148 | }
149 |
150 | ```
151 |
--------------------------------------------------------------------------------
/Retrieval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import math
4 |
5 | import ruamel.yaml as yaml
6 | import numpy as np
7 | import random
8 | import time
9 | import datetime
10 | import json
11 | from pathlib import Path
12 | from prettytable import PrettyTable
13 |
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 | import torch.distributed as dist
17 |
18 | from models.model_retrieval import APTM_Retrieval
19 | from models.tokenization_bert import BertTokenizer
20 |
21 | import utils
22 | from dataset import create_dataset, create_sampler, create_loader
23 | from dataset.re_dataset import TextMaskingGenerator
24 | from scheduler import create_scheduler
25 | from optim import create_optimizer
26 |
27 | from trains import train, train_attr
28 | from train_pa100ks import train_pa100k, train_pa100k_only_img_classifier
29 |
30 | from reTools import evaluation, mAP
31 | from reTools import evaluation_attr, itm_eval_attr
32 | from reTools import evaluation_attr_only_img_classifier, itm_eval_attr_only_img_classifier
33 |
34 |
35 | def main(args, config):
36 | utils.init_distributed_mode(args)
37 | device = torch.device(args.device)
38 | world_size = utils.get_world_size()
39 |
40 | if args.bs > 0:
41 | config['batch_size_train'] = args.bs // world_size
42 | if args.epo > 0:
43 | config['schedular']['epochs'] = args.epo
44 |
45 | seed = args.seed + utils.get_rank()
46 | torch.manual_seed(seed)
47 | np.random.seed(seed)
48 | random.seed(seed)
49 | cudnn.benchmark = True
50 |
51 | print("Creating model", flush=True)
52 | tokenizer = BertTokenizer.from_pretrained(config['text_encoder'])
53 | model = APTM_Retrieval(config=config)
54 | if config['load_pretrained']:
55 | model.load_pretrained(args.checkpoint, config, is_eval=args.evaluate)
56 | model = model.to(device)
57 |
58 | print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad), flush=True)
59 |
60 | model_without_ddp = model
61 | if args.distributed:
62 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
63 | model_without_ddp = model.module
64 |
65 | print("Creating retrieval dataset", flush=True)
66 | if args.task == "itr_icfg":
67 | train_dataset, test_dataset = create_dataset('re_icfg', config, args.evaluate)
68 | elif args.task == "itr_rstp":
69 | train_dataset, val_dataset, test_dataset = create_dataset('re_rstp', config, args.evaluate)
70 | elif args.task == "itr_cuhk":
71 | train_dataset, val_dataset, test_dataset = create_dataset('re_cuhk', config, args.evaluate)
72 | elif args.task == "itr_pa100k":
73 | train_dataset, val_dataset, test_dataset = create_dataset('re_pa100k', config, args.evaluate)
74 | else:
75 | train_dataset, val_dataset, test_dataset = create_dataset('re_gene', config, args.evaluate)
76 |
77 | start_time = time.time()
78 | print("### output_dir, ", args.output_dir, flush=True)
79 |
80 | if args.evaluate:
81 | print("Start evaluating", flush=True)
82 | # if args.task not in ["itr_icfg", "itr_pa100k"]:
83 | # print("val_dataset", flush=True)
84 | # val_loader = create_loader([val_dataset], [None],
85 | # batch_size=[config['batch_size_test']],
86 | # num_workers=[4],
87 | # is_trains=[False],
88 | # collate_fns=[None])[0]
89 | # score_val_t2i = evaluation(model_without_ddp, val_loader,
90 | # tokenizer, device, config, args)
91 |
92 | print("test_dataset", flush=True)
93 | test_loader = create_loader([test_dataset], [None],
94 | batch_size=[config['batch_size_test']],
95 | num_workers=[4],
96 | is_trains=[False],
97 | collate_fns=[None])[0]
98 | if args.task == "itr_pa100k":
99 | if model_without_ddp.pa100k_only_img_classifier:
100 | score_test_i2t_attr = evaluation_attr_only_img_classifier(model_without_ddp, test_loader,
101 | tokenizer, device, config, args)
102 | else:
103 | score_test_i2t_attr = evaluation_attr(model_without_ddp, test_loader,
104 | tokenizer, device, config, args)
105 | else:
106 | score_test_t2i = evaluation(model_without_ddp, test_loader,
107 | tokenizer, device, config, args)
108 |
109 | if utils.is_main_process():
110 | # if args.task not in ["itr_icfg", "itr_pa100k"]:
111 | # print('val_result:', flush=True)
112 | # mAP(score_val_t2i, val_loader.dataset.g_pids, val_loader.dataset.q_pids)
113 | if args.task == "itr_pa100k":
114 | if model_without_ddp.pa100k_only_img_classifier:
115 | test_result_attr = itm_eval_attr_only_img_classifier(score_test_i2t_attr, test_loader.dataset)
116 | else:
117 | test_result_attr = itm_eval_attr(score_test_i2t_attr, test_loader.dataset)
118 | print('test_result_attr:', flush=True)
119 | print(test_result_attr, flush=True)
120 | else:
121 | print('test_result:', flush=True)
122 | mAP(score_test_t2i, test_loader.dataset.g_pids, test_loader.dataset.q_pids)
123 |
124 | dist.barrier()
125 |
126 | else:
127 | print("Start training", flush=True)
128 | train_dataset_size = len(train_dataset)
129 | if utils.is_main_process():
130 | print(f"### data {train_dataset_size}, batch size, {config['batch_size_train']} x {world_size}")
131 | if args.task == "itr_pa100k":
132 | table = PrettyTable(["epoch", "label_mA", "ins_acc", "ins_prec", "ins_rec", "ins_f1"])
133 | else:
134 | table = PrettyTable(["epoch", "R1", "R5", "R10", "mAP", "mINP"])
135 | table.custom_format["R1"] = lambda f, v: f"{v:.3f}"
136 | table.custom_format["R5"] = lambda f, v: f"{v:.3f}"
137 | table.custom_format["R10"] = lambda f, v: f"{v:.3f}"
138 | table.custom_format["mAP"] = lambda f, v: f"{v:.3f}"
139 | table.custom_format["mINP"] = lambda f, v: f"{v:.3f}"
140 | if args.distributed:
141 | num_tasks = utils.get_world_size()
142 | global_rank = utils.get_rank()
143 | if args.task == "itr_icfg":
144 | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None]
145 | else:
146 | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
147 | else:
148 | if args.task == "itr_icfg":
149 | samplers = [None, None]
150 | else:
151 | samplers = [None, None, None]
152 |
153 | if args.task == "itr_icfg":
154 | train_loader, test_loader = create_loader([train_dataset, test_dataset], samplers,
155 | batch_size=[config['batch_size_train']] + [
156 | config['batch_size_test']],
157 | num_workers=[4, 4], is_trains=[True, False],
158 | collate_fns=[None, None])
159 | else:
160 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset], samplers,
161 | batch_size=[config['batch_size_train']] + [
162 | config['batch_size_test']] * 2,
163 | num_workers=[4, 4, 4], is_trains=[True, False, False],
164 | collate_fns=[None, None, None])
165 |
166 | arg_opt = utils.AttrDict(config['optimizer'])
167 | optimizer = create_optimizer(arg_opt, model_without_ddp)
168 | arg_sche = utils.AttrDict(config['schedular'])
169 | arg_sche['step_per_epoch'] = math.ceil(train_dataset_size / (config['batch_size_train'] * world_size))
170 | lr_scheduler = create_scheduler(arg_sche, optimizer)
171 |
172 | max_epoch = config['schedular']['epochs']
173 | best = 0
174 | best_epoch = 0
175 |
176 | if config['mlm']:
177 | mask_generator = TextMaskingGenerator(tokenizer, config['mask_prob'], config['max_masks'],
178 | config['skipgram_prb'], config['skipgram_size'],
179 | config['mask_whole_word'])
180 | else:
181 | mask_generator = None
182 |
183 | for epoch in range(0, max_epoch):
184 | if args.distributed:
185 | train_loader.sampler.set_epoch(epoch)
186 |
187 | if args.task == "itr_pa100k":
188 | if model_without_ddp.pa100k_only_img_classifier:
189 | train_stats = train_pa100k_only_img_classifier(model, train_loader, optimizer, tokenizer, epoch,
190 | device, lr_scheduler, config, mask_generator)
191 | else:
192 | train_stats = train_pa100k(model, train_loader, optimizer, tokenizer, epoch,
193 | device, lr_scheduler, config, mask_generator)
194 | else:
195 | if ('attr' in config.keys()) and config['attr']:
196 | train_stats = train_attr(model, train_loader, optimizer, tokenizer, epoch,
197 | device, lr_scheduler, config, mask_generator)
198 | else:
199 | train_stats = train(model, train_loader, optimizer, tokenizer, epoch,
200 | device, lr_scheduler, config, mask_generator)
201 |
202 | if (epoch + 1) % 1 == 0:
203 | # if args.task not in ["itr_icfg", "itr_pa100k"]:
204 | # score_val_t2i = evaluation(model_without_ddp, val_loader, tokenizer,
205 | # device, config, args)
206 | if args.task == "itr_pa100k":
207 | if model_without_ddp.pa100k_only_img_classifier:
208 | score_test_i2t = evaluation_attr_only_img_classifier(model_without_ddp, test_loader,
209 | tokenizer, device, config, args)
210 | else:
211 | score_test_i2t = evaluation_attr(model_without_ddp, test_loader,
212 | tokenizer, device, config, args)
213 | else:
214 | score_test_t2i = evaluation(model_without_ddp, test_loader,
215 | tokenizer, device, config, args)
216 |
217 | if utils.is_main_process():
218 | # if args.task not in ["itr_icfg", "itr_pa100k"]:
219 | # val_result = mAP(score_val_t2i, val_loader.dataset.g_pids, val_loader.dataset.q_pids, table)
220 | if args.task == "itr_pa100k":
221 | if model_without_ddp.pa100k_only_img_classifier:
222 | test_result = itm_eval_attr_only_img_classifier(score_test_i2t, test_loader.dataset)
223 | else:
224 | test_result = itm_eval_attr(score_test_i2t, test_loader.dataset)
225 | table.add_row([epoch, test_result['label_mA'] * 100, test_result['ins_acc'] * 100,
226 | test_result['ins_prec'] * 100, test_result['ins_rec'] * 100,
227 | test_result['ins_f1'] * 100])
228 | test_result_log = test_result
229 | else:
230 | test_result = mAP(score_test_t2i, test_loader.dataset.g_pids, test_loader.dataset.q_pids, table)
231 | table.add_row([epoch, test_result['R1'], test_result['R5'], test_result['R10'],
232 | test_result['mAP'], test_result['mINP']])
233 | test_result_log = {}
234 | for k, v in test_result.items():
235 | test_result_log[k] = str(np.around(v, 3))
236 | print(table, flush=True)
237 |
238 | log_stats = {'e': epoch,
239 | **{k: v for k, v in test_result_log.items()},
240 | **{k: v for k, v in train_stats.items()},
241 | # **{f'val_{k}': v for k, v in val_result.items()},
242 | }
243 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
244 | f.write(json.dumps(log_stats) + "\n")
245 |
246 | if args.task == "itr_pa100k":
247 | result = test_result['label_mA']
248 | else:
249 | result = test_result['R1']
250 |
251 | if result > best:
252 | save_obj = {'model': model_without_ddp.state_dict(), 'config': config, }
253 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
254 | best = result
255 | best_epoch = epoch
256 | elif epoch >= max_epoch - 1:
257 | save_obj = {
258 | 'model': model_without_ddp.state_dict(),
259 | # 'optimizer': optimizer.state_dict(),
260 | # 'lr_scheduler': lr_scheduler.state_dict(),
261 | 'config': config,
262 | # 'epoch': epoch,
263 | }
264 | torch.save(save_obj, os.path.join(args.output_dir, f'checkpoint_{epoch}.pth'))
265 | dist.barrier()
266 | torch.cuda.empty_cache()
267 |
268 | if utils.is_main_process():
269 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
270 | f.write("best epoch: %d" % best_epoch)
271 |
272 | os.system(f"cat {args.output_dir}/log.txt")
273 |
274 | total_time = time.time() - start_time
275 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
276 | print(' ### Time {}'.format(total_time_str))
277 |
278 |
279 | if __name__ == '__main__':
280 | parser = argparse.ArgumentParser()
281 | parser.add_argument('--checkpoint', type=str)
282 | parser.add_argument('--config', type=str, required=True)
283 | parser.add_argument('--task', type=str, required=True)
284 | parser.add_argument('--output_dir', type=str, required=True)
285 | parser.add_argument('--device', default='cuda')
286 | parser.add_argument('--seed', default=42, type=int)
287 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
288 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
289 | parser.add_argument('--distributed', action='store_false')
290 | parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus")
291 | parser.add_argument('--epo', default=-1, type=int, help="epoch")
292 | parser.add_argument('--evaluate', action='store_true')
293 |
294 | args = parser.parse_args()
295 |
296 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
297 |
298 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
299 |
300 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
301 |
302 | main(args, config)
303 |
--------------------------------------------------------------------------------
/assets/badsamples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/badsamples.jpg
--------------------------------------------------------------------------------
/assets/chart1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/chart1.jpg
--------------------------------------------------------------------------------
/assets/cuhk_freq50.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/cuhk_freq50.jpg
--------------------------------------------------------------------------------
/assets/examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/examples.jpg
--------------------------------------------------------------------------------
/assets/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/framework.jpg
--------------------------------------------------------------------------------
/assets/mals_fre50.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/mals_fre50.jpg
--------------------------------------------------------------------------------
/assets/mals_rst_30.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/mals_rst_30.jpg
--------------------------------------------------------------------------------
/assets/readme.txt:
--------------------------------------------------------------------------------
1 | assets
2 |
--------------------------------------------------------------------------------
/assets/result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shuyu-XJTU/APTM/4b59dc4a85ddc72be662e5e5aa9f058fe85fa9ee/assets/result.jpg
--------------------------------------------------------------------------------
/configs/Retrieval_cuhk.yaml:
--------------------------------------------------------------------------------
1 | image_root: 'images/CUHK-PEDES/'
2 | test_file: 'data/finetune/cuhk_test.json'
3 | val_file: 'data/finetune/cuhk_val.json'
4 | train_file: ['data/finetune/cuhk_train.json']
5 |
6 |
7 | ## Vision Encoder
8 | vision_config: 'configs/config_swinB_384.json'
9 | image_res: 384
10 | patch_size: 32
11 | h: 384
12 | w: 128
13 |
14 |
15 | ## Text Encoder
16 | text_config: 'configs/config_bert.json'
17 | text_encoder: 'data/bert-base-uncased'
18 |
19 |
20 | ## Training
21 | batch_size_train: 120
22 | batch_size_test: 150
23 | batch_size_test_text: 750
24 |
25 | max_tokens: 56
26 | max_words: 56
27 |
28 | embed_dim: 256
29 | temp: 0.07
30 | k_test: 128
31 |
32 |
33 | ## mlm loss
34 | mlm: True
35 | mask_prob: 0.25
36 | max_masks: 10
37 | skipgram_prb: 0.2
38 | skipgram_size: 3
39 | mask_whole_word: True
40 |
41 |
42 | ## Other Settings
43 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2}
44 | schedular: {sched: step, lr: 1e-4, epochs: 30, num_warmup_steps: 0.1}
45 |
46 | pa100k: False
47 | icfg_rstp: False
48 |
49 | lr_2: True
50 | load_params: False
51 | load_pretrained: True
52 |
53 | eda: True
54 | eda_p: 1
55 | erasing_p: 0.6
56 | LabelSmooth: 0
--------------------------------------------------------------------------------
/configs/Retrieval_gene.yaml:
--------------------------------------------------------------------------------
1 | image_root: 'images/CUHK-PEDES/'
2 | test_file: 'data/finetune/cuhk_test.json'
3 | val_file: 'data/finetune/cuhk_val.json'
4 | train_file: ['data/finetune/gene_attrs/g_4x_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_attrs.json',
5 | 'data/finetune/gene_attrs/g_c_g_a_0_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_1_attrs.json',
6 | 'data/finetune/gene_attrs/g_c_g_a_2_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_3_attrs.json',
7 | 'data/finetune/gene_attrs/g_c_g_a_4_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_5_attrs.json',
8 | 'data/finetune/gene_attrs/g_c_g_a_6_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_7_attrs.json',
9 | 'data/finetune/gene_attrs/g_c_g_a_8_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_9_attrs.json',
10 | 'data/finetune/gene_attrs/g_c_g_a_10_attrs.json','data/finetune/gene_attrs/g_c_g_a_11_attrs.json',
11 | 'data/finetune/gene_attrs/g_c_g_a_12_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_13_attrs.json',
12 | 'data/finetune/gene_attrs/g_c_g_a_14_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_15_attrs.json',
13 | 'data/finetune/gene_attrs/g_c_g_a_16_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_17_attrs.json',
14 | 'data/finetune/gene_attrs/g_c_g_a_18_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_19_attrs.json',
15 | 'data/finetune/gene_attrs/g_c_g_a_20_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_21_attrs.json',
16 | 'data/finetune/gene_attrs/g_c_g_a_22_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_23_attrs.json',
17 | 'data/finetune/gene_attrs/g_c_g_a_24_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_25_attrs.json',
18 | 'data/finetune/gene_attrs/g_c_g_a_26_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_27_attrs.json',
19 | 'data/finetune/gene_attrs/g_c_g_a_28_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_29_attrs.json',
20 | 'data/finetune/gene_attrs/g_c_g_a_30_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_31_attrs.json',
21 | 'data/finetune/gene_attrs/g_c_g_a_32_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_33_attrs.json',
22 | 'data/finetune/gene_attrs/g_c_g_a_34_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_35_attrs.json',
23 | 'data/finetune/gene_attrs/g_c_g_a_36_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_37_attrs.json',
24 | 'data/finetune/gene_attrs/g_c_g_a_38_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_39_attrs.json',
25 | 'data/finetune/gene_attrs/g_c_g_a_40_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_41_attrs.json',
26 | 'data/finetune/gene_attrs/g_c_g_a_42_attrs.json', 'data/finetune/gene_attrs/g_c_g_a_43_attrs.json',
27 | 'data/finetune/gene_attrs/g_i_g_a_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_0_attrs.json',
28 | 'data/finetune/gene_attrs/g_i_g_a_1_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_2_attrs.json',
29 | 'data/finetune/gene_attrs/g_i_g_a_3_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_4_attrs.json',
30 | 'data/finetune/gene_attrs/g_i_g_a_5_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_6_attrs.json',
31 | 'data/finetune/gene_attrs/g_i_g_a_7_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_8_attrs.json',
32 | 'data/finetune/gene_attrs/g_i_g_a_9_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_10_attrs.json',
33 | 'data/finetune/gene_attrs/g_i_g_a_11_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_12_attrs.json',
34 | 'data/finetune/gene_attrs/g_i_g_a_13_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_14_attrs.json',
35 | 'data/finetune/gene_attrs/g_i_g_a_15_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_16_attrs.json',
36 | 'data/finetune/gene_attrs/g_i_g_a_17_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_18_attrs.json',
37 | 'data/finetune/gene_attrs/g_i_g_a_19_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_20_attrs.json',
38 | 'data/finetune/gene_attrs/g_i_g_a_21_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_22_attrs.json',
39 | 'data/finetune/gene_attrs/g_i_g_a_23_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_24_attrs.json',
40 | 'data/finetune/gene_attrs/g_i_g_a_25_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_26_attrs.json',
41 | 'data/finetune/gene_attrs/g_i_g_a_27_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_28_attrs.json',
42 | 'data/finetune/gene_attrs/g_i_g_a_29_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_30_attrs.json',
43 | 'data/finetune/gene_attrs/g_i_g_a_31_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_32_attrs.json',
44 | 'data/finetune/gene_attrs/g_i_g_a_33_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_34_attrs.json',
45 | 'data/finetune/gene_attrs/g_i_g_a_35_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_36_attrs.json',
46 | 'data/finetune/gene_attrs/g_i_g_a_37_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_38_attrs.json',
47 | 'data/finetune/gene_attrs/g_i_g_a_39_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_40_attrs.json',
48 | 'data/finetune/gene_attrs/g_i_g_a_41_attrs.json', 'data/finetune/gene_attrs/g_i_g_a_42_attrs.json',
49 | 'data/finetune/gene_attrs/g_i_g_a_43_attrs.json']
50 |
51 |
52 | ## Vision Encoder
53 | vision_config: 'configs/config_swinB_384.json'
54 | image_res: 384
55 | patch_size: 32
56 | h: 384
57 | w: 128
58 |
59 | ## Text Encoder
60 | text_config: 'configs/config_bert.json'
61 | text_encoder: 'data/bert-base-uncased'
62 |
63 |
64 | ## Training
65 | batch_size_train: 150
66 | batch_size_test: 150
67 | batch_size_test_text: 750
68 |
69 | max_tokens: 56
70 | max_words: 56
71 |
72 | embed_dim: 256
73 | temp: 0.07
74 | k_test: 128
75 |
76 |
77 | ## mlm loss
78 | mlm: True
79 | mask_prob: 0.25
80 | max_masks: 10
81 | skipgram_prb: 0.2
82 | skipgram_size: 3
83 | mask_whole_word: True
84 |
85 |
86 | ## Other Settings
87 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2}
88 | schedular: {sched: linear, lr: 1e-4, epochs: 32, num_warmup_steps: 2600}
89 |
90 | pa100k: False
91 | icfg_rstp: False
92 |
93 | lr_2: False
94 | init_cross: True
95 | load_params: True
96 | load_pretrained: False
97 |
98 | erasing_p: 0.6
99 | eda: False
100 | eda_p: 1
101 |
102 | attr: True
103 | LabelSmooth: 0.4
104 | t: 0.8
105 |
--------------------------------------------------------------------------------
/configs/Retrieval_icfg.yaml:
--------------------------------------------------------------------------------
1 | image_root: 'images/ICFG-PEDES/'
2 | train_file: ['data/finetune/icfg_train.json']
3 | test_file: 'data/finetune/icfg_test.json'
4 |
5 |
6 | ## Vision Encoder
7 | vision_config: 'configs/config_swinB_384.json'
8 | image_res: 384
9 | patch_size: 32
10 | h: 384
11 | w: 128
12 |
13 |
14 | ## Text Encoder
15 | text_config: 'configs/config_bert.json'
16 | text_encoder: 'data/bert-base-uncased'
17 |
18 |
19 | ## Training
20 | batch_size_train: 120
21 | batch_size_test: 150
22 | batch_size_test_text: 750
23 |
24 | max_tokens: 56
25 | max_words: 56
26 |
27 | embed_dim: 256
28 | temp: 0.07
29 | k_test: 128
30 |
31 |
32 | ## mlm loss
33 | mlm: True
34 | mask_prob: 0.25
35 | max_masks: 10
36 | skipgram_prb: 0.2
37 | skipgram_size: 3
38 | mask_whole_word: True
39 |
40 |
41 | ## Other Settings
42 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2}
43 | schedular: {sched: step, lr: 1e-4, epochs: 30, num_warmup_steps: 0.1}
44 |
45 | pa100k: False
46 | icfg_rstp: True
47 |
48 | lr_2: True
49 | load_params: False
50 | load_pretrained: True
51 |
52 | erasing_p: 0.6
53 | eda: True
54 | eda_p: 1
55 | LabelSmooth: 0
--------------------------------------------------------------------------------
/configs/Retrieval_pa100k.yaml:
--------------------------------------------------------------------------------
1 | image_root: 'images/pa100k/release_data'
2 | test_file: 'data/finetune/PA100K_test.json'
3 | val_file: 'data/finetune/PA100K_val.json'
4 | train_file: ['data/finetune/PA100K_train.json']
5 |
6 |
7 | ## Vision Encoder
8 | vision_config: 'configs/config_swinB_384.json'
9 | image_res: 384
10 | patch_size: 32
11 | h: 384
12 | w: 128
13 |
14 | ## Text Encoder
15 | text_config: 'configs/config_bert.json'
16 | text_encoder: 'data/bert-base-uncased'
17 |
18 |
19 | ## Training
20 | batch_size_train: 200
21 | batch_size_test: 200
22 | batch_size_test_text: 1000
23 |
24 | max_tokens: 15
25 | max_words: 56
26 |
27 | embed_dim: 256
28 | temp: 0.07
29 | k_test: 128
30 |
31 |
32 | ## mlm loss
33 | mlm: True
34 | mask_prob: 0.25
35 | max_masks: 10
36 | skipgram_prb: 0.2
37 | skipgram_size: 3
38 | mask_whole_word: True
39 |
40 |
41 | ## Other Settings
42 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2}
43 | schedular: {sched: step, lr: 1e-4, epochs: 30, num_warmup_steps: 0.1}
44 |
45 | lr_2: True
46 | load_params: False
47 | load_pretrained: True
48 |
49 | pa100k: True
50 | #pa100k_only_img_classifier: True
51 | #dop: 0.1
52 |
53 | erasing_p: 0.6
54 | LabelSmooth: 0 # 0.4
--------------------------------------------------------------------------------
/configs/Retrieval_rstp.yaml:
--------------------------------------------------------------------------------
1 | image_root: 'images/RSTPReid/'
2 | train_file: ['data/finetune/rstp_train.json']
3 | val_file: 'data/finetune/rstp_val.json'
4 | test_file: 'data/finetune/rstp_test.json'
5 |
6 |
7 | ## Vision Encoder
8 | vision_config: 'configs/config_swinB_384.json'
9 | image_res: 384
10 | patch_size: 32
11 | h: 384
12 | w: 128
13 |
14 |
15 | ## Text Encoder
16 | text_config: 'configs/config_bert.json'
17 | text_encoder: 'data/bert-base-uncased'
18 |
19 |
20 | ## Training
21 | batch_size_train: 120
22 | batch_size_test: 150
23 | batch_size_test_text: 750
24 |
25 | max_tokens: 56
26 | max_words: 56
27 |
28 | embed_dim: 256
29 | temp: 0.07
30 | k_test: 128
31 |
32 |
33 | ## mlm loss
34 | mlm: True
35 | mask_prob: 0.25
36 | max_masks: 10
37 | skipgram_prb: 0.2
38 | skipgram_size: 3
39 | mask_whole_word: True
40 |
41 |
42 | ## Other Settings
43 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2}
44 | schedular: {sched: step, lr: 1e-4, epochs: 30, num_warmup_steps: 0.1}
45 |
46 | pa100k: False
47 | icfg_rstp: True
48 |
49 | lr_2: True
50 | load_params: False
51 | load_pretrained: True
52 |
53 | erasing_p: 0.6
54 | eda: True
55 | eda_p: 1
56 | LabelSmooth: 0
--------------------------------------------------------------------------------
/configs/config_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522,
19 | "fusion_layer": 6,
20 | "encoder_width": 1024
21 | }
22 |
--------------------------------------------------------------------------------
/configs/config_swinB_384.json:
--------------------------------------------------------------------------------
1 | {
2 | "ckpt": "data/swin_base_patch4_window7_224_22k.pth",
3 | "vision_width": 1024,
4 | "image_res": 384,
5 | "h": 384,
6 | "w": 128,
7 | "window_size": 8,
8 | "embed_dim": 128,
9 | "depths": [ 2, 2, 18, 2 ],
10 | "num_heads": [ 4, 8, 16, 32 ]
11 | }
12 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from torchvision import transforms
5 | from torchvision.transforms import InterpolationMode
6 | from PIL import Image
7 |
8 | from dataset.re_dataset import re_train_dataset, re_test_dataset, re_test_dataset_icfg, re_train_dataset_attr, \
9 | re_test_dataset_attr
10 | from dataset.randaugment import RandomAugment
11 | from dataset.random_erasing import RandomErasing
12 |
13 |
14 | def create_dataset(dataset, config, evaluate=False):
15 | # gene
16 | gene_norm = transforms.Normalize((0.4416847, 0.41812873, 0.4237452), (0.3088255, 0.29743394, 0.301009))
17 | # cuhk
18 | cuhk_norm = transforms.Normalize((0.38901278, 0.3651612, 0.34836376), (0.24344306, 0.23738699, 0.23368555))
19 | # icfg
20 | icfg_norm = transforms.Normalize((0.30941582, 0.28956893, 0.30347288), (0.25849792, 0.24547698, 0.2366199))
21 | # rstp
22 | rstp_norm = transforms.Normalize((0.27722597, 0.26065794, 0.3036557), (0.2609547, 0.2508087, 0.25293276))
23 | # pa100k
24 | pa100k_norm = transforms.Normalize((0.46485138, 0.45038012, 0.4632019), (0.25088054, 0.24609283, 0.24240193))
25 |
26 | if dataset == 're_cuhk':
27 | train_norm = cuhk_norm
28 | test_norm = cuhk_norm
29 | elif dataset == 're_icfg':
30 | train_norm = icfg_norm
31 | test_norm = icfg_norm
32 | elif dataset == 're_rstp':
33 | train_norm = rstp_norm
34 | test_norm = rstp_norm
35 | elif dataset == 're_gene':
36 | train_norm = gene_norm
37 | test_norm = cuhk_norm
38 | elif dataset == 're_pa100k':
39 | train_norm = pa100k_norm
40 | test_norm = pa100k_norm
41 |
42 | train_transform = transforms.Compose([
43 | # transforms.RandomResizedCrop((config['h'], config['h']),
44 | # scale=(0.5, 1.0), interpolation=InterpolationMode.BICUBIC),
45 | transforms.Resize((config['h'], config['w']), interpolation=InterpolationMode.BICUBIC),
46 | transforms.RandomHorizontalFlip(),
47 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize',
48 | 'Brightness', 'Sharpness', 'ShearX',
49 | 'ShearY', 'TranslateX', 'TranslateY',
50 | 'Rotate']),
51 | transforms.ToTensor(),
52 | train_norm,
53 | RandomErasing(probability=config['erasing_p'], mean=[0.0, 0.0, 0.0])
54 | ])
55 |
56 | pre_transform = transforms.Compose([
57 | transforms.RandomResizedCrop((config['h'], config['h']),
58 | scale=(0.5, 1.0), interpolation=InterpolationMode.BICUBIC),
59 | transforms.Resize((config['h'], config['w']), interpolation=InterpolationMode.BICUBIC),
60 | transforms.RandomHorizontalFlip(),
61 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize',
62 | 'Brightness', 'Sharpness', 'ShearX',
63 | 'ShearY', 'TranslateX', 'TranslateY',
64 | 'Rotate']),
65 | transforms.ToTensor(),
66 | train_norm,
67 | RandomErasing(probability=config['erasing_p'], mean=[0.0, 0.0, 0.0])
68 | ])
69 |
70 | test_transform = transforms.Compose([
71 | transforms.Resize((config['h'], config['w']), interpolation=InterpolationMode.BICUBIC),
72 | transforms.ToTensor(),
73 | test_norm,
74 | ])
75 |
76 | if dataset == 're_icfg':
77 | test_dataset = re_test_dataset_icfg(config, test_transform)
78 | if evaluate:
79 | return None, test_dataset
80 | train_dataset = re_train_dataset(config, train_transform, pre_transform)
81 | return train_dataset, test_dataset
82 | elif dataset == 're_pa100k':
83 | test_dataset = re_test_dataset_attr(config['test_file'], config, test_transform)
84 | val_dataset = re_test_dataset_attr(config['val_file'], config, test_transform)
85 | if evaluate:
86 | return None, val_dataset, test_dataset
87 | train_dataset = re_train_dataset_attr(config, train_transform)
88 | return train_dataset, val_dataset, test_dataset
89 | else:
90 | test_dataset = re_test_dataset(config['test_file'], config, test_transform)
91 | val_dataset = re_test_dataset(config['val_file'], config, test_transform)
92 | if evaluate:
93 | return None, val_dataset, test_dataset
94 | train_dataset = re_train_dataset(config, train_transform, pre_transform)
95 | return train_dataset, val_dataset, test_dataset
96 |
97 |
98 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
99 | samplers = []
100 | for dataset, shuffle in zip(datasets, shuffles):
101 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
102 | shuffle=shuffle)
103 | samplers.append(sampler)
104 | return samplers
105 |
106 |
107 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
108 | loaders = []
109 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(datasets, samplers, batch_size, num_workers,
110 | is_trains, collate_fns):
111 | if is_train:
112 | shuffle = (sampler is None)
113 | drop_last = True
114 | else:
115 | shuffle = False
116 | drop_last = False
117 | loader = DataLoader(
118 | dataset,
119 | batch_size=bs,
120 | num_workers=n_worker,
121 | pin_memory=True,
122 | sampler=sampler,
123 | shuffle=shuffle,
124 | collate_fn=collate_fn,
125 | drop_last=drop_last,
126 | )
127 | loaders.append(loader)
128 |
129 | if len(loaders) <= 1:
130 | print(f"### be careful: func create_loader returns a list length of {len(loaders)}")
131 |
132 | return loaders
133 |
--------------------------------------------------------------------------------
/dataset/eda.py:
--------------------------------------------------------------------------------
1 | # Easy data augmentation techniques for text classification
2 | # Jason Wei and Kai Zou
3 |
4 | import random
5 | from random import shuffle
6 |
7 | random.seed(1)
8 |
9 | # stop words list
10 | stop_words = ['i', 'me', 'my', 'myself', 'we', 'our',
11 | 'ours', 'ourselves', 'you', 'your', 'yours',
12 | 'yourself', 'yourselves', 'he', 'him', 'his',
13 | 'himself', 'she', 'her', 'hers', 'herself',
14 | 'it', 'its', 'itself', 'they', 'them', 'their',
15 | 'theirs', 'themselves', 'what', 'which', 'who',
16 | 'whom', 'this', 'that', 'these', 'those', 'am',
17 | 'is', 'are', 'was', 'were', 'be', 'been', 'being',
18 | 'have', 'has', 'had', 'having', 'do', 'does', 'did',
19 | 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or',
20 | 'because', 'as', 'until', 'while', 'of', 'at',
21 | 'by', 'for', 'with', 'about', 'against', 'between',
22 | 'into', 'through', 'during', 'before', 'after',
23 | 'above', 'below', 'to', 'from', 'up', 'down', 'in',
24 | 'out', 'on', 'off', 'over', 'under', 'again',
25 | 'further', 'then', 'once', 'here', 'there', 'when',
26 | 'where', 'why', 'how', 'all', 'any', 'both', 'each',
27 | 'few', 'more', 'most', 'other', 'some', 'such', 'no',
28 | 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too',
29 | 'very', 's', 't', 'can', 'will', 'just', 'don',
30 | 'should', 'now', '']
31 |
32 | # cleaning up text
33 | import re
34 |
35 |
36 | def get_only_chars(line):
37 | clean_line = ""
38 |
39 | line = line.replace("’", "")
40 | line = line.replace("'", "")
41 | line = line.replace("-", " ") # replace hyphens with spaces
42 | line = line.replace("\t", " ")
43 | line = line.replace("\n", " ")
44 | line = line.lower()
45 |
46 | for char in line:
47 | if char in 'qwertyuiopasdfghjklzxcvbnm ':
48 | clean_line += char
49 | else:
50 | clean_line += ' '
51 |
52 | clean_line = re.sub(' +', ' ', clean_line) # delete extra spaces
53 | if clean_line[0] == ' ':
54 | clean_line = clean_line[1:]
55 | return clean_line
56 |
57 |
58 | ########################################################################
59 | # Synonym replacement
60 | # Replace n words in the sentence with synonyms from wordnet
61 | ########################################################################
62 |
63 | # for the first time you use wordnet
64 | # import nltk
65 | # nltk.download('wordnet')
66 | from nltk.corpus import wordnet
67 |
68 |
69 | def synonym_replacement(words, n):
70 | new_words = words.copy()
71 | random_word_list = list(set([word for word in words if word not in stop_words]))
72 | random.shuffle(random_word_list)
73 | num_replaced = 0
74 | for random_word in random_word_list:
75 | synonyms = get_synonyms(random_word)
76 | if len(synonyms) >= 1:
77 | synonym = random.choice(list(synonyms))
78 | new_words = [synonym if word == random_word else word for word in new_words]
79 | # print("replaced", random_word, "with", synonym)
80 | num_replaced += 1
81 | if num_replaced >= n: # only replace up to n words
82 | break
83 |
84 | # this is stupid but we need it, trust me
85 | sentence = ' '.join(new_words)
86 | new_words = sentence.split(' ')
87 |
88 | return new_words
89 |
90 |
91 | def get_synonyms(word):
92 | synonyms = set()
93 | for syn in wordnet.synsets(word):
94 | for l in syn.lemmas():
95 | synonym = l.name().replace("_", " ").replace("-", " ").lower()
96 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
97 | synonyms.add(synonym)
98 | if word in synonyms:
99 | synonyms.remove(word)
100 | return list(synonyms)
101 |
102 |
103 | ########################################################################
104 | # Random deletion
105 | # Randomly delete words from the sentence with probability p
106 | ########################################################################
107 |
108 | def random_deletion(words, p):
109 | # obviously, if there's only one word, don't delete it
110 | if len(words) == 1:
111 | return words
112 |
113 | # randomly delete words with probability p
114 | new_words = []
115 | for word in words:
116 | r = random.uniform(0, 1)
117 | if r > p:
118 | new_words.append(word)
119 |
120 | # if you end up deleting all words, just return a random word
121 | if len(new_words) == 0:
122 | rand_int = random.randint(0, len(words) - 1)
123 | return [words[rand_int]]
124 |
125 | return new_words
126 |
127 |
128 | ########################################################################
129 | # Random swap
130 | # Randomly swap two words in the sentence n times
131 | ########################################################################
132 |
133 | def random_swap(words, n):
134 | new_words = words.copy()
135 | for _ in range(n):
136 | new_words = swap_word(new_words)
137 | return new_words
138 |
139 |
140 | def swap_word(new_words):
141 | random_idx_1 = random.randint(0, len(new_words) - 1)
142 | random_idx_2 = random_idx_1
143 | counter = 0
144 | while random_idx_2 == random_idx_1:
145 | random_idx_2 = random.randint(0, len(new_words) - 1)
146 | counter += 1
147 | if counter > 3:
148 | return new_words
149 | new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
150 | return new_words
151 |
152 |
153 | ########################################################################
154 | # Random insertion
155 | # Randomly insert n words into the sentence
156 | ########################################################################
157 |
158 | def random_insertion(words, n):
159 | new_words = words.copy()
160 | for _ in range(n):
161 | add_word(new_words)
162 | return new_words
163 |
164 |
165 | def add_word(new_words):
166 | synonyms = []
167 | counter = 0
168 | while len(synonyms) < 1:
169 | random_word = new_words[random.randint(0, len(new_words) - 1)]
170 | synonyms = get_synonyms(random_word)
171 | counter += 1
172 | if counter >= 10:
173 | return
174 | random_synonym = synonyms[0]
175 | random_idx = random.randint(0, len(new_words) - 1)
176 | new_words.insert(random_idx, random_synonym)
177 |
178 |
179 | ########################################################################
180 | # main data augmentation function
181 | ########################################################################
182 |
183 | def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):
184 | sentence = get_only_chars(sentence)
185 | words = sentence.split(' ')
186 | words = [word for word in words if word != '']
187 | num_words = len(words)
188 |
189 | augmented_sentences = []
190 | num_new_per_technique = int(num_aug / 4) + 1
191 |
192 | # sr
193 | if (alpha_sr > 0):
194 | n_sr = max(1, int(alpha_sr * num_words))
195 | for _ in range(num_new_per_technique):
196 | a_words = synonym_replacement(words, n_sr)
197 | augmented_sentences.append(' '.join(a_words))
198 |
199 | # ri
200 | if (alpha_ri > 0):
201 | n_ri = max(1, int(alpha_ri * num_words))
202 | for _ in range(num_new_per_technique):
203 | a_words = random_insertion(words, n_ri)
204 | augmented_sentences.append(' '.join(a_words))
205 |
206 | # rs
207 | if (alpha_rs > 0):
208 | n_rs = max(1, int(alpha_rs * num_words))
209 | for _ in range(num_new_per_technique):
210 | a_words = random_swap(words, n_rs)
211 | augmented_sentences.append(' '.join(a_words))
212 |
213 | # rd
214 | if (p_rd > 0):
215 | for _ in range(num_new_per_technique):
216 | a_words = random_deletion(words, p_rd)
217 | augmented_sentences.append(' '.join(a_words))
218 |
219 | augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences]
220 | shuffle(augmented_sentences)
221 |
222 | # trim so that we have the desired number of augmented sentences
223 | if num_aug >= 1:
224 | augmented_sentences = augmented_sentences[:num_aug]
225 | else:
226 | keep_prob = num_aug / len(augmented_sentences)
227 | augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]
228 |
229 | # append the original sentence
230 | augmented_sentences.append(sentence)
231 |
232 | return augmented_sentences
--------------------------------------------------------------------------------
/dataset/randaugment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | ## aug functions
6 | def identity_func(img):
7 | return img
8 |
9 |
10 | def autocontrast_func(img, cutoff=0):
11 | '''
12 | same output as PIL.ImageOps.autocontrast
13 | '''
14 | n_bins = 256
15 |
16 | def tune_channel(ch):
17 | n = ch.size
18 | cut = cutoff * n // 100
19 | if cut == 0:
20 | high, low = ch.max(), ch.min()
21 | else:
22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23 | low = np.argwhere(np.cumsum(hist) > cut)
24 | low = 0 if low.shape[0] == 0 else low[0]
25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27 | if high <= low:
28 | table = np.arange(n_bins)
29 | else:
30 | scale = (n_bins - 1) / (high - low)
31 | offset = -low * scale
32 | table = np.arange(n_bins) * scale + offset
33 | table[table < 0] = 0
34 | table[table > n_bins - 1] = n_bins - 1
35 | table = table.clip(0, 255).astype(np.uint8)
36 | return table[ch]
37 |
38 | channels = [tune_channel(ch) for ch in cv2.split(img)]
39 | out = cv2.merge(channels)
40 | return out
41 |
42 |
43 | def equalize_func(img):
44 | '''
45 | same output as PIL.ImageOps.equalize
46 | PIL's implementation is different from cv2.equalize
47 | '''
48 | n_bins = 256
49 |
50 | def tune_channel(ch):
51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52 | non_zero_hist = hist[hist != 0].reshape(-1)
53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54 | if step == 0: return ch
55 | n = np.empty_like(hist)
56 | n[0] = step // 2
57 | n[1:] = hist[:-1]
58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
59 | return table[ch]
60 |
61 | channels = [tune_channel(ch) for ch in cv2.split(img)]
62 | out = cv2.merge(channels)
63 | return out
64 |
65 |
66 | def rotate_func(img, degree, fill=(0, 0, 0)):
67 | '''
68 | like PIL, rotate by degree, not radians
69 | '''
70 | H, W = img.shape[0], img.shape[1]
71 | center = W / 2, H / 2
72 | M = cv2.getRotationMatrix2D(center, degree, 1)
73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
74 | return out
75 |
76 |
77 | def solarize_func(img, thresh=128):
78 | '''
79 | same output as PIL.ImageOps.posterize
80 | '''
81 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
82 | table = table.clip(0, 255).astype(np.uint8)
83 | out = table[img]
84 | return out
85 |
86 |
87 | def color_func(img, factor):
88 | '''
89 | same output as PIL.ImageEnhance.Color
90 | '''
91 | ## implementation according to PIL definition, quite slow
92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
93 | # out = blend(degenerate, img, factor)
94 | # M = (
95 | # np.eye(3) * factor
96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
97 | # )[np.newaxis, np.newaxis, :]
98 | M = (
99 | np.float32([
100 | [0.886, -0.114, -0.114],
101 | [-0.587, 0.413, -0.587],
102 | [-0.299, -0.299, 0.701]]) * factor
103 | + np.float32([[0.114], [0.587], [0.299]])
104 | )
105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106 | return out
107 |
108 |
109 | def contrast_func(img, factor):
110 | """
111 | same output as PIL.ImageEnhance.Contrast
112 | """
113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114 | table = np.array([(
115 | el - mean) * factor + mean
116 | for el in range(256)
117 | ]).clip(0, 255).astype(np.uint8)
118 | out = table[img]
119 | return out
120 |
121 |
122 | def brightness_func(img, factor):
123 | '''
124 | same output as PIL.ImageEnhance.Contrast
125 | '''
126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
127 | out = table[img]
128 | return out
129 |
130 |
131 | def sharpness_func(img, factor):
132 | '''
133 | The differences the this result and PIL are all on the 4 boundaries, the center
134 | areas are same
135 | '''
136 | kernel = np.ones((3, 3), dtype=np.float32)
137 | kernel[1][1] = 5
138 | kernel /= 13
139 | degenerate = cv2.filter2D(img, -1, kernel)
140 | if factor == 0.0:
141 | out = degenerate
142 | elif factor == 1.0:
143 | out = img
144 | else:
145 | out = img.astype(np.float32)
146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
148 | out = out.astype(np.uint8)
149 | return out
150 |
151 |
152 | def shear_x_func(img, factor, fill=(0, 0, 0)):
153 | H, W = img.shape[0], img.shape[1]
154 | M = np.float32([[1, factor, 0], [0, 1, 0]])
155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
156 | return out
157 |
158 |
159 | def translate_x_func(img, offset, fill=(0, 0, 0)):
160 | '''
161 | same output as PIL.Image.transform
162 | '''
163 | H, W = img.shape[0], img.shape[1]
164 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
166 | return out
167 |
168 |
169 | def translate_y_func(img, offset, fill=(0, 0, 0)):
170 | '''
171 | same output as PIL.Image.transform
172 | '''
173 | H, W = img.shape[0], img.shape[1]
174 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
176 | return out
177 |
178 |
179 | def posterize_func(img, bits):
180 | '''
181 | same output as PIL.ImageOps.posterize
182 | '''
183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
184 | return out
185 |
186 |
187 | def shear_y_func(img, factor, fill=(0, 0, 0)):
188 | H, W = img.shape[0], img.shape[1]
189 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
191 | return out
192 |
193 |
194 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
195 | replace = np.array(replace, dtype=np.uint8)
196 | H, W = img.shape[0], img.shape[1]
197 | rh, rw = np.random.random(2)
198 | pad_size = pad_size // 2
199 | ch, cw = int(rh * H), int(rw * W)
200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
202 | out = img.copy()
203 | out[x1:x2, y1:y2, :] = replace
204 | return out
205 |
206 |
207 | ### level to args
208 | def enhance_level_to_args(MAX_LEVEL):
209 | def level_to_args(level):
210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
211 | return level_to_args
212 |
213 |
214 | def shear_level_to_args(MAX_LEVEL, replace_value):
215 | def level_to_args(level):
216 | level = (level / MAX_LEVEL) * 0.3
217 | if np.random.random() > 0.5: level = -level
218 | return (level, replace_value)
219 |
220 | return level_to_args
221 |
222 |
223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
224 | def level_to_args(level):
225 | level = (level / MAX_LEVEL) * float(translate_const)
226 | if np.random.random() > 0.5: level = -level
227 | return (level, replace_value)
228 |
229 | return level_to_args
230 |
231 |
232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
233 | def level_to_args(level):
234 | level = int((level / MAX_LEVEL) * cutout_const)
235 | return (level, replace_value)
236 |
237 | return level_to_args
238 |
239 |
240 | def solarize_level_to_args(MAX_LEVEL):
241 | def level_to_args(level):
242 | level = int((level / MAX_LEVEL) * 256)
243 | return (level, )
244 | return level_to_args
245 |
246 |
247 | def none_level_to_args(level):
248 | return ()
249 |
250 |
251 | def posterize_level_to_args(MAX_LEVEL):
252 | def level_to_args(level):
253 | level = int((level / MAX_LEVEL) * 4)
254 | return (level, )
255 | return level_to_args
256 |
257 |
258 | def rotate_level_to_args(MAX_LEVEL, replace_value):
259 | def level_to_args(level):
260 | level = (level / MAX_LEVEL) * 30
261 | if np.random.random() < 0.5:
262 | level = -level
263 | return (level, replace_value)
264 |
265 | return level_to_args
266 |
267 |
268 | func_dict = {
269 | 'Identity': identity_func,
270 | 'AutoContrast': autocontrast_func,
271 | 'Equalize': equalize_func,
272 | 'Rotate': rotate_func,
273 | 'Solarize': solarize_func,
274 | 'Color': color_func,
275 | 'Contrast': contrast_func,
276 | 'Brightness': brightness_func,
277 | 'Sharpness': sharpness_func,
278 | 'ShearX': shear_x_func,
279 | 'TranslateX': translate_x_func,
280 | 'TranslateY': translate_y_func,
281 | 'Posterize': posterize_func,
282 | 'ShearY': shear_y_func,
283 | }
284 |
285 | translate_const = 10
286 | MAX_LEVEL = 10
287 | replace_value = (128, 128, 128)
288 | arg_dict = {
289 | 'Identity': none_level_to_args,
290 | 'AutoContrast': none_level_to_args,
291 | 'Equalize': none_level_to_args,
292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
293 | 'Solarize': solarize_level_to_args(MAX_LEVEL),
294 | 'Color': enhance_level_to_args(MAX_LEVEL),
295 | 'Contrast': enhance_level_to_args(MAX_LEVEL),
296 | 'Brightness': enhance_level_to_args(MAX_LEVEL),
297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL),
298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
299 | 'TranslateX': translate_level_to_args(
300 | translate_const, MAX_LEVEL, replace_value
301 | ),
302 | 'TranslateY': translate_level_to_args(
303 | translate_const, MAX_LEVEL, replace_value
304 | ),
305 | 'Posterize': posterize_level_to_args(MAX_LEVEL),
306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
307 | }
308 |
309 |
310 | class RandomAugment(object):
311 |
312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
313 | self.N = N
314 | self.M = M
315 | self.isPIL = isPIL
316 | if augs:
317 | self.augs = augs
318 | else:
319 | self.augs = list(arg_dict.keys())
320 |
321 | def get_random_ops(self):
322 | sampled_ops = np.random.choice(self.augs, self.N)
323 | return [(op, 0.5, self.M) for op in sampled_ops]
324 |
325 | def __call__(self, img):
326 | if self.isPIL:
327 | img = np.array(img)
328 | ops = self.get_random_ops()
329 | for name, prob, level in ops:
330 | if np.random.random() > prob:
331 | continue
332 | args = arg_dict[name](level)
333 | img = func_dict[name](img, *args)
334 | return img
335 |
336 |
337 | if __name__ == '__main__':
338 | a = RandomAugment()
339 | img = np.random.randn(32, 32, 3)
340 | a(img)
--------------------------------------------------------------------------------
/dataset/random_erasing.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | #from torchvision.transforms import *
4 |
5 | #from PIL import Image
6 | import random
7 | import math
8 | #import numpy as np
9 |
10 | class RandomErasing(object):
11 | """ Randomly selects a rectangle region in an image and erases its pixels.
12 | 'Random Erasing Data Augmentation' by Zhong et al.
13 | See https://arxiv.org/pdf/1708.04896.pdf
14 | Args:
15 | probability: The probability that the Random Erasing operation will be performed.
16 | sl: Minimum proportion of erased area against input image.
17 | sh: Maximum proportion of erased area against input image.
18 | r1: Minimum aspect ratio of erased area.
19 | mean: Erasing value.
20 | """
21 |
22 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
23 | self.probability = probability
24 | self.mean = mean
25 | self.sl = sl
26 | self.sh = sh
27 | self.r1 = r1
28 |
29 | def __call__(self, img):
30 |
31 | if random.uniform(0, 1) > self.probability:
32 | return img
33 |
34 | for attempt in range(100):
35 | area = img.size()[1] * img.size()[2]
36 |
37 | target_area = random.uniform(self.sl, self.sh) * area
38 | aspect_ratio = random.uniform(self.r1, 1/self.r1)
39 |
40 | h = int(round(math.sqrt(target_area * aspect_ratio)))
41 | w = int(round(math.sqrt(target_area / aspect_ratio)))
42 |
43 | if w < img.size()[2] and h < img.size()[1]:
44 | x1 = random.randint(0, img.size()[1] - h)
45 | y1 = random.randint(0, img.size()[2] - w)
46 | if img.size()[0] == 3:
47 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
48 | img[1, x1:x1+h, y1:y1+w] = self.mean[1]
49 | img[2, x1:x1+h, y1:y1+w] = self.mean[2]
50 | else:
51 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
52 | return img
53 |
54 | return img
55 |
56 |
57 | class RandomGrayscaleErasing(object):
58 | """ Randomly selects a rectangle region in an image and use grayscale image
59 | instead of its pixels.
60 | 'Local Grayscale Transfomation' by Yunpeng Gong.
61 | See https://arxiv.org/pdf/2101.08533.pdf
62 | Args:
63 | probability: The probability that the Random Grayscale Erasing operation will be performed.
64 | sl: Minimum proportion of erased area against input image.
65 | sh: Maximum proportion of erased area against input image.
66 | r1: Minimum aspect ratio of erased area.
67 | """
68 |
69 | def __init__(self, probability: float = 0.2, sl: float = 0.02, sh: float = 0.4, r1: float = 0.3):
70 | self.probability = probability
71 | self.sl = sl
72 | self.sh = sh
73 | self.r1 = r1
74 |
75 | def __call__(self, img):
76 | """
77 | Args:
78 | img: after ToTensor() and Normalize([...]), img's type is Tensor
79 | """
80 | if random.uniform(0, 1) > self.probability:
81 | return img
82 |
83 | height, width = img.size()[-2], img.size()[-1]
84 | area = height * width
85 |
86 | for _ in range(100):
87 |
88 | target_area = random.uniform(self.sl, self.sh) * area
89 | aspect_ratio = random.uniform(self.r1, 1/self.r1) # height / width
90 |
91 | h = int(round(math.sqrt(target_area * aspect_ratio)))
92 | w = int(round(math.sqrt(target_area / aspect_ratio)))
93 |
94 | if w < width and h < height:
95 | # tl
96 | x = random.randint(0, height - h)
97 | y = random.randint(0, width - w)
98 | # unbind channel dim
99 | r, g, b = img.unbind(dim=-3)
100 | # Weighted average method -> grayscale patch
101 | l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
102 | l_img = l_img.unsqueeze(dim=-3) # rebind channel
103 | # erasing
104 | img[0, y:y + h, x:x + w] = l_img[0, y:y + h, x:x + w]
105 | img[1, y:y + h, x:x + w] = l_img[0, y:y + h, x:x + w]
106 | img[2, y:y + h, x:x + w] = l_img[0, y:y + h, x:x + w]
107 |
108 | return img
109 |
110 | return img
--------------------------------------------------------------------------------
/dataset/re_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | from random import randint, shuffle
7 | from random import random as rand
8 | from PIL import Image
9 | from PIL import ImageFile
10 |
11 | import torch
12 | from torch.utils.data import Dataset
13 |
14 | from dataset.utils import pre_caption
15 |
16 | ImageFile.LOAD_TRUNCATED_IMAGES = True
17 | Image.MAX_IMAGE_PIXELS = None
18 |
19 |
20 | class TextMaskingGenerator:
21 | def __init__(self, tokenizer, mask_prob, mask_max, skipgram_prb=0.2, skipgram_size=3, mask_whole_word=True,
22 | use_roberta=False):
23 | self.id2token = {i: w for w, i in tokenizer.get_vocab().items()}
24 | self.use_roberta = use_roberta
25 | for i in range(len(self.id2token)):
26 | assert i in self.id2token.keys() # check
27 | self.cls_token_id = tokenizer.cls_token_id
28 | self.mask_token_id = tokenizer.mask_token_id
29 | self.mask_max = mask_max
30 | self.mask_prob = mask_prob
31 | self.skipgram_prb = skipgram_prb
32 | self.skipgram_size = skipgram_size
33 | self.mask_whole_word = mask_whole_word
34 |
35 | print("len(tokenizer.id2token): ", len(self.id2token), " ---- cls_token_id: ", self.cls_token_id,
36 | " ---- mask_token_id: ", self.mask_token_id, flush=True)
37 |
38 | def get_random_word(self):
39 | i = randint(0, len(self.id2token) - 1)
40 | return i # self.id2token[i]
41 |
42 | def __call__(self, text_ids): # tokens: [CLS] + ...
43 | n_pred = min(self.mask_max, max(1, int(round(len(text_ids) * self.mask_prob))))
44 |
45 | # candidate positions of masked tokens
46 | assert text_ids[0] == self.cls_token_id
47 | special_pos = set([0]) # will not be masked
48 | cand_pos = list(range(1, len(text_ids)))
49 |
50 | shuffle(cand_pos)
51 | masked_pos = set()
52 | max_cand_pos = max(cand_pos)
53 | for pos in cand_pos:
54 | if len(masked_pos) >= n_pred:
55 | break
56 | if pos in masked_pos:
57 | continue
58 |
59 | def _expand_whole_word(st, end):
60 | new_st, new_end = st, end
61 |
62 | if self.use_roberta:
63 | while (new_st > 1) and (self.id2token[text_ids[new_st].item()][0] != 'Ġ'):
64 | new_st -= 1
65 | while (new_end < len(text_ids)) and (self.id2token[text_ids[new_end].item()][0] != 'Ġ'):
66 | new_end += 1
67 | else:
68 | # bert, WordPiece
69 | while (new_st >= 0) and self.id2token[text_ids[new_st].item()].startswith('##'):
70 | new_st -= 1
71 | while (new_end < len(text_ids)) and self.id2token[text_ids[new_end].item()].startswith('##'):
72 | new_end += 1
73 |
74 | return new_st, new_end
75 |
76 | if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb):
77 | # ngram
78 | cur_skipgram_size = randint(2, self.skipgram_size)
79 | if self.mask_whole_word:
80 | st_pos, end_pos = _expand_whole_word(
81 | pos, pos + cur_skipgram_size)
82 | else:
83 | st_pos, end_pos = pos, pos + cur_skipgram_size
84 | else:
85 | if self.mask_whole_word:
86 | st_pos, end_pos = _expand_whole_word(pos, pos + 1)
87 | else:
88 | st_pos, end_pos = pos, pos + 1
89 |
90 | for mp in range(st_pos, end_pos):
91 | if (0 < mp <= max_cand_pos) and (mp not in special_pos):
92 | masked_pos.add(mp)
93 | else:
94 | break
95 |
96 | masked_pos = list(masked_pos)
97 | n_real_pred = len(masked_pos)
98 | if n_real_pred > n_pred:
99 | shuffle(masked_pos)
100 | masked_pos = masked_pos[:n_pred]
101 |
102 | for pos in masked_pos:
103 | if rand() < 0.8: # 80%
104 | text_ids[pos] = self.mask_token_id
105 | elif rand() < 0.5: # 10%
106 | text_ids[pos] = self.get_random_word()
107 |
108 | return text_ids, masked_pos
109 |
110 |
111 | class re_train_dataset(Dataset):
112 | def __init__(self, config, transform, pre_transform):
113 | self.image_root = config['image_root']
114 | self.max_words = config['max_words']
115 | self.icfg_rstp = config['icfg_rstp']
116 | self.eda = config['eda']
117 | self.eda_p = config['eda_p']
118 | ann_file = config['train_file']
119 |
120 | if ('attr' in config.keys()) and config['attr']:
121 | self.attr = True
122 | else:
123 | self.attr = False
124 |
125 | self.transform = transform
126 | self.pre_transform = pre_transform
127 | self.ann = []
128 | for f in ann_file:
129 | self.ann += json.load(open(f, 'r'))
130 |
131 | self.img_ids = {}
132 |
133 | n = 1
134 | for ann in self.ann:
135 | img_id = ann['image_id']
136 | if img_id not in self.img_ids.keys():
137 | self.img_ids[img_id] = n
138 | n += 1
139 |
140 | def __len__(self):
141 | return len(self.ann)
142 |
143 | def __getitem__(self, index):
144 |
145 | ann = self.ann[index]
146 | try:
147 | image_path = os.path.join(self.image_root, ann['image'])
148 | except:
149 | print("self.image_root", self.image_root)
150 | print("ann['image']", ann['image'])
151 | image = Image.open(image_path).convert('RGB')
152 | image1 = self.transform(image)
153 |
154 | caption = pre_caption(ann['caption'], self.max_words)
155 | if self.eda:
156 | caption1 = pre_caption(ann['caption'], self.max_words, self.icfg_rstp, True, self.eda_p)
157 | return image1, caption, caption1, self.img_ids[ann['image_id']]
158 | elif self.attr:
159 | label = torch.tensor(ann['label'])
160 | return image1, caption, self.img_ids[ann['image_id']], label
161 | else:
162 | return image1, caption, self.img_ids[ann['image_id']]
163 |
164 |
165 | class re_test_dataset(Dataset):
166 | def __init__(self, ann_file, config, transform):
167 | self.ann = json.load(open(ann_file, 'r'))
168 | self.transform = transform
169 | self.image_root = config['image_root']
170 | self.max_words = config['max_words']
171 | self.icfg_rstp = config['icfg_rstp']
172 |
173 | self.text = []
174 | self.image = []
175 | self.txt2img = {}
176 | self.img2txt = {}
177 |
178 | self.g_pids = []
179 | self.q_pids = []
180 |
181 | txt_id = 0
182 | for img_id, ann in enumerate(self.ann):
183 | self.g_pids.append(ann['image_id'])
184 | self.image.append(ann['image'])
185 | self.img2txt[img_id] = []
186 |
187 | t = 0
188 | for i, caption in enumerate(ann['caption']):
189 | self.q_pids.append(ann['image_id'])
190 | self.text.append(pre_caption(caption, self.max_words, icfg_rstp=self.icfg_rstp))
191 | self.img2txt[img_id].append(txt_id)
192 | self.txt2img[txt_id] = []
193 | self.txt2img[txt_id].append(img_id)
194 | txt_id += 1
195 | t += 1
196 |
197 | txt_id1 = 0
198 | for img_id1, ann1 in enumerate(self.ann):
199 | for i1, caption1 in enumerate(ann1['caption']):
200 | if ann['image_id'] == ann1['image_id'] and img_id != img_id1:
201 | self.img2txt[img_id].append(txt_id1)
202 | txt_id1 += 1
203 | if ann['image_id'] == ann1['image_id'] and img_id != img_id1:
204 | for temp in range(t):
205 | self.txt2img[txt_id - 1 - temp].append(img_id1)
206 |
207 | def __len__(self):
208 | return len(self.image)
209 |
210 | def __getitem__(self, index):
211 | image_path = os.path.join(self.image_root, self.ann[index]['image'])
212 | image = Image.open(image_path).convert('RGB')
213 | image = self.transform(image)
214 | return image, index
215 |
216 |
217 | class re_test_dataset_icfg(Dataset):
218 | def __init__(self, config, transform):
219 | ann_file = config['test_file']
220 | self.ann = json.load(open(ann_file, 'r'))
221 | self.transform = transform
222 | self.image_root = config['image_root']
223 | self.max_words = config['max_words']
224 |
225 | self.text = []
226 | self.image = []
227 | self.txt2img = {}
228 | self.img2txt = {}
229 |
230 | self.g_pids = []
231 | self.q_pids = []
232 |
233 | for img_id, ann in enumerate(self.ann):
234 | self.image.append(ann['image'])
235 | self.g_pids.append(ann['image_id'])
236 | self.img2txt[img_id] = []
237 | self.img2txt[img_id].append(img_id)
238 |
239 | self.text.append(pre_caption(ann['caption'][0], self.max_words, icfg_rstp=True))
240 | self.q_pids.append(ann['image_id'])
241 |
242 | self.txt2img[img_id] = []
243 | self.txt2img[img_id].append(img_id)
244 |
245 | for img_id1, ann1 in enumerate(self.ann):
246 | if ann['image_id'] == ann1['image_id'] and img_id != img_id1:
247 | self.txt2img[img_id].append(img_id1)
248 | self.img2txt[img_id].append(img_id1)
249 |
250 | def __len__(self):
251 | return len(self.image)
252 |
253 | def __getitem__(self, index):
254 | image_path = os.path.join(self.image_root, self.ann[index]['image'])
255 | image = Image.open(image_path).convert('RGB')
256 | image = self.transform(image)
257 | return image, index
258 |
259 |
260 | class re_train_dataset_attr(Dataset):
261 | def __init__(self, config, transform):
262 | ann_file = config['train_file']
263 | self.ann = []
264 | for f in ann_file:
265 | self.ann += json.load(open(f, 'r'))
266 | self.transform = transform
267 | self.image_root = config['image_root']
268 | self.max_words = config['max_words']
269 |
270 | def __len__(self):
271 | return len(self.ann)
272 |
273 | def __getitem__(self, index):
274 | ann = self.ann[index]
275 | image_path = os.path.join(self.image_root, ann['image'])
276 | image = Image.open(image_path).convert('RGB')
277 | image = self.transform(image)
278 | label = torch.tensor(ann['label'])
279 | return image, label
280 |
281 |
282 | class re_test_dataset_attr(Dataset):
283 | def __init__(self, ann_file, config, transform):
284 | self.ann = json.load(open(ann_file, 'r'))
285 | self.transform = transform
286 | self.image_root = config['image_root']
287 | self.max_words = config['max_words']
288 |
289 | self.image = []
290 | self.label = []
291 | for img_id, ann in enumerate(self.ann):
292 | self.image.append(ann['image'])
293 | self.label.append(ann['label'])
294 | self.label = np.array(self.label)
295 |
296 | def __len__(self):
297 | return len(self.image)
298 |
299 | def __getitem__(self, index):
300 | image_path = os.path.join(self.image_root, self.ann[index]['image'])
301 | image = Image.open(image_path).convert('RGB')
302 | image = self.transform(image)
303 | return image, index
304 |
--------------------------------------------------------------------------------
/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import os
4 | import random
5 | import numpy as np
6 | import torch
7 | import torch.distributed as dist
8 | import torch.nn.functional as F
9 |
10 | import utils
11 | from tqdm import tqdm
12 | from dataset.eda import *
13 |
14 | def pre_caption(caption, max_words, icfg_rstp=False, is_eda=False, eda_p=0.5):
15 | if icfg_rstp:
16 | try:
17 | caption = re.sub(
18 | r'[^0-9a-z]+',
19 | ' ',
20 | caption.lower(),
21 | )
22 | except:
23 | print(caption)
24 | caption_words = caption.split()
25 | caption = ' '.join(caption_words)
26 |
27 | # eda
28 | if is_eda and random.random() < eda_p:
29 | caption = eda(caption, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=1)[0]
30 |
31 | # truncate caption
32 | caption_words = caption.split()
33 | if len(caption_words) > max_words:
34 | caption = ' '.join(caption_words[:max_words])
35 |
36 | if not len(caption):
37 | raise ValueError("pre_caption yields invalid text")
38 |
39 | return caption
40 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.aptm import APTM
2 | from models.aptm import load_pretrained
3 | from models.aptm import AllGather
--------------------------------------------------------------------------------
/models/aptm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.distributed as dist
9 | from torch.nn import init
10 | from timm.models.layers import trunc_normal_
11 | from functools import partial
12 |
13 | from models.swin_transformer import SwinTransformer, interpolate_relative_pos_embed
14 | from models.bert import BertConfig, BertForMaskedLM, BertModel
15 | from utils import read_json
16 |
17 |
18 | class CrossEntropyLabelSmooth(nn.Module):
19 | """Cross entropy loss with label smoothing regularizer.
20 | Reference:
21 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
22 | Equation: y = (1 - epsilon) * y + epsilon / K.
23 | Args:
24 | epsilon (float): weight.
25 | """
26 |
27 | def __init__(self, epsilon=0.1, use_gpu=True):
28 | super(CrossEntropyLabelSmooth, self).__init__()
29 | self.epsilon = epsilon
30 | self.use_gpu = use_gpu
31 | self.logsoftmax = nn.LogSoftmax(dim=1)
32 |
33 | def forward(self, inputs, targets):
34 | """
35 | Args:
36 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
37 | targets: ground truth labels with shape (num_classes)
38 | """
39 | _, num_classes = inputs.shape
40 | log_probs = self.logsoftmax(inputs)
41 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
42 | if self.use_gpu: targets = targets.cuda()
43 | targets = (1 - self.epsilon) * targets + self.epsilon / num_classes
44 | loss = (- targets * log_probs).mean(0).sum()
45 | return loss
46 |
47 |
48 | class AllGather(torch.autograd.Function):
49 | """An autograd function that performs allgather on a tensor."""
50 |
51 | @staticmethod
52 | def forward(ctx, tensor, rank, world_size):
53 | output = [torch.empty_like(tensor) for _ in range(world_size)]
54 | dist.all_gather(output, tensor)
55 | ctx.rank = rank
56 | ctx.batch_size = tensor.shape[0]
57 | return torch.cat(output, 0)
58 |
59 | @staticmethod
60 | def backward(ctx, grad_output):
61 | return (
62 | grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
63 | None,
64 | None
65 | )
66 |
67 |
68 | allgather = AllGather.apply
69 |
70 |
71 | def build_vision_encoder(config, load_params=False):
72 | """
73 | Args: load_params: False when building fine-tuning models
74 | """
75 |
76 | print('use_swin')
77 | vision_config = read_json(config['vision_config'])
78 | assert config['image_res'] == vision_config['image_res']
79 | assert config['patch_size'] == 32
80 | vision_width = vision_config['vision_width']
81 |
82 | vision_encoder = SwinTransformer(img_size=vision_config['image_res'],
83 | h=vision_config['h'],
84 | w=vision_config['w'],
85 | patch_size=4,
86 | in_chans=3,
87 | embed_dim=vision_config['embed_dim'],
88 | depths=vision_config['depths'],
89 | num_heads=vision_config['num_heads'],
90 | window_size=vision_config['window_size'],
91 | mlp_ratio=4.,
92 | qkv_bias=True,
93 | drop_rate=0.0,
94 | drop_path_rate=0.1,
95 | ape=False,
96 | patch_norm=True,
97 | use_checkpoint=False)
98 |
99 | if load_params:
100 | # download from https://github.com/microsoft/Swin-Transformer
101 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model']
102 | window_size = vision_config['window_size']
103 |
104 | for k in list(state_dict.keys()):
105 | if 'relative_position_bias_table' in k:
106 | if 'layers.3' in k:
107 | window_size = 4
108 | dst_num_pos = (2 * window_size - 1) ** 2
109 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k)
110 | elif ('relative_position_index' in k) or ('attn_mask' in k):
111 | del state_dict[k]
112 | print("### build_vision_encoder: ", flush=True)
113 | msg = vision_encoder.load_state_dict(state_dict, strict=False)
114 | print("missing_keys: ", msg.missing_keys)
115 | print("unexpected_keys: ", msg.unexpected_keys)
116 |
117 | return vision_encoder, vision_width
118 |
119 |
120 | def build_text_encoder(config, vision_width, load_text_params=False, use_mlm_loss=False, config_text=None):
121 | init_params = [] # train from scratch with larger lr
122 |
123 | if config_text is None:
124 | config_text = BertConfig.from_json_file(config['text_config'])
125 |
126 | config_text.encoder_width = vision_width
127 |
128 | if use_mlm_loss:
129 | text_encoder, msg = BertForMaskedLM.from_pretrained(config['text_encoder'], config=config_text,
130 | output_loading_info=True)
131 | if ('init_cross' in config.keys()) and config['init_cross']:
132 | init_params.extend(['text_encoder.' + n for n in msg['missing_keys']]) # of cross attention
133 | print("### init_params.extend --> cross attention ###")
134 |
135 | if load_text_params:
136 | print("### build_text_encoder --> Load BERT: ")
137 | for k, v in msg.items():
138 | print(f"{k}: {sorted(v)}")
139 | return text_encoder, init_params
140 |
141 |
142 | def build_mlp(input_dim, output_dim):
143 | return nn.Sequential(
144 | nn.Linear(input_dim, input_dim * 2),
145 | nn.LayerNorm(input_dim * 2),
146 | nn.GELU(),
147 | nn.Linear(input_dim * 2, output_dim)
148 | )
149 |
150 |
151 | def attr_mlp(input_dim, inter_dim, output_dim, after_cross, dropout_p):
152 | if after_cross:
153 | new_mlp = nn.Sequential(
154 | nn.Flatten(),
155 | nn.Linear(input_dim, inter_dim),
156 | nn.LayerNorm(inter_dim),
157 | nn.Dropout(p=dropout_p),
158 | nn.Linear(inter_dim, output_dim)
159 | )
160 | else:
161 | new_mlp = nn.Sequential(
162 | nn.Flatten(),
163 | nn.Linear(input_dim, inter_dim),
164 | nn.BatchNorm1d(inter_dim),
165 | nn.Dropout(p=dropout_p),
166 | nn.Linear(inter_dim, output_dim)
167 | )
168 | init.normal_(new_mlp[1].weight.data, std=0.00001)
169 | init.constant_(new_mlp[1].bias.data, 0.0)
170 | init.normal_(new_mlp[4].weight.data, std=0.00001)
171 | init.constant_(new_mlp[4].bias.data, 0.0)
172 | return new_mlp
173 |
174 |
175 | def load_pretrained(ckpt_rpath, config, is_eval=False, load_text=False):
176 | checkpoint = torch.load(ckpt_rpath, map_location='cpu')
177 | state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint
178 | if is_eval:
179 | return state_dict
180 |
181 | print("### Loading pretrained vision encoder", flush=True)
182 |
183 | if load_text:
184 | print("### Loading pretrained text encoder", flush=True)
185 | for key in list(state_dict.keys()):
186 | if 'text_encoder.' in key:
187 | if not config['mlm']:
188 | if 'bert.' in key:
189 | encoder_key = key.replace('bert.', '')
190 | state_dict[encoder_key] = state_dict[key]
191 | del state_dict[key]
192 | else:
193 | if 'bert.' not in key and 'cls' not in key:
194 | encoder_key = key.replace('text_encoder.', 'text_encoder.bert.')
195 | state_dict[encoder_key] = state_dict[key]
196 | del state_dict[key]
197 |
198 | return state_dict
199 |
200 |
201 | class APTM(nn.Module):
202 | def __init__(self, config=None, load_vision_params=False, load_text_params=False,
203 | use_contrastive_loss=False, use_matching_loss=False,
204 | use_mlm_loss=False, config_text=None):
205 | super().__init__()
206 | self.init_params = [] # train from scratch with larger lr
207 |
208 | self.vision_encoder, vision_width = build_vision_encoder(config, load_params=load_vision_params)
209 | self.vision_width = vision_width
210 |
211 | if ('pa100k_only_img_classifier' in config.keys()) and config['pa100k_only_img_classifier']:
212 | self.pa100k_only_img_classifier = config['pa100k_only_img_classifier']
213 | self.img_cls = attr_mlp(self.vision_width, config['embed_dim'], 26, False, config['dop'])
214 | self.criterion = nn.BCEWithLogitsLoss()
215 | self.criterion = self.criterion.cuda()
216 |
217 | else:
218 | self.pa100k_only_img_classifier = False
219 | # text & cross-modal
220 | self.text_encoder, init_params = build_text_encoder(config, vision_width=vision_width,
221 | load_text_params=load_text_params,
222 | use_mlm_loss=use_mlm_loss, config_text=config_text)
223 | self.text_width = self.text_encoder.config.hidden_size # i.e. cross_width
224 | self.init_params.extend(init_params)
225 | if 0 < config['LabelSmooth'] < 1:
226 | self.new_cross_entropy = CrossEntropyLabelSmooth(epsilon=config['LabelSmooth'])
227 | self.add_label_smooth = True
228 | else:
229 | self.add_label_smooth = False
230 |
231 | # lr * x
232 | if config['lr_2']:
233 | # vision encoder
234 | for i in range(2, 4):
235 | for name, param in self.vision_encoder.layers[i].named_parameters():
236 | # param.requires_grad = False
237 | self.init_params.extend(['vision_encoder.layers.' + str(i) + '.' + name])
238 | # text encoder
239 | if config['mlm']:
240 | self.init_params.extend(
241 | ['text_encoder.cls.' + n for n, _ in self.text_encoder.cls.named_parameters()])
242 | temp_name = 'text_encoder.bert.encoder.layer.'
243 | temp_encoder = self.text_encoder.bert
244 | else:
245 | temp_name = 'text_encoder.encoder.layer.'
246 | temp_encoder = self.text_encoder
247 | temp_list = [4, 5, 10, 11]
248 | for i in temp_list:
249 | for name, param in temp_encoder.encoder.layer[i].named_parameters():
250 | self.init_params.extend([temp_name + str(i) + '.' + name])
251 |
252 | if use_contrastive_loss:
253 | self.embed_dim = config['embed_dim']
254 | self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
255 | self.text_proj = nn.Linear(self.text_width, self.embed_dim)
256 | self.temp = nn.Parameter(torch.ones([]) * config['temp'])
257 | if config['lr_2']:
258 | self.init_params.extend(['vision_proj.' + n for n, _ in self.vision_proj.named_parameters()])
259 | self.init_params.extend(['text_proj.' + n for n, _ in self.text_proj.named_parameters()])
260 |
261 | if use_matching_loss:
262 | self.itm_head = build_mlp(input_dim=self.text_width, output_dim=2)
263 | if config['lr_2']:
264 | self.init_params.extend(['itm_head.' + n for n, _ in self.itm_head.named_parameters()])
265 |
266 | def load_pretrained(self, ckpt_rpath, config, is_eval=False):
267 | state_dict = load_pretrained(ckpt_rpath, config, is_eval=is_eval, load_text=True)
268 | msg = self.load_state_dict(state_dict, strict=False)
269 | print('load checkpoint from %s' % ckpt_rpath)
270 | print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p])
271 | print("unexpected_keys: ", msg.unexpected_keys)
272 |
273 | def get_vision_embeds(self, image):
274 | image_embeds = self.vision_encoder(image)
275 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
276 | return image_embeds, image_atts
277 |
278 | def get_text_embeds(self, text_ids, text_atts):
279 | encoder = self.text_encoder.bert if hasattr(self.text_encoder, 'bert') else self.text_encoder
280 | return encoder(text_ids, attention_mask=text_atts, return_dict=True, mode='text').last_hidden_state
281 |
282 | def get_cross_embeds(self, image_embeds, image_atts, text_ids=None, text_embeds=None, text_atts=None):
283 | assert text_atts is not None
284 | encoder = self.text_encoder.bert if hasattr(self.text_encoder, 'bert') else self.text_encoder
285 | if text_embeds is not None:
286 | return encoder(encoder_embeds=text_embeds,
287 | attention_mask=text_atts,
288 | encoder_hidden_states=image_embeds,
289 | encoder_attention_mask=image_atts,
290 | return_dict=True,
291 | mode='fusion',
292 | ).last_hidden_state
293 | elif text_ids is not None:
294 | return encoder(text_ids,
295 | attention_mask=text_atts,
296 | encoder_hidden_states=image_embeds,
297 | encoder_attention_mask=image_atts,
298 | return_dict=True,
299 | ).last_hidden_state
300 | else:
301 | raise ValueError
302 |
303 | def get_features(self, image_embeds=None, text_embeds=None):
304 | if image_embeds is None:
305 | text_feat = self.text_proj(text_embeds[:, 0, :])
306 | return text_feat
307 | elif text_embeds is None:
308 | image_feat = self.vision_proj(image_embeds[:, 0, :])
309 | return image_feat
310 | else:
311 | image_feat = self.vision_proj(image_embeds[:, 0, :])
312 | text_feat = self.text_proj(text_embeds[:, 0, :])
313 | return image_feat, text_feat
314 |
315 | def get_contrastive_loss(self, image_feat, text_feat, idx=None):
316 | assert image_feat.size(-1) == self.embed_dim
317 | assert text_feat.size(-1) == self.embed_dim
318 | image_feat = F.normalize(image_feat, dim=-1)
319 | text_feat = F.normalize(text_feat, dim=-1)
320 |
321 | image_feat_all = allgather(image_feat, torch.distributed.get_rank(), torch.distributed.get_world_size())
322 | text_feat_all = allgather(text_feat, torch.distributed.get_rank(), torch.distributed.get_world_size())
323 | logits = image_feat_all @ text_feat_all.t() / self.temp
324 | bsz = image_feat_all.shape[0]
325 |
326 | if idx is None:
327 | labels = torch.arange(bsz, device=image_feat.device)
328 | loss_i2t = F.cross_entropy(logits, labels)
329 | loss_t2i = F.cross_entropy(logits.t(), labels)
330 | return (loss_i2t + loss_t2i) / 2
331 | else:
332 | idx = idx.view(-1, 1)
333 | assert idx.size(0) == image_feat.size(0)
334 | idx_all = allgather(idx, torch.distributed.get_rank(), torch.distributed.get_world_size())
335 | pos_idx = torch.eq(idx_all, idx_all.t()).float()
336 | labels = pos_idx / pos_idx.sum(1, keepdim=True)
337 |
338 | loss_i2t = -torch.sum(F.log_softmax(logits, dim=1) * labels, dim=1).mean()
339 | loss_t2i = -torch.sum(F.log_softmax(logits.t(), dim=1) * labels, dim=1).mean()
340 | return (loss_i2t + loss_t2i) / 2
341 |
342 | def get_matching_loss(self, image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat, idx=None):
343 | """
344 | Matching Loss with hard negatives
345 | """
346 | bs = image_embeds.size(0)
347 |
348 | image_feat = F.normalize(image_feat, dim=-1)
349 | text_feat = F.normalize(text_feat, dim=-1)
350 |
351 | with torch.no_grad():
352 | sim_i2t = image_feat @ text_feat.t() / self.temp
353 | sim_t2i = text_feat @ image_feat.t() / self.temp
354 |
355 | weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-5
356 | weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-5
357 |
358 | if idx is None:
359 | weights_i2t.fill_diagonal_(0)
360 | weights_t2i.fill_diagonal_(0)
361 | else:
362 | idx = idx.view(-1, 1)
363 | assert idx.size(0) == bs
364 | mask = torch.eq(idx, idx.t())
365 | weights_i2t.masked_fill_(mask, 0)
366 | weights_t2i.masked_fill_(mask, 0)
367 |
368 | image_embeds_neg = []
369 | image_atts_neg = []
370 | for b in range(bs):
371 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
372 | image_embeds_neg.append(image_embeds[neg_idx])
373 | image_atts_neg.append(image_atts[neg_idx])
374 | image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
375 | image_atts_neg = torch.stack(image_atts_neg, dim=0)
376 |
377 | text_embeds_neg = []
378 | text_atts_neg = []
379 | for b in range(bs):
380 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
381 | text_embeds_neg.append(text_embeds[neg_idx])
382 | text_atts_neg.append(text_atts[neg_idx])
383 | text_embeds_neg = torch.stack(text_embeds_neg, dim=0)
384 | text_atts_neg = torch.stack(text_atts_neg, dim=0)
385 |
386 | text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)
387 | text_atts_all = torch.cat([text_atts, text_atts_neg], dim=0)
388 | image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)
389 | image_atts_all = torch.cat([image_atts_neg, image_atts], dim=0)
390 |
391 | cross_pos = self.get_cross_embeds(image_embeds, image_atts, text_embeds=text_embeds,
392 | text_atts=text_atts)[:, 0, :]
393 | cross_neg = self.get_cross_embeds(image_embeds_all, image_atts_all, text_embeds=text_embeds_all,
394 | text_atts=text_atts_all)[:, 0, :]
395 |
396 | output = self.itm_head(torch.cat([cross_pos, cross_neg], dim=0))
397 | itm_labels = torch.cat([torch.ones(bs, dtype=torch.long),
398 | torch.zeros(2 * bs, dtype=torch.long)], dim=0).to(image_embeds.device)
399 | itm_loss = F.cross_entropy(output, itm_labels)
400 |
401 | return itm_loss
402 |
403 | def get_mlm_loss(self, text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids):
404 | return self.text_encoder(text_ids_masked,
405 | attention_mask=text_atts,
406 | encoder_hidden_states=image_embeds,
407 | encoder_attention_mask=image_atts,
408 | return_dict=True,
409 | labels=masked_ids,
410 | masked_pos=masked_pos).loss
411 |
412 | def label_smooth_loss(self, inputs, targets):
413 | bs = inputs.size(0)
414 | inputs_neg = []
415 | targets_neg = []
416 | for b in range(bs):
417 | if targets[b] != -1:
418 | inputs_neg.append(inputs[b])
419 | targets_neg.append(targets[b])
420 | if not inputs_neg:
421 | return 0
422 | inputs = torch.stack(inputs_neg, dim=0)
423 | targets = torch.stack(targets_neg, dim=0)
424 | return self.new_cross_entropy(inputs, targets)
425 |
426 | def get_contrastive_loss_attr(self, image_feat, text_feat, label):
427 | image_feat = F.normalize(image_feat, dim=-1)
428 | text_feat = F.normalize(text_feat, dim=-1)
429 | logits = image_feat @ text_feat.t() / self.temp
430 | l = 0
431 | for i in range(label.size(1)):
432 | left = 2 * i
433 | right = 2 * i + 2
434 | if self.add_label_smooth:
435 | l = l + self.label_smooth_loss(logits[:, left:right], label[:, i])
436 | else:
437 | l = l + F.cross_entropy(logits[:, left:right], label[:, i], ignore_index=-1)
438 |
439 | return l / label.size(1)
440 |
441 | def get_matching_loss_attr(self, image_embeds, image_atts, text_embeds, text_atts, label):
442 | bs = image_embeds.size(0)
443 |
444 | labels = []
445 | for i in range(label.size(1)):
446 | l = 1 - label[:, i]
447 | l = torch.where(l == 2, -1, l)
448 | labels.append(l)
449 | labels.append(label[:, i])
450 | labels = torch.stack(labels, dim=1)
451 |
452 | r = random.sample(range(0, text_embeds.size(0)), 5)
453 | ll = 0
454 | for t in r:
455 | text_embeds_0 = text_embeds[t].repeat(bs, 1, 1)
456 | text_atts_0 = text_atts[t].repeat(bs, 1, 1)
457 | cross_0 = self.get_cross_embeds(image_embeds, image_atts, text_embeds=text_embeds_0,
458 | text_atts=text_atts_0)[:, 0, :]
459 | output_0 = self.itm_head(cross_0)
460 | if self.add_label_smooth:
461 | ll = ll + self.label_smooth_loss(output_0, labels[:, t])
462 | else:
463 | ll = ll + F.cross_entropy(output_0, labels[:, t], ignore_index=-1)
464 | return ll / 5
465 |
466 | def get_mlm_loss_attr(self, text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids, label):
467 |
468 | labels = []
469 | for i in range(label.size(1)):
470 | l = 1 - label[:, i]
471 | l = torch.where(l == 2, -1, l)
472 | labels.append(l)
473 | labels.append(label[:, i])
474 | labels = torch.stack(labels, dim=1)
475 |
476 | image_embeds_pos = []
477 | image_atts_pos = []
478 | text_ids_masked_pos = []
479 | text_atts_pos = []
480 | masked_pos_pos = []
481 | masked_ids_pos = []
482 | for b in range(text_atts.size(0)):
483 | temp_label = labels[:, b]
484 | temp_label = torch.where(temp_label == -1, 0, temp_label)
485 | if torch.count_nonzero(temp_label).item() > 0:
486 | text_ids_masked_pos.append(text_ids_masked[b])
487 | text_atts_pos.append(text_atts[b])
488 | masked_pos_pos.append(masked_pos[b])
489 | masked_ids_pos.append(masked_ids[b])
490 | idx = torch.multinomial(temp_label.float(), 1).item()
491 | image_embeds_pos.append(image_embeds[idx])
492 | image_atts_pos.append(image_atts[idx])
493 |
494 | image_embeds_pos = torch.stack(image_embeds_pos, dim=0)
495 | image_atts_pos = torch.stack(image_atts_pos, dim=0)
496 | text_ids_masked_pos = torch.stack(text_ids_masked_pos, dim=0)
497 | text_atts_pos = torch.stack(text_atts_pos, dim=0)
498 | masked_pos_pos = torch.stack(masked_pos_pos, dim=0)
499 | masked_ids_pos = torch.stack(masked_ids_pos, dim=0)
500 |
501 | loss = self.text_encoder(text_ids_masked_pos,
502 | attention_mask=text_atts_pos,
503 | encoder_hidden_states=image_embeds_pos,
504 | encoder_attention_mask=image_atts_pos,
505 | return_dict=True,
506 | labels=masked_ids_pos,
507 | masked_pos=masked_pos_pos).loss
508 | return loss
509 |
--------------------------------------------------------------------------------
/models/model_retrieval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import APTM, load_pretrained, AllGather
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class APTM_Retrieval(APTM):
8 | def __init__(self, config):
9 | super().__init__(config, load_vision_params=config['load_params'], load_text_params=config['load_params'],
10 | use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=config['mlm'])
11 |
12 | if not self.pa100k_only_img_classifier:
13 | self.mlm = config['mlm']
14 | self.pa100k = config['pa100k']
15 | if not self.pa100k:
16 | self.eda = config['eda']
17 | if ('attr' in config.keys()) and config['attr']:
18 | self.attr = True
19 | else:
20 | self.attr = False
21 |
22 | def load_pretrained(self, ckpt_rpath, config, is_eval=False):
23 | state_dict = load_pretrained(ckpt_rpath, config, is_eval=is_eval, load_text=True)
24 | msg = self.load_state_dict(state_dict, strict=False)
25 | print('load checkpoint from %s' % ckpt_rpath)
26 | print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p])
27 | print("vision_encoder missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' in p])
28 | print("unexpected_keys: ", msg.unexpected_keys)
29 |
30 | def forward(self, image, text_ids, text_atts, text_ids_masked=None, masked_pos=None, masked_ids=None,
31 | idx=None, attr_text_ids=None, attr_text_atts=None, attr_text_ids_masked=None,
32 | attr_masked_pos=None, attr_masked_ids=None, label=None, text_ids_eda=None, text_atts_eda=None):
33 |
34 | if self.pa100k_only_img_classifier:
35 | image_embeds = self.vision_encoder(image)
36 | outputs = self.img_cls(image_embeds[:, 0, :])
37 | loss = self.criterion(outputs, label.float())
38 | return loss
39 |
40 | if self.pa100k:
41 | image_embeds, image_atts = self.get_vision_embeds(image)
42 | text_embeds = self.get_text_embeds(text_ids, text_atts)
43 | image_feat, text_feat = self.get_features(image_embeds, text_embeds)
44 | loss_itc = self.get_contrastive_loss_attr(image_feat, text_feat, label)
45 | loss_itm = self.get_matching_loss_attr(image_embeds, image_atts, text_embeds, text_atts, label)
46 | if self.mlm:
47 | loss_mlm = self.get_mlm_loss_attr(text_ids_masked, text_atts, image_embeds, image_atts,
48 | masked_pos, masked_ids, label)
49 | return loss_itc, loss_itm, loss_mlm
50 | else:
51 | return loss_itc, loss_itm
52 |
53 | if self.attr:
54 | image_embeds, image_atts = self.get_vision_embeds(image)
55 | text_embeds = self.get_text_embeds(text_ids, text_atts)
56 | image_feat, text_feat = self.get_features(image_embeds, text_embeds)
57 |
58 | attr_text_embeds = self.get_text_embeds(attr_text_ids, attr_text_atts)
59 | attr_text_feat = self.get_features(text_embeds=attr_text_embeds)
60 |
61 | attr_loss_itc = self.get_contrastive_loss_attr(image_feat, attr_text_feat, label)
62 | attr_loss_itm = self.get_matching_loss_attr(image_embeds, image_atts, attr_text_embeds, attr_text_atts,
63 | label)
64 |
65 | loss_itc = self.get_contrastive_loss(image_feat, text_feat, idx=idx)
66 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat,
67 | text_embeds, text_atts, text_feat, idx=idx)
68 |
69 | if self.mlm:
70 | attr_loss_mlm = self.get_mlm_loss_attr(attr_text_ids_masked, attr_text_atts, image_embeds, image_atts,
71 | attr_masked_pos, attr_masked_ids, label)
72 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, image_embeds, image_atts, masked_pos,
73 | masked_ids)
74 | loss_attr = (attr_loss_itc + attr_loss_itm + attr_loss_mlm) / 3
75 | return loss_itc, loss_itm, loss_mlm, loss_attr
76 | else:
77 | loss_attr = (attr_loss_itc + attr_loss_itm) / 2
78 | return loss_itc, loss_itm, loss_attr
79 |
80 | image_embeds, image_atts = self.get_vision_embeds(image)
81 | text_embeds = self.get_text_embeds(text_ids, text_atts)
82 | image_feat, text_feat = self.get_features(image_embeds, text_embeds)
83 | loss_itc = self.get_contrastive_loss(image_feat, text_feat, idx=idx)
84 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat,
85 | text_embeds, text_atts, text_feat, idx=idx)
86 |
87 | # eda
88 | if self.eda:
89 | text_embeds_eda = self.get_text_embeds(text_ids_eda, text_atts_eda)
90 | text_feat_eda = self.get_features(text_embeds=text_embeds_eda)
91 | loss_itc_eda = self.get_contrastive_loss(image_feat, text_feat_eda, idx=idx)
92 | loss_itm_eda = self.get_matching_loss(image_embeds, image_atts, image_feat,
93 | text_embeds_eda, text_atts_eda, text_feat_eda, idx=idx)
94 | loss_itc = loss_itc + 0.8 * loss_itc_eda
95 | loss_itm = loss_itm + 0.8 * loss_itm_eda
96 |
97 | if self.mlm:
98 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, image_embeds, image_atts, masked_pos,
99 | masked_ids)
100 | return loss_itc, loss_itm, loss_mlm
101 | else:
102 | return loss_itc, loss_itm
103 |
--------------------------------------------------------------------------------
/models/swin_transformer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import numpy as np
9 | from scipy import interpolate
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.utils.checkpoint as checkpoint
14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19 | super().__init__()
20 | out_features = out_features or in_features
21 | hidden_features = hidden_features or in_features
22 | self.fc1 = nn.Linear(in_features, hidden_features)
23 | self.act = act_layer()
24 | self.fc2 = nn.Linear(hidden_features, out_features)
25 | self.drop = nn.Dropout(drop)
26 |
27 | def forward(self, x):
28 | x = self.fc1(x)
29 | x = self.act(x)
30 | x = self.drop(x)
31 | x = self.fc2(x)
32 | x = self.drop(x)
33 | return x
34 |
35 |
36 | def window_partition(x, window_size):
37 | """
38 | Args:
39 | x: (B, H, W, C)
40 | window_size (int): window size
41 |
42 | Returns:
43 | windows: (num_windows*B, window_size, window_size, C)
44 | """
45 | B, H, W, C = x.shape
46 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
47 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
48 | return windows
49 |
50 |
51 | def window_reverse(windows, window_size, H, W):
52 | """
53 | Args:
54 | windows: (num_windows*B, window_size, window_size, C)
55 | window_size (int): Window size
56 | H (int): Height of image
57 | W (int): Width of image
58 |
59 | Returns:
60 | x: (B, H, W, C)
61 | """
62 | B = int(windows.shape[0] / (H * W / window_size / window_size))
63 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
64 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
65 | return x
66 |
67 |
68 | class WindowAttention(nn.Module):
69 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
70 | It supports both of shifted and non-shifted window.
71 |
72 | Args:
73 | dim (int): Number of input channels.
74 | window_size (tuple[int]): The height and width of the window.
75 | num_heads (int): Number of attention heads.
76 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
77 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
78 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
79 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
80 | """
81 |
82 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
83 |
84 | super().__init__()
85 | self.dim = dim
86 | self.window_size = window_size # Wh, Ww
87 | self.num_heads = num_heads
88 | head_dim = dim // num_heads
89 | self.scale = qk_scale or head_dim ** -0.5
90 |
91 | # define a parameter table of relative position bias
92 | self.relative_position_bias_table = nn.Parameter(
93 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
94 |
95 | # get pair-wise relative position index for each token inside the window
96 | coords_h = torch.arange(self.window_size[0])
97 | coords_w = torch.arange(self.window_size[1])
98 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
99 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
100 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
101 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
102 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
103 | relative_coords[:, :, 1] += self.window_size[1] - 1
104 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
105 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
106 | self.register_buffer("relative_position_index", relative_position_index)
107 |
108 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
109 | self.attn_drop = nn.Dropout(attn_drop)
110 | self.proj = nn.Linear(dim, dim)
111 | self.proj_drop = nn.Dropout(proj_drop)
112 |
113 | trunc_normal_(self.relative_position_bias_table, std=.02)
114 | self.softmax = nn.Softmax(dim=-1)
115 |
116 | def forward(self, x, mask=None):
117 | """
118 | Args:
119 | x: input features with shape of (num_windows*B, N, C)
120 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
121 | """
122 | B_, N, C = x.shape
123 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
125 |
126 | q = q * self.scale
127 | attn = (q @ k.transpose(-2, -1))
128 |
129 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
130 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
131 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
132 | attn = attn + relative_position_bias.unsqueeze(0)
133 |
134 | if mask is not None:
135 | nW = mask.shape[0]
136 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
137 | attn = attn.view(-1, self.num_heads, N, N)
138 | attn = self.softmax(attn)
139 | else:
140 | attn = self.softmax(attn)
141 |
142 | attn = self.attn_drop(attn)
143 |
144 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
145 | x = self.proj(x)
146 | x = self.proj_drop(x)
147 | return x
148 |
149 | def extra_repr(self) -> str:
150 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
151 |
152 | def flops(self, N):
153 | # calculate flops for 1 window with token length of N
154 | flops = 0
155 | # qkv = self.qkv(x)
156 | flops += N * self.dim * 3 * self.dim
157 | # attn = (q @ k.transpose(-2, -1))
158 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
159 | # x = (attn @ v)
160 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
161 | # x = self.proj(x)
162 | flops += N * self.dim * self.dim
163 | return flops
164 |
165 |
166 | class SwinTransformerBlock(nn.Module):
167 | r""" Swin Transformer Block.
168 |
169 | Args:
170 | dim (int): Number of input channels.
171 | input_resolution (tuple[int]): Input resulotion.
172 | num_heads (int): Number of attention heads.
173 | window_size (int): Window size.
174 | shift_size (int): Shift size for SW-MSA.
175 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
176 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
177 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
178 | drop (float, optional): Dropout rate. Default: 0.0
179 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
180 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
181 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
182 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
183 | """
184 |
185 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
186 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
187 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
188 | super().__init__()
189 | self.dim = dim
190 | self.input_resolution = input_resolution
191 | self.num_heads = num_heads
192 | self.window_size = window_size
193 | self.shift_size = shift_size
194 | self.mlp_ratio = mlp_ratio
195 | if min(self.input_resolution) <= self.window_size:
196 | # if window size is larger than input resolution, we don't partition windows
197 | self.shift_size = 0
198 | self.window_size = min(self.input_resolution)
199 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
200 |
201 | self.norm1 = norm_layer(dim)
202 | self.attn = WindowAttention(
203 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
204 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
205 |
206 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
207 | self.norm2 = norm_layer(dim)
208 | mlp_hidden_dim = int(dim * mlp_ratio)
209 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
210 |
211 | if self.shift_size > 0:
212 | # calculate attention mask for SW-MSA
213 | H, W = self.input_resolution
214 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
215 | h_slices = (slice(0, -self.window_size),
216 | slice(-self.window_size, -self.shift_size),
217 | slice(-self.shift_size, None))
218 | w_slices = (slice(0, -self.window_size),
219 | slice(-self.window_size, -self.shift_size),
220 | slice(-self.shift_size, None))
221 | cnt = 0
222 | for h in h_slices:
223 | for w in w_slices:
224 | img_mask[:, h, w, :] = cnt
225 | cnt += 1
226 |
227 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
228 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
229 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
230 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
231 | else:
232 | attn_mask = None
233 |
234 | self.register_buffer("attn_mask", attn_mask)
235 |
236 | def forward(self, x):
237 | H, W = self.input_resolution
238 | B, L, C = x.shape
239 | assert L == H * W, "input feature has wrong size"
240 |
241 | shortcut = x
242 | x = self.norm1(x)
243 | x = x.view(B, H, W, C)
244 |
245 | # cyclic shift
246 | if self.shift_size > 0:
247 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
248 | else:
249 | shifted_x = x
250 |
251 | # partition windows
252 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
253 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
254 |
255 | # W-MSA/SW-MSA
256 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
257 |
258 | # merge windows
259 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
260 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
261 |
262 | # reverse cyclic shift
263 | if self.shift_size > 0:
264 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
265 | else:
266 | x = shifted_x
267 | x = x.view(B, H * W, C)
268 |
269 | # FFN
270 | x = shortcut + self.drop_path(x)
271 | x = x + self.drop_path(self.mlp(self.norm2(x)))
272 |
273 | return x
274 |
275 | def extra_repr(self) -> str:
276 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
277 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
278 |
279 | def flops(self):
280 | flops = 0
281 | H, W = self.input_resolution
282 | # norm1
283 | flops += self.dim * H * W
284 | # W-MSA/SW-MSA
285 | nW = H * W / self.window_size / self.window_size
286 | flops += nW * self.attn.flops(self.window_size * self.window_size)
287 | # mlp
288 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
289 | # norm2
290 | flops += self.dim * H * W
291 | return flops
292 |
293 |
294 | class PatchMerging(nn.Module):
295 | r""" Patch Merging Layer.
296 |
297 | Args:
298 | input_resolution (tuple[int]): Resolution of input feature.
299 | dim (int): Number of input channels.
300 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
301 | """
302 |
303 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
304 | super().__init__()
305 | self.input_resolution = input_resolution
306 | self.dim = dim
307 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
308 | self.norm = norm_layer(4 * dim)
309 |
310 | def forward(self, x):
311 | """
312 | x: B, H*W, C
313 | """
314 | H, W = self.input_resolution
315 | B, L, C = x.shape
316 | assert L == H * W, "input feature has wrong size"
317 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
318 |
319 | x = x.view(B, H, W, C)
320 |
321 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
322 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
323 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
324 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
325 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
326 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
327 |
328 | x = self.norm(x)
329 | x = self.reduction(x)
330 |
331 | return x
332 |
333 | def extra_repr(self) -> str:
334 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
335 |
336 | def flops(self):
337 | H, W = self.input_resolution
338 | flops = H * W * self.dim
339 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
340 | return flops
341 |
342 |
343 | class BasicLayer(nn.Module):
344 | """ A basic Swin Transformer layer for one stage.
345 |
346 | Args:
347 | dim (int): Number of input channels.
348 | input_resolution (tuple[int]): Input resolution.
349 | depth (int): Number of blocks.
350 | num_heads (int): Number of attention heads.
351 | window_size (int): Local window size.
352 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
353 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
354 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
355 | drop (float, optional): Dropout rate. Default: 0.0
356 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
357 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
358 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
359 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
360 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
361 | """
362 |
363 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
364 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
365 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
366 |
367 | super().__init__()
368 | self.dim = dim
369 | self.input_resolution = input_resolution
370 | self.depth = depth
371 | self.use_checkpoint = use_checkpoint
372 |
373 | # build blocks
374 | self.blocks = nn.ModuleList([
375 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
376 | num_heads=num_heads, window_size=window_size,
377 | shift_size=0 if (i % 2 == 0) else window_size // 2,
378 | mlp_ratio=mlp_ratio,
379 | qkv_bias=qkv_bias, qk_scale=qk_scale,
380 | drop=drop, attn_drop=attn_drop,
381 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
382 | norm_layer=norm_layer)
383 | for i in range(depth)])
384 |
385 | # patch merging layer
386 | if downsample is not None:
387 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
388 | else:
389 | self.downsample = None
390 |
391 | def forward(self, x):
392 | for blk in self.blocks:
393 | if self.use_checkpoint:
394 | x = checkpoint.checkpoint(blk, x)
395 | else:
396 | x = blk(x)
397 | if self.downsample is not None:
398 | x = self.downsample(x)
399 | return x
400 |
401 | def extra_repr(self) -> str:
402 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
403 |
404 | def flops(self):
405 | flops = 0
406 | for blk in self.blocks:
407 | flops += blk.flops()
408 | if self.downsample is not None:
409 | flops += self.downsample.flops()
410 | return flops
411 |
412 |
413 | class PatchEmbed(nn.Module):
414 | r""" Image to Patch Embedding
415 |
416 | Args:
417 | img_size (int): Image size. Default: 224.
418 | patch_size (int): Patch token size. Default: 4.
419 | in_chans (int): Number of input image channels. Default: 3.
420 | embed_dim (int): Number of linear projection output channels. Default: 96.
421 | norm_layer (nn.Module, optional): Normalization layer. Default: None
422 | """
423 |
424 | def __init__(self, img_size=224, h=224, w=224,
425 | patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
426 | super().__init__()
427 | img_size = (h, w)
428 | patch_size = to_2tuple(patch_size)
429 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
430 | self.img_size = img_size
431 | self.patch_size = patch_size
432 | self.patches_resolution = patches_resolution
433 | self.num_patches = patches_resolution[0] * patches_resolution[1]
434 |
435 | self.in_chans = in_chans
436 | self.embed_dim = embed_dim
437 |
438 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
439 | if norm_layer is not None:
440 | self.norm = norm_layer(embed_dim)
441 | else:
442 | self.norm = None
443 |
444 | def forward(self, x):
445 | B, C, H, W = x.shape
446 | # FIXME look at relaxing size constraints
447 | assert H == self.img_size[0] and W == self.img_size[1], \
448 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
449 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
450 | if self.norm is not None:
451 | x = self.norm(x)
452 | return x
453 |
454 | def flops(self):
455 | Ho, Wo = self.patches_resolution
456 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
457 | if self.norm is not None:
458 | flops += Ho * Wo * self.embed_dim
459 | return flops
460 |
461 |
462 | class SwinTransformer(nn.Module):
463 | r""" Swin Transformer
464 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
465 | https://arxiv.org/pdf/2103.14030
466 |
467 | Args:
468 | img_size (int | tuple(int)): Input image size. Default 224
469 | patch_size (int | tuple(int)): Patch size. Default: 4
470 | in_chans (int): Number of input image channels. Default: 3
471 | num_classes (int): Number of classes for classification head. Default: 1000
472 | embed_dim (int): Patch embedding dimension. Default: 96
473 | depths (tuple(int)): Depth of each Swin Transformer layer.
474 | num_heads (tuple(int)): Number of attention heads in different layers.
475 | window_size (int): Window size. Default: 7
476 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
477 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
478 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
479 | drop_rate (float): Dropout rate. Default: 0
480 | attn_drop_rate (float): Attention dropout rate. Default: 0
481 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
482 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
483 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
484 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
485 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
486 | """
487 |
488 | def __init__(self, img_size=224, h=224, w=224,
489 | patch_size=4, in_chans=3, num_classes=1000,
490 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
491 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
492 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
493 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
494 | use_checkpoint=False, **kwargs):
495 | super().__init__()
496 |
497 | self.num_classes = num_classes
498 | self.num_layers = len(depths)
499 | self.embed_dim = embed_dim
500 | self.ape = ape
501 | self.patch_norm = patch_norm
502 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
503 | self.mlp_ratio = mlp_ratio
504 |
505 | # split image into non-overlapping patches
506 | self.patch_embed = PatchEmbed(
507 | img_size=img_size, h=h, w=w,
508 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
509 | norm_layer=norm_layer if self.patch_norm else None)
510 | num_patches = self.patch_embed.num_patches
511 | patches_resolution = self.patch_embed.patches_resolution
512 | self.patches_resolution = patches_resolution
513 |
514 | # absolute position embedding
515 | if self.ape:
516 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
517 | trunc_normal_(self.absolute_pos_embed, std=.02)
518 |
519 | self.pos_drop = nn.Dropout(p=drop_rate)
520 |
521 | # stochastic depth
522 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
523 |
524 | # build layers
525 | self.layers = nn.ModuleList()
526 | for i_layer in range(self.num_layers):
527 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
528 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
529 | patches_resolution[1] // (2 ** i_layer)),
530 | depth=depths[i_layer],
531 | num_heads=num_heads[i_layer],
532 | window_size=window_size,
533 | mlp_ratio=self.mlp_ratio,
534 | qkv_bias=qkv_bias, qk_scale=qk_scale,
535 | drop=drop_rate, attn_drop=attn_drop_rate,
536 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
537 | norm_layer=norm_layer,
538 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
539 | use_checkpoint=use_checkpoint)
540 | self.layers.append(layer)
541 |
542 | self.norm = norm_layer(self.num_features)
543 | self.avgpool = nn.AdaptiveAvgPool1d(1)
544 |
545 | # shortcut block 1-->4
546 | # self.my_proj = nn.Conv2d(256, 1024, kernel_size=patch_size, stride=patch_size)
547 |
548 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
549 |
550 | self.apply(self._init_weights)
551 |
552 | def _init_weights(self, m):
553 | if isinstance(m, nn.Linear):
554 | trunc_normal_(m.weight, std=.02)
555 | if isinstance(m, nn.Linear) and m.bias is not None:
556 | nn.init.constant_(m.bias, 0)
557 | elif isinstance(m, nn.LayerNorm):
558 | nn.init.constant_(m.bias, 0)
559 | nn.init.constant_(m.weight, 1.0)
560 |
561 | @torch.jit.ignore
562 | def no_weight_decay(self):
563 | return {'absolute_pos_embed'}
564 |
565 | @torch.jit.ignore
566 | def no_weight_decay_keywords(self):
567 | return {'relative_position_bias_table'}
568 |
569 | def forward(self, x):
570 | x = self.patch_embed(x)
571 | if self.ape:
572 | x = x + self.absolute_pos_embed
573 | x = self.pos_drop(x)
574 |
575 | for i, layer in enumerate(self.layers):
576 | x = layer(x)
577 |
578 | x = self.norm(x) # B L C
579 | x_cls = self.avgpool(x.transpose(1, 2)) # B C 1
580 | x = torch.cat([x_cls.transpose(1, 2), x], dim=1)
581 | return x
582 |
583 | def flops(self):
584 | flops = 0
585 | flops += self.patch_embed.flops()
586 | for i, layer in enumerate(self.layers):
587 | flops += layer.flops()
588 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
589 | flops += self.num_features * self.num_classes
590 | return flops
591 |
592 |
593 | def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''):
594 | # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348
595 |
596 | # rel_pos_bias: relative_position_bias_table
597 | src_num_pos, num_attn_heads = rel_pos_bias.size()
598 |
599 | num_extra_tokens = 0
600 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
601 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
602 | if src_size != dst_size:
603 | print("Position interpolate %s from %dx%d to %dx%d" % (param_name, src_size, src_size, dst_size, dst_size))
604 |
605 | # extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
606 | # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
607 |
608 | def geometric_progression(a, r, n):
609 | return a * (1.0 - r ** n) / (1.0 - r)
610 |
611 | left, right = 1.01, 1.5
612 | while right - left > 1e-6:
613 | q = (left + right) / 2.0
614 | gp = geometric_progression(1, q, src_size // 2)
615 | if gp > dst_size // 2:
616 | right = q
617 | else:
618 | left = q
619 |
620 | # if q > 1.090307:
621 | # q = 1.090307
622 |
623 | dis = []
624 | cur = 1
625 | for i in range(src_size // 2):
626 | dis.append(cur)
627 | cur += q ** (i + 1)
628 |
629 | r_ids = [-_ for _ in reversed(dis)]
630 |
631 | x = r_ids + [0] + dis
632 | y = r_ids + [0] + dis
633 |
634 | t = dst_size // 2.0
635 | dx = np.arange(-t, t + 0.1, 1.0)
636 | dy = np.arange(-t, t + 0.1, 1.0)
637 |
638 | # print("Original positions = %s" % str(x))
639 | # print("Target positions = %s" % str(dx))
640 |
641 | all_rel_pos_bias = []
642 |
643 | for i in range(num_attn_heads):
644 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
645 | f = interpolate.interp2d(x, y, z, kind='cubic')
646 | all_rel_pos_bias.append(
647 | torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
648 |
649 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
650 |
651 | return rel_pos_bias
652 |
--------------------------------------------------------------------------------
/models/tokenization_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for Bert."""
16 |
17 |
18 | import collections
19 | import os
20 | import unicodedata
21 | from typing import List, Optional, Tuple
22 |
23 | from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
24 | from transformers.utils import logging
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30 |
31 | PRETRAINED_VOCAB_FILES_MAP = {
32 | "vocab_file": {
33 | "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
34 | "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
35 | "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
36 | "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
37 | "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
38 | "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
39 | "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
40 | "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
41 | "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
42 | "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
43 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
44 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
45 | "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
46 | "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
47 | "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
50 | "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
51 | }
52 | }
53 |
54 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55 | "bert-base-uncased": 512,
56 | "bert-large-uncased": 512,
57 | "bert-base-cased": 512,
58 | "bert-large-cased": 512,
59 | "bert-base-multilingual-uncased": 512,
60 | "bert-base-multilingual-cased": 512,
61 | "bert-base-chinese": 512,
62 | "bert-base-german-cased": 512,
63 | "bert-large-uncased-whole-word-masking": 512,
64 | "bert-large-cased-whole-word-masking": 512,
65 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66 | "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67 | "bert-base-cased-finetuned-mrpc": 512,
68 | "bert-base-german-dbmdz-cased": 512,
69 | "bert-base-german-dbmdz-uncased": 512,
70 | "TurkuNLP/bert-base-finnish-cased-v1": 512,
71 | "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72 | "wietsedv/bert-base-dutch-cased": 512,
73 | }
74 |
75 | PRETRAINED_INIT_CONFIGURATION = {
76 | "bert-base-uncased": {"do_lower_case": True},
77 | "bert-large-uncased": {"do_lower_case": True},
78 | "bert-base-cased": {"do_lower_case": False},
79 | "bert-large-cased": {"do_lower_case": False},
80 | "bert-base-multilingual-uncased": {"do_lower_case": True},
81 | "bert-base-multilingual-cased": {"do_lower_case": False},
82 | "bert-base-chinese": {"do_lower_case": False},
83 | "bert-base-german-cased": {"do_lower_case": False},
84 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85 | "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89 | "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94 | }
95 |
96 |
97 | def load_vocab(vocab_file):
98 | """Loads a vocabulary file into a dictionary."""
99 | vocab = collections.OrderedDict()
100 | with open(vocab_file, "r", encoding="utf-8") as reader:
101 | tokens = reader.readlines()
102 | for index, token in enumerate(tokens):
103 | token = token.rstrip("\n")
104 | vocab[token] = index
105 | return vocab
106 |
107 |
108 | def whitespace_tokenize(text):
109 | """Runs basic whitespace cleaning and splitting on a piece of text."""
110 | text = text.strip()
111 | if not text:
112 | return []
113 | tokens = text.split()
114 | return tokens
115 |
116 |
117 | class BertTokenizer(PreTrainedTokenizer):
118 | r"""
119 | Construct a BERT tokenizer. Based on WordPiece.
120 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
121 | Users should refer to this superclass for more information regarding those methods.
122 | Args:
123 | vocab_file (:obj:`str`):
124 | File containing the vocabulary.
125 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
126 | Whether or not to lowercase the input when tokenizing.
127 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
128 | Whether or not to do basic tokenization before WordPiece.
129 | never_split (:obj:`Iterable`, `optional`):
130 | Collection of tokens which will never be split during tokenization. Only has an effect when
131 | :obj:`do_basic_tokenize=True`
132 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
133 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
134 | token instead.
135 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
136 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
137 | sequence classification or for a text and a question for question answering. It is also used as the last
138 | token of a sequence built with special tokens.
139 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
140 | The token used for padding, for example when batching sequences of different lengths.
141 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
142 | The classifier token which is used when doing sequence classification (classification of the whole sequence
143 | instead of per-token classification). It is the first token of the sequence when built with special tokens.
144 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
145 | The token used for masking values. This is the token used when training this model with masked language
146 | modeling. This is the token which the model will try to predict.
147 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
148 | Whether or not to tokenize Chinese characters.
149 | This should likely be deactivated for Japanese (see this `issue
150 | `__).
151 | strip_accents: (:obj:`bool`, `optional`):
152 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the
153 | value for :obj:`lowercase` (as in the original BERT).
154 | """
155 |
156 | vocab_files_names = VOCAB_FILES_NAMES
157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160 |
161 | def __init__(
162 | self,
163 | vocab_file,
164 | do_lower_case=True,
165 | do_basic_tokenize=True,
166 | never_split=None,
167 | unk_token="[UNK]",
168 | sep_token="[SEP]",
169 | pad_token="[PAD]",
170 | cls_token="[CLS]",
171 | mask_token="[MASK]",
172 | tokenize_chinese_chars=True,
173 | strip_accents=None,
174 | **kwargs
175 | ):
176 | super().__init__(
177 | do_lower_case=do_lower_case,
178 | do_basic_tokenize=do_basic_tokenize,
179 | never_split=never_split,
180 | unk_token=unk_token,
181 | sep_token=sep_token,
182 | pad_token=pad_token,
183 | cls_token=cls_token,
184 | mask_token=mask_token,
185 | tokenize_chinese_chars=tokenize_chinese_chars,
186 | strip_accents=strip_accents,
187 | **kwargs,
188 | )
189 |
190 | if not os.path.isfile(vocab_file):
191 | raise ValueError(
192 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
193 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
194 | )
195 | self.vocab = load_vocab(vocab_file)
196 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
197 | self.do_basic_tokenize = do_basic_tokenize
198 | if do_basic_tokenize:
199 | self.basic_tokenizer = BasicTokenizer(
200 | do_lower_case=do_lower_case,
201 | never_split=never_split,
202 | tokenize_chinese_chars=tokenize_chinese_chars,
203 | strip_accents=strip_accents,
204 | )
205 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
206 |
207 | @property
208 | def do_lower_case(self):
209 | return self.basic_tokenizer.do_lower_case
210 |
211 | @property
212 | def vocab_size(self):
213 | return len(self.vocab)
214 |
215 | def get_vocab(self):
216 | return dict(self.vocab, **self.added_tokens_encoder)
217 |
218 | def _tokenize(self, text):
219 | split_tokens = []
220 | if self.do_basic_tokenize:
221 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
222 |
223 | # If the token is part of the never_split set
224 | if token in self.basic_tokenizer.never_split:
225 | split_tokens.append(token)
226 | else:
227 | split_tokens += self.wordpiece_tokenizer.tokenize(token)
228 | else:
229 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
230 | return split_tokens
231 |
232 | def _convert_token_to_id(self, token):
233 | """ Converts a token (str) in an id using the vocab. """
234 | return self.vocab.get(token, self.vocab.get(self.unk_token))
235 |
236 | def _convert_id_to_token(self, index):
237 | """Converts an index (integer) in a token (str) using the vocab."""
238 | return self.ids_to_tokens.get(index, self.unk_token)
239 |
240 | def convert_tokens_to_string(self, tokens):
241 | """ Converts a sequence of tokens (string) in a single string. """
242 | out_string = " ".join(tokens).replace(" ##", "").strip()
243 | return out_string
244 |
245 | def build_inputs_with_special_tokens(
246 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
247 | ) -> List[int]:
248 | """
249 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
250 | adding special tokens. A BERT sequence has the following format:
251 | - single sequence: ``[CLS] X ``
252 | - pair of sequences: ``[CLS] A [SEP] B [SEP]``
253 | Args:
254 | token_ids_0 (:obj:`List[int]`):
255 | List of IDs to which the special tokens will be added.
256 | token_ids_1 (:obj:`List[int]`, `optional`):
257 | Optional second list of IDs for sequence pairs.
258 | Returns:
259 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
260 | """
261 | if token_ids_1 is None:
262 | return [self.cls_token_id] + token_ids_0
263 | cls = [self.cls_token_id]
264 | sep = [self.sep_token_id]
265 | return cls + token_ids_0 + sep + token_ids_1 + sep
266 |
267 | def get_special_tokens_mask(
268 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
269 | ) -> List[int]:
270 | """
271 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
272 | special tokens using the tokenizer ``prepare_for_model`` method.
273 | Args:
274 | token_ids_0 (:obj:`List[int]`):
275 | List of IDs.
276 | token_ids_1 (:obj:`List[int]`, `optional`):
277 | Optional second list of IDs for sequence pairs.
278 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
279 | Whether or not the token list is already formatted with special tokens for the model.
280 | Returns:
281 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
282 | """
283 |
284 | if already_has_special_tokens:
285 | if token_ids_1 is not None:
286 | raise ValueError(
287 | "You should not supply a second sequence if the provided sequence of "
288 | "ids is already formatted with special tokens for the model."
289 | )
290 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
291 |
292 | if token_ids_1 is not None:
293 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
294 | return [1] + ([0] * len(token_ids_0)) + [1]
295 |
296 | def create_token_type_ids_from_sequences(
297 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
298 | ) -> List[int]:
299 | """
300 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
301 | pair mask has the following format:
302 | ::
303 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
304 | | first sequence | second sequence |
305 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
306 | Args:
307 | token_ids_0 (:obj:`List[int]`):
308 | List of IDs.
309 | token_ids_1 (:obj:`List[int]`, `optional`):
310 | Optional second list of IDs for sequence pairs.
311 | Returns:
312 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
313 | sequence(s).
314 | """
315 | sep = [self.sep_token_id]
316 | cls = [self.cls_token_id]
317 | if token_ids_1 is None:
318 | return len(cls + token_ids_0 + sep) * [0]
319 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
320 |
321 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
322 | index = 0
323 | if os.path.isdir(save_directory):
324 | vocab_file = os.path.join(
325 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
326 | )
327 | else:
328 | vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
329 | with open(vocab_file, "w", encoding="utf-8") as writer:
330 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
331 | if index != token_index:
332 | logger.warning(
333 | "Saving vocabulary to {}: vocabulary indices are not consecutive."
334 | " Please check that the vocabulary is not corrupted!".format(vocab_file)
335 | )
336 | index = token_index
337 | writer.write(token + "\n")
338 | index += 1
339 | return (vocab_file,)
340 |
341 |
342 | class BasicTokenizer(object):
343 | """
344 | Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
345 | Args:
346 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
347 | Whether or not to lowercase the input when tokenizing.
348 | never_split (:obj:`Iterable`, `optional`):
349 | Collection of tokens which will never be split during tokenization. Only has an effect when
350 | :obj:`do_basic_tokenize=True`
351 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
352 | Whether or not to tokenize Chinese characters.
353 | This should likely be deactivated for Japanese (see this `issue
354 | `__).
355 | strip_accents: (:obj:`bool`, `optional`):
356 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the
357 | value for :obj:`lowercase` (as in the original BERT).
358 | """
359 |
360 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
361 | if never_split is None:
362 | never_split = []
363 | self.do_lower_case = do_lower_case
364 | self.never_split = set(never_split)
365 | self.tokenize_chinese_chars = tokenize_chinese_chars
366 | self.strip_accents = strip_accents
367 |
368 | def tokenize(self, text, never_split=None):
369 | """
370 | Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
371 | WordPieceTokenizer.
372 | Args:
373 | **never_split**: (`optional`) list of str
374 | Kept for backward compatibility purposes. Now implemented directly at the base class level (see
375 | :func:`PreTrainedTokenizer.tokenize`) List of token not to split.
376 | """
377 | # union() returns a new set by concatenating the two sets.
378 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
379 | text = self._clean_text(text)
380 |
381 | # This was added on November 1st, 2018 for the multilingual and Chinese
382 | # models. This is also applied to the English models now, but it doesn't
383 | # matter since the English models were not trained on any Chinese data
384 | # and generally don't have any Chinese data in them (there are Chinese
385 | # characters in the vocabulary because Wikipedia does have some Chinese
386 | # words in the English Wikipedia.).
387 | if self.tokenize_chinese_chars:
388 | text = self._tokenize_chinese_chars(text)
389 | orig_tokens = whitespace_tokenize(text)
390 | split_tokens = []
391 | for token in orig_tokens:
392 | if token not in never_split:
393 | if self.do_lower_case:
394 | token = token.lower()
395 | if self.strip_accents is not False:
396 | token = self._run_strip_accents(token)
397 | elif self.strip_accents:
398 | token = self._run_strip_accents(token)
399 | split_tokens.extend(self._run_split_on_punc(token, never_split))
400 |
401 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
402 | return output_tokens
403 |
404 | def _run_strip_accents(self, text):
405 | """Strips accents from a piece of text."""
406 | text = unicodedata.normalize("NFD", text)
407 | output = []
408 | for char in text:
409 | cat = unicodedata.category(char)
410 | if cat == "Mn":
411 | continue
412 | output.append(char)
413 | return "".join(output)
414 |
415 | def _run_split_on_punc(self, text, never_split=None):
416 | """Splits punctuation on a piece of text."""
417 | if never_split is not None and text in never_split:
418 | return [text]
419 | chars = list(text)
420 | i = 0
421 | start_new_word = True
422 | output = []
423 | while i < len(chars):
424 | char = chars[i]
425 | if _is_punctuation(char):
426 | output.append([char])
427 | start_new_word = True
428 | else:
429 | if start_new_word:
430 | output.append([])
431 | start_new_word = False
432 | output[-1].append(char)
433 | i += 1
434 |
435 | return ["".join(x) for x in output]
436 |
437 | def _tokenize_chinese_chars(self, text):
438 | """Adds whitespace around any CJK character."""
439 | output = []
440 | for char in text:
441 | cp = ord(char)
442 | if self._is_chinese_char(cp):
443 | output.append(" ")
444 | output.append(char)
445 | output.append(" ")
446 | else:
447 | output.append(char)
448 | return "".join(output)
449 |
450 | def _is_chinese_char(self, cp):
451 | """Checks whether CP is the codepoint of a CJK character."""
452 | # This defines a "chinese character" as anything in the CJK Unicode block:
453 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
454 | #
455 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
456 | # despite its name. The modern Korean Hangul alphabet is a different block,
457 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
458 | # space-separated words, so they are not treated specially and handled
459 | # like the all of the other languages.
460 | if (
461 | (cp >= 0x4E00 and cp <= 0x9FFF)
462 | or (cp >= 0x3400 and cp <= 0x4DBF) #
463 | or (cp >= 0x20000 and cp <= 0x2A6DF) #
464 | or (cp >= 0x2A700 and cp <= 0x2B73F) #
465 | or (cp >= 0x2B740 and cp <= 0x2B81F) #
466 | or (cp >= 0x2B820 and cp <= 0x2CEAF) #
467 | or (cp >= 0xF900 and cp <= 0xFAFF)
468 | or (cp >= 0x2F800 and cp <= 0x2FA1F) #
469 | ): #
470 | return True
471 |
472 | return False
473 |
474 | def _clean_text(self, text):
475 | """Performs invalid character removal and whitespace cleanup on text."""
476 | output = []
477 | for char in text:
478 | cp = ord(char)
479 | if cp == 0 or cp == 0xFFFD or _is_control(char):
480 | continue
481 | if _is_whitespace(char):
482 | output.append(" ")
483 | else:
484 | output.append(char)
485 | return "".join(output)
486 |
487 |
488 | class WordpieceTokenizer(object):
489 | """Runs WordPiece tokenization."""
490 |
491 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
492 | self.vocab = vocab
493 | self.unk_token = unk_token
494 | self.max_input_chars_per_word = max_input_chars_per_word
495 |
496 | def tokenize(self, text):
497 | """
498 | Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
499 | tokenization using the given vocabulary.
500 | For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
501 | Args:
502 | text: A single token or whitespace separated tokens. This should have
503 | already been passed through `BasicTokenizer`.
504 | Returns:
505 | A list of wordpiece tokens.
506 | """
507 |
508 | output_tokens = []
509 | for token in whitespace_tokenize(text):
510 | chars = list(token)
511 | if len(chars) > self.max_input_chars_per_word:
512 | output_tokens.append(self.unk_token)
513 | continue
514 |
515 | is_bad = False
516 | start = 0
517 | sub_tokens = []
518 | while start < len(chars):
519 | end = len(chars)
520 | cur_substr = None
521 | while start < end:
522 | substr = "".join(chars[start:end])
523 | if start > 0:
524 | substr = "##" + substr
525 | if substr in self.vocab:
526 | cur_substr = substr
527 | break
528 | end -= 1
529 | if cur_substr is None:
530 | is_bad = True
531 | break
532 | sub_tokens.append(cur_substr)
533 | start = end
534 |
535 | if is_bad:
536 | output_tokens.append(self.unk_token)
537 | else:
538 | output_tokens.extend(sub_tokens)
539 | return output_tokens
--------------------------------------------------------------------------------
/optim.py:
--------------------------------------------------------------------------------
1 | from torch.optim import AdamW
2 |
3 |
4 | def create_optimizer(args, model):
5 | lr = args.lr
6 | wd = args.weight_decay
7 | lr_mult = getattr(args, 'lr_mult', 1)
8 | print("### lr: ", lr, " ### lr_mult: ", lr_mult, flush=True)
9 |
10 | optimizer_grouped_parameters = [
11 | {"params": [], "weight_decay": wd, "lr": lr},
12 | {"params": [], "weight_decay": 0.0, "lr": lr},
13 | {"params": [], "weight_decay": wd, "lr": lr * lr_mult},
14 | {"params": [], "weight_decay": 0.0, "lr": lr * lr_mult}
15 | ]
16 |
17 | no_decay = {"bias",
18 | "LayerNorm.bias",
19 | "LayerNorm.weight",
20 | "norm.bias",
21 | "norm.weight",
22 | "norm1.bias",
23 | "norm1.weight",
24 | "norm2.bias",
25 | "norm2.weight"}
26 |
27 | if hasattr(model, 'init_params'):
28 | large_lr = model.init_params
29 | print("### model has 'init_params', ", len(large_lr), flush=True)
30 | else:
31 | large_lr = {}
32 |
33 | for n, p in model.named_parameters():
34 | if not p.requires_grad:
35 | continue # frozen weights
36 |
37 | if any(nd in n for nd in no_decay):
38 | if n in large_lr:
39 | optimizer_grouped_parameters[3]['params'].append(p)
40 | else:
41 | optimizer_grouped_parameters[1]['params'].append(p)
42 | else: # decay
43 | if n in large_lr:
44 | optimizer_grouped_parameters[2]['params'].append(p)
45 | else:
46 | optimizer_grouped_parameters[0]['params'].append(p)
47 |
48 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98))
49 |
50 | return optimizer
51 |
--------------------------------------------------------------------------------
/reTools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import copy
4 | import numpy as np
5 | import time
6 | import datetime
7 | import json
8 | from pathlib import Path
9 | from matplotlib import pyplot as plt
10 | import seaborn as sns
11 | from PIL import Image, ImageFont, ImageDraw
12 | from sklearn import metrics
13 | from easydict import EasyDict
14 | from prettytable import PrettyTable
15 |
16 | import torch
17 | import torch.distributed as dist
18 | import torch.nn.functional as F
19 |
20 | import utils
21 |
22 |
23 | @torch.no_grad()
24 | def evaluation_attr(model, data_loader, tokenizer, device, config, args):
25 | model.eval()
26 | metric_logger = utils.MetricLogger(delimiter=" ")
27 | header = 'Evaluation:'
28 | print('Computing features for evaluation attr...')
29 | start_time = time.time()
30 |
31 | text = ['the person is a man', 'the person is a woman',
32 | 'the person is no more than 60 years old', 'the person is older than 60 years old',
33 | 'the person is a young or old one', 'the person is of mid age, between 18 and 60 years old',
34 | 'the person is older than 18', 'the person is a baby or a teenager, younger than 18',
35 |
36 | 'the picture is not the front of the person', 'the picture shows the front of the person',
37 | 'the picture is not the side of the person', 'the picture shows the side of the person',
38 | 'the picture is not the back of the person', 'the picture shows the back of the person',
39 | 'a person without a hat', 'a person with a hat',
40 |
41 | 'a person without a glasses', 'a person with a glasses',
42 | 'a person without a handbag', 'a person with a handbag',
43 | 'a person without a shoulder bag', 'a person with a shoulder bag',
44 | 'a person without a backpack', 'a person with a backpack',
45 |
46 | 'the person does not hold an object in front', 'the person hold an object in front',
47 | 'the person does not wear short sleeved upper clothes', 'the person wears short sleeved upper clothes',
48 | 'the person does not wear long sleeved upper clothes', 'the person wears long sleeved upper clothes',
49 | 'there is no stride on the upper clothes of the person',
50 | 'there is stride on the upper clothes of the person',
51 |
52 | 'there is no logo on the upper clothes of the person',
53 | 'there is logo on the upper clothes of the person',
54 | 'there is no plaid on the upper clothes of the person',
55 | 'there is plaid on the upper clothes of the person',
56 | 'there is no splice on the upper clothes of the person',
57 | 'there is splice on the upper clothes of the person',
58 | 'there is no stripe on the upper clothes of the person',
59 | 'there is stripe on the upper clothes of the person',
60 |
61 | 'there is no pattern on the lower part of the person',
62 | 'there is pattern on the lower part of the person',
63 | 'the person does not wear long coat', 'the person wears long coat',
64 | 'the person does not wear trousers', 'the person wears trousers',
65 | 'the person does not wear shorts', 'the person wears shorts',
66 |
67 | 'the person does not wear a skirt or a dress', 'the person wears a skirt or a dress',
68 | 'the person does not wear boots', 'the person wears boots',
69 | ]
70 |
71 | text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'],
72 | return_tensors="pt").to(device)
73 | text_embeds = model.get_text_embeds(text_input.input_ids, text_input.attention_mask)
74 | text_atts = text_input.attention_mask
75 |
76 | image_embeds = []
77 | for image, img_id in data_loader:
78 | image = image.to(device)
79 | image_embed, _ = model.get_vision_embeds(image)
80 | image_embeds.append(image_embed)
81 | image_embeds = torch.cat(image_embeds, dim=0)
82 |
83 | score_matrix_i2t = torch.full((len(data_loader.dataset.image), len(text)), -1000.0).to(device)
84 | num_tasks = utils.get_world_size()
85 | rank = utils.get_rank()
86 | step = image_embeds.size(0) // num_tasks + 1
87 | start = rank * step
88 | end = min(image_embeds.size(0), start + step)
89 |
90 | for i, image_embed in enumerate(metric_logger.log_every(image_embeds[start:end], 50, header)):
91 | encoder_output = image_embed.repeat(len(text), 1, 1)
92 | encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
93 | output = model.get_cross_embeds(encoder_output, encoder_att, text_embeds=text_embeds,
94 | text_atts=text_atts)[:, 0, :]
95 | score = model.itm_head(output)[:, 1]
96 | score_matrix_i2t[start + i] = score
97 | if args.distributed:
98 | dist.barrier()
99 | torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
100 | total_time = time.time() - start_time
101 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
102 | print('Evaluation time {}'.format(total_time_str))
103 | return score_matrix_i2t.cpu().numpy()
104 |
105 |
106 | @torch.no_grad()
107 | def evaluation_attr_only_img_classifier(model, data_loader, tokenizer, device, config, args):
108 | model.eval()
109 | metric_logger = utils.MetricLogger(delimiter=" ")
110 | header = 'Evaluation:'
111 | print('Computing features for evaluation attr...')
112 | start_time = time.time()
113 |
114 | image_embeds = []
115 | outputs = []
116 | for image, img_id in data_loader:
117 | image = image.to(device)
118 | image_embed = model.vision_encoder(image)
119 | output = model.img_cls(image_embed[:, 0, :])
120 | output = torch.sigmoid(output)
121 | outputs.append(output)
122 | outputs = torch.cat(outputs, dim=0)
123 | orig_outputs = outputs.data.cpu().numpy()
124 | # transform raw outputs to attributes (binary codes)
125 | outputs = copy.deepcopy(orig_outputs)
126 | outputs[outputs < 0.5] = 0
127 | outputs[outputs >= 0.5] = 1
128 | return outputs
129 |
130 |
131 | @torch.no_grad()
132 | def accs(pred, y):
133 | print('Testing ... metrics')
134 | num_persons = pred.shape[0]
135 | print('num_persons', num_persons)
136 | ins_acc = 0
137 | ins_prec = 0
138 | ins_rec = 0
139 | mA_history = {
140 | 'correct_pos': 0,
141 | 'real_pos': 0,
142 | 'correct_neg': 0,
143 | 'real_neg': 0
144 | }
145 |
146 | # compute label-based metric
147 | outputs = pred
148 | attrs = y
149 | overlaps = outputs * attrs
150 | mA_history['correct_pos'] += overlaps.sum(0)
151 | mA_history['real_pos'] += attrs.sum(0)
152 | inv_overlaps = (1 - outputs) * (1 - attrs)
153 | mA_history['correct_neg'] += inv_overlaps.sum(0)
154 | mA_history['real_neg'] += (1 - attrs).sum(0)
155 |
156 | outputs = outputs.astype(bool)
157 | attrs = attrs.astype(bool)
158 |
159 | # compute instabce-based accuracy
160 | intersect = (outputs & attrs).astype(float)
161 | union = (outputs | attrs).astype(float)
162 | ins_acc += (intersect.sum(1) / union.sum(1)).sum()
163 | ins_prec += (intersect.sum(1) / outputs.astype(float).sum(1)).sum()
164 | ins_rec += (intersect.sum(1) / attrs.astype(float).sum(1)).sum()
165 |
166 | ins_acc /= num_persons
167 | ins_prec /= num_persons
168 | ins_rec /= num_persons
169 | ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec)
170 |
171 | term1 = mA_history['correct_pos'] / mA_history['real_pos']
172 | term2 = mA_history['correct_neg'] / mA_history['real_neg']
173 | label_mA_verbose = (term1 + term2) * 0.5
174 | label_mA = label_mA_verbose.mean()
175 |
176 | print('* Results *')
177 | print(' # test persons: {}'.format(num_persons))
178 | print(' (label-based) mean accuracy: {:.2%}'.format(label_mA))
179 | print(' (instance-based) accuracy: {:.2%}'.format(ins_acc))
180 | print(' (instance-based) precition: {:.2%}'.format(ins_prec))
181 | print(' (instance-based) recall: {:.2%}'.format(ins_rec))
182 | print(' (instance-based) f1-score: {:.2%}'.format(ins_f1))
183 | print(' mA for each attribute: {}'.format(label_mA_verbose))
184 | return label_mA, ins_acc, ins_prec, ins_rec, ins_f1
185 |
186 |
187 | @torch.no_grad()
188 | def itm_eval_attr(scores_i2t, dataset):
189 | label = dataset.label
190 | pred = []
191 | for i in range(label.shape[1]):
192 | a = np.argmax(scores_i2t[:, 2 * i: 2 * i + 2], axis=1)
193 | pred.append(a)
194 |
195 | label_mA, ins_acc, ins_prec, ins_rec, ins_f1 = accs(np.array(pred).T, label)
196 | print('############################################################\n')
197 | eval_result = {'label_mA': round(label_mA, 4),
198 | 'ins_acc': round(ins_acc, 4),
199 | 'ins_prec': round(ins_prec, 4),
200 | 'ins_rec': round(ins_rec, 4),
201 | 'ins_f1': round(ins_f1, 4),
202 | }
203 | return eval_result
204 |
205 |
206 | @torch.no_grad()
207 | def itm_eval_attr_only_img_classifier(scores_i2t, dataset):
208 | label = dataset.label
209 | pred = scores_i2t
210 | label_mA, ins_acc, ins_prec, ins_rec, ins_f1 = accs(pred, label)
211 | print('############################################################\n')
212 | eval_result = {'label_mA': round(label_mA, 4),
213 | 'ins_acc': round(ins_acc, 4),
214 | 'ins_prec': round(ins_prec, 4),
215 | 'ins_rec': round(ins_rec, 4),
216 | 'ins_f1': round(ins_f1, 4),
217 | }
218 | return eval_result
219 |
220 |
221 | @torch.no_grad()
222 | def evaluation(model, data_loader, tokenizer, device, config, args):
223 | model.eval()
224 |
225 | metric_logger = utils.MetricLogger(delimiter=" ")
226 | header = 'Evaluation:'
227 |
228 | print('Computing features for evaluation...')
229 | start_time = time.time()
230 |
231 | texts = data_loader.dataset.text
232 | num_text = len(texts)
233 | text_bs = config['batch_size_test_text'] # 256
234 | text_embeds = []
235 | text_atts = []
236 | text_feats = []
237 | for i in range(0, num_text, text_bs):
238 | text = texts[i: min(num_text, i + text_bs)]
239 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'],
240 | return_tensors="pt").to(device)
241 | text_embed = model.get_text_embeds(text_input.input_ids, text_input.attention_mask)
242 | text_feat = model.text_proj(text_embed[:, 0, :])
243 | text_feat = F.normalize(text_feat, dim=-1)
244 |
245 | text_embeds.append(text_embed)
246 | text_atts.append(text_input.attention_mask)
247 | text_feats.append(text_feat)
248 | text_embeds = torch.cat(text_embeds, dim=0)
249 | text_atts = torch.cat(text_atts, dim=0)
250 | text_feats = torch.cat(text_feats, dim=0)
251 |
252 | image_embeds = []
253 | image_feats = []
254 | for image, img_id in data_loader:
255 | image = image.to(device)
256 | image_embed, _ = model.get_vision_embeds(image)
257 | image_feat = model.vision_proj(image_embed[:, 0, :])
258 | image_feat = F.normalize(image_feat, dim=-1)
259 | image_embeds.append(image_embed)
260 | image_feats.append(image_feat)
261 | image_embeds = torch.cat(image_embeds, dim=0)
262 | image_feats = torch.cat(image_feats, dim=0)
263 | sims_matrix = image_feats @ text_feats.t()
264 | sims_matrix = sims_matrix.t()
265 | score_matrix_t2i = torch.full((len(texts), len(data_loader.dataset.image)), 1000.0).to(device)
266 | score_sim_t2i = sims_matrix
267 |
268 | num_tasks = utils.get_world_size()
269 | rank = utils.get_rank()
270 | step = sims_matrix.size(0) // num_tasks + 1
271 | start = rank * step
272 | end = min(sims_matrix.size(0), start + step)
273 |
274 | for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
275 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
276 | encoder_output = image_embeds[topk_idx]
277 | encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
278 |
279 | output = model.get_cross_embeds(encoder_output, encoder_att,
280 | text_embeds=text_embeds[start + i].repeat(config['k_test'], 1, 1),
281 | text_atts=text_atts[start + i].repeat(config['k_test'], 1))[:, 0, :]
282 | score = model.itm_head(output)[:, 1]
283 | score_matrix_t2i[start + i, topk_idx] = score
284 | score_sim_t2i[start + i, topk_idx] = topk_sim
285 |
286 | min_values, _ = torch.min(score_matrix_t2i, dim=1)
287 | replacement_tensor = min_values.view(-1, 1).expand(-1, score_matrix_t2i.size(1))
288 | score_matrix_t2i[score_matrix_t2i == 1000.0] = replacement_tensor[score_matrix_t2i == 1000.0]
289 | score_sim_t2i = (score_sim_t2i - score_sim_t2i.min()) / (score_sim_t2i.max() - score_sim_t2i.min())
290 | score_matrix_t2i = (score_matrix_t2i - score_matrix_t2i.min()) / (score_matrix_t2i.max() - score_matrix_t2i.min())
291 | score_matrix_t2i = score_matrix_t2i + 0.002*score_sim_t2i
292 |
293 | if args.distributed:
294 | dist.barrier()
295 | torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
296 |
297 | total_time = time.time() - start_time
298 | per_time = total_time / num_text
299 | print('total_time', total_time)
300 | print('per_time', per_time)
301 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
302 | print('Evaluation time {}'.format(total_time_str))
303 |
304 | return score_matrix_t2i.cpu().numpy()
305 |
306 |
307 | def mAP(scores_t2i, g_pids, q_pids, table=None):
308 | similarity = torch.tensor(scores_t2i)
309 | indices = torch.argsort(similarity, dim=1, descending=True)
310 | g_pids = torch.tensor(g_pids)
311 | q_pids = torch.tensor(q_pids)
312 | pred_labels = g_pids[indices.cpu()] # q * k
313 | matches = pred_labels.eq(q_pids.view(-1, 1)) # q * k
314 |
315 | all_cmc = matches[:, :10].cumsum(1) # cumulative sum
316 | all_cmc[all_cmc > 1] = 1
317 | all_cmc = all_cmc.float().mean(0) * 100
318 | # all_cmc = all_cmc[topk - 1]
319 |
320 | num_rel = matches.sum(1) # q
321 | tmp_cmc = matches.cumsum(1) # q * k
322 |
323 | inp = [tmp_cmc[i][match_row.nonzero()[-1]] / (match_row.nonzero()[-1] + 1.) for i, match_row in enumerate(matches)]
324 | mINP = torch.cat(inp).mean() * 100
325 |
326 | tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])]
327 | tmp_cmc = torch.stack(tmp_cmc, 1) * matches
328 | AP = tmp_cmc.sum(1) / num_rel # q
329 | mAP = AP.mean() * 100
330 |
331 | t2i_cmc, t2i_mAP, t2i_mINP, _ = all_cmc, mAP, mINP, indices
332 | t2i_cmc, t2i_mAP, t2i_mINP = t2i_cmc.numpy(), t2i_mAP.numpy(), t2i_mINP.numpy()
333 |
334 | if not table:
335 | table = PrettyTable(["task", "R1", "R5", "R10", "mAP", "mINP"])
336 | table.add_row(['t2i', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP, t2i_mINP])
337 | table.custom_format["R1"] = lambda f, v: f"{v:.3f}"
338 | table.custom_format["R5"] = lambda f, v: f"{v:.3f}"
339 | table.custom_format["R10"] = lambda f, v: f"{v:.3f}"
340 | table.custom_format["mAP"] = lambda f, v: f"{v:.3f}"
341 | table.custom_format["mINP"] = lambda f, v: f"{v:.3f}"
342 | print(table)
343 |
344 | eval_result = {'R1': t2i_cmc[0],
345 | 'R5': t2i_cmc[4],
346 | 'R10': t2i_cmc[9],
347 | 'mAP': t2i_mAP,
348 | 'mINP': t2i_mINP,
349 | }
350 | return eval_result
351 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.4.9
2 | transformers==4.12.5
3 | ruamel_yaml
4 | opencv-python
5 | scikit-image
6 | matplotlib
7 | audtorch
8 | seaborn
9 | prettytable
10 | easydict
11 | nltk
12 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 |
5 | # Set it correctly for distributed training across nodes
6 | NNODES = 1 # e.g. 1/2/3/4
7 | NPROC_PER_NODE = 4 # e.g. 4 gpus
8 | MASTER_ADDR = '127.0.0.1'
9 | MASTER_PORT = 3000 # 0~65536
10 | NODE_RANK = 0 # e.g. 0/1/2
11 |
12 | print("NNODES, ", NNODES)
13 | print("NPROC_PER_NODE, ", NPROC_PER_NODE)
14 | print("MASTER_ADDR, ", MASTER_ADDR)
15 | print("MASTER_PORT, ", MASTER_PORT)
16 | print("NODE_RANK, ", NODE_RANK)
17 |
18 |
19 | def get_dist_launch(args): # some examples
20 | if args.dist == 'f4':
21 | return "CUDA_VISIBLE_DEVICES=0,1,2,3 WORLD_SIZE=4 python3 -m torch.distributed.launch --nproc_per_node=4 " \
22 | "--nnodes=1 --master_port={:}".format(MASTER_PORT)
23 |
24 | elif args.dist == 'f2':
25 | return "CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2 python3 -m torch.distributed.launch --nproc_per_node=2 " \
26 | "--nnodes=1 --master_port={:}".format(MASTER_PORT)
27 |
28 | elif args.dist == 'l2':
29 | return "CUDA_VISIBLE_DEVICES=2,3 WORLD_SIZE=2 python3 -m torch.distributed.launch --nproc_per_node=2 " \
30 | "--nnodes=1 --master_port={:}".format(MASTER_PORT)
31 |
32 | elif args.dist == 'f-0':
33 | return "CUDA_VISIBLE_DEVICES=1,2,3 WORLD_SIZE=3 python3 -m torch.distributed.launch --nproc_per_node=3 " \
34 | "--nnodes=1 "
35 |
36 | elif args.dist == 'f-1':
37 | return "CUDA_VISIBLE_DEVICES=0,2,3 WORLD_SIZE=3 python3 -m torch.distributed.launch --nproc_per_node=3 " \
38 | "--nnodes=1 "
39 |
40 | elif args.dist == 'f-2':
41 | return "CUDA_VISIBLE_DEVICES=0,1,3 WORLD_SIZE=3 python3 -m torch.distributed.launch --nproc_per_node=3 " \
42 | "--nnodes=1 "
43 |
44 | elif args.dist == 'f-3':
45 | return "CUDA_VISIBLE_DEVICES=0,1,2 WORLD_SIZE=3 python3 -m torch.distributed.launch --nproc_per_node=3 " \
46 | "--nnodes=1 "
47 |
48 | elif args.dist.startswith('gpu'): # use one gpu, --dist "gpu0"
49 | num = int(args.dist[3:])
50 | assert 0 <= num <= 3
51 | return "CUDA_VISIBLE_DEVICES={:} WORLD_SIZE=1 python3 -m torch.distributed.launch --nproc_per_node=1 " \
52 | "--nnodes=1 --master_port={:} ".format(num, MASTER_PORT)
53 |
54 | else:
55 | raise ValueError
56 |
57 |
58 | def run_retrieval(args):
59 | dist_launch = get_dist_launch(args)
60 |
61 | os.system(f"{dist_launch} "
62 | f"--use_env Retrieval.py --config {args.config} "
63 | f"--task {args.task} --output_dir {args.output_dir} --bs {args.bs} --epo {args.epo} --checkpoint {args.checkpoint} {'--evaluate' if args.evaluate else ''}")
64 |
65 |
66 | def run(args):
67 | if args.task not in ['itr_gene']:
68 | assert os.path.exists(args.checkpoint)
69 |
70 | if args.task == 'itr_cuhk':
71 | assert os.path.exists("images/CUHK-PEDES")
72 | args.config = 'configs/Retrieval_cuhk.yaml'
73 | run_retrieval(args)
74 |
75 | elif args.task == 'itr_icfg':
76 | assert os.path.exists("images/ICFG-PEDES")
77 | args.config = 'configs/Retrieval_icfg.yaml'
78 | run_retrieval(args)
79 |
80 | elif args.task == 'itr_rstp':
81 | assert os.path.exists("images/RSTPReid")
82 | args.config = 'configs/Retrieval_rstp.yaml'
83 | run_retrieval(args)
84 |
85 | elif args.task == 'itr_gene':
86 | assert os.path.exists("images/CUHK-PEDES")
87 | args.config = 'configs/Retrieval_gene.yaml'
88 | run_retrieval(args)
89 |
90 | elif args.task == 'itr_pa100k':
91 | assert os.path.exists("images/pa100k")
92 | args.config = 'configs/Retrieval_pa100k.yaml'
93 | run_retrieval(args)
94 |
95 | else:
96 | raise NotImplementedError(f"task == {args.task}")
97 |
98 |
99 | if __name__ == '__main__':
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument('--task', type=str, required=True)
102 | parser.add_argument('--dist', type=str, required=True, help="see func get_dist_launch for details")
103 | parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus; ")
104 | parser.add_argument('--epo', default=-1, type=int, help="epoch")
105 | parser.add_argument('--seed', default=42, type=int)
106 | parser.add_argument('--checkpoint', default='output/pretrain/checkpoint_31.pth', type=str, help="for fine-tuning")
107 | parser.add_argument('--output_dir', type=str, required=True, help='local path; ')
108 | parser.add_argument('--evaluate', action='store_true', help="evaluation on downstream tasks")
109 | args = parser.parse_args()
110 |
111 | assert os.path.exists(os.path.dirname(args.output_dir))
112 | if not os.path.exists(args.output_dir):
113 | os.mkdir(args.output_dir)
114 |
115 | run(args)
116 |
--------------------------------------------------------------------------------
/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import LambdaLR
2 |
3 |
4 | def create_scheduler(args, optimizer):
5 | if 'num_training_steps' not in args:
6 | args['num_training_steps'] = args['epochs'] * args['step_per_epoch']
7 | print("### num_training_steps, ", args['num_training_steps'], flush=True)
8 |
9 | if isinstance(args['num_warmup_steps'], float):
10 | assert 0 <= args['num_warmup_steps'] < 1
11 | args['num_warmup_steps'] = int(args['num_training_steps'] * args['num_warmup_steps'])
12 | print("### num_warmup_steps, ", args['num_warmup_steps'], flush=True)
13 |
14 | print('sched:', args.sched, flush=True)
15 |
16 | if args.sched == 'linear':
17 | def lr_lambda(current_step: int):
18 | if current_step < args.num_warmup_steps:
19 | return float(current_step) / float(max(1, args.num_warmup_steps))
20 | return max(
21 | 0.0, float(args.num_training_steps - current_step) / float(
22 | max(1, args.num_training_steps - args.num_warmup_steps))
23 | )
24 |
25 | lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
26 |
27 | elif args.sched == 'step':
28 | def lr_lambda(current_step: int):
29 | if current_step < args.num_warmup_steps:
30 | return float(current_step) / float(max(1, args.num_warmup_steps))
31 | elif current_step < args.num_warmup_steps * 4:
32 | tt = 1
33 | elif current_step < args.num_warmup_steps * 7:
34 | tt = 0.5
35 | else:
36 | tt = 0.2
37 |
38 | return tt * max(
39 | 0.0, float(args.num_training_steps - current_step) / float(
40 | max(1, args.num_training_steps - args.num_warmup_steps))
41 | )
42 |
43 | lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
44 |
45 | else:
46 | raise NotImplementedError(f"args.sched == {args.sched}")
47 |
48 | return lr_scheduler
49 |
--------------------------------------------------------------------------------
/train_pa100ks.py:
--------------------------------------------------------------------------------
1 | import utils
2 | from train_tools import mlm
3 |
4 | def train_pa100k(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config, mask_generator=None):
5 | model.train()
6 |
7 | metric_logger = utils.MetricLogger(delimiter=" ")
8 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
9 | metric_logger.add_meter('loss_itc', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
10 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
11 | if config['mlm']:
12 | metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
13 | header = 'Train Epoch: [{}]'.format(epoch)
14 | print_freq = 50
15 |
16 | for i, (image, label) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
17 | image = image.to(device, non_blocking=True)
18 | label = label.to(device, non_blocking=True)
19 |
20 | text = ['the person is a man', 'the person is a woman',
21 | 'the person is no more than 60 years old', 'the person is older than 60 years old',
22 | 'the person is a young or old one', 'the person is of mid age, between 18 and 60 years old',
23 | 'the person is older than 18', 'the person is a baby or a teenager, younger than 18',
24 |
25 | 'the picture is not the front of the person', 'the picture shows the front of the person',
26 | 'the picture is not the side of the person', 'the picture shows the side of the person',
27 | 'the picture is not the back of the person', 'the picture shows the back of the person',
28 | 'a person without a hat', 'a person with a hat',
29 |
30 | 'a person without a glasses', 'a person with a glasses',
31 | 'a person without a handbag', 'a person with a handbag',
32 | 'a person without a shoulder bag', 'a person with a shoulder bag',
33 | 'a person without a backpack', 'a person with a backpack',
34 |
35 | 'the person does not hold an object in front', 'the person hold an object in front',
36 | 'the person does not wear short sleeved upper clothes', 'the person wears short sleeved upper clothes',
37 | 'the person does not wear long sleeved upper clothes', 'the person wears long sleeved upper clothes',
38 | 'there is no stride on the upper clothes of the person',
39 | 'there is stride on the upper clothes of the person',
40 |
41 | 'there is no logo on the upper clothes of the person',
42 | 'there is logo on the upper clothes of the person',
43 | 'there is no plaid on the upper clothes of the person',
44 | 'there is plaid on the upper clothes of the person',
45 | 'there is no splice on the upper clothes of the person',
46 | 'there is splice on the upper clothes of the person',
47 | 'there is no stripe on the upper clothes of the person',
48 | 'there is stripe on the upper clothes of the person',
49 |
50 | 'there is no pattern on the lower part of the person',
51 | 'there is pattern on the lower part of the person',
52 | 'the person does not wear long coat', 'the person wears long coat',
53 | 'the person does not wear trousers', 'the person wears trousers',
54 | 'the person does not wear shorts', 'the person wears shorts',
55 |
56 | 'the person does not wear a skirt or a dress', 'the person wears a skirt or a dress',
57 | 'the person does not wear boots', 'the person wears boots',
58 | ]
59 |
60 | text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'],
61 | return_tensors="pt").to(device)
62 |
63 | # mlm loss
64 | if config['mlm']:
65 | text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator, config, True)
66 | loss_itc, loss_itm, loss_mlm = model(image, text_input.input_ids, text_input.attention_mask,
67 | text_ids_masked=text_ids_masked, masked_pos=masked_pos,
68 | masked_ids=masked_ids, label=label)
69 | loss = loss_itc + loss_itm + loss_mlm
70 | else:
71 | loss_itc, loss_itm = model(image, text_input.input_ids, text_input.attention_mask, label=label)
72 | loss = loss_itc + loss_itm
73 |
74 | optimizer.zero_grad()
75 | loss.backward()
76 | optimizer.step()
77 | scheduler.step()
78 |
79 | metric_logger.update(loss_itc=loss_itc.item())
80 | metric_logger.update(loss_itm=loss_itm.item())
81 | if config['mlm']:
82 | metric_logger.update(loss_mlm=loss_mlm.item())
83 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
84 |
85 | # gather the stats from all processes
86 | metric_logger.synchronize_between_processes()
87 | print("Averaged stats:", metric_logger.global_avg())
88 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
89 |
90 |
91 | def train_pa100k_only_img_classifier(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config,
92 | mask_generator=None):
93 | model.train()
94 |
95 | metric_logger = utils.MetricLogger(delimiter=" ")
96 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
97 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
98 | header = 'Train Epoch: [{}]'.format(epoch)
99 | print_freq = 50
100 |
101 | for i, (image, label) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
102 | image = image.to(device, non_blocking=True)
103 | label = label.to(device, non_blocking=True)
104 |
105 | loss = model(image, None, None, label=label)
106 |
107 | optimizer.zero_grad()
108 | loss.backward()
109 | optimizer.step()
110 | scheduler.step()
111 |
112 | metric_logger.update(loss=loss.item())
113 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
114 |
115 | # gather the stats from all processes
116 | metric_logger.synchronize_between_processes()
117 | print("Averaged stats:", metric_logger.global_avg())
118 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/train_tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | def mlm(text, text_input, tokenizer, device, mask_generator, config, pa100k=False):
6 | if pa100k:
7 | text_masked = tokenizer(text, padding='longest', max_length=config['max_tokens'],
8 | return_tensors="pt").to(device)
9 | else:
10 | text_masked = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'],
11 | return_tensors="pt").to(device)
12 | text_ids_masked = text_masked.input_ids
13 | masked_pos = torch.empty((text_ids_masked.shape[0], config['max_masks']), dtype=torch.int64, device=device)
14 | masked_ids = torch.empty((text_ids_masked.shape[0], config['max_masks']), dtype=torch.long, device=device)
15 | for index, text_id in enumerate(text_ids_masked):
16 | text_ids_masked_, masked_pos_ = mask_generator(text_id)
17 | masked_ids_ = [text_input.input_ids[index][p].item() for p in masked_pos_]
18 | n_pad = config['max_masks'] - len(masked_ids_)
19 | masked_pos_ = masked_pos_ + [0] * n_pad
20 | masked_pos_ = torch.tensor(masked_pos_, dtype=torch.int64).to(device)
21 | masked_ids_ = masked_ids_ + [-100] * n_pad
22 | masked_ids_ = torch.tensor(masked_ids_, dtype=torch.long).to(device)
23 | masked_pos[index] = masked_pos_
24 | masked_ids[index] = masked_ids_
25 | return text_ids_masked, masked_pos, masked_ids
26 |
--------------------------------------------------------------------------------
/trains.py:
--------------------------------------------------------------------------------
1 | import utils
2 | from train_tools import mlm
3 | import numpy as np
4 |
5 |
6 | def train(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config, mask_generator=None):
7 | model.train()
8 |
9 | metric_logger = utils.MetricLogger(delimiter=" ")
10 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
11 | metric_logger.add_meter('loss_itc', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
12 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
13 | if config['mlm']:
14 | metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
15 | header = 'Train Epoch: [{}]'.format(epoch)
16 | print_freq = 50
17 |
18 | if config['eda']:
19 | for i, (image, text, text_eda, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
20 | image = image.to(device, non_blocking=True)
21 | idx = idx.to(device, non_blocking=True)
22 | # text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'],
23 | # return_tensors="pt").to(device)
24 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'],
25 | return_tensors="pt").to(device)
26 | text_input_eda = tokenizer(text_eda, padding='max_length', truncation=True, max_length=config['max_tokens'],
27 | return_tensors="pt").to(device)
28 | if config['mlm']:
29 | text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator,
30 | config)
31 | loss_itc, loss_itm, loss_mlm = model(image, text_input.input_ids, text_input.attention_mask,
32 | text_ids_masked=text_ids_masked,
33 | masked_pos=masked_pos, masked_ids=masked_ids, idx=idx,
34 | text_ids_eda=text_input_eda.input_ids,
35 | text_atts_eda=text_input_eda.attention_mask)
36 | loss = loss_itc + loss_itm + loss_mlm
37 | else:
38 | loss_itc, loss_itm = model(image, text_input.input_ids, text_input.attention_mask, idx=idx,
39 | text_ids_eda=text_input_eda.input_ids,
40 | text_atts_eda=text_input_eda.attention_mask)
41 | loss = loss_itc + loss_itm
42 |
43 | optimizer.zero_grad()
44 | loss.backward()
45 | optimizer.step()
46 | scheduler.step()
47 |
48 | metric_logger.update(loss_itc=loss_itc.item())
49 | metric_logger.update(loss_itm=loss_itm.item())
50 | if config['mlm']:
51 | metric_logger.update(loss_mlm=loss_mlm.item())
52 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
53 | else:
54 | for i, (image, text, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
55 | image = image.to(device, non_blocking=True)
56 | idx = idx.to(device, non_blocking=True)
57 | # text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'],
58 | # return_tensors="pt").to(device)
59 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'],
60 | return_tensors="pt").to(device)
61 | # mlm loss
62 | if config['mlm']:
63 | text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator,
64 | config)
65 | loss_itc, loss_itm, loss_mlm = model(image, text_input.input_ids,
66 | text_input.attention_mask,
67 | text_ids_masked=text_ids_masked,
68 | masked_pos=masked_pos, masked_ids=masked_ids,
69 | idx=idx)
70 | loss = loss_itc + loss_itm + loss_mlm
71 | else:
72 | loss_itc, loss_itm = model(image, text_input.input_ids, text_input.attention_mask, idx=idx)
73 | loss = loss_itc + loss_itm
74 |
75 | optimizer.zero_grad()
76 | loss.backward()
77 | optimizer.step()
78 | scheduler.step()
79 |
80 | metric_logger.update(loss_itc=loss_itc.item())
81 | metric_logger.update(loss_itm=loss_itm.item())
82 | if config['mlm']:
83 | metric_logger.update(loss_mlm=loss_mlm.item())
84 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
85 |
86 | # gather the stats from all processes
87 | metric_logger.synchronize_between_processes()
88 | print("Averaged stats:", metric_logger.global_avg())
89 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
90 |
91 |
92 | def train_attr(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config, mask_generator=None):
93 | model.train()
94 |
95 | metric_logger = utils.MetricLogger(delimiter=" ")
96 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
97 | metric_logger.add_meter('loss_itc', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
98 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
99 | if config['mlm']:
100 | metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
101 | metric_logger.add_meter('loss_attr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
102 |
103 | header = 'Train Epoch: [{}]'.format(epoch)
104 | print_freq = 50
105 |
106 | for i, (image, text, idx, label) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
107 | image = image.to(device, non_blocking=True)
108 | idx = idx.to(device, non_blocking=True)
109 | # text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'],
110 | # return_tensors="pt").to(device)
111 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'],
112 | return_tensors="pt").to(device)
113 | label = label.to(device, non_blocking=True)
114 |
115 | attr = ['the person is a woman', 'the person is a man',
116 | 'the person is younger than 18 years old', 'the person is older than 18 years old',
117 |
118 | 'the person with short hair', 'the person with long hair',
119 | 'the person with a hat', 'the person without a hat',
120 | 'the person with a backpack', 'the person without a backpack',
121 | 'the person with a handbag', 'the person without a handbag',
122 | 'the person with a bag', 'the person without a bag',
123 |
124 | 'the person wears long sleeved upper clothes', 'the person wears short sleeved upper clothes',
125 | 'the person wears long dress or long pants', 'the person wears short dress or short pants',
126 | 'the person wears dress or skirt', 'the person wears pants or shorts',
127 |
128 | 'the person wears black upper clothes', 'the person does not wear black upper clothes',
129 | 'the person wears white upper clothes', 'the person does not wear white upper clothes',
130 | 'the person wears red upper clothes', 'the person does not wear red upper clothes',
131 | 'the person wears purple upper clothes', 'the person does not wear purple upper clothes',
132 |
133 | 'the person wears yellow upper clothes', 'the person does not wear yellow upper clothes',
134 | 'the person wears blue upper clothes', 'the person does not wear blue upper clothes',
135 | 'the person wears green upper clothes', 'the person does not wear green upper clothes',
136 | 'the person wears gray upper clothes', 'the person does not wear gray upper clothes',
137 |
138 | 'the person wears black lower clothes', 'the person does not wear black lower clothes',
139 | 'the person wears white lower clothes', 'the person does not wear white lower clothes',
140 | 'the person wears purple lower clothes', 'the person does not wear purple lower clothes',
141 | 'the person wears yellow lower clothes', 'the person does not wear yellow lower clothes',
142 |
143 | 'the person wears blue lower clothes', 'the person does not wear blue lower clothes',
144 | 'the person wears green lower clothes', 'the person does not wear green lower clothes',
145 | 'the person wears pink lower clothes', 'the person does not wear pink lower clothes',
146 | 'the person wears gray lower clothes', 'the person does not wear gray lower clothes',
147 | 'the person wears brown lower clothes', 'the person does not wear brown lower clothes',
148 |
149 | ]
150 | attr_input = tokenizer(attr, padding='longest', max_length=config['max_tokens'],
151 | return_tensors="pt").to(device)
152 |
153 | # mlm loss
154 | if config['mlm']:
155 | text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator,
156 | config)
157 | attr_text_ids_masked, attr_masked_pos, attr_masked_ids = mlm(attr, attr_input, tokenizer, device,
158 | mask_generator, config,
159 | True)
160 |
161 | loss_itc, loss_itm, loss_mlm, loss_attr = model(image, text_input.input_ids, text_input.attention_mask,
162 | text_ids_masked=text_ids_masked, masked_pos=masked_pos,
163 | masked_ids=masked_ids, idx=idx,
164 | attr_text_ids=attr_input.input_ids,
165 | attr_text_atts=attr_input.attention_mask,
166 | attr_text_ids_masked=attr_text_ids_masked,
167 | attr_masked_pos=attr_masked_pos,
168 | attr_masked_ids=attr_masked_ids, label=label)
169 | loss = loss_itc + loss_itm + loss_mlm + config['t'] * loss_attr
170 | else:
171 | loss_itc, loss_itm, loss_attr = model(image, text_input.input_ids, text_input.attention_mask, idx=idx,
172 | attr_text_ids=attr_input.input_ids,
173 | attr_text_atts=attr_input.attention_mask,
174 | label=label)
175 | loss = loss_itc + loss_itm + config['t'] * loss_attr
176 |
177 | optimizer.zero_grad()
178 | loss.backward()
179 | optimizer.step()
180 | scheduler.step()
181 |
182 | metric_logger.update(loss_itc=loss_itc.item())
183 | metric_logger.update(loss_itm=loss_itm.item())
184 | if config['mlm']:
185 | metric_logger.update(loss_mlm=loss_mlm.item())
186 | metric_logger.update(loss_attr=loss_attr.item())
187 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
188 |
189 | # gather the stats from all processes
190 | metric_logger.synchronize_between_processes()
191 | print("Averaged stats:", metric_logger.global_avg())
192 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | from collections import defaultdict, deque, OrderedDict
5 | import datetime
6 | import numpy as np
7 |
8 | import torch
9 | import torch.distributed as dist
10 |
11 |
12 | class SmoothedValue(object):
13 | """Track a series of values and provide access to smoothed values over a
14 | window or the global series average.
15 | """
16 |
17 | def __init__(self, window_size=20, fmt=None):
18 | if fmt is None:
19 | fmt = "{median:.4f} ({global_avg:.4f})"
20 | self.deque = deque(maxlen=window_size)
21 | self.total = 0.0
22 | self.count = 0
23 | self.fmt = fmt
24 |
25 | def update(self, value, n=1):
26 | self.deque.append(value)
27 | self.count += n
28 | self.total += value * n
29 |
30 | def synchronize_between_processes(self):
31 | """
32 | Warning: does not synchronize the deque!
33 | """
34 | if not is_dist_avail_and_initialized():
35 | return
36 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
37 | dist.barrier()
38 | dist.all_reduce(t)
39 | t = t.tolist()
40 | self.count = int(t[0])
41 | self.total = t[1]
42 |
43 | @property
44 | def median(self):
45 | d = torch.tensor(list(self.deque))
46 | return d.median().item()
47 |
48 | @property
49 | def avg(self):
50 | d = torch.tensor(list(self.deque), dtype=torch.float32)
51 | return d.mean().item()
52 |
53 | @property
54 | def global_avg(self):
55 | return self.total / self.count
56 |
57 | @property
58 | def max(self):
59 | return max(self.deque)
60 |
61 | @property
62 | def value(self):
63 | return self.deque[-1]
64 |
65 | def __str__(self):
66 | return self.fmt.format(
67 | median=self.median,
68 | avg=self.avg,
69 | global_avg=self.global_avg,
70 | max=self.max,
71 | value=self.value)
72 |
73 |
74 | class MetricLogger(object):
75 | def __init__(self, delimiter="\t"):
76 | self.meters = defaultdict(SmoothedValue)
77 | self.delimiter = delimiter
78 |
79 | def update(self, **kwargs):
80 | for k, v in kwargs.items():
81 | if isinstance(v, torch.Tensor):
82 | v = v.item()
83 | assert isinstance(v, (float, int))
84 | self.meters[k].update(v)
85 |
86 | def __getattr__(self, attr):
87 | if attr in self.meters:
88 | return self.meters[attr]
89 | if attr in self.__dict__:
90 | return self.__dict__[attr]
91 | raise AttributeError("'{}' object has no attribute '{}'".format(
92 | type(self).__name__, attr))
93 |
94 | def __str__(self):
95 | loss_str = []
96 | for name, meter in self.meters.items():
97 | loss_str.append(
98 | "{}: {}".format(name, str(meter))
99 | )
100 | return self.delimiter.join(loss_str)
101 |
102 | def global_avg(self):
103 | loss_str = []
104 | for name, meter in self.meters.items():
105 | loss_str.append(
106 | "{}: {:.4f}".format(name, meter.global_avg)
107 | )
108 | return self.delimiter.join(loss_str)
109 |
110 | def synchronize_between_processes(self):
111 | for meter in self.meters.values():
112 | meter.synchronize_between_processes()
113 |
114 | def add_meter(self, name, meter):
115 | self.meters[name] = meter
116 |
117 | def log_every(self, iterable, print_freq, header=None, dataset_len=None, epoch_info=None):
118 | if not header:
119 | header = ''
120 | if not dataset_len:
121 | dataset_len = len(iterable)
122 | start_time = time.time()
123 | end = time.time()
124 | iter_time = SmoothedValue(fmt='{avg:.4f}')
125 | data_time = SmoothedValue(fmt='{avg:.4f}')
126 | space_fmt = ':' + str(len(str(dataset_len))) + 'd'
127 |
128 | _msg = [
129 | '[{0' + space_fmt + '}/{1}]',
130 | 'eta: {eta}',
131 | '{meters}',
132 | 'time: {time}',
133 | 'data: {data}'
134 | ]
135 | if torch.cuda.is_available():
136 | _msg.append('max mem: {memory:.0f}')
137 | _msg = self.delimiter.join(_msg)
138 | MB = 1024.0 * 1024.0
139 | iterable = iter(iterable)
140 | train_steps = dataset_len
141 | if epoch_info:
142 | start_epoch, end_epoch = epoch_info
143 | train_steps = (end_epoch - start_epoch) * dataset_len
144 | for i in range(train_steps):
145 | obj = next(iterable)
146 | data_time.update(time.time() - end)
147 | yield obj
148 | iter_time.update(time.time() - end)
149 | if epoch_info:
150 | header = int(i / dataset_len) + start_epoch
151 | header = 'Train step: [{}]'.format(header)
152 | log_msg = header + " " + _msg
153 | if (i % dataset_len) % print_freq == 0 or i == dataset_len - 1:
154 | eta_seconds = iter_time.global_avg * (dataset_len - i % dataset_len)
155 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
156 | if torch.cuda.is_available():
157 | print(log_msg.format(
158 | i % dataset_len, dataset_len, eta=eta_string,
159 | meters=str(self),
160 | time=str(iter_time), data=str(data_time),
161 | memory=torch.cuda.max_memory_allocated() / MB))
162 | else:
163 | print(log_msg.format(
164 | i % dataset_len, dataset_len, eta=eta_string,
165 | meters=str(self),
166 | time=str(iter_time), data=str(data_time)))
167 |
168 | end = time.time()
169 | total_time = time.time() - start_time
170 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
171 | print('{} Total time: {} ({:.4f} s / it)'.format(
172 | header, total_time_str, total_time / dataset_len))
173 |
174 |
175 | class AttrDict(dict):
176 | def __init__(self, *args, **kwargs):
177 | super(AttrDict, self).__init__(*args, **kwargs)
178 | self.__dict__ = self
179 |
180 |
181 | def compute_acc(logits, label, reduction='mean'):
182 | ret = (torch.argmax(logits, dim=1) == label).float()
183 | if reduction == 'none':
184 | return ret.detach()
185 | elif reduction == 'mean':
186 | return ret.mean().item()
187 |
188 |
189 | def compute_n_params(model, return_str=True):
190 | tot = 0
191 | for p in model.parameters():
192 | w = 1
193 | for x in p.shape:
194 | w *= x
195 | tot += w
196 | if return_str:
197 | if tot >= 1e6:
198 | return '{:.1f}M'.format(tot / 1e6)
199 | else:
200 | return '{:.1f}K'.format(tot / 1e3)
201 | else:
202 | return tot
203 |
204 |
205 | def setup_for_distributed(is_master):
206 | """
207 | This function disables printing when not in master process
208 | """
209 | import builtins as __builtin__
210 | builtin_print = __builtin__.print
211 |
212 | def print(*args, **kwargs):
213 | force = kwargs.pop('force', False)
214 | if is_master or force:
215 | builtin_print(*args, **kwargs)
216 |
217 | __builtin__.print = print
218 |
219 |
220 | def is_dist_avail_and_initialized():
221 | if not dist.is_available():
222 | return False
223 | if not dist.is_initialized():
224 | return False
225 | return True
226 |
227 |
228 | def get_world_size():
229 | if not is_dist_avail_and_initialized():
230 | return 1
231 | return dist.get_world_size()
232 |
233 |
234 | def get_rank():
235 | if not is_dist_avail_and_initialized():
236 | return 0
237 | return dist.get_rank()
238 |
239 |
240 | def is_main_process():
241 | return get_rank() == 0
242 |
243 |
244 | def save_on_master(*args, **kwargs):
245 | if is_main_process():
246 | torch.save(*args, **kwargs)
247 |
248 |
249 | def init_distributed_mode(args):
250 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
251 | args.rank = int(os.environ["RANK"])
252 | args.world_size = int(os.environ['WORLD_SIZE'])
253 | args.gpu = int(os.environ['LOCAL_RANK'])
254 | elif 'SLURM_PROCID' in os.environ:
255 | args.rank = int(os.environ['SLURM_PROCID'])
256 | args.gpu = args.rank % torch.cuda.device_count()
257 | else:
258 | print('Not using distributed mode')
259 | args.distributed = False
260 | return
261 |
262 | args.distributed = True
263 |
264 | torch.cuda.set_device(args.gpu)
265 | args.dist_backend = 'nccl'
266 | print('| distributed init (rank {}): {}'.format(
267 | args.rank, args.dist_url), flush=True)
268 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
269 | world_size=args.world_size, rank=args.rank)
270 | torch.distributed.barrier()
271 | setup_for_distributed(args.rank == 0)
272 |
273 |
274 | def read_json(rpath):
275 | with open(rpath, 'r') as f:
276 | return json.load(f)
--------------------------------------------------------------------------------