├── .gitignore ├── .idea ├── .gitignore ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── tbps-clip.iml └── vcs.xml ├── README.md ├── config ├── config.yaml └── s.config.yaml ├── image └── intro.png ├── main.py ├── misc ├── build.py ├── caption_dataset.py ├── data.py ├── eval.py ├── lr_scheduler.py └── utils.py ├── model ├── __init__.py ├── base_transformer.py ├── eda.py ├── loss.py ├── mixgen.py ├── shared_modules.py ├── tbps_model.py ├── text_transformer.py └── visual_transformer.py ├── options.py ├── requirements.txt ├── shell └── train.sh └── text_utils ├── bpe_simple_vocab_16e6.txt.gz ├── mask_tokens.py ├── simple_tokenizer.py └── tokenizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | # IDEA 142 | .idea/ 143 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/tbps-clip.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # 【AAAI 2024 🔥】An Empirical Study of CLIP for Text-based Person Search 4 | [![Paper](http://img.shields.io/badge/Paper-AAAI_2024-Green.svg)](https://ojs.aaai.org/index.php/AAAI/article/view/27801) 5 | [![Paper](http://img.shields.io/badge/Paper-arxiv.2308.10045-FF6B6B.svg)](https://arxiv.org/abs/2308.10045) 6 |
7 | 8 | This repository offers the official implementation of [TBPS-CLIP](https://arxiv.org/abs/2308.10045) in PyTorch. 9 | 10 | In the meantime, check out our related papers if you are interested: 11 | + 【ACM MM 2023】 [Text-based Person Search without Parallel Image-Text Data](https://arxiv.org/abs/2305.12964) 12 | + 【IJCAI 2023】 [RaSa: Relation and Sensitivity Aware Representation Learning for Text-based Person Search](https://arxiv.org/abs/2305.13653) 13 | + 【ICASSP 2022】 [Learning Semantic-Aligned Feature Representation for Text-based Person Search](https://arxiv.org/abs/2112.06714) 14 | 15 | ## Note 16 | More experiments and implementation details are attached on the Appendix of the [arXiv](https://arxiv.org/abs/2308.10045) version. 17 | 18 | 19 | ## Overview 20 | By revisiting the critical design of data augmentation and loss function in [CLIP](https://arxiv.org/abs/2103.00020), 21 | we provide a strong baseline [TBPS-CLIP](https://arxiv.org/abs/2308.10045) for text-based person search. 22 | 23 | 24 | 25 | 26 | ## Environment 27 | 28 | All the experiments are conducted on 4 Nvidia A40 (48GB) GPUs. The CUDA version is 11.7. 29 | 30 | The required packages are listed in `requirements.txt`. You can install them using: 31 | 32 | ```sh 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Download 37 | 1. Download CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description), ICFG-PEDES dataset from [here](https://github.com/zifyloo/SSAN) and RSTPReid dataset from [here](https://github.com/NjtechCVLab/RSTPReid-Dataset). 38 | 2. Download the annotation json files from [here](https://drive.google.com/file/d/1C5bgGCABtuzZMaa2n4Sc0qclUvZ-mqG9/view?usp=drive_link). 39 | 3. Download the pretrained CLIP checkpoint from [here](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt). 40 | 41 | ## Configuration 42 | In `config/config.yaml` and `config/s.config.yaml`, set the paths for the annotation file, image path and the CLIP checkpoint path. 43 | 44 | 45 | ## Training 46 | 47 | You can start the training using PyTorch's torchrun with ease: 48 | 49 | ```sh 50 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 51 | torchrun --rdzv_id=3 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 --nproc_per_node=4 \ 52 | main.py 53 | ``` 54 | 55 | You can also easily run simplified version using: 56 | 57 | ```sh 58 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 59 | torchrun --rdzv_id=3 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 --nproc_per_node=4 \ 60 | main.py --simplified 61 | ``` 62 | 63 | 64 | ## Model Checkpoints 65 | | | **CUHK-PEDES** | **ICFG-PEDES** | **RSTPReid** | 66 | |:-----------------------------------:|:-------------------------------------------------------------------------------------------------:|:--------------:|:------------:| 67 | | **TBPS-CLIP (ViT-B/16)** | [Download](https://drive.google.com/file/d/1m_3pKanUWHQHeJ-zt-QeRXs7bmay-U5P/view?usp=drive_link) | [Download](https://drive.google.com/file/d/1az4z5b_ADXR7DcysPB5giOl52LjWDCSu/view?usp=drive_link) | [Download](https://drive.google.com/file/d/1qMUAsH-1lzkWUFQsUvUKTY0J6ZuGkYd6/view?usp=drive_link) | 68 | | **Simplified TBPS-CLIP (ViT-B/16)** | [Download](https://drive.google.com/file/d/1W5oFZK9WNHMfy0OOaYQBzPsP1LZR80bT/view?usp=drive_link) | [Download](https://drive.google.com/file/d/1UoLd-MQ8tYJ7YPgCbh3nVSVYnJ9a_TG5/view?usp=drive_link) | [Download](https://drive.google.com/file/d/18zlc3q3Sze5rx3TqcfEeZEjrQXUTpcQF/view?usp=drive_link) | 69 | 70 | 71 | ## Acknowledgement 72 | + [CLIP](https://arxiv.org/abs/2103.00020) The model architecture of TBPS-CLIP 73 | 74 | ## Citation 75 | If you find this paper useful, please consider staring 🌟 this repo and citing 📑 our paper: 76 | ``` 77 | @inproceedings{cao2024empirical, 78 | title={An Empirical Study of CLIP for Text-Based Person Search}, 79 | author={Cao, Min and Bai, Yang and Zeng, Ziyin and Ye, Mang and Zhang, Min}, 80 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 81 | volume={38}, 82 | number={1}, 83 | pages={465--473}, 84 | year={2024} 85 | } 86 | ``` 87 | 88 | 89 | ## License 90 | This code is distributed under an MIT LICENSE. 91 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | device: 5 2 | 3 | misc: 4 | seed: 1 5 | 6 | experiment: 7 | # image 8 | input_resolution: [224, 224] 9 | simclr_mlp: [512, 128, 512] 10 | simclr_temperature: 0.1 11 | # text 12 | dropout: 0.05 13 | eda_alpha: 0.05 14 | back_trans: true 15 | backtrans_p: 0.1 16 | text_length: 77 17 | # mix 18 | mixgen: false 19 | mixgen_type: cat # ori or cat 20 | mixgen_p: 0.1 21 | mixgen_ratio: 0.1 22 | mvs_image: true 23 | 24 | # loss 25 | nitc_ratio: 1.0 26 | #### 27 | ss: true 28 | ss_ratio: 0.4 29 | #### 30 | ritc: true 31 | ritc_eps: 1.0e-2 32 | ritc_ratio: 1.0 33 | #### 34 | mlm: false 35 | mlm_ratio: 1.0 36 | cmt_depth: 4 # cross modal transformer self attn layers 37 | #### 38 | citc: true 39 | citc_lambda1: 0.25 40 | citc_lambda2: 0.25 41 | citc_ratio: 0.1 42 | #### 43 | id: false 44 | id_ratio: 1.0 45 | 46 | schedule: 47 | lr: 1.0e-4 48 | epoch: 5 49 | epoch_warmup: 1 50 | lr_start: 1.0e-6 51 | lr_end: 5.0e-6 52 | weight_decay: 0.02 53 | betas: [0.9, 0.98] 54 | eps: 1.0e-8 55 | 56 | model: 57 | ckpt_type: original_clip # original_clip / saved 58 | saved_path: 'ckpts/baseline_224_224/CUHK-PEDES' 59 | checkpoint: 'CLIP checkpoint path' # e.g., '../../data/CLIP/ViT-B-16.pt' 60 | use_gather: true 61 | softlabel_ratio: 0.5 62 | embed_dim: 512 63 | vocab_size: 49408 64 | 65 | log: 66 | print_period: 50 67 | 68 | data: 69 | batch_size: 80 70 | test_batch_size: 256 71 | num_workers: 8 72 | 73 | distributed: 74 | backend: nccl 75 | url: 'env://' 76 | 77 | anno_dir: 'annotation json path' # e.g., 'data/CUHK-PEDES' 78 | image_dir: 'image path' # e.g., '../../datasets/cuhkpedes/imgs' -------------------------------------------------------------------------------- /config/s.config.yaml: -------------------------------------------------------------------------------- 1 | device: 5 2 | 3 | misc: 4 | seed: 0 5 | 6 | experiment: 7 | # image 8 | input_resolution: [224, 224] 9 | simclr_mlp: [512, 128, 512] 10 | simclr_temperature: 0.1 11 | # text 12 | dropout: 0.05 13 | eda_alpha: 0.05 14 | back_trans: true 15 | backtrans_p: 0.1 16 | text_length: 77 17 | # mix 18 | mixgen: false 19 | mixgen_type: cat # ori or cat 20 | mixgen_p: 0.1 21 | mixgen_ratio: 0.1 22 | mvs_image: false 23 | 24 | # loss 25 | nitc_ratio: 1.0 26 | #### 27 | ss: false 28 | ss_ratio: 0.4 29 | #### 30 | ritc: true 31 | ritc_eps: 1.0e-2 32 | ritc_ratio: 1.0 33 | #### 34 | mlm: false 35 | mlm_ratio: 1.0 36 | cmt_depth: 4 # cross modal transformer self attn layers 37 | #### 38 | citc: false 39 | citc_lambda1: 0.25 40 | citc_lambda2: 0.25 41 | citc_ratio: 0.1 42 | #### 43 | id: false 44 | id_ratio: 1.0 45 | 46 | schedule: 47 | lr: 1.0e-4 48 | epoch: 5 49 | epoch_warmup: 1 50 | lr_start: 1.0e-6 51 | lr_end: 5.0e-6 52 | weight_decay: 0.02 53 | betas: [0.9, 0.98] 54 | eps: 1.0e-8 55 | 56 | model: 57 | ckpt_type: original_clip # original_clip / saved 58 | saved_path: 'ckpts/s.baseline_224_224/CUHK-PEDES' 59 | checkpoint: 'CLIP checkpoint path' # e.g., '../../data/CLIP/ViT-B-16.pt' 60 | use_gather: true 61 | softlabel_ratio: 0.5 62 | embed_dim: 512 63 | vocab_size: 49408 64 | 65 | log: 66 | print_period: 50 67 | 68 | data: 69 | batch_size: 80 70 | test_batch_size: 256 71 | num_workers: 8 72 | 73 | distributed: 74 | backend: nccl 75 | url: 'env://' 76 | 77 | anno_dir: 'annotation json path' # e.g., 'data/CUHK-PEDES' 78 | image_dir: 'image path' # e.g., '../../datasets/cuhkpedes/imgs' -------------------------------------------------------------------------------- /image/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Flame-Chasers/TBPS-CLIP/6160a877af99229bbf39077b1047d96cf7fda64c/image/intro.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | from misc.build import load_checkpoint, cosine_scheduler, build_optimizer 9 | from misc.data import build_pedes_data 10 | from misc.eval import test 11 | from misc.utils import parse_config, init_distributed_mode, set_seed, is_master, is_using_distributed, \ 12 | AverageMeter 13 | from model.tbps_model import clip_vitb 14 | from options import get_args 15 | 16 | 17 | def run(config): 18 | print(config) 19 | 20 | # data 21 | dataloader = build_pedes_data(config) 22 | train_loader = dataloader['train_loader'] 23 | num_classes = len(train_loader.dataset.person2text) 24 | 25 | meters = { 26 | "loss": AverageMeter(), 27 | "nitc_loss": AverageMeter(), 28 | "ss_loss": AverageMeter(), 29 | "citc_loss": AverageMeter(), 30 | "ritc_loss": AverageMeter(), 31 | "mlm_loss": AverageMeter(), 32 | "id_loss": AverageMeter(), 33 | } 34 | best_rank_1 = 0.0 35 | best_epoch = 0 36 | 37 | # model 38 | model = clip_vitb(config, num_classes) 39 | model.to(config.device) 40 | 41 | model, load_result = load_checkpoint(model, config) 42 | 43 | if is_using_distributed(): 44 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.device], 45 | find_unused_parameters=True) 46 | 47 | # schedule 48 | config.schedule.niter_per_ep = len(train_loader) 49 | lr_schedule = cosine_scheduler(config) 50 | 51 | # optimizer 52 | optimizer = build_optimizer(config, model) 53 | 54 | # train 55 | it = 0 56 | scaler = torch.cuda.amp.GradScaler() 57 | for epoch in range(config.schedule.epoch): 58 | print() 59 | if is_using_distributed(): 60 | dataloader['train_sampler'].set_epoch(epoch) 61 | 62 | start_time = time.time() 63 | for meter in meters.values(): 64 | meter.reset() 65 | model.train() 66 | 67 | for i, batch in enumerate(train_loader): 68 | for param_group in optimizer.param_groups: 69 | param_group['lr'] = lr_schedule[it] * param_group['ratio'] 70 | 71 | if epoch == 0: 72 | alpha = config.model.softlabel_ratio * min(1.0, i / len(train_loader)) 73 | else: 74 | alpha = config.model.softlabel_ratio 75 | 76 | if config.experiment.mixgen: 77 | if random.random() < config.experiment.mixgen_p: 78 | import model.mixgen as mg 79 | if config.experiment.mixgen_type == 'cat': 80 | mixgen_func = mg.concatgen 81 | else: 82 | mixgen_func = mg.mixgen 83 | img, cap = mixgen_func(batch['image'], batch['caption'], 84 | num=int(config.experiment.mixgen_ratio * len(batch['caption']))) 85 | batch.update({ 86 | 'image': img, 87 | 'caption': cap, 88 | }) 89 | 90 | with torch.autocast(device_type='cuda'): 91 | ret = model(batch, alpha) 92 | loss = sum([v for k, v in ret.items() if "loss" in k]) 93 | 94 | batch_size = batch['image'].shape[0] 95 | meters['loss'].update(loss.item(), batch_size) 96 | meters['nitc_loss'].update(ret.get('nitc_loss', 0), batch_size) 97 | meters['ss_loss'].update(ret.get('ss_loss', 0), batch_size) 98 | meters['citc_loss'].update(ret.get('citc_loss', 0), batch_size) 99 | meters['ritc_loss'].update(ret.get('ritc_loss', 0), batch_size) 100 | meters['mlm_loss'].update(ret.get('mlm_loss', 0), batch_size) 101 | meters['id_loss'].update(ret.get('id_loss', 0), batch_size) 102 | 103 | scaler.scale(loss).backward() 104 | scaler.step(optimizer) 105 | scaler.update() 106 | model.zero_grad() 107 | optimizer.zero_grad() 108 | it += 1 109 | 110 | if (i + 1) % config.log.print_period == 0: 111 | info_str = f"Epoch[{epoch + 1}] Iteration[{i + 1}/{len(train_loader)}]" 112 | # log loss 113 | for k, v in meters.items(): 114 | if v.val != 0: 115 | info_str += f", {k}: {v.val:.4f}" 116 | info_str += f", Base Lr: {param_group['lr']:.2e}" 117 | print(info_str) 118 | 119 | if is_master(): 120 | end_time = time.time() 121 | time_per_batch = (end_time - start_time) / (i + 1) 122 | print("Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" 123 | .format(epoch + 1, time_per_batch, train_loader.batch_size / time_per_batch)) 124 | 125 | eval_result = test(model.module, dataloader['test_loader'], 77, config.device) 126 | rank_1, rank_5, rank_10, map = eval_result['r1'], eval_result['r5'], eval_result['r10'], eval_result['mAP'] 127 | print('Acc@1 {top1:.5f} Acc@5 {top5:.5f} Acc@10 {top10:.5f} mAP {mAP:.5f}'.format(top1=rank_1, top5=rank_5, 128 | top10=rank_10, mAP=map)) 129 | torch.cuda.empty_cache() 130 | if best_rank_1 < rank_1: 131 | best_rank_1 = rank_1 132 | best_epoch = epoch 133 | 134 | save_obj = { 135 | 'model': model.module.state_dict(), 136 | 'optimizer': optimizer.state_dict(), 137 | 'config': config, 138 | } 139 | torch.save(save_obj, os.path.join(config.model.saved_path, 'checkpoint_best.pth')) 140 | 141 | print(f"best Acc@1: {best_rank_1} at epoch {best_epoch + 1}") 142 | 143 | 144 | if __name__ == '__main__': 145 | config_path = 'config/config.yaml' 146 | 147 | args = get_args() 148 | if args.simplified: 149 | config_path = 'config/s.config.yaml' 150 | config = parse_config(config_path) 151 | 152 | Path(config.model.saved_path).mkdir(parents=True, exist_ok=True) 153 | 154 | init_distributed_mode(config) 155 | 156 | set_seed(config) 157 | 158 | run(config) 159 | -------------------------------------------------------------------------------- /misc/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import math 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize_pos_embed(posemb, posemb_new, hight, width): 9 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 10 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 11 | posemb = posemb.unsqueeze(0) 12 | posemb_new = posemb_new.unsqueeze(0) 13 | 14 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 15 | 16 | gs_old = int(math.sqrt(len(posemb_grid))) 17 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) 18 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 19 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 20 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 21 | posemb = torch.cat([posemb_token, posemb_grid], dim=1) 22 | return posemb.squeeze(0) 23 | 24 | 25 | def interpolate_text(pos_embed_checkpoint, target_dim=77): 26 | # (n_ctx, n_feat) for pos_embed_checkpoint, including SOT and EOT 27 | if pos_embed_checkpoint.size(0) == target_dim: 28 | return pos_embed_checkpoint 29 | start_token = pos_embed_checkpoint[:1, :] 30 | end_token = pos_embed_checkpoint[-1:, :] 31 | pos_tokens = pos_embed_checkpoint[1:-1, :].unsqueeze(0).permute(0, 2, 1) 32 | pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=target_dim - 2, mode='linear') 33 | pos_tokens = pos_tokens.squeeze(0).t() 34 | pos_tokens = torch.cat([start_token, pos_tokens, end_token], dim=0) 35 | return pos_tokens 36 | 37 | 38 | def load_checkpoint(model, config): 39 | if config.model.ckpt_type == 'original_clip': 40 | with open(config.model.checkpoint, 'rb') as opened_file: 41 | model_tmp = torch.jit.load(opened_file, map_location="cpu") 42 | state = model_tmp.state_dict() 43 | for key in ["input_resolution", "context_length", "vocab_size"]: 44 | if key in state: 45 | del state[key] 46 | 47 | # 2 towers in new_state: visual, encode_text 48 | new_state = {} 49 | for name, params in state.items(): 50 | if name == 'visual.positional_embedding' and params.shape != model.visual.positional_embedding.shape: 51 | params = resize_pos_embed(params, model.visual.positional_embedding, model.visual.num_y, model.visual.num_x) 52 | 53 | if name == 'positional_embedding': 54 | new_state['encode_text.' + name] = interpolate_text(params, config.experiment.text_length) 55 | elif name.startswith('transformer') or name in ['positional_embedding', 'token_embedding.weight', 56 | 'ln_final.weight', 'ln_final.bias', 'text_projection']: 57 | new_state['encode_text.' + name] = params 58 | else: 59 | new_state[name] = params 60 | elif config.model.ckpt_type == 'saved': 61 | ckpt = torch.load(os.path.join(config.model.saved_path, 'checkpoint_best.pth'), map_location='cpu') 62 | new_state = ckpt['model'] 63 | else: 64 | raise KeyError 65 | 66 | load_result = model.load_state_dict(new_state, strict=False) 67 | return model, load_result 68 | 69 | 70 | def cosine_scheduler(config): 71 | schedule_config = config.schedule 72 | base_value = schedule_config.lr 73 | start_warmup_value = schedule_config.lr_start 74 | final_value = schedule_config.lr_end 75 | epochs = schedule_config.epoch 76 | warmup_epochs = schedule_config.epoch_warmup 77 | niter_per_ep = schedule_config.niter_per_ep 78 | 79 | warmup_schedule = np.array([]) 80 | warmup_iters = warmup_epochs * niter_per_ep 81 | if warmup_epochs > 0: 82 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 83 | 84 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 85 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 86 | 87 | schedule = np.concatenate((warmup_schedule, schedule)) 88 | assert len(schedule) == epochs * niter_per_ep 89 | return schedule 90 | 91 | 92 | # def build_optimizer(config, model): 93 | # p_wd, p_non_wd = [], [] 94 | # for n, p in model.named_parameters(): 95 | # if not p.requires_grad: 96 | # continue # frozen weights 97 | # if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: 98 | # p_non_wd.append(p) 99 | # else: 100 | # p_wd.append(p) 101 | # 102 | # schedule_config = config.schedule 103 | # optim_params = [{"params": p_wd, "weight_decay": schedule_config.weight_decay, "ratio": 1.}, 104 | # {"params": p_non_wd, "weight_decay": 0, "ratio": 1.}] 105 | # 106 | # optimizer = torch.optim.AdamW(optim_params, lr=schedule_config.lr, betas=schedule_config.betas, 107 | # eps=schedule_config.eps, weight_decay=schedule_config.weight_decay) 108 | # return optimizer 109 | 110 | 111 | def build_optimizer(config, model): 112 | params = [] 113 | schedule_config = config.schedule 114 | for n, p in model.named_parameters(): 115 | if not p.requires_grad: 116 | continue # frozen weights 117 | weight_decay = schedule_config.weight_decay 118 | ratio = 1. 119 | 120 | if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: 121 | weight_decay = 0. 122 | if "cross" in n or "classifier" in n or "mlm_head" in n: 123 | ratio = ratio * schedule_config.ratio_factor # default 5.0 124 | 125 | params += [{"params": [p], "weight_decay": weight_decay, "ratio": ratio}] 126 | 127 | optimizer = torch.optim.AdamW(params, lr=schedule_config.lr, betas=schedule_config.betas, 128 | eps=schedule_config.eps, weight_decay=schedule_config.weight_decay) 129 | return optimizer 130 | -------------------------------------------------------------------------------- /misc/caption_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from collections import defaultdict 5 | 6 | import torch 7 | from PIL import Image 8 | from PIL import ImageFile 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | from PIL import ImageFilter 12 | import random 13 | 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | Image.MAX_IMAGE_PIXELS = None 16 | 17 | 18 | class ps_train_dataset(Dataset): 19 | def __init__(self, ann_root, image_root, transform, aug_ss, split, max_words=30): 20 | ann_file = os.path.join(ann_root, split + '_reid.json') 21 | anns = json.load(open(ann_file)) 22 | self.transform = transform 23 | 24 | self.person2text = defaultdict(list) 25 | person_id2idx = {} 26 | n = 0 27 | self.pairs = [] 28 | 29 | for ann in anns: 30 | image_path = os.path.join(image_root, ann['file_path']) 31 | person_id = ann['id'] 32 | if person_id not in person_id2idx.keys(): 33 | person_id2idx[person_id] = n 34 | n += 1 35 | person_idx = person_id2idx[person_id] 36 | if 'captions_bt' not in ann: 37 | ann['captions_bt'] = [''] * len(ann['captions']) 38 | for caption, caption_bt in zip(ann['captions'], ann['captions_bt']): 39 | caption = pre_caption(caption, max_words) 40 | caption_bt = pre_caption(caption_bt, max_words) 41 | self.pairs.append((image_path, caption, caption_bt, person_idx)) 42 | self.person2text[person_idx].append(caption) 43 | 44 | self.augmentation_ss = aug_ss 45 | 46 | def __len__(self): 47 | return len(self.pairs) 48 | 49 | def __getitem__(self, index): 50 | image_path, caption, caption_bt, person = self.pairs[index] 51 | 52 | image_pil = Image.open(image_path) 53 | image = self.transform(image_pil.convert('RGB')) 54 | aug1 = self.transform(image_pil.convert('RGB')) 55 | aug_ss_1 = self.augmentation_ss(image_pil) 56 | aug_ss_2 = self.augmentation_ss(image_pil) 57 | return { 58 | 'image': image, 59 | 'caption': caption, 60 | 'caption_bt': caption_bt, 61 | 'id': person, 62 | 'aug1': aug1, 63 | 'aug_ss_1': aug_ss_1, 64 | 'aug_ss_2': aug_ss_2 65 | } 66 | 67 | 68 | class ps_eval_dataset(Dataset): 69 | def __init__(self, ann_root, image_root, transform, split, max_words=30): 70 | ann_file = os.path.join(ann_root, split + '_reid.json') 71 | anns = json.load(open(ann_file, 'r')) 72 | self.transform = transform 73 | 74 | self.text = [] 75 | self.image = [] 76 | self.txt2person = [] 77 | self.img2person = [] 78 | 79 | for ann in anns: 80 | image_path = os.path.join(image_root, ann['file_path']) 81 | self.image.append(image_path) 82 | 83 | person_id = ann['id'] 84 | self.img2person.append(person_id) 85 | for caption in ann['captions']: 86 | self.text.append(pre_caption(caption, max_words)) 87 | self.txt2person.append(person_id) 88 | 89 | self.txt2person = torch.tensor(self.txt2person, dtype=torch.long) 90 | self.img2person = torch.tensor(self.img2person, dtype=torch.long) 91 | 92 | def __len__(self): 93 | return len(self.image) 94 | 95 | def __getitem__(self, index): 96 | image_path = self.image[index] 97 | image = Image.open(image_path).convert('RGB') 98 | image = self.transform(image) 99 | 100 | return image 101 | 102 | def pre_caption(caption, max_words=50): 103 | caption = re.sub( 104 | r"([.!\"()*#:;~])", 105 | ' ', 106 | caption.lower(), 107 | ) 108 | caption = re.sub( 109 | r"\s{2,}", 110 | ' ', 111 | caption, 112 | ) 113 | caption = caption.rstrip('\n') 114 | caption = caption.strip(' ') 115 | 116 | # truncate caption 117 | caption_words = caption.split(' ') 118 | if len(caption_words) > max_words: 119 | caption = ' '.join(caption_words[:max_words]) 120 | 121 | return caption -------------------------------------------------------------------------------- /misc/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from PIL import ImageFilter 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | 12 | from misc.caption_dataset import ps_train_dataset, ps_eval_dataset 13 | from misc.utils import is_using_distributed 14 | 15 | 16 | def get_self_supervised_augmentation(img_size): 17 | class GaussianBlur(object): 18 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 19 | 20 | def __init__(self, sigma=[.1, 2.]): 21 | self.sigma = sigma 22 | 23 | def __call__(self, x): 24 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 25 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 26 | return x 27 | 28 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 29 | std=[0.229, 0.224, 0.225]) 30 | 31 | aug = transforms.Compose([ 32 | transforms.RandomResizedCrop(img_size, scale=(0.2, 1.), antialias=True), 33 | transforms.RandomApply([ 34 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 35 | ], p=0.8), 36 | transforms.RandomGrayscale(p=0.2), 37 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 38 | transforms.RandomHorizontalFlip(), 39 | transforms.ToTensor(), 40 | normalize 41 | ]) 42 | return aug 43 | 44 | 45 | def pil_loader(path): 46 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 47 | with open(path, 'rb') as f: 48 | img = Image.open(f) 49 | return img.convert('RGB') 50 | 51 | 52 | class cuhkpedes_eval(torch.utils.data.Dataset): 53 | def __init__(self, ann_file, transform, image_root): 54 | self.ann = json.load(open(ann_file, 'r')) 55 | self.transform = transform 56 | self.image_root = image_root 57 | 58 | self.text = [] 59 | self.image = [] 60 | self.txt2img = {} 61 | self.img2txt = {} 62 | self.pid2txt, self.pid2img = {}, {} 63 | self.txt_ids, self.img_ids = [], [] 64 | 65 | txt_id = 0 66 | for img_id, ann in enumerate(self.ann): 67 | self.image.append(ann['image']) 68 | if ann['image_id'] not in self.pid2txt.keys(): 69 | self.pid2txt[ann['image_id']] = [] 70 | self.pid2img[ann['image_id']] = [] 71 | self.pid2img[ann['image_id']].append(img_id) 72 | self.img_ids.append(ann['image_id']) 73 | for i, caption in enumerate(ann['caption']): 74 | self.text.append(caption) 75 | self.pid2txt[ann['image_id']].append(txt_id) 76 | self.txt_ids.append(ann['image_id']) 77 | txt_id += 1 78 | 79 | for tid in range(len(self.text)): 80 | self.txt2img[tid] = self.pid2img[self.txt_ids[tid]] 81 | for iid in range(len(self.image)): 82 | self.img2txt[iid] = self.pid2txt[self.img_ids[iid]] 83 | 84 | def __len__(self): 85 | return len(self.image) 86 | 87 | def __getitem__(self, index): 88 | image_path = os.path.join(self.image_root, self.ann[index]['image']) 89 | image = Image.open(image_path) 90 | image = self.transform(image) 91 | 92 | return image, index 93 | 94 | 95 | def build_pedes_data(config): 96 | size = config.experiment.input_resolution 97 | if isinstance(size, int): 98 | size = (size, size) 99 | 100 | normalize = transforms.Normalize( 101 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 102 | val_transform = transforms.Compose([ 103 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), 104 | transforms.ToTensor(), 105 | normalize 106 | ]) 107 | 108 | rand_from = [ 109 | transforms.ColorJitter(.1, .1, .1, 0), 110 | transforms.RandomRotation(15), 111 | transforms.RandomResizedCrop(size, (0.9, 1.0), antialias=True), 112 | transforms.RandomGrayscale(), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.RandomErasing(scale=(0.10, 0.20)), 115 | ] 116 | aug = Choose(rand_from, size) 117 | aug_ss = get_self_supervised_augmentation(size) 118 | 119 | train_dataset = ps_train_dataset(config.anno_dir, config.image_dir, aug, aug_ss, split='train', max_words=77) 120 | test_dataset = ps_eval_dataset(config.anno_dir, config.image_dir, val_transform, split='test', max_words=77) 121 | 122 | if is_using_distributed(): 123 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 124 | else: 125 | train_sampler = None 126 | test_sampler = None 127 | 128 | config_data = config.data 129 | train_loader = DataLoader( 130 | dataset=train_dataset, 131 | batch_size=config_data.batch_size, 132 | shuffle=train_sampler is None, 133 | num_workers=config_data.num_workers, 134 | pin_memory=True, 135 | sampler=train_sampler, 136 | drop_last=True, 137 | ) 138 | test_loader = DataLoader( 139 | dataset=test_dataset, 140 | batch_size=32, 141 | shuffle=False, 142 | sampler=test_sampler, 143 | drop_last=False, 144 | ) 145 | 146 | return { 147 | 'train_loader': train_loader, 148 | 'train_sampler': train_sampler, 149 | 'test_loader': test_loader, 150 | } 151 | 152 | 153 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 154 | 155 | 156 | class Choose: 157 | def __init__(self, rand_from, size): 158 | self.choose_from = rand_from 159 | self.size = size 160 | 161 | def __call__(self, image): 162 | aug_choice = np.random.choice(self.choose_from, 2) 163 | return transforms.Compose([ 164 | transforms.Resize(self.size), 165 | transforms.ToTensor(), 166 | *aug_choice, 167 | normalize 168 | ])(image) 169 | -------------------------------------------------------------------------------- /misc/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | # import clip 4 | from text_utils.tokenizer import tokenize 5 | 6 | 7 | @torch.no_grad() 8 | def test(model, data_loader, max_length, device): 9 | # switch to evaluate mode 10 | model.eval() 11 | 12 | dataset = data_loader.dataset 13 | texts = dataset.text 14 | num_text = len(texts) 15 | text_bs = 256 16 | 17 | text_feats = [] 18 | for i in range(0, num_text, text_bs): 19 | text = texts[i: min(num_text, i + text_bs)] 20 | text = tokenize(text, context_length=max_length).to(device) 21 | text_feat = F.normalize(model.encode_text(text), dim=-1) 22 | text_feats.append(text_feat) 23 | text_feats = torch.cat(text_feats, dim=0) 24 | 25 | image_feats = [] 26 | for image in data_loader: 27 | image = image.to(device) 28 | image_feat = F.normalize(model.encode_image(image), dim=-1) 29 | image_feats.append(image_feat) 30 | image_feats = torch.cat(image_feats, dim=0) 31 | 32 | sims_matrix = text_feats @ image_feats.t() 33 | eval_result = metric_eval(sims_matrix, dataset.img2person, dataset.txt2person) 34 | 35 | return eval_result 36 | 37 | 38 | @torch.no_grad() 39 | def metric_eval(scores_t2i, img2person, txt2person): 40 | device = scores_t2i.device 41 | img2person = img2person.to(device) 42 | txt2person = txt2person.to(device) 43 | 44 | index = torch.argsort(scores_t2i, dim=-1, descending=True) 45 | pred_person = img2person[index] 46 | matches = (txt2person.view(-1, 1).eq(pred_person)).long() 47 | 48 | def acc_k(matches, k=1): 49 | matches_k = matches[:, :k].sum(dim=-1) 50 | matches_k = torch.sum((matches_k > 0)) 51 | return 100.0 * matches_k / matches.size(0) 52 | 53 | # Compute metrics 54 | ir1 = acc_k(matches, k=1).item() 55 | ir5 = acc_k(matches, k=5).item() 56 | ir10 = acc_k(matches, k=10).item() 57 | ir_mean = (ir1 + ir5 + ir10) / 3 58 | 59 | real_num = matches.sum(dim=-1) 60 | tmp_cmc = matches.cumsum(dim=-1).float() 61 | order = torch.arange(start=1, end=matches.size(1) + 1, dtype=torch.long).to(device) 62 | tmp_cmc /= order 63 | tmp_cmc *= matches 64 | AP = tmp_cmc.sum(dim=-1) / real_num 65 | mAP = AP.mean() * 100.0 66 | 67 | eval_result = {'r1': ir1, 68 | 'r5': ir5, 69 | 'r10': ir10, 70 | 'r_mean': ir_mean, 71 | 'mAP': mAP.item() 72 | } 73 | 74 | return eval_result 75 | -------------------------------------------------------------------------------- /misc/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from math import cos, pi 3 | 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | class LRSchedulerWithWarmup(_LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | mode="step", 14 | warmup_factor=1.0 / 3, 15 | warmup_epochs=10, 16 | warmup_method="linear", 17 | total_epochs=100, 18 | target_lr=0, 19 | power=0.9, 20 | last_epoch=-1, 21 | ): 22 | if not list(milestones) == sorted(milestones): 23 | raise ValueError( 24 | "Milestones should be a list of" 25 | " increasing integers. Got {}".format(milestones), 26 | ) 27 | if mode not in ("step", "exp", "poly", "cosine", "linear"): 28 | raise ValueError( 29 | "Only 'step', 'exp', 'poly' or 'cosine' learning rate scheduler accepted" 30 | "got {}".format(mode) 31 | ) 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.mode = mode 39 | self.gamma = gamma 40 | self.warmup_factor = warmup_factor 41 | self.warmup_epochs = warmup_epochs 42 | self.warmup_method = warmup_method 43 | self.total_epochs = total_epochs 44 | self.target_lr = target_lr 45 | self.power = power 46 | super().__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | 50 | if self.last_epoch < self.warmup_epochs: 51 | if self.warmup_method == "constant": 52 | warmup_factor = self.warmup_factor 53 | elif self.warmup_method == "linear": 54 | alpha = self.last_epoch / self.warmup_epochs 55 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 56 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 57 | 58 | if self.mode == "step": 59 | return [ 60 | base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 61 | for base_lr in self.base_lrs 62 | ] 63 | 64 | epoch_ratio = (self.last_epoch - self.warmup_epochs) / ( 65 | self.total_epochs - self.warmup_epochs 66 | ) 67 | 68 | if self.mode == "exp": 69 | factor = epoch_ratio 70 | return [base_lr * self.power ** factor for base_lr in self.base_lrs] 71 | if self.mode == "linear": 72 | factor = 1 - epoch_ratio 73 | return [base_lr * factor for base_lr in self.base_lrs] 74 | 75 | if self.mode == "poly": 76 | factor = 1 - epoch_ratio 77 | return [ 78 | self.target_lr + (base_lr - self.target_lr) * self.power ** factor 79 | for base_lr in self.base_lrs 80 | ] 81 | if self.mode == "cosine": 82 | factor = 0.5 * (1 + cos(pi * epoch_ratio)) 83 | return [ 84 | self.target_lr + (base_lr - self.target_lr) * factor 85 | for base_lr in self.base_lrs 86 | ] 87 | raise NotImplementedError 88 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from easydict import EasyDict 4 | 5 | import yaml 6 | import os 7 | 8 | import torch 9 | import numpy as np 10 | import random 11 | import torch.distributed as dist 12 | 13 | 14 | def parse_config(config_path): 15 | with open(config_path) as f: 16 | config = yaml.load(f, Loader=yaml.FullLoader) 17 | config = EasyDict(config) 18 | return config 19 | 20 | 21 | def is_using_distributed(): 22 | return True 23 | 24 | 25 | def is_dist_avail_and_initialized(): 26 | if not dist.is_available(): 27 | return False 28 | if not dist.is_initialized(): 29 | return False 30 | return True 31 | 32 | 33 | def get_world_size(): 34 | if not is_dist_avail_and_initialized(): 35 | return 1 36 | return dist.get_world_size() 37 | 38 | 39 | def get_rank(): 40 | if not is_dist_avail_and_initialized(): 41 | return 0 42 | return dist.get_rank() 43 | 44 | 45 | def is_master(): 46 | return not is_using_distributed() or get_rank() == 0 47 | 48 | 49 | def wandb_record(): 50 | if not 'WANDB_PROJECT' in os.environ: 51 | return False 52 | return not is_using_distributed() or get_rank() == 0 53 | 54 | 55 | def init_distributed_mode(config): 56 | if is_using_distributed(): 57 | config.distributed.rank = int(os.environ['RANK']) 58 | config.distributed.world_size = int(os.environ['WORLD_SIZE']) 59 | config.distributed.local_rank = int(os.environ['LOCAL_RANK']) 60 | torch.distributed.init_process_group(backend=config.distributed.backend, 61 | init_method=config.distributed.url) 62 | used_for_printing(get_rank() == 0) 63 | 64 | if torch.cuda.is_available(): 65 | if is_using_distributed(): 66 | device = f'cuda:{get_rank()}' 67 | else: 68 | device = f'cuda:{d}' if str(d := config.device).isdigit() else d 69 | torch.cuda.set_device(device) 70 | else: 71 | device = 'cpu' 72 | config.device = device 73 | 74 | 75 | def used_for_printing(is_master): 76 | import builtins as __builtin__ 77 | builtin_print = __builtin__.print 78 | 79 | def print(*args, **kwargs): 80 | force = kwargs.pop('force', False) 81 | if is_master or force: 82 | builtin_print(*args, **kwargs) 83 | 84 | __builtin__.print = print 85 | 86 | 87 | def set_seed(config): 88 | seed = config.misc.seed 89 | 90 | torch.manual_seed(seed) 91 | np.random.seed(seed) 92 | random.seed(seed) 93 | os.environ["PYTHONHASHSEED"] = str(seed) 94 | 95 | if torch.cuda.is_available(): 96 | torch.cuda.manual_seed_all(seed) 97 | torch.backends.cudnn.deterministic = True 98 | torch.backends.cudnn.benchmark = False 99 | 100 | 101 | class AverageMeter(object): 102 | """Computes and stores the average and current value""" 103 | 104 | def __init__(self): 105 | self.val = 0 106 | self.avg = 0 107 | self.sum = 0 108 | self.count = 0 109 | 110 | def reset(self): 111 | self.val = 0 112 | self.avg = 0 113 | self.sum = 0 114 | self.count = 0 115 | 116 | def update(self, val, n=1): 117 | self.val = val 118 | self.sum += val * n 119 | self.count += n 120 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Flame-Chasers/TBPS-CLIP/6160a877af99229bbf39077b1047d96cf7fda64c/model/__init__.py -------------------------------------------------------------------------------- /model/base_transformer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.utils.checkpoint import checkpoint_sequential 6 | 7 | global LAYER_NORM 8 | LAYER_NORM = True 9 | 10 | 11 | class LayerNorm(nn.LayerNorm): 12 | """Subclass torch's LayerNorm to handle fp16.""" 13 | 14 | def forward(self, x: torch.Tensor): 15 | if LAYER_NORM: 16 | ret = super().forward(x) 17 | else: 18 | ret = x 19 | return ret 20 | 21 | 22 | class QuickGELU(nn.Module): 23 | def forward(self, x: torch.Tensor): 24 | return x * torch.sigmoid(1.702 * x) 25 | 26 | 27 | class ResidualAttentionBlock(nn.Module): 28 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, dropout: float = 0.): 29 | super().__init__() 30 | 31 | self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout) 32 | self.ln_1 = LayerNorm(d_model) 33 | self.mlp = nn.Sequential(OrderedDict([ 34 | ("c_fc", nn.Linear(d_model, d_model * 4)), 35 | ("gelu", QuickGELU()), 36 | # ("dropout_1", nn.Dropout(dropout)), 37 | ("c_proj", nn.Linear(d_model * 4, d_model)), 38 | # ("dropout_2", nn.Dropout(dropout)) 39 | ])) 40 | self.ln_2 = LayerNorm(d_model) 41 | self.attn_mask = attn_mask 42 | 43 | def attention(self, x: torch.Tensor): 44 | self.attn_mask = self.attn_mask.to( 45 | dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 46 | return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)[0] 47 | 48 | def forward(self, x: torch.Tensor): 49 | x = x + self.attention(self.ln_1(x)) 50 | x = x + self.mlp(self.ln_2(x)) 51 | return x 52 | 53 | 54 | class Transformer(nn.Module): 55 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, checkpoint: bool = False, 56 | dropout: float = 0., emb_dropout: float = 0.): 57 | super().__init__() 58 | self.width = width 59 | self.layers = layers 60 | self.checkpoint = checkpoint 61 | self.dropout = nn.Dropout(emb_dropout) 62 | self.resblocks = nn.Sequential( 63 | *[ResidualAttentionBlock(width, heads, attn_mask, dropout=dropout) for _ in range(layers)]) 64 | 65 | def checkpoint_fwd(self, layer, input, segments=2): 66 | """checkpoint forward""" 67 | # Make sure that the input to checkpoint have requires_grad=True, so that 68 | # the autograd can take care of the checkpointed part of model 69 | if not input.requires_grad: 70 | input = input.detach() 71 | input.requires_grad = True 72 | return checkpoint_sequential(layer, segments, input) 73 | 74 | def forward(self, x: torch.Tensor): 75 | x = self.dropout(x) 76 | if self.checkpoint: 77 | return self.checkpoint_fwd(self.resblocks, x, self.layers) 78 | return self.resblocks(x) 79 | -------------------------------------------------------------------------------- /model/eda.py: -------------------------------------------------------------------------------- 1 | from nltk.corpus import wordnet, stopwords 2 | import random 3 | 4 | 5 | class EDA: 6 | """ 7 | This class is an implementation of the original EDA algorithm (2019) [1]. 8 | 9 | [1] Wei, J. and Zou, K., 2019, November. EDA: Easy Data Augmentation Techniques for Boosting Performance on 10 | Text Classification Tasks. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing 11 | and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP) (pp. 6383-6389). 12 | https://www.aclweb.org/anthology/D19-1670.pdf 13 | 14 | Example usage: :: 15 | >>> from textaugment import EDA 16 | >>> t = EDA() 17 | >>> t.synonym_replacement("John is going to town") 18 | John is give out to town 19 | >>> t.random_deletion("John is going to town", p=0.2) 20 | is going to town 21 | >>> t.random_swap("John is going to town") 22 | John town going to is 23 | >>> t.random_insertion("John is going to town") 24 | John is going to make up town 25 | """ 26 | 27 | @staticmethod 28 | def _get_synonyms(word): 29 | """Generate synonym""" 30 | synonyms = set() 31 | for syn in wordnet.synsets(word): 32 | for lemma in syn.lemmas(): 33 | synonym = lemma.name().replace("_", " ").replace("-", " ").lower() 34 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm']) 35 | synonyms.add(synonym) 36 | if word in synonyms: 37 | synonyms.remove(word) 38 | synonyms = sorted(synonyms) 39 | random.shuffle(synonyms) 40 | return synonyms 41 | 42 | @staticmethod 43 | def swap_word(new_words): 44 | """Swap words""" 45 | random_idx_1 = random.randint(0, len(new_words) - 1) 46 | random_idx_2 = random_idx_1 47 | counter = 0 48 | while random_idx_2 == random_idx_1: 49 | random_idx_2 = random.randint(0, len(new_words) - 1) 50 | counter += 1 51 | if counter > 3: 52 | return new_words 53 | new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1] 54 | return new_words 55 | 56 | @staticmethod 57 | def validate(**kwargs): 58 | """Validate input data""" 59 | 60 | if 'p' in kwargs: 61 | if kwargs['p'] > 1 or kwargs['p'] < 0: 62 | raise TypeError("p must be a fraction between 0 and 1") 63 | if 'sentence' in kwargs: 64 | if not isinstance(kwargs['sentence'].strip(), str) or len(kwargs['sentence'].strip()) == 0: 65 | raise TypeError("sentence must be a valid sentence") 66 | if 'n' in kwargs: 67 | if not isinstance(kwargs['n'], int): 68 | raise TypeError("n must be a valid integer") 69 | 70 | def __init__(self, stop_words=None, random_state=1): 71 | """A method to initialize parameters 72 | 73 | :type random_state: int 74 | :param random_state: (optional) Seed 75 | :type stop_words: list 76 | :param stop_words: (optional) List of stopwords 77 | 78 | :rtype: None 79 | :return: Constructer do not return. 80 | """ 81 | self.stopwords = stopwords.words('english') if stop_words is None else stop_words 82 | self.sentence = None 83 | self.p = None 84 | self.n = None 85 | # self.random_state = random_state 86 | # if isinstance(self.random_state, int): 87 | # random.seed(self.random_state) 88 | # else: 89 | # raise TypeError("random_state must have type int") 90 | 91 | def add_word(self, new_words): 92 | """Insert word""" 93 | synonyms = list() 94 | counter = 0 95 | while len(synonyms) < 1: 96 | random_word_list = list([word for word in new_words if word not in self.stopwords]) 97 | random_word = random_word_list[random.randint(0, len(random_word_list) - 1)] 98 | synonyms = self._get_synonyms(random_word) 99 | counter += 1 100 | if counter >= 10: 101 | return new_words # See Issue 14 for details 102 | random_synonym = synonyms[0] # TODO 103 | random_idx = random.randint(0, len(new_words) - 1) 104 | new_words.insert(random_idx, random_synonym) 105 | return new_words 106 | 107 | def synonym_replacement(self, sentence: str, n: int = 1): 108 | """Replace n words in the sentence with synonyms from wordnet 109 | 110 | :type sentence: str 111 | :param sentence: Sentence 112 | :type n: int 113 | :param n: Number of repetitions to replace 114 | 115 | :rtype: str 116 | :return: Augmented sentence 117 | """ 118 | self.validate(sentence=sentence, n=n) 119 | self.n = n 120 | self.sentence = sentence 121 | words = sentence.split() 122 | new_words = words.copy() 123 | random_word_list = sorted(list(set([word for word in words if word not in self.stopwords]))) 124 | random.shuffle(random_word_list) 125 | replaced = 0 126 | for random_word in random_word_list: 127 | synonyms = self._get_synonyms(random_word) 128 | if len(synonyms) > 0: 129 | synonym = random.choice(list(synonyms)) 130 | new_words = [synonym if word == random_word else word for word in new_words] 131 | replaced += 1 132 | if replaced >= self.n: 133 | break 134 | sentence = ' '.join(new_words) 135 | 136 | return sentence 137 | 138 | def random_deletion(self, sentence: str, p: float = 0.1): 139 | """Randomly delete words from the sentence with probability p 140 | 141 | :type sentence: str 142 | :param sentence: Sentence 143 | :type p: int 144 | :param p: Probability between 0 and 1 145 | 146 | :rtype: str 147 | :return: Augmented sentence 148 | """ 149 | self.validate(sentence=sentence, p=p) 150 | self.p = p 151 | self.sentence = sentence 152 | words = sentence.split() 153 | if len(words) == 1: 154 | return words 155 | new_words = list() 156 | for word in words: 157 | r = random.uniform(0, 1) 158 | if r > self.p: 159 | new_words.append(word) 160 | # if all words are deleted, just return a random word 161 | if len(new_words) == 0: 162 | return random.choice(words) 163 | 164 | return " ".join(new_words) 165 | 166 | def random_swap(self, sentence: str, n: int = 1): 167 | """Randomly swap two words in the sentence n times 168 | 169 | :type sentence: str 170 | :param sentence: Sentence 171 | :type n: int 172 | :param n: Number of repetitions to swap 173 | 174 | :rtype: str 175 | :return: Augmented sentence 176 | """ 177 | self.validate(sentence=sentence, n=n) 178 | self.n = n 179 | self.sentence = sentence 180 | words = sentence.split() 181 | new_words = words.copy() 182 | for _ in range(self.n): 183 | new_words = self.swap_word(new_words) 184 | return " ".join(new_words) 185 | 186 | def random_insertion(self, sentence: str, n: int = 1): 187 | """Randomly insert n words into the sentence 188 | 189 | :type sentence: str 190 | :param sentence: Sentence 191 | :type n: int 192 | :param n: Number of words to insert 193 | 194 | :rtype: str 195 | :return: Augmented sentence 196 | """ 197 | self.validate(sentence=sentence, n=n) 198 | self.n = n 199 | self.sentence = sentence 200 | words = sentence.split() 201 | new_words = words.copy() 202 | for _ in range(self.n): 203 | new_words = self.add_word(new_words) 204 | return " ".join(new_words) 205 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def compute_simclr_loss(logits_a, logits_b, logits_a_gathered, logits_b_gathered, labels, temperature): 6 | sim_aa = logits_a @ logits_a_gathered.t() / temperature 7 | sim_ab = logits_a @ logits_b_gathered.t() / temperature 8 | sim_ba = logits_b @ logits_a_gathered.t() / temperature 9 | sim_bb = logits_b @ logits_b_gathered.t() / temperature 10 | masks = torch.where(F.one_hot(labels, logits_a_gathered.size(0)) == 0, 0, float('-inf')) 11 | sim_aa += masks 12 | sim_bb += masks 13 | sim_a = torch.cat([sim_ab, sim_aa], 1) 14 | sim_b = torch.cat([sim_ba, sim_bb], 1) 15 | loss_a = F.cross_entropy(sim_a, labels) 16 | loss_b = F.cross_entropy(sim_b, labels) 17 | return (loss_a + loss_b) * 0.5 18 | -------------------------------------------------------------------------------- /model/mixgen.py: -------------------------------------------------------------------------------- 1 | """ 2 | MixGen: A New Multi-Modal Data Augmentation 3 | https://arxiv.org/abs/2206.08358 4 | Apache-2.0 License, Copyright 2022 Amazon 5 | """ 6 | import random 7 | import numpy as np 8 | import torch 9 | from torchvision import transforms 10 | 11 | 12 | def mixgen(image, text, num, lam=0.5): 13 | # default MixGen 14 | for i in range(num): 15 | # image mixup 16 | image[i,:] = lam * image[i,:] + (1 - lam) * image[i+num,:] 17 | # text concat 18 | text[i] = text[i] + " " + text[i+num] 19 | return image, text 20 | 21 | def concatgen(image, text, num, lam=0.5): 22 | for i in range(num): 23 | # image mixup 24 | img1 = transforms.functional.resize(image[i], (224, 112)) 25 | img2 = transforms.functional.resize(image[i+num], (224, 112)) 26 | image[i] = torch.cat((img1, img2), dim=2) 27 | image[i] = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image[i]) 28 | # text concat 29 | text[i] = text[i] + " " + text[i+num] 30 | return image, text 31 | 32 | 33 | def mixgen_batch(image, text, num, lam=0.5): 34 | batch_size = image.size()[0] 35 | index = np.random.permutation(batch_size) 36 | for i in range(batch_size): 37 | # image mixup 38 | image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 39 | # text concat 40 | text[i] = text[i] + " " + text[index[i]] 41 | return image, text 42 | -------------------------------------------------------------------------------- /model/shared_modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | 7 | 8 | class AllGather(torch.autograd.Function): 9 | 10 | @staticmethod 11 | def forward(ctx, tensor): 12 | ctx.rank = int(os.environ['RANK']) 13 | ctx.world_size = int(os.environ['WORLD_SIZE']) 14 | 15 | # y = tensor.new(ctx.world_size, *tensor.size()) 16 | 17 | y = [tensor.new(*tensor.size()) for _ in range(ctx.world_size)] 18 | 19 | dist.all_gather(y, tensor.contiguous()) 20 | 21 | y = torch.cat(y, 0).view(-1, *tensor.size()) 22 | 23 | return y 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | in_grad = torch.zeros_like(grad_output) 28 | in_grad.copy_(grad_output) 29 | # sum grad for gathered tensor 30 | dist.all_reduce(in_grad.contiguous()) 31 | # split 32 | return in_grad[ctx.rank] 33 | 34 | -------------------------------------------------------------------------------- /model/tbps_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | import numpy as np 7 | import copy 8 | 9 | from misc import utils 10 | from misc.utils import is_using_distributed 11 | from text_utils.tokenizer import tokenize 12 | from .loss import compute_simclr_loss 13 | from .visual_transformer import visual_transformer 14 | from .text_transformer import text_transformers 15 | from .eda import EDA 16 | from .base_transformer import Transformer, LayerNorm, QuickGELU 17 | 18 | from .shared_modules import AllGather 19 | from collections import OrderedDict 20 | 21 | 22 | class CLIP(nn.Module): 23 | def __init__(self, config, image_encode, text_encode, num_classes=11003, eps=1e-2): 24 | super().__init__() 25 | self.visual = image_encode 26 | self.encode_text = text_encode 27 | self.embed_dim = config.model.embed_dim 28 | 29 | self.use_gather = config.model.use_gather 30 | self.logit_scale = nn.Parameter(torch.ones([])) 31 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) 32 | self.config = config 33 | self.eda = EDA() 34 | self.eps = eps 35 | 36 | if config.experiment.ss: 37 | structure = config.experiment.simclr_mlp 38 | self.simclr_mlp = self._build_mlp(*structure) 39 | 40 | if config.experiment.id: 41 | self.classifier = nn.Linear(self.embed_dim, num_classes) 42 | nn.init.normal_(self.classifier.weight.data, std=0.001) 43 | nn.init.constant_(self.classifier.bias.data, val=0.0) 44 | 45 | if config.experiment.mlm: 46 | self.vocab_size = config.model.vocab_size 47 | self.cross_attn = nn.MultiheadAttention(self.embed_dim, 48 | self.embed_dim // 64, 49 | batch_first=True) 50 | self.cross_modal_transformer = Transformer(width=self.embed_dim, 51 | layers=config.experiment.cmt_depth, 52 | heads=self.embed_dim // 64) 53 | scale = self.cross_modal_transformer.width ** -0.5 54 | 55 | self.ln_pre_t = LayerNorm(self.embed_dim) 56 | self.ln_pre_i = LayerNorm(self.embed_dim) 57 | self.ln_post = LayerNorm(self.embed_dim) 58 | 59 | proj_std = scale * ((2 * self.cross_modal_transformer.layers) ** -0.5) 60 | attn_std = scale 61 | fc_std = (2 * self.cross_modal_transformer.width) ** -0.5 62 | for block in self.cross_modal_transformer.resblocks: 63 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 64 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 65 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 66 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 67 | 68 | # init cross attn 69 | nn.init.normal_(self.cross_attn.in_proj_weight, std=attn_std) 70 | nn.init.normal_(self.cross_attn.out_proj.weight, std=proj_std) 71 | 72 | self.mlm_head = nn.Sequential( 73 | OrderedDict([('dense', nn.Linear(self.embed_dim, self.embed_dim)), 74 | ('gelu', QuickGELU()), 75 | ('ln', LayerNorm(self.embed_dim)), 76 | ('fc', nn.Linear(self.embed_dim, self.vocab_size))])) 77 | # init mlm head 78 | nn.init.normal_(self.mlm_head.dense.weight, std=fc_std) 79 | nn.init.normal_(self.mlm_head.fc.weight, std=proj_std) 80 | 81 | def forward(self, input, alpha): 82 | ret = dict() 83 | 84 | images = input['image'].to(self.config.device) 85 | images_1 = input['aug1'].to(self.config.device) 86 | texts = input['caption'] 87 | texts_bt = input['caption_bt'] 88 | 89 | # back translation 90 | if self.config.experiment.back_trans: 91 | for i in range(len(texts)): 92 | if random.random() < self.config.experiment.backtrans_p: 93 | texts[i] = texts_bt[i] 94 | 95 | # random deletion 96 | cap_new = [] 97 | for text in texts: 98 | eda_alpha = self.config.experiment.eda_alpha 99 | cap_new.append(self.eda.random_deletion(text, eda_alpha)) 100 | texts = cap_new 101 | 102 | # MLM 103 | if self.config.experiment.mlm: 104 | text_tokens, mlm_labels = tokenize(texts, context_length=self.config.experiment.text_length, 105 | mask_type='MLM') 106 | text_tokens = text_tokens.to(self.config.device) 107 | mlm_labels = mlm_labels.to(self.config.device) 108 | else: 109 | text_tokens = tokenize(texts, context_length=self.config.experiment.text_length).to(self.config.device) 110 | ids = input['id'].to(self.config.device) 111 | 112 | image_features, image_seq_embeddings = self.encode_image(images, return_dense=True) 113 | text_features, text_seq_embeddings = self.encode_text(text_tokens, return_dense=True) 114 | image_features_norm = F.normalize(image_features) 115 | text_features_norm = F.normalize(text_features) 116 | image_features_norm_gathered = self.all_gather(image_features_norm) 117 | text_features_norm_gathered = self.all_gather(text_features_norm) 118 | 119 | # image ss 120 | if self.config.experiment.ss: 121 | aug1_embed = self.simclr_mlp(self.encode_image(input['aug_ss_1'].to(self.config.device))) 122 | aug2_embed = self.simclr_mlp(self.encode_image(input['aug_ss_2'].to(self.config.device))) 123 | q_a = F.normalize(aug1_embed, dim=-1, p=2) 124 | q_b = F.normalize(aug2_embed, dim=-1, p=2) 125 | local_batch_size = q_a.size(0) 126 | labels = local_batch_size * utils.get_rank() + torch.arange(local_batch_size, device=q_a.device) 127 | k_a = self.all_gather(q_a) 128 | k_b = self.all_gather(q_b) 129 | ss_loss = compute_simclr_loss(q_a, q_b, k_a, k_b, labels, self.config.experiment.simclr_temperature) 130 | ret['ss_loss'] = ss_loss * self.config.experiment.ss_ratio 131 | 132 | logit_scale = self.logit_scale.exp() 133 | logit_scale.data = torch.clamp(logit_scale.data, max=100) 134 | 135 | idx = ids.view(-1, 1) 136 | gathered_ids = self.all_gather(ids) 137 | idx_all = gathered_ids.view(1, -1) 138 | pos_idx = torch.eq(idx, idx_all).float() 139 | sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) 140 | 141 | with torch.no_grad(): 142 | image_features_s = self.encode_image(images).detach() 143 | text_features_s = self.encode_text(text_tokens).detach() 144 | image_features_s_norm = F.normalize(image_features_s) 145 | text_features_s_norm = F.normalize(text_features_s) 146 | image_features_s_norm_gathered = self.all_gather(image_features_s_norm) 147 | text_features_s_norm_gathered = self.all_gather(text_features_s_norm) 148 | nitc_loss = self.calc_contrastive(image_features_norm, text_features_norm, image_features_s_norm, 149 | text_features_s_norm, 150 | image_features_norm_gathered, text_features_norm_gathered, 151 | image_features_s_norm_gathered, text_features_s_norm_gathered, 152 | sim_targets, alpha, logit_scale) 153 | 154 | if self.config.experiment.mvs_image: 155 | image_1_features = self.encode_image(images_1) 156 | image_1_features_norm = F.normalize(image_1_features) 157 | image_1_features_norm_gathered = self.all_gather(image_1_features_norm) 158 | with torch.no_grad(): 159 | image_1_features_s = self.encode_image(images_1).detach() 160 | image_1_features_s_norm = F.normalize(image_1_features_s) 161 | image_1_features_s_norm_gathered = self.all_gather(image_1_features_s_norm) 162 | loss_img1_txt0 = self.calc_contrastive(image_1_features_norm, text_features_norm, image_1_features_s_norm, 163 | text_features_s_norm, 164 | image_1_features_norm_gathered, text_features_norm_gathered, 165 | image_1_features_s_norm_gathered, text_features_s_norm_gathered, 166 | sim_targets, alpha, logit_scale) 167 | nitc_loss = (nitc_loss + loss_img1_txt0) / 2 168 | 169 | ret['nitc_loss'] = nitc_loss * self.config.experiment.nitc_ratio 170 | 171 | if self.config.experiment.citc: 172 | logits_image_per_image = logit_scale * image_features_norm_gathered @ image_features_norm_gathered.t() 173 | logits_text_per_text = logit_scale * text_features_norm_gathered @ text_features_norm_gathered.t() 174 | inmodal_cyclic_loss = (logits_image_per_image - logits_text_per_text).square().mean() / ( 175 | logit_scale * logit_scale) 176 | logits_text_per_image = logit_scale * image_features_norm_gathered @ text_features_norm_gathered.t() 177 | logits_image_per_text = logit_scale * text_features_norm_gathered @ image_features_norm_gathered.t() 178 | crossmodal_cyclic_loss = (logits_text_per_image - logits_image_per_text).square().mean() / ( 179 | logit_scale * logit_scale) 180 | citc_loss = self.config.experiment.citc_lambda1 * inmodal_cyclic_loss + self.config.experiment.citc_lambda2 * crossmodal_cyclic_loss 181 | ret['citc_loss'] = citc_loss * self.config.experiment.citc_ratio 182 | 183 | if self.config.experiment.ritc: 184 | logits_per_image_1 = logit_scale * image_features_norm @ text_features_norm_gathered.t() 185 | logits_per_text_1 = logit_scale * text_features_norm @ image_features_norm_gathered.t() 186 | img_log = F.log_softmax(logits_per_image_1, dim=1) 187 | txt_log = F.log_softmax(logits_per_text_1, dim=1) 188 | target_log = (sim_targets + self.eps).log() 189 | kl_img = F.kl_div(target_log, img_log, log_target=True, reduction='batchmean') 190 | kl_txt = F.kl_div(target_log, txt_log, log_target=True, reduction='batchmean') 191 | ritc_loss = 0.5 * (kl_img + kl_txt) 192 | ret['ritc_loss'] = ritc_loss * self.config.experiment.ritc_ratio 193 | 194 | if self.config.experiment.mlm: 195 | x = self.cross_former(text_seq_embeddings, image_seq_embeddings, image_seq_embeddings) 196 | x = self.mlm_head(x) 197 | scores = x.float().reshape(-1, self.vocab_size) 198 | mlm_labels = mlm_labels.reshape(-1) 199 | mlm_loss = F.cross_entropy(scores, mlm_labels) 200 | ret['mlm_loss'] = mlm_loss * self.config.experiment.mlm_ratio 201 | 202 | if self.config.experiment.id: 203 | image_logits = self.classifier(image_features) 204 | text_logits = self.classifier(text_features) 205 | id_loss = (F.cross_entropy(image_logits, ids) + F.cross_entropy(text_logits, ids)) / 2 206 | ret['id_loss'] = id_loss * self.config.experiment.id_ratio 207 | 208 | return ret 209 | 210 | def cross_former(self, q, k, v): 211 | x = self.cross_attn( 212 | self.ln_pre_t(q), 213 | self.ln_pre_i(k), 214 | self.ln_pre_i(v), 215 | need_weights=False)[0] 216 | x = x.permute(1, 0, 2) # NLD -> LND 217 | x = self.cross_modal_transformer(x) 218 | x = x.permute(1, 0, 2) # LND -> NLD 219 | 220 | x = self.ln_post(x) 221 | return x 222 | 223 | # input features are normed 224 | def calc_contrastive(self, image_features, text_features, image_features_s, text_features_s, 225 | image_features_gathered, text_features_gathered, image_features_s_gathered, 226 | text_features_s_gathered, 227 | sim_targets, alpha, logit_scale): 228 | with torch.no_grad(): 229 | sim_i2t_s = logit_scale * image_features_s @ text_features_s_gathered.t() 230 | sim_t2i_s = logit_scale * text_features_s @ image_features_s_gathered.t() 231 | sim_i2t_targets = alpha * F.softmax(sim_i2t_s, dim=1) + (1 - alpha) * sim_targets 232 | sim_t2i_targets = alpha * F.softmax(sim_t2i_s, dim=1) + (1 - alpha) * sim_targets # soft + hard 233 | sim_i2t = logit_scale * image_features @ text_features_gathered.t() 234 | sim_t2i = logit_scale * text_features @ image_features_gathered.t() 235 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean() 236 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean() 237 | loss_ita = (loss_i2t + loss_t2i) / 2 238 | return loss_ita 239 | 240 | def compute_simclr_loss(self, logits_a, logits_b, logits_a_gathered, logits_b_gathered, labels, temperature): 241 | sim_aa = logits_a @ logits_a_gathered.t() / temperature 242 | sim_ab = logits_a @ logits_b_gathered.t() / temperature 243 | sim_ba = logits_b @ logits_a_gathered.t() / temperature 244 | sim_bb = logits_b @ logits_b_gathered.t() / temperature 245 | masks = torch.where(F.one_hot(labels, logits_a_gathered.size(0)) == 0, 0, float('-inf')) 246 | sim_aa += masks 247 | sim_bb += masks 248 | sim_a = torch.cat([sim_ab, sim_aa], 1) 249 | sim_b = torch.cat([sim_ba, sim_bb], 1) 250 | loss_a = F.cross_entropy(sim_a, labels) 251 | loss_b = F.cross_entropy(sim_b, labels) 252 | return (loss_a + loss_b) * 0.5 253 | 254 | def _build_mlp(self, in_dim=512, mlp_dim=512, out_dim=512): 255 | return nn.Sequential( 256 | nn.Linear(in_dim, mlp_dim), 257 | nn.ReLU(inplace=True), 258 | nn.Linear(mlp_dim, out_dim) 259 | ) 260 | 261 | @property 262 | def dtype(self): 263 | try: 264 | return self.visual.conv1.weight.dtype 265 | except: 266 | try: 267 | return self.visual.head.weight.dtype 268 | except: 269 | try: 270 | return self.visual.stem[0].weight.dtype 271 | except: 272 | return self.encode_text.text_projection.weight.dtype 273 | 274 | def encode_image(self, image, return_dense=False): 275 | if return_dense: 276 | output = self.visual(image.type(self.dtype), return_dense=return_dense) 277 | return output 278 | output = self.visual(image.type(self.dtype)) 279 | return output 280 | 281 | def all_gather(self, input): 282 | if not self.use_gather or not is_using_distributed(): 283 | return input 284 | output = AllGather.apply(input) 285 | output = output.view(-1, *(output.shape[2:])) 286 | return output 287 | 288 | 289 | def clip_vitb(config, num_classes=11003): 290 | image_encode = visual_transformer(config) 291 | text_encode = text_transformers(config) 292 | model = CLIP(config, image_encode, text_encode, num_classes, config.experiment.ritc_eps) 293 | return model 294 | -------------------------------------------------------------------------------- /model/text_transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from .base_transformer import Transformer, LayerNorm 6 | 7 | 8 | class TextTransformer(nn.Module): 9 | def __init__(self, config, 10 | embed_dim: int, 11 | context_length: int, 12 | transformer_width: int, 13 | transformer_heads: int, 14 | transformer_layers: int, 15 | positional_embedding_flag: bool, 16 | checkpoint: bool, 17 | bpe_path=None, 18 | ): 19 | super().__init__() 20 | self.config = config 21 | self.context_length = context_length 22 | self.positional_embedding_flag = positional_embedding_flag 23 | 24 | self.transformer = Transformer( 25 | width=transformer_width, 26 | layers=transformer_layers, 27 | heads=transformer_heads, 28 | attn_mask=self.build_attention_mask(), 29 | checkpoint=checkpoint, 30 | dropout=config.experiment.dropout 31 | ) 32 | self.token_embedding = nn.Embedding(49408, transformer_width) 33 | self.positional_embedding = nn.Parameter( 34 | torch.normal(mean=0, std=0.02, size=(self.context_length, transformer_width))) 35 | self.ln_final = LayerNorm(transformer_width) 36 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 37 | self.initialize_parameters() 38 | 39 | def train(self, mode=True): 40 | self.training = mode 41 | for module in self.children(): 42 | module.train(mode) 43 | return self 44 | 45 | def initialize_parameters(self): 46 | nn.init.normal_(self.token_embedding.weight, std=0.02) 47 | nn.init.normal_(self.positional_embedding, std=0.01) 48 | 49 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 50 | attn_std = self.transformer.width ** -0.5 51 | fc_std = (2 * self.transformer.width) ** -0.5 52 | for block in self.transformer.resblocks: 53 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 54 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 55 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 56 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 57 | if self.text_projection is not None: 58 | # nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) # todo 59 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 60 | 61 | @property 62 | def dtype(self): 63 | return self.positional_embedding.dtype 64 | 65 | def build_attention_mask(self): 66 | # lazily create causal attention mask, with full attention between the vision tokens 67 | # pytorch uses additive attention mask; fill with -inf 68 | mask = torch.empty(self.context_length, self.context_length) 69 | mask.fill_(float("-inf")) 70 | mask.triu_(1) # zero out the lower diagonal 71 | return mask 72 | 73 | def forward(self, texts, mask_type=None, return_dense=False): 74 | if mask_type is not None: 75 | texts, labels = texts 76 | x = self.token_embedding(texts).type(self.dtype) # [batch_size, n_ctx, d_model] 77 | if self.positional_embedding_flag: 78 | x = x + self.positional_embedding.type(self.dtype) # Fix!!! 79 | x = x.permute(1, 0, 2) # NLD -> LND 80 | x = self.transformer(x) 81 | x = x.permute(1, 0, 2) # LND -> NLD 82 | x = self.ln_final(x).type(self.dtype) 83 | 84 | x = x @ self.text_projection 85 | 86 | if mask_type is not None or return_dense: 87 | words_feat = x 88 | 89 | x = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)] 90 | 91 | if mask_type is not None: 92 | return x, words_feat, labels 93 | 94 | if return_dense: 95 | return x, words_feat 96 | 97 | return x 98 | 99 | 100 | def text_transformers(config): 101 | model_config = config.model 102 | kwargs = { 103 | 'context_length': config.experiment.text_length, 104 | 'transformer_width': 512, 105 | 'transformer_heads': 8, 106 | 'transformer_layers': 12, 107 | 'positional_embedding_flag': True, 108 | 'checkpoint': False, 109 | 'embed_dim': model_config.embed_dim, 110 | } 111 | model = TextTransformer(config, **kwargs) 112 | return model 113 | -------------------------------------------------------------------------------- /model/visual_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .base_transformer import Transformer, LayerNorm 4 | from typing import Tuple, Union 5 | 6 | 7 | class VisualTransformer(nn.Module): 8 | def __init__(self, input_resolution: Union[int, Tuple[int, int]], patch_size: int, width: int, layers: int, heads: int, embed_dim: int, 9 | checkpoint: bool, dropout: float = 0, emb_dropout: float = 0): 10 | super().__init__() 11 | if isinstance(input_resolution, int): 12 | input_resolution = (input_resolution, input_resolution) 13 | self.input_resolution = input_resolution 14 | self.num_x = (input_resolution[1] - patch_size) // patch_size + 1 15 | self.num_y = (input_resolution[0] - patch_size) // patch_size + 1 16 | num_patches = self.num_x * self.num_y 17 | 18 | output_dim = embed_dim 19 | self.output_dim = output_dim 20 | self.freeze_conv1 = True 21 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, 22 | kernel_size=patch_size, stride=patch_size, bias=False) 23 | 24 | scale = width ** -0.5 25 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 26 | self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, width)) 27 | self.ln_pre = LayerNorm(width) 28 | 29 | self.transformer = Transformer(width, layers, heads, checkpoint=checkpoint, dropout=dropout, 30 | emb_dropout=emb_dropout) 31 | 32 | self.ln_post = LayerNorm(width) 33 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 34 | self.initialize_parameters() 35 | 36 | def initialize_parameters(self): 37 | nn.init.normal_(self.positional_embedding, std=0.01) 38 | 39 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 40 | attn_std = self.transformer.width ** -0.5 41 | fc_std = (2 * self.transformer.width) ** -0.5 42 | for block in self.transformer.resblocks: 43 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 44 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 45 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 46 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 47 | 48 | def train(self, mode=True): 49 | self.training = mode 50 | for module in self.children(): 51 | module.train(mode) 52 | 53 | if self.freeze_conv1: 54 | for layer in [self.conv1]: 55 | layer.eval() 56 | for param in layer.parameters(): 57 | param.requires_grad = False 58 | return self 59 | 60 | def forward(self, x: torch.Tensor, return_dense=False, return_feature=False): 61 | x = self.conv1(x) # shape = [*, width, grid, grid] 62 | # shape = [*, width, grid ** 2] 63 | x = x.reshape(x.shape[0], x.shape[1], -1) 64 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 65 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) 66 | x = x + self.positional_embedding.to(x.dtype) 67 | x = self.ln_pre(x) 68 | 69 | x = x.permute(1, 0, 2) # NLD -> LND 70 | x = self.transformer(x) 71 | x = x.permute(1, 0, 2) # LND -> NLD 72 | 73 | # x = self.ln_post(x[:, 0, :]) 74 | x = self.ln_post(x) 75 | dense_feat = x 76 | 77 | if self.proj is not None: 78 | dense_feat = x @ self.proj 79 | x = dense_feat[:, 0, :] 80 | 81 | if return_dense: 82 | return x, dense_feat 83 | if return_feature: 84 | return dense_feat 85 | return x 86 | 87 | 88 | def visual_transformer(config): 89 | vision_width = 768 90 | vision_layers = 12 91 | vision_heads = vision_width // 64 92 | 93 | kwargs = { 94 | 'layers': vision_layers, 95 | 'heads': vision_heads, 96 | 'input_resolution': config.experiment.input_resolution, 97 | 'patch_size': 16, 98 | 'width': vision_width, 99 | 'checkpoint': False, 100 | 'embed_dim': config.model.embed_dim, 101 | } 102 | 103 | model = VisualTransformer(**kwargs) 104 | return model 105 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser(description="IRRA Args") 6 | ######################## mode ######################## 7 | parser.add_argument("--simplified", default=False, action='store_true') 8 | 9 | args = parser.parse_args() 10 | 11 | return args -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.0 2 | torchvision==0.14.0 3 | torchaudio==0.13.0 4 | 5 | timm==0.6.11 6 | wandb==0.13.5 7 | ftfy==6.1.1 8 | 9 | regex 10 | easydict 11 | pyyaml 12 | textaugment 13 | ipdb 14 | 15 | torchmetrics 16 | matplotlib 17 | jupyter 18 | ipykernel 19 | -------------------------------------------------------------------------------- /shell/train.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=1 \ 2 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 3 | torchrun --rdzv_id=3 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 --nproc_per_node=4 \ 4 | main.py --simplified -------------------------------------------------------------------------------- /text_utils/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Flame-Chasers/TBPS-CLIP/6160a877af99229bbf39077b1047d96cf7fda64c/text_utils/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /text_utils/mask_tokens.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, List 3 | 4 | 5 | def mask_tokens(inputs, special_tokens, mask_token, tokenizer_length, mlm_probability=0.15, special_tokens_mask=None) -> Tuple[torch.Tensor, torch.Tensor]: 6 | """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ 7 | labels = inputs.clone() 8 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 9 | probability_matrix = torch.full(labels.shape, mlm_probability) 10 | if special_tokens_mask is None: 11 | special_tokens_mask = [1 if val in special_tokens else 0 for val in labels.tolist()] 12 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 13 | # if tokenizer._pad_token is not None: 14 | # padding_mask = labels.eq(tokenizer.pad_token_id) 15 | # probability_matrix.masked_fill_(padding_mask, value=0.0) 16 | masked_indices = torch.bernoulli(probability_matrix).bool() 17 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 18 | 19 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 20 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 21 | inputs[indices_replaced] = mask_token 22 | 23 | # 10% of the time, we replace masked input tokens with random word 24 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 25 | random_words = torch.randint(tokenizer_length, labels.shape, dtype=torch.long) 26 | inputs[indices_random] = random_words[indices_random] 27 | 28 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 29 | return inputs, labels 30 | 31 | 32 | def MaskTokens(tokens, mask_type, mask_token, special_tokens=None, tokenizer_length=None, sepcial_tokens_mask=None, special_tokens_mask=None): 33 | if mask_type == 'MLM': 34 | tokens, labels = mask_tokens(inputs=tokens, special_tokens=special_tokens, mask_token=mask_token, tokenizer_length=tokenizer_length, special_tokens_mask=special_tokens_mask) 35 | else: 36 | raise NotImplementedError(mask_type) 37 | return tokens, labels 38 | -------------------------------------------------------------------------------- /text_utils/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2 ** 8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2 ** 8 + n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | # Change: Extend <|mask|> tokenizer-size+=1 63 | class SimpleTokenizer(object): 64 | def __init__(self, bpe_path: str = default_bpe()): 65 | self.byte_encoder = bytes_to_unicode() 66 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 67 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 68 | merges = merges[1:49152 - 256 - 2 + 1] 69 | merges = [tuple(merge.split()) for merge in merges] 70 | vocab = list(bytes_to_unicode().values()) 71 | vocab = vocab + [v + '' for v in vocab] 72 | for merge in merges: 73 | vocab.append(''.join(merge)) 74 | 75 | vocab.pop(-1) # remove last one in vocab(jekyll) to keep vocab_size unchanged 76 | vocab.extend(['<|mask|>', '<|startoftext|>', '<|endoftext|>']) # vocab_size 49408 77 | # vocab.extend(['<|startoftext|>', '<|endoftext|>']) # vocab_size 49408 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|mask|>': '<|mask|>', '<|endoftext|>': '<|endoftext|>'} 82 | self.pat = re.compile( 83 | r"""<\|startoftext\|>|<\|mask\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 84 | re.IGNORECASE) 85 | 86 | def bpe(self, token): 87 | if token in self.cache: 88 | return self.cache[token] 89 | word = tuple(token[:-1]) + (token[-1] + '',) 90 | pairs = get_pairs(word) 91 | 92 | if not pairs: 93 | return token + '' 94 | 95 | while True: 96 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 97 | if bigram not in self.bpe_ranks: 98 | break 99 | first, second = bigram 100 | new_word = [] 101 | i = 0 102 | while i < len(word): 103 | try: 104 | j = word.index(first, i) 105 | new_word.extend(word[i:j]) 106 | i = j 107 | except: 108 | new_word.extend(word[i:]) 109 | break 110 | 111 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 112 | new_word.append(first + second) 113 | i += 2 114 | else: 115 | new_word.append(word[i]) 116 | i += 1 117 | new_word = tuple(new_word) 118 | word = new_word 119 | if len(word) == 1: 120 | break 121 | else: 122 | pairs = get_pairs(word) 123 | word = ' '.join(word) 124 | self.cache[token] = word 125 | return word 126 | 127 | def encode(self, text): 128 | bpe_tokens = [] 129 | text = whitespace_clean(basic_clean(text)).lower() 130 | for token in re.findall(self.pat, text): 131 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 132 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 133 | return bpe_tokens 134 | 135 | def decode(self, tokens): 136 | text = ''.join([self.decoder[token] for token in tokens]) 137 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 138 | return text 139 | -------------------------------------------------------------------------------- /text_utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | import torch 4 | 5 | from .mask_tokens import MaskTokens 6 | from text_utils.simple_tokenizer import SimpleTokenizer as _Tokenizer 7 | 8 | _tokenizer = _Tokenizer() 9 | 10 | 11 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, return_length: bool = False, 12 | mask_type=None): 13 | if isinstance(texts, str): 14 | texts = [texts] 15 | 16 | sot_token = _tokenizer.encoder["<|startoftext|>"] 17 | eot_token = _tokenizer.encoder["<|endoftext|>"] 18 | 19 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 20 | for i, tokens in enumerate(all_tokens): 21 | if len(tokens) > context_length: 22 | all_tokens[i] = [tokens[0]] + tokens[1:context_length - 1] + [tokens[-1]] 23 | all_tokens[i] = torch.Tensor(all_tokens[i]).long() 24 | 25 | if mask_type is not None: 26 | mask_token = _tokenizer.encoder["<|mask|>"] 27 | special_tokens = [sot_token, eot_token, mask_token] 28 | masked_tokens = [ 29 | MaskTokens(tokens, mask_type=mask_type, mask_token=mask_token, special_tokens=special_tokens, 30 | tokenizer_length=len(_tokenizer.encoder)) for tokens in all_tokens] 31 | all_tokens = [item[0] for item in masked_tokens] 32 | all_labels = [item[1] for item in masked_tokens] 33 | 34 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 35 | labels = torch.ones(len(all_tokens), context_length, dtype=torch.long) * -100 36 | token_lengths = torch.ones(len(all_tokens), dtype=torch.long) 37 | 38 | for i, tokens in enumerate(all_tokens): 39 | result[i, :len(tokens)] = tokens 40 | token_lengths[i] = min(len(tokens), context_length) 41 | if mask_type is not None: 42 | labels[i, :len(tokens)] = all_labels[i] 43 | 44 | if mask_type: 45 | # print(result[0], labels[0], '<< masking', flush=True) 46 | return result, labels 47 | if return_length: 48 | return result, token_lengths 49 | else: 50 | return result --------------------------------------------------------------------------------