├── .gitignore
├── .idea
├── .gitignore
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── tbps-clip.iml
└── vcs.xml
├── README.md
├── config
├── config.yaml
└── s.config.yaml
├── image
└── intro.png
├── main.py
├── misc
├── build.py
├── caption_dataset.py
├── data.py
├── eval.py
├── lr_scheduler.py
└── utils.py
├── model
├── __init__.py
├── base_transformer.py
├── eda.py
├── loss.py
├── mixgen.py
├── shared_modules.py
├── tbps_model.py
├── text_transformer.py
└── visual_transformer.py
├── options.py
├── requirements.txt
├── shell
└── train.sh
└── text_utils
├── bpe_simple_vocab_16e6.txt.gz
├── mask_tokens.py
├── simple_tokenizer.py
└── tokenizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Python template
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # pytype static type analyzer
136 | .pytype/
137 |
138 | # Cython debug symbols
139 | cython_debug/
140 |
141 | # IDEA
142 | .idea/
143 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/tbps-clip.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 【AAAI 2024 🔥】An Empirical Study of CLIP for Text-based Person Search
4 | [](https://ojs.aaai.org/index.php/AAAI/article/view/27801)
5 | [](https://arxiv.org/abs/2308.10045)
6 |
7 |
8 | This repository offers the official implementation of [TBPS-CLIP](https://arxiv.org/abs/2308.10045) in PyTorch.
9 |
10 | In the meantime, check out our related papers if you are interested:
11 | + 【ACM MM 2023】 [Text-based Person Search without Parallel Image-Text Data](https://arxiv.org/abs/2305.12964)
12 | + 【IJCAI 2023】 [RaSa: Relation and Sensitivity Aware Representation Learning for Text-based Person Search](https://arxiv.org/abs/2305.13653)
13 | + 【ICASSP 2022】 [Learning Semantic-Aligned Feature Representation for Text-based Person Search](https://arxiv.org/abs/2112.06714)
14 |
15 | ## Note
16 | More experiments and implementation details are attached on the Appendix of the [arXiv](https://arxiv.org/abs/2308.10045) version.
17 |
18 |
19 | ## Overview
20 | By revisiting the critical design of data augmentation and loss function in [CLIP](https://arxiv.org/abs/2103.00020),
21 | we provide a strong baseline [TBPS-CLIP](https://arxiv.org/abs/2308.10045) for text-based person search.
22 |
23 |
24 |
25 |
26 | ## Environment
27 |
28 | All the experiments are conducted on 4 Nvidia A40 (48GB) GPUs. The CUDA version is 11.7.
29 |
30 | The required packages are listed in `requirements.txt`. You can install them using:
31 |
32 | ```sh
33 | pip install -r requirements.txt
34 | ```
35 |
36 | ## Download
37 | 1. Download CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description), ICFG-PEDES dataset from [here](https://github.com/zifyloo/SSAN) and RSTPReid dataset from [here](https://github.com/NjtechCVLab/RSTPReid-Dataset).
38 | 2. Download the annotation json files from [here](https://drive.google.com/file/d/1C5bgGCABtuzZMaa2n4Sc0qclUvZ-mqG9/view?usp=drive_link).
39 | 3. Download the pretrained CLIP checkpoint from [here](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt).
40 |
41 | ## Configuration
42 | In `config/config.yaml` and `config/s.config.yaml`, set the paths for the annotation file, image path and the CLIP checkpoint path.
43 |
44 |
45 | ## Training
46 |
47 | You can start the training using PyTorch's torchrun with ease:
48 |
49 | ```sh
50 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
51 | torchrun --rdzv_id=3 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 --nproc_per_node=4 \
52 | main.py
53 | ```
54 |
55 | You can also easily run simplified version using:
56 |
57 | ```sh
58 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
59 | torchrun --rdzv_id=3 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 --nproc_per_node=4 \
60 | main.py --simplified
61 | ```
62 |
63 |
64 | ## Model Checkpoints
65 | | | **CUHK-PEDES** | **ICFG-PEDES** | **RSTPReid** |
66 | |:-----------------------------------:|:-------------------------------------------------------------------------------------------------:|:--------------:|:------------:|
67 | | **TBPS-CLIP (ViT-B/16)** | [Download](https://drive.google.com/file/d/1m_3pKanUWHQHeJ-zt-QeRXs7bmay-U5P/view?usp=drive_link) | [Download](https://drive.google.com/file/d/1az4z5b_ADXR7DcysPB5giOl52LjWDCSu/view?usp=drive_link) | [Download](https://drive.google.com/file/d/1qMUAsH-1lzkWUFQsUvUKTY0J6ZuGkYd6/view?usp=drive_link) |
68 | | **Simplified TBPS-CLIP (ViT-B/16)** | [Download](https://drive.google.com/file/d/1W5oFZK9WNHMfy0OOaYQBzPsP1LZR80bT/view?usp=drive_link) | [Download](https://drive.google.com/file/d/1UoLd-MQ8tYJ7YPgCbh3nVSVYnJ9a_TG5/view?usp=drive_link) | [Download](https://drive.google.com/file/d/18zlc3q3Sze5rx3TqcfEeZEjrQXUTpcQF/view?usp=drive_link) |
69 |
70 |
71 | ## Acknowledgement
72 | + [CLIP](https://arxiv.org/abs/2103.00020) The model architecture of TBPS-CLIP
73 |
74 | ## Citation
75 | If you find this paper useful, please consider staring 🌟 this repo and citing 📑 our paper:
76 | ```
77 | @inproceedings{cao2024empirical,
78 | title={An Empirical Study of CLIP for Text-Based Person Search},
79 | author={Cao, Min and Bai, Yang and Zeng, Ziyin and Ye, Mang and Zhang, Min},
80 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
81 | volume={38},
82 | number={1},
83 | pages={465--473},
84 | year={2024}
85 | }
86 | ```
87 |
88 |
89 | ## License
90 | This code is distributed under an MIT LICENSE.
91 |
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | device: 5
2 |
3 | misc:
4 | seed: 1
5 |
6 | experiment:
7 | # image
8 | input_resolution: [224, 224]
9 | simclr_mlp: [512, 128, 512]
10 | simclr_temperature: 0.1
11 | # text
12 | dropout: 0.05
13 | eda_alpha: 0.05
14 | back_trans: true
15 | backtrans_p: 0.1
16 | text_length: 77
17 | # mix
18 | mixgen: false
19 | mixgen_type: cat # ori or cat
20 | mixgen_p: 0.1
21 | mixgen_ratio: 0.1
22 | mvs_image: true
23 |
24 | # loss
25 | nitc_ratio: 1.0
26 | ####
27 | ss: true
28 | ss_ratio: 0.4
29 | ####
30 | ritc: true
31 | ritc_eps: 1.0e-2
32 | ritc_ratio: 1.0
33 | ####
34 | mlm: false
35 | mlm_ratio: 1.0
36 | cmt_depth: 4 # cross modal transformer self attn layers
37 | ####
38 | citc: true
39 | citc_lambda1: 0.25
40 | citc_lambda2: 0.25
41 | citc_ratio: 0.1
42 | ####
43 | id: false
44 | id_ratio: 1.0
45 |
46 | schedule:
47 | lr: 1.0e-4
48 | epoch: 5
49 | epoch_warmup: 1
50 | lr_start: 1.0e-6
51 | lr_end: 5.0e-6
52 | weight_decay: 0.02
53 | betas: [0.9, 0.98]
54 | eps: 1.0e-8
55 |
56 | model:
57 | ckpt_type: original_clip # original_clip / saved
58 | saved_path: 'ckpts/baseline_224_224/CUHK-PEDES'
59 | checkpoint: 'CLIP checkpoint path' # e.g., '../../data/CLIP/ViT-B-16.pt'
60 | use_gather: true
61 | softlabel_ratio: 0.5
62 | embed_dim: 512
63 | vocab_size: 49408
64 |
65 | log:
66 | print_period: 50
67 |
68 | data:
69 | batch_size: 80
70 | test_batch_size: 256
71 | num_workers: 8
72 |
73 | distributed:
74 | backend: nccl
75 | url: 'env://'
76 |
77 | anno_dir: 'annotation json path' # e.g., 'data/CUHK-PEDES'
78 | image_dir: 'image path' # e.g., '../../datasets/cuhkpedes/imgs'
--------------------------------------------------------------------------------
/config/s.config.yaml:
--------------------------------------------------------------------------------
1 | device: 5
2 |
3 | misc:
4 | seed: 0
5 |
6 | experiment:
7 | # image
8 | input_resolution: [224, 224]
9 | simclr_mlp: [512, 128, 512]
10 | simclr_temperature: 0.1
11 | # text
12 | dropout: 0.05
13 | eda_alpha: 0.05
14 | back_trans: true
15 | backtrans_p: 0.1
16 | text_length: 77
17 | # mix
18 | mixgen: false
19 | mixgen_type: cat # ori or cat
20 | mixgen_p: 0.1
21 | mixgen_ratio: 0.1
22 | mvs_image: false
23 |
24 | # loss
25 | nitc_ratio: 1.0
26 | ####
27 | ss: false
28 | ss_ratio: 0.4
29 | ####
30 | ritc: true
31 | ritc_eps: 1.0e-2
32 | ritc_ratio: 1.0
33 | ####
34 | mlm: false
35 | mlm_ratio: 1.0
36 | cmt_depth: 4 # cross modal transformer self attn layers
37 | ####
38 | citc: false
39 | citc_lambda1: 0.25
40 | citc_lambda2: 0.25
41 | citc_ratio: 0.1
42 | ####
43 | id: false
44 | id_ratio: 1.0
45 |
46 | schedule:
47 | lr: 1.0e-4
48 | epoch: 5
49 | epoch_warmup: 1
50 | lr_start: 1.0e-6
51 | lr_end: 5.0e-6
52 | weight_decay: 0.02
53 | betas: [0.9, 0.98]
54 | eps: 1.0e-8
55 |
56 | model:
57 | ckpt_type: original_clip # original_clip / saved
58 | saved_path: 'ckpts/s.baseline_224_224/CUHK-PEDES'
59 | checkpoint: 'CLIP checkpoint path' # e.g., '../../data/CLIP/ViT-B-16.pt'
60 | use_gather: true
61 | softlabel_ratio: 0.5
62 | embed_dim: 512
63 | vocab_size: 49408
64 |
65 | log:
66 | print_period: 50
67 |
68 | data:
69 | batch_size: 80
70 | test_batch_size: 256
71 | num_workers: 8
72 |
73 | distributed:
74 | backend: nccl
75 | url: 'env://'
76 |
77 | anno_dir: 'annotation json path' # e.g., 'data/CUHK-PEDES'
78 | image_dir: 'image path' # e.g., '../../datasets/cuhkpedes/imgs'
--------------------------------------------------------------------------------
/image/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Flame-Chasers/TBPS-CLIP/6160a877af99229bbf39077b1047d96cf7fda64c/image/intro.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 | from pathlib import Path
5 |
6 | import torch
7 |
8 | from misc.build import load_checkpoint, cosine_scheduler, build_optimizer
9 | from misc.data import build_pedes_data
10 | from misc.eval import test
11 | from misc.utils import parse_config, init_distributed_mode, set_seed, is_master, is_using_distributed, \
12 | AverageMeter
13 | from model.tbps_model import clip_vitb
14 | from options import get_args
15 |
16 |
17 | def run(config):
18 | print(config)
19 |
20 | # data
21 | dataloader = build_pedes_data(config)
22 | train_loader = dataloader['train_loader']
23 | num_classes = len(train_loader.dataset.person2text)
24 |
25 | meters = {
26 | "loss": AverageMeter(),
27 | "nitc_loss": AverageMeter(),
28 | "ss_loss": AverageMeter(),
29 | "citc_loss": AverageMeter(),
30 | "ritc_loss": AverageMeter(),
31 | "mlm_loss": AverageMeter(),
32 | "id_loss": AverageMeter(),
33 | }
34 | best_rank_1 = 0.0
35 | best_epoch = 0
36 |
37 | # model
38 | model = clip_vitb(config, num_classes)
39 | model.to(config.device)
40 |
41 | model, load_result = load_checkpoint(model, config)
42 |
43 | if is_using_distributed():
44 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.device],
45 | find_unused_parameters=True)
46 |
47 | # schedule
48 | config.schedule.niter_per_ep = len(train_loader)
49 | lr_schedule = cosine_scheduler(config)
50 |
51 | # optimizer
52 | optimizer = build_optimizer(config, model)
53 |
54 | # train
55 | it = 0
56 | scaler = torch.cuda.amp.GradScaler()
57 | for epoch in range(config.schedule.epoch):
58 | print()
59 | if is_using_distributed():
60 | dataloader['train_sampler'].set_epoch(epoch)
61 |
62 | start_time = time.time()
63 | for meter in meters.values():
64 | meter.reset()
65 | model.train()
66 |
67 | for i, batch in enumerate(train_loader):
68 | for param_group in optimizer.param_groups:
69 | param_group['lr'] = lr_schedule[it] * param_group['ratio']
70 |
71 | if epoch == 0:
72 | alpha = config.model.softlabel_ratio * min(1.0, i / len(train_loader))
73 | else:
74 | alpha = config.model.softlabel_ratio
75 |
76 | if config.experiment.mixgen:
77 | if random.random() < config.experiment.mixgen_p:
78 | import model.mixgen as mg
79 | if config.experiment.mixgen_type == 'cat':
80 | mixgen_func = mg.concatgen
81 | else:
82 | mixgen_func = mg.mixgen
83 | img, cap = mixgen_func(batch['image'], batch['caption'],
84 | num=int(config.experiment.mixgen_ratio * len(batch['caption'])))
85 | batch.update({
86 | 'image': img,
87 | 'caption': cap,
88 | })
89 |
90 | with torch.autocast(device_type='cuda'):
91 | ret = model(batch, alpha)
92 | loss = sum([v for k, v in ret.items() if "loss" in k])
93 |
94 | batch_size = batch['image'].shape[0]
95 | meters['loss'].update(loss.item(), batch_size)
96 | meters['nitc_loss'].update(ret.get('nitc_loss', 0), batch_size)
97 | meters['ss_loss'].update(ret.get('ss_loss', 0), batch_size)
98 | meters['citc_loss'].update(ret.get('citc_loss', 0), batch_size)
99 | meters['ritc_loss'].update(ret.get('ritc_loss', 0), batch_size)
100 | meters['mlm_loss'].update(ret.get('mlm_loss', 0), batch_size)
101 | meters['id_loss'].update(ret.get('id_loss', 0), batch_size)
102 |
103 | scaler.scale(loss).backward()
104 | scaler.step(optimizer)
105 | scaler.update()
106 | model.zero_grad()
107 | optimizer.zero_grad()
108 | it += 1
109 |
110 | if (i + 1) % config.log.print_period == 0:
111 | info_str = f"Epoch[{epoch + 1}] Iteration[{i + 1}/{len(train_loader)}]"
112 | # log loss
113 | for k, v in meters.items():
114 | if v.val != 0:
115 | info_str += f", {k}: {v.val:.4f}"
116 | info_str += f", Base Lr: {param_group['lr']:.2e}"
117 | print(info_str)
118 |
119 | if is_master():
120 | end_time = time.time()
121 | time_per_batch = (end_time - start_time) / (i + 1)
122 | print("Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
123 | .format(epoch + 1, time_per_batch, train_loader.batch_size / time_per_batch))
124 |
125 | eval_result = test(model.module, dataloader['test_loader'], 77, config.device)
126 | rank_1, rank_5, rank_10, map = eval_result['r1'], eval_result['r5'], eval_result['r10'], eval_result['mAP']
127 | print('Acc@1 {top1:.5f} Acc@5 {top5:.5f} Acc@10 {top10:.5f} mAP {mAP:.5f}'.format(top1=rank_1, top5=rank_5,
128 | top10=rank_10, mAP=map))
129 | torch.cuda.empty_cache()
130 | if best_rank_1 < rank_1:
131 | best_rank_1 = rank_1
132 | best_epoch = epoch
133 |
134 | save_obj = {
135 | 'model': model.module.state_dict(),
136 | 'optimizer': optimizer.state_dict(),
137 | 'config': config,
138 | }
139 | torch.save(save_obj, os.path.join(config.model.saved_path, 'checkpoint_best.pth'))
140 |
141 | print(f"best Acc@1: {best_rank_1} at epoch {best_epoch + 1}")
142 |
143 |
144 | if __name__ == '__main__':
145 | config_path = 'config/config.yaml'
146 |
147 | args = get_args()
148 | if args.simplified:
149 | config_path = 'config/s.config.yaml'
150 | config = parse_config(config_path)
151 |
152 | Path(config.model.saved_path).mkdir(parents=True, exist_ok=True)
153 |
154 | init_distributed_mode(config)
155 |
156 | set_seed(config)
157 |
158 | run(config)
159 |
--------------------------------------------------------------------------------
/misc/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import math
5 | import torch.nn.functional as F
6 |
7 |
8 | def resize_pos_embed(posemb, posemb_new, hight, width):
9 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
10 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
11 | posemb = posemb.unsqueeze(0)
12 | posemb_new = posemb_new.unsqueeze(0)
13 |
14 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
15 |
16 | gs_old = int(math.sqrt(len(posemb_grid)))
17 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width))
18 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
19 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
20 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
21 | posemb = torch.cat([posemb_token, posemb_grid], dim=1)
22 | return posemb.squeeze(0)
23 |
24 |
25 | def interpolate_text(pos_embed_checkpoint, target_dim=77):
26 | # (n_ctx, n_feat) for pos_embed_checkpoint, including SOT and EOT
27 | if pos_embed_checkpoint.size(0) == target_dim:
28 | return pos_embed_checkpoint
29 | start_token = pos_embed_checkpoint[:1, :]
30 | end_token = pos_embed_checkpoint[-1:, :]
31 | pos_tokens = pos_embed_checkpoint[1:-1, :].unsqueeze(0).permute(0, 2, 1)
32 | pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=target_dim - 2, mode='linear')
33 | pos_tokens = pos_tokens.squeeze(0).t()
34 | pos_tokens = torch.cat([start_token, pos_tokens, end_token], dim=0)
35 | return pos_tokens
36 |
37 |
38 | def load_checkpoint(model, config):
39 | if config.model.ckpt_type == 'original_clip':
40 | with open(config.model.checkpoint, 'rb') as opened_file:
41 | model_tmp = torch.jit.load(opened_file, map_location="cpu")
42 | state = model_tmp.state_dict()
43 | for key in ["input_resolution", "context_length", "vocab_size"]:
44 | if key in state:
45 | del state[key]
46 |
47 | # 2 towers in new_state: visual, encode_text
48 | new_state = {}
49 | for name, params in state.items():
50 | if name == 'visual.positional_embedding' and params.shape != model.visual.positional_embedding.shape:
51 | params = resize_pos_embed(params, model.visual.positional_embedding, model.visual.num_y, model.visual.num_x)
52 |
53 | if name == 'positional_embedding':
54 | new_state['encode_text.' + name] = interpolate_text(params, config.experiment.text_length)
55 | elif name.startswith('transformer') or name in ['positional_embedding', 'token_embedding.weight',
56 | 'ln_final.weight', 'ln_final.bias', 'text_projection']:
57 | new_state['encode_text.' + name] = params
58 | else:
59 | new_state[name] = params
60 | elif config.model.ckpt_type == 'saved':
61 | ckpt = torch.load(os.path.join(config.model.saved_path, 'checkpoint_best.pth'), map_location='cpu')
62 | new_state = ckpt['model']
63 | else:
64 | raise KeyError
65 |
66 | load_result = model.load_state_dict(new_state, strict=False)
67 | return model, load_result
68 |
69 |
70 | def cosine_scheduler(config):
71 | schedule_config = config.schedule
72 | base_value = schedule_config.lr
73 | start_warmup_value = schedule_config.lr_start
74 | final_value = schedule_config.lr_end
75 | epochs = schedule_config.epoch
76 | warmup_epochs = schedule_config.epoch_warmup
77 | niter_per_ep = schedule_config.niter_per_ep
78 |
79 | warmup_schedule = np.array([])
80 | warmup_iters = warmup_epochs * niter_per_ep
81 | if warmup_epochs > 0:
82 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
83 |
84 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
85 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
86 |
87 | schedule = np.concatenate((warmup_schedule, schedule))
88 | assert len(schedule) == epochs * niter_per_ep
89 | return schedule
90 |
91 |
92 | # def build_optimizer(config, model):
93 | # p_wd, p_non_wd = [], []
94 | # for n, p in model.named_parameters():
95 | # if not p.requires_grad:
96 | # continue # frozen weights
97 | # if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n:
98 | # p_non_wd.append(p)
99 | # else:
100 | # p_wd.append(p)
101 | #
102 | # schedule_config = config.schedule
103 | # optim_params = [{"params": p_wd, "weight_decay": schedule_config.weight_decay, "ratio": 1.},
104 | # {"params": p_non_wd, "weight_decay": 0, "ratio": 1.}]
105 | #
106 | # optimizer = torch.optim.AdamW(optim_params, lr=schedule_config.lr, betas=schedule_config.betas,
107 | # eps=schedule_config.eps, weight_decay=schedule_config.weight_decay)
108 | # return optimizer
109 |
110 |
111 | def build_optimizer(config, model):
112 | params = []
113 | schedule_config = config.schedule
114 | for n, p in model.named_parameters():
115 | if not p.requires_grad:
116 | continue # frozen weights
117 | weight_decay = schedule_config.weight_decay
118 | ratio = 1.
119 |
120 | if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n:
121 | weight_decay = 0.
122 | if "cross" in n or "classifier" in n or "mlm_head" in n:
123 | ratio = ratio * schedule_config.ratio_factor # default 5.0
124 |
125 | params += [{"params": [p], "weight_decay": weight_decay, "ratio": ratio}]
126 |
127 | optimizer = torch.optim.AdamW(params, lr=schedule_config.lr, betas=schedule_config.betas,
128 | eps=schedule_config.eps, weight_decay=schedule_config.weight_decay)
129 | return optimizer
130 |
--------------------------------------------------------------------------------
/misc/caption_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | from collections import defaultdict
5 |
6 | import torch
7 | from PIL import Image
8 | from PIL import ImageFile
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms
11 | from PIL import ImageFilter
12 | import random
13 |
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 | Image.MAX_IMAGE_PIXELS = None
16 |
17 |
18 | class ps_train_dataset(Dataset):
19 | def __init__(self, ann_root, image_root, transform, aug_ss, split, max_words=30):
20 | ann_file = os.path.join(ann_root, split + '_reid.json')
21 | anns = json.load(open(ann_file))
22 | self.transform = transform
23 |
24 | self.person2text = defaultdict(list)
25 | person_id2idx = {}
26 | n = 0
27 | self.pairs = []
28 |
29 | for ann in anns:
30 | image_path = os.path.join(image_root, ann['file_path'])
31 | person_id = ann['id']
32 | if person_id not in person_id2idx.keys():
33 | person_id2idx[person_id] = n
34 | n += 1
35 | person_idx = person_id2idx[person_id]
36 | if 'captions_bt' not in ann:
37 | ann['captions_bt'] = [''] * len(ann['captions'])
38 | for caption, caption_bt in zip(ann['captions'], ann['captions_bt']):
39 | caption = pre_caption(caption, max_words)
40 | caption_bt = pre_caption(caption_bt, max_words)
41 | self.pairs.append((image_path, caption, caption_bt, person_idx))
42 | self.person2text[person_idx].append(caption)
43 |
44 | self.augmentation_ss = aug_ss
45 |
46 | def __len__(self):
47 | return len(self.pairs)
48 |
49 | def __getitem__(self, index):
50 | image_path, caption, caption_bt, person = self.pairs[index]
51 |
52 | image_pil = Image.open(image_path)
53 | image = self.transform(image_pil.convert('RGB'))
54 | aug1 = self.transform(image_pil.convert('RGB'))
55 | aug_ss_1 = self.augmentation_ss(image_pil)
56 | aug_ss_2 = self.augmentation_ss(image_pil)
57 | return {
58 | 'image': image,
59 | 'caption': caption,
60 | 'caption_bt': caption_bt,
61 | 'id': person,
62 | 'aug1': aug1,
63 | 'aug_ss_1': aug_ss_1,
64 | 'aug_ss_2': aug_ss_2
65 | }
66 |
67 |
68 | class ps_eval_dataset(Dataset):
69 | def __init__(self, ann_root, image_root, transform, split, max_words=30):
70 | ann_file = os.path.join(ann_root, split + '_reid.json')
71 | anns = json.load(open(ann_file, 'r'))
72 | self.transform = transform
73 |
74 | self.text = []
75 | self.image = []
76 | self.txt2person = []
77 | self.img2person = []
78 |
79 | for ann in anns:
80 | image_path = os.path.join(image_root, ann['file_path'])
81 | self.image.append(image_path)
82 |
83 | person_id = ann['id']
84 | self.img2person.append(person_id)
85 | for caption in ann['captions']:
86 | self.text.append(pre_caption(caption, max_words))
87 | self.txt2person.append(person_id)
88 |
89 | self.txt2person = torch.tensor(self.txt2person, dtype=torch.long)
90 | self.img2person = torch.tensor(self.img2person, dtype=torch.long)
91 |
92 | def __len__(self):
93 | return len(self.image)
94 |
95 | def __getitem__(self, index):
96 | image_path = self.image[index]
97 | image = Image.open(image_path).convert('RGB')
98 | image = self.transform(image)
99 |
100 | return image
101 |
102 | def pre_caption(caption, max_words=50):
103 | caption = re.sub(
104 | r"([.!\"()*#:;~])",
105 | ' ',
106 | caption.lower(),
107 | )
108 | caption = re.sub(
109 | r"\s{2,}",
110 | ' ',
111 | caption,
112 | )
113 | caption = caption.rstrip('\n')
114 | caption = caption.strip(' ')
115 |
116 | # truncate caption
117 | caption_words = caption.split(' ')
118 | if len(caption_words) > max_words:
119 | caption = ' '.join(caption_words[:max_words])
120 |
121 | return caption
--------------------------------------------------------------------------------
/misc/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from PIL import ImageFilter
9 | from torch.utils.data import DataLoader
10 | from torchvision import transforms
11 |
12 | from misc.caption_dataset import ps_train_dataset, ps_eval_dataset
13 | from misc.utils import is_using_distributed
14 |
15 |
16 | def get_self_supervised_augmentation(img_size):
17 | class GaussianBlur(object):
18 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
19 |
20 | def __init__(self, sigma=[.1, 2.]):
21 | self.sigma = sigma
22 |
23 | def __call__(self, x):
24 | sigma = random.uniform(self.sigma[0], self.sigma[1])
25 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
26 | return x
27 |
28 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
29 | std=[0.229, 0.224, 0.225])
30 |
31 | aug = transforms.Compose([
32 | transforms.RandomResizedCrop(img_size, scale=(0.2, 1.), antialias=True),
33 | transforms.RandomApply([
34 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
35 | ], p=0.8),
36 | transforms.RandomGrayscale(p=0.2),
37 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
38 | transforms.RandomHorizontalFlip(),
39 | transforms.ToTensor(),
40 | normalize
41 | ])
42 | return aug
43 |
44 |
45 | def pil_loader(path):
46 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
47 | with open(path, 'rb') as f:
48 | img = Image.open(f)
49 | return img.convert('RGB')
50 |
51 |
52 | class cuhkpedes_eval(torch.utils.data.Dataset):
53 | def __init__(self, ann_file, transform, image_root):
54 | self.ann = json.load(open(ann_file, 'r'))
55 | self.transform = transform
56 | self.image_root = image_root
57 |
58 | self.text = []
59 | self.image = []
60 | self.txt2img = {}
61 | self.img2txt = {}
62 | self.pid2txt, self.pid2img = {}, {}
63 | self.txt_ids, self.img_ids = [], []
64 |
65 | txt_id = 0
66 | for img_id, ann in enumerate(self.ann):
67 | self.image.append(ann['image'])
68 | if ann['image_id'] not in self.pid2txt.keys():
69 | self.pid2txt[ann['image_id']] = []
70 | self.pid2img[ann['image_id']] = []
71 | self.pid2img[ann['image_id']].append(img_id)
72 | self.img_ids.append(ann['image_id'])
73 | for i, caption in enumerate(ann['caption']):
74 | self.text.append(caption)
75 | self.pid2txt[ann['image_id']].append(txt_id)
76 | self.txt_ids.append(ann['image_id'])
77 | txt_id += 1
78 |
79 | for tid in range(len(self.text)):
80 | self.txt2img[tid] = self.pid2img[self.txt_ids[tid]]
81 | for iid in range(len(self.image)):
82 | self.img2txt[iid] = self.pid2txt[self.img_ids[iid]]
83 |
84 | def __len__(self):
85 | return len(self.image)
86 |
87 | def __getitem__(self, index):
88 | image_path = os.path.join(self.image_root, self.ann[index]['image'])
89 | image = Image.open(image_path)
90 | image = self.transform(image)
91 |
92 | return image, index
93 |
94 |
95 | def build_pedes_data(config):
96 | size = config.experiment.input_resolution
97 | if isinstance(size, int):
98 | size = (size, size)
99 |
100 | normalize = transforms.Normalize(
101 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
102 | val_transform = transforms.Compose([
103 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
104 | transforms.ToTensor(),
105 | normalize
106 | ])
107 |
108 | rand_from = [
109 | transforms.ColorJitter(.1, .1, .1, 0),
110 | transforms.RandomRotation(15),
111 | transforms.RandomResizedCrop(size, (0.9, 1.0), antialias=True),
112 | transforms.RandomGrayscale(),
113 | transforms.RandomHorizontalFlip(),
114 | transforms.RandomErasing(scale=(0.10, 0.20)),
115 | ]
116 | aug = Choose(rand_from, size)
117 | aug_ss = get_self_supervised_augmentation(size)
118 |
119 | train_dataset = ps_train_dataset(config.anno_dir, config.image_dir, aug, aug_ss, split='train', max_words=77)
120 | test_dataset = ps_eval_dataset(config.anno_dir, config.image_dir, val_transform, split='test', max_words=77)
121 |
122 | if is_using_distributed():
123 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
124 | else:
125 | train_sampler = None
126 | test_sampler = None
127 |
128 | config_data = config.data
129 | train_loader = DataLoader(
130 | dataset=train_dataset,
131 | batch_size=config_data.batch_size,
132 | shuffle=train_sampler is None,
133 | num_workers=config_data.num_workers,
134 | pin_memory=True,
135 | sampler=train_sampler,
136 | drop_last=True,
137 | )
138 | test_loader = DataLoader(
139 | dataset=test_dataset,
140 | batch_size=32,
141 | shuffle=False,
142 | sampler=test_sampler,
143 | drop_last=False,
144 | )
145 |
146 | return {
147 | 'train_loader': train_loader,
148 | 'train_sampler': train_sampler,
149 | 'test_loader': test_loader,
150 | }
151 |
152 |
153 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
154 |
155 |
156 | class Choose:
157 | def __init__(self, rand_from, size):
158 | self.choose_from = rand_from
159 | self.size = size
160 |
161 | def __call__(self, image):
162 | aug_choice = np.random.choice(self.choose_from, 2)
163 | return transforms.Compose([
164 | transforms.Resize(self.size),
165 | transforms.ToTensor(),
166 | *aug_choice,
167 | normalize
168 | ])(image)
169 |
--------------------------------------------------------------------------------
/misc/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | # import clip
4 | from text_utils.tokenizer import tokenize
5 |
6 |
7 | @torch.no_grad()
8 | def test(model, data_loader, max_length, device):
9 | # switch to evaluate mode
10 | model.eval()
11 |
12 | dataset = data_loader.dataset
13 | texts = dataset.text
14 | num_text = len(texts)
15 | text_bs = 256
16 |
17 | text_feats = []
18 | for i in range(0, num_text, text_bs):
19 | text = texts[i: min(num_text, i + text_bs)]
20 | text = tokenize(text, context_length=max_length).to(device)
21 | text_feat = F.normalize(model.encode_text(text), dim=-1)
22 | text_feats.append(text_feat)
23 | text_feats = torch.cat(text_feats, dim=0)
24 |
25 | image_feats = []
26 | for image in data_loader:
27 | image = image.to(device)
28 | image_feat = F.normalize(model.encode_image(image), dim=-1)
29 | image_feats.append(image_feat)
30 | image_feats = torch.cat(image_feats, dim=0)
31 |
32 | sims_matrix = text_feats @ image_feats.t()
33 | eval_result = metric_eval(sims_matrix, dataset.img2person, dataset.txt2person)
34 |
35 | return eval_result
36 |
37 |
38 | @torch.no_grad()
39 | def metric_eval(scores_t2i, img2person, txt2person):
40 | device = scores_t2i.device
41 | img2person = img2person.to(device)
42 | txt2person = txt2person.to(device)
43 |
44 | index = torch.argsort(scores_t2i, dim=-1, descending=True)
45 | pred_person = img2person[index]
46 | matches = (txt2person.view(-1, 1).eq(pred_person)).long()
47 |
48 | def acc_k(matches, k=1):
49 | matches_k = matches[:, :k].sum(dim=-1)
50 | matches_k = torch.sum((matches_k > 0))
51 | return 100.0 * matches_k / matches.size(0)
52 |
53 | # Compute metrics
54 | ir1 = acc_k(matches, k=1).item()
55 | ir5 = acc_k(matches, k=5).item()
56 | ir10 = acc_k(matches, k=10).item()
57 | ir_mean = (ir1 + ir5 + ir10) / 3
58 |
59 | real_num = matches.sum(dim=-1)
60 | tmp_cmc = matches.cumsum(dim=-1).float()
61 | order = torch.arange(start=1, end=matches.size(1) + 1, dtype=torch.long).to(device)
62 | tmp_cmc /= order
63 | tmp_cmc *= matches
64 | AP = tmp_cmc.sum(dim=-1) / real_num
65 | mAP = AP.mean() * 100.0
66 |
67 | eval_result = {'r1': ir1,
68 | 'r5': ir5,
69 | 'r10': ir10,
70 | 'r_mean': ir_mean,
71 | 'mAP': mAP.item()
72 | }
73 |
74 | return eval_result
75 |
--------------------------------------------------------------------------------
/misc/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from bisect import bisect_right
2 | from math import cos, pi
3 |
4 | from torch.optim.lr_scheduler import _LRScheduler
5 |
6 |
7 | class LRSchedulerWithWarmup(_LRScheduler):
8 | def __init__(
9 | self,
10 | optimizer,
11 | milestones,
12 | gamma=0.1,
13 | mode="step",
14 | warmup_factor=1.0 / 3,
15 | warmup_epochs=10,
16 | warmup_method="linear",
17 | total_epochs=100,
18 | target_lr=0,
19 | power=0.9,
20 | last_epoch=-1,
21 | ):
22 | if not list(milestones) == sorted(milestones):
23 | raise ValueError(
24 | "Milestones should be a list of"
25 | " increasing integers. Got {}".format(milestones),
26 | )
27 | if mode not in ("step", "exp", "poly", "cosine", "linear"):
28 | raise ValueError(
29 | "Only 'step', 'exp', 'poly' or 'cosine' learning rate scheduler accepted"
30 | "got {}".format(mode)
31 | )
32 | if warmup_method not in ("constant", "linear"):
33 | raise ValueError(
34 | "Only 'constant' or 'linear' warmup_method accepted"
35 | "got {}".format(warmup_method)
36 | )
37 | self.milestones = milestones
38 | self.mode = mode
39 | self.gamma = gamma
40 | self.warmup_factor = warmup_factor
41 | self.warmup_epochs = warmup_epochs
42 | self.warmup_method = warmup_method
43 | self.total_epochs = total_epochs
44 | self.target_lr = target_lr
45 | self.power = power
46 | super().__init__(optimizer, last_epoch)
47 |
48 | def get_lr(self):
49 |
50 | if self.last_epoch < self.warmup_epochs:
51 | if self.warmup_method == "constant":
52 | warmup_factor = self.warmup_factor
53 | elif self.warmup_method == "linear":
54 | alpha = self.last_epoch / self.warmup_epochs
55 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
56 | return [base_lr * warmup_factor for base_lr in self.base_lrs]
57 |
58 | if self.mode == "step":
59 | return [
60 | base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
61 | for base_lr in self.base_lrs
62 | ]
63 |
64 | epoch_ratio = (self.last_epoch - self.warmup_epochs) / (
65 | self.total_epochs - self.warmup_epochs
66 | )
67 |
68 | if self.mode == "exp":
69 | factor = epoch_ratio
70 | return [base_lr * self.power ** factor for base_lr in self.base_lrs]
71 | if self.mode == "linear":
72 | factor = 1 - epoch_ratio
73 | return [base_lr * factor for base_lr in self.base_lrs]
74 |
75 | if self.mode == "poly":
76 | factor = 1 - epoch_ratio
77 | return [
78 | self.target_lr + (base_lr - self.target_lr) * self.power ** factor
79 | for base_lr in self.base_lrs
80 | ]
81 | if self.mode == "cosine":
82 | factor = 0.5 * (1 + cos(pi * epoch_ratio))
83 | return [
84 | self.target_lr + (base_lr - self.target_lr) * factor
85 | for base_lr in self.base_lrs
86 | ]
87 | raise NotImplementedError
88 |
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from easydict import EasyDict
4 |
5 | import yaml
6 | import os
7 |
8 | import torch
9 | import numpy as np
10 | import random
11 | import torch.distributed as dist
12 |
13 |
14 | def parse_config(config_path):
15 | with open(config_path) as f:
16 | config = yaml.load(f, Loader=yaml.FullLoader)
17 | config = EasyDict(config)
18 | return config
19 |
20 |
21 | def is_using_distributed():
22 | return True
23 |
24 |
25 | def is_dist_avail_and_initialized():
26 | if not dist.is_available():
27 | return False
28 | if not dist.is_initialized():
29 | return False
30 | return True
31 |
32 |
33 | def get_world_size():
34 | if not is_dist_avail_and_initialized():
35 | return 1
36 | return dist.get_world_size()
37 |
38 |
39 | def get_rank():
40 | if not is_dist_avail_and_initialized():
41 | return 0
42 | return dist.get_rank()
43 |
44 |
45 | def is_master():
46 | return not is_using_distributed() or get_rank() == 0
47 |
48 |
49 | def wandb_record():
50 | if not 'WANDB_PROJECT' in os.environ:
51 | return False
52 | return not is_using_distributed() or get_rank() == 0
53 |
54 |
55 | def init_distributed_mode(config):
56 | if is_using_distributed():
57 | config.distributed.rank = int(os.environ['RANK'])
58 | config.distributed.world_size = int(os.environ['WORLD_SIZE'])
59 | config.distributed.local_rank = int(os.environ['LOCAL_RANK'])
60 | torch.distributed.init_process_group(backend=config.distributed.backend,
61 | init_method=config.distributed.url)
62 | used_for_printing(get_rank() == 0)
63 |
64 | if torch.cuda.is_available():
65 | if is_using_distributed():
66 | device = f'cuda:{get_rank()}'
67 | else:
68 | device = f'cuda:{d}' if str(d := config.device).isdigit() else d
69 | torch.cuda.set_device(device)
70 | else:
71 | device = 'cpu'
72 | config.device = device
73 |
74 |
75 | def used_for_printing(is_master):
76 | import builtins as __builtin__
77 | builtin_print = __builtin__.print
78 |
79 | def print(*args, **kwargs):
80 | force = kwargs.pop('force', False)
81 | if is_master or force:
82 | builtin_print(*args, **kwargs)
83 |
84 | __builtin__.print = print
85 |
86 |
87 | def set_seed(config):
88 | seed = config.misc.seed
89 |
90 | torch.manual_seed(seed)
91 | np.random.seed(seed)
92 | random.seed(seed)
93 | os.environ["PYTHONHASHSEED"] = str(seed)
94 |
95 | if torch.cuda.is_available():
96 | torch.cuda.manual_seed_all(seed)
97 | torch.backends.cudnn.deterministic = True
98 | torch.backends.cudnn.benchmark = False
99 |
100 |
101 | class AverageMeter(object):
102 | """Computes and stores the average and current value"""
103 |
104 | def __init__(self):
105 | self.val = 0
106 | self.avg = 0
107 | self.sum = 0
108 | self.count = 0
109 |
110 | def reset(self):
111 | self.val = 0
112 | self.avg = 0
113 | self.sum = 0
114 | self.count = 0
115 |
116 | def update(self, val, n=1):
117 | self.val = val
118 | self.sum += val * n
119 | self.count += n
120 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Flame-Chasers/TBPS-CLIP/6160a877af99229bbf39077b1047d96cf7fda64c/model/__init__.py
--------------------------------------------------------------------------------
/model/base_transformer.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | from torch.utils.checkpoint import checkpoint_sequential
6 |
7 | global LAYER_NORM
8 | LAYER_NORM = True
9 |
10 |
11 | class LayerNorm(nn.LayerNorm):
12 | """Subclass torch's LayerNorm to handle fp16."""
13 |
14 | def forward(self, x: torch.Tensor):
15 | if LAYER_NORM:
16 | ret = super().forward(x)
17 | else:
18 | ret = x
19 | return ret
20 |
21 |
22 | class QuickGELU(nn.Module):
23 | def forward(self, x: torch.Tensor):
24 | return x * torch.sigmoid(1.702 * x)
25 |
26 |
27 | class ResidualAttentionBlock(nn.Module):
28 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, dropout: float = 0.):
29 | super().__init__()
30 |
31 | self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout)
32 | self.ln_1 = LayerNorm(d_model)
33 | self.mlp = nn.Sequential(OrderedDict([
34 | ("c_fc", nn.Linear(d_model, d_model * 4)),
35 | ("gelu", QuickGELU()),
36 | # ("dropout_1", nn.Dropout(dropout)),
37 | ("c_proj", nn.Linear(d_model * 4, d_model)),
38 | # ("dropout_2", nn.Dropout(dropout))
39 | ]))
40 | self.ln_2 = LayerNorm(d_model)
41 | self.attn_mask = attn_mask
42 |
43 | def attention(self, x: torch.Tensor):
44 | self.attn_mask = self.attn_mask.to(
45 | dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
46 | return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)[0]
47 |
48 | def forward(self, x: torch.Tensor):
49 | x = x + self.attention(self.ln_1(x))
50 | x = x + self.mlp(self.ln_2(x))
51 | return x
52 |
53 |
54 | class Transformer(nn.Module):
55 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, checkpoint: bool = False,
56 | dropout: float = 0., emb_dropout: float = 0.):
57 | super().__init__()
58 | self.width = width
59 | self.layers = layers
60 | self.checkpoint = checkpoint
61 | self.dropout = nn.Dropout(emb_dropout)
62 | self.resblocks = nn.Sequential(
63 | *[ResidualAttentionBlock(width, heads, attn_mask, dropout=dropout) for _ in range(layers)])
64 |
65 | def checkpoint_fwd(self, layer, input, segments=2):
66 | """checkpoint forward"""
67 | # Make sure that the input to checkpoint have requires_grad=True, so that
68 | # the autograd can take care of the checkpointed part of model
69 | if not input.requires_grad:
70 | input = input.detach()
71 | input.requires_grad = True
72 | return checkpoint_sequential(layer, segments, input)
73 |
74 | def forward(self, x: torch.Tensor):
75 | x = self.dropout(x)
76 | if self.checkpoint:
77 | return self.checkpoint_fwd(self.resblocks, x, self.layers)
78 | return self.resblocks(x)
79 |
--------------------------------------------------------------------------------
/model/eda.py:
--------------------------------------------------------------------------------
1 | from nltk.corpus import wordnet, stopwords
2 | import random
3 |
4 |
5 | class EDA:
6 | """
7 | This class is an implementation of the original EDA algorithm (2019) [1].
8 |
9 | [1] Wei, J. and Zou, K., 2019, November. EDA: Easy Data Augmentation Techniques for Boosting Performance on
10 | Text Classification Tasks. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing
11 | and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP) (pp. 6383-6389).
12 | https://www.aclweb.org/anthology/D19-1670.pdf
13 |
14 | Example usage: ::
15 | >>> from textaugment import EDA
16 | >>> t = EDA()
17 | >>> t.synonym_replacement("John is going to town")
18 | John is give out to town
19 | >>> t.random_deletion("John is going to town", p=0.2)
20 | is going to town
21 | >>> t.random_swap("John is going to town")
22 | John town going to is
23 | >>> t.random_insertion("John is going to town")
24 | John is going to make up town
25 | """
26 |
27 | @staticmethod
28 | def _get_synonyms(word):
29 | """Generate synonym"""
30 | synonyms = set()
31 | for syn in wordnet.synsets(word):
32 | for lemma in syn.lemmas():
33 | synonym = lemma.name().replace("_", " ").replace("-", " ").lower()
34 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
35 | synonyms.add(synonym)
36 | if word in synonyms:
37 | synonyms.remove(word)
38 | synonyms = sorted(synonyms)
39 | random.shuffle(synonyms)
40 | return synonyms
41 |
42 | @staticmethod
43 | def swap_word(new_words):
44 | """Swap words"""
45 | random_idx_1 = random.randint(0, len(new_words) - 1)
46 | random_idx_2 = random_idx_1
47 | counter = 0
48 | while random_idx_2 == random_idx_1:
49 | random_idx_2 = random.randint(0, len(new_words) - 1)
50 | counter += 1
51 | if counter > 3:
52 | return new_words
53 | new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
54 | return new_words
55 |
56 | @staticmethod
57 | def validate(**kwargs):
58 | """Validate input data"""
59 |
60 | if 'p' in kwargs:
61 | if kwargs['p'] > 1 or kwargs['p'] < 0:
62 | raise TypeError("p must be a fraction between 0 and 1")
63 | if 'sentence' in kwargs:
64 | if not isinstance(kwargs['sentence'].strip(), str) or len(kwargs['sentence'].strip()) == 0:
65 | raise TypeError("sentence must be a valid sentence")
66 | if 'n' in kwargs:
67 | if not isinstance(kwargs['n'], int):
68 | raise TypeError("n must be a valid integer")
69 |
70 | def __init__(self, stop_words=None, random_state=1):
71 | """A method to initialize parameters
72 |
73 | :type random_state: int
74 | :param random_state: (optional) Seed
75 | :type stop_words: list
76 | :param stop_words: (optional) List of stopwords
77 |
78 | :rtype: None
79 | :return: Constructer do not return.
80 | """
81 | self.stopwords = stopwords.words('english') if stop_words is None else stop_words
82 | self.sentence = None
83 | self.p = None
84 | self.n = None
85 | # self.random_state = random_state
86 | # if isinstance(self.random_state, int):
87 | # random.seed(self.random_state)
88 | # else:
89 | # raise TypeError("random_state must have type int")
90 |
91 | def add_word(self, new_words):
92 | """Insert word"""
93 | synonyms = list()
94 | counter = 0
95 | while len(synonyms) < 1:
96 | random_word_list = list([word for word in new_words if word not in self.stopwords])
97 | random_word = random_word_list[random.randint(0, len(random_word_list) - 1)]
98 | synonyms = self._get_synonyms(random_word)
99 | counter += 1
100 | if counter >= 10:
101 | return new_words # See Issue 14 for details
102 | random_synonym = synonyms[0] # TODO
103 | random_idx = random.randint(0, len(new_words) - 1)
104 | new_words.insert(random_idx, random_synonym)
105 | return new_words
106 |
107 | def synonym_replacement(self, sentence: str, n: int = 1):
108 | """Replace n words in the sentence with synonyms from wordnet
109 |
110 | :type sentence: str
111 | :param sentence: Sentence
112 | :type n: int
113 | :param n: Number of repetitions to replace
114 |
115 | :rtype: str
116 | :return: Augmented sentence
117 | """
118 | self.validate(sentence=sentence, n=n)
119 | self.n = n
120 | self.sentence = sentence
121 | words = sentence.split()
122 | new_words = words.copy()
123 | random_word_list = sorted(list(set([word for word in words if word not in self.stopwords])))
124 | random.shuffle(random_word_list)
125 | replaced = 0
126 | for random_word in random_word_list:
127 | synonyms = self._get_synonyms(random_word)
128 | if len(synonyms) > 0:
129 | synonym = random.choice(list(synonyms))
130 | new_words = [synonym if word == random_word else word for word in new_words]
131 | replaced += 1
132 | if replaced >= self.n:
133 | break
134 | sentence = ' '.join(new_words)
135 |
136 | return sentence
137 |
138 | def random_deletion(self, sentence: str, p: float = 0.1):
139 | """Randomly delete words from the sentence with probability p
140 |
141 | :type sentence: str
142 | :param sentence: Sentence
143 | :type p: int
144 | :param p: Probability between 0 and 1
145 |
146 | :rtype: str
147 | :return: Augmented sentence
148 | """
149 | self.validate(sentence=sentence, p=p)
150 | self.p = p
151 | self.sentence = sentence
152 | words = sentence.split()
153 | if len(words) == 1:
154 | return words
155 | new_words = list()
156 | for word in words:
157 | r = random.uniform(0, 1)
158 | if r > self.p:
159 | new_words.append(word)
160 | # if all words are deleted, just return a random word
161 | if len(new_words) == 0:
162 | return random.choice(words)
163 |
164 | return " ".join(new_words)
165 |
166 | def random_swap(self, sentence: str, n: int = 1):
167 | """Randomly swap two words in the sentence n times
168 |
169 | :type sentence: str
170 | :param sentence: Sentence
171 | :type n: int
172 | :param n: Number of repetitions to swap
173 |
174 | :rtype: str
175 | :return: Augmented sentence
176 | """
177 | self.validate(sentence=sentence, n=n)
178 | self.n = n
179 | self.sentence = sentence
180 | words = sentence.split()
181 | new_words = words.copy()
182 | for _ in range(self.n):
183 | new_words = self.swap_word(new_words)
184 | return " ".join(new_words)
185 |
186 | def random_insertion(self, sentence: str, n: int = 1):
187 | """Randomly insert n words into the sentence
188 |
189 | :type sentence: str
190 | :param sentence: Sentence
191 | :type n: int
192 | :param n: Number of words to insert
193 |
194 | :rtype: str
195 | :return: Augmented sentence
196 | """
197 | self.validate(sentence=sentence, n=n)
198 | self.n = n
199 | self.sentence = sentence
200 | words = sentence.split()
201 | new_words = words.copy()
202 | for _ in range(self.n):
203 | new_words = self.add_word(new_words)
204 | return " ".join(new_words)
205 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def compute_simclr_loss(logits_a, logits_b, logits_a_gathered, logits_b_gathered, labels, temperature):
6 | sim_aa = logits_a @ logits_a_gathered.t() / temperature
7 | sim_ab = logits_a @ logits_b_gathered.t() / temperature
8 | sim_ba = logits_b @ logits_a_gathered.t() / temperature
9 | sim_bb = logits_b @ logits_b_gathered.t() / temperature
10 | masks = torch.where(F.one_hot(labels, logits_a_gathered.size(0)) == 0, 0, float('-inf'))
11 | sim_aa += masks
12 | sim_bb += masks
13 | sim_a = torch.cat([sim_ab, sim_aa], 1)
14 | sim_b = torch.cat([sim_ba, sim_bb], 1)
15 | loss_a = F.cross_entropy(sim_a, labels)
16 | loss_b = F.cross_entropy(sim_b, labels)
17 | return (loss_a + loss_b) * 0.5
18 |
--------------------------------------------------------------------------------
/model/mixgen.py:
--------------------------------------------------------------------------------
1 | """
2 | MixGen: A New Multi-Modal Data Augmentation
3 | https://arxiv.org/abs/2206.08358
4 | Apache-2.0 License, Copyright 2022 Amazon
5 | """
6 | import random
7 | import numpy as np
8 | import torch
9 | from torchvision import transforms
10 |
11 |
12 | def mixgen(image, text, num, lam=0.5):
13 | # default MixGen
14 | for i in range(num):
15 | # image mixup
16 | image[i,:] = lam * image[i,:] + (1 - lam) * image[i+num,:]
17 | # text concat
18 | text[i] = text[i] + " " + text[i+num]
19 | return image, text
20 |
21 | def concatgen(image, text, num, lam=0.5):
22 | for i in range(num):
23 | # image mixup
24 | img1 = transforms.functional.resize(image[i], (224, 112))
25 | img2 = transforms.functional.resize(image[i+num], (224, 112))
26 | image[i] = torch.cat((img1, img2), dim=2)
27 | image[i] = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image[i])
28 | # text concat
29 | text[i] = text[i] + " " + text[i+num]
30 | return image, text
31 |
32 |
33 | def mixgen_batch(image, text, num, lam=0.5):
34 | batch_size = image.size()[0]
35 | index = np.random.permutation(batch_size)
36 | for i in range(batch_size):
37 | # image mixup
38 | image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
39 | # text concat
40 | text[i] = text[i] + " " + text[index[i]]
41 | return image, text
42 |
--------------------------------------------------------------------------------
/model/shared_modules.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | import torch.distributed as dist
6 |
7 |
8 | class AllGather(torch.autograd.Function):
9 |
10 | @staticmethod
11 | def forward(ctx, tensor):
12 | ctx.rank = int(os.environ['RANK'])
13 | ctx.world_size = int(os.environ['WORLD_SIZE'])
14 |
15 | # y = tensor.new(ctx.world_size, *tensor.size())
16 |
17 | y = [tensor.new(*tensor.size()) for _ in range(ctx.world_size)]
18 |
19 | dist.all_gather(y, tensor.contiguous())
20 |
21 | y = torch.cat(y, 0).view(-1, *tensor.size())
22 |
23 | return y
24 |
25 | @staticmethod
26 | def backward(ctx, grad_output):
27 | in_grad = torch.zeros_like(grad_output)
28 | in_grad.copy_(grad_output)
29 | # sum grad for gathered tensor
30 | dist.all_reduce(in_grad.contiguous())
31 | # split
32 | return in_grad[ctx.rank]
33 |
34 |
--------------------------------------------------------------------------------
/model/tbps_model.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | import numpy as np
7 | import copy
8 |
9 | from misc import utils
10 | from misc.utils import is_using_distributed
11 | from text_utils.tokenizer import tokenize
12 | from .loss import compute_simclr_loss
13 | from .visual_transformer import visual_transformer
14 | from .text_transformer import text_transformers
15 | from .eda import EDA
16 | from .base_transformer import Transformer, LayerNorm, QuickGELU
17 |
18 | from .shared_modules import AllGather
19 | from collections import OrderedDict
20 |
21 |
22 | class CLIP(nn.Module):
23 | def __init__(self, config, image_encode, text_encode, num_classes=11003, eps=1e-2):
24 | super().__init__()
25 | self.visual = image_encode
26 | self.encode_text = text_encode
27 | self.embed_dim = config.model.embed_dim
28 |
29 | self.use_gather = config.model.use_gather
30 | self.logit_scale = nn.Parameter(torch.ones([]))
31 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07))
32 | self.config = config
33 | self.eda = EDA()
34 | self.eps = eps
35 |
36 | if config.experiment.ss:
37 | structure = config.experiment.simclr_mlp
38 | self.simclr_mlp = self._build_mlp(*structure)
39 |
40 | if config.experiment.id:
41 | self.classifier = nn.Linear(self.embed_dim, num_classes)
42 | nn.init.normal_(self.classifier.weight.data, std=0.001)
43 | nn.init.constant_(self.classifier.bias.data, val=0.0)
44 |
45 | if config.experiment.mlm:
46 | self.vocab_size = config.model.vocab_size
47 | self.cross_attn = nn.MultiheadAttention(self.embed_dim,
48 | self.embed_dim // 64,
49 | batch_first=True)
50 | self.cross_modal_transformer = Transformer(width=self.embed_dim,
51 | layers=config.experiment.cmt_depth,
52 | heads=self.embed_dim // 64)
53 | scale = self.cross_modal_transformer.width ** -0.5
54 |
55 | self.ln_pre_t = LayerNorm(self.embed_dim)
56 | self.ln_pre_i = LayerNorm(self.embed_dim)
57 | self.ln_post = LayerNorm(self.embed_dim)
58 |
59 | proj_std = scale * ((2 * self.cross_modal_transformer.layers) ** -0.5)
60 | attn_std = scale
61 | fc_std = (2 * self.cross_modal_transformer.width) ** -0.5
62 | for block in self.cross_modal_transformer.resblocks:
63 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
64 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
65 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
66 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
67 |
68 | # init cross attn
69 | nn.init.normal_(self.cross_attn.in_proj_weight, std=attn_std)
70 | nn.init.normal_(self.cross_attn.out_proj.weight, std=proj_std)
71 |
72 | self.mlm_head = nn.Sequential(
73 | OrderedDict([('dense', nn.Linear(self.embed_dim, self.embed_dim)),
74 | ('gelu', QuickGELU()),
75 | ('ln', LayerNorm(self.embed_dim)),
76 | ('fc', nn.Linear(self.embed_dim, self.vocab_size))]))
77 | # init mlm head
78 | nn.init.normal_(self.mlm_head.dense.weight, std=fc_std)
79 | nn.init.normal_(self.mlm_head.fc.weight, std=proj_std)
80 |
81 | def forward(self, input, alpha):
82 | ret = dict()
83 |
84 | images = input['image'].to(self.config.device)
85 | images_1 = input['aug1'].to(self.config.device)
86 | texts = input['caption']
87 | texts_bt = input['caption_bt']
88 |
89 | # back translation
90 | if self.config.experiment.back_trans:
91 | for i in range(len(texts)):
92 | if random.random() < self.config.experiment.backtrans_p:
93 | texts[i] = texts_bt[i]
94 |
95 | # random deletion
96 | cap_new = []
97 | for text in texts:
98 | eda_alpha = self.config.experiment.eda_alpha
99 | cap_new.append(self.eda.random_deletion(text, eda_alpha))
100 | texts = cap_new
101 |
102 | # MLM
103 | if self.config.experiment.mlm:
104 | text_tokens, mlm_labels = tokenize(texts, context_length=self.config.experiment.text_length,
105 | mask_type='MLM')
106 | text_tokens = text_tokens.to(self.config.device)
107 | mlm_labels = mlm_labels.to(self.config.device)
108 | else:
109 | text_tokens = tokenize(texts, context_length=self.config.experiment.text_length).to(self.config.device)
110 | ids = input['id'].to(self.config.device)
111 |
112 | image_features, image_seq_embeddings = self.encode_image(images, return_dense=True)
113 | text_features, text_seq_embeddings = self.encode_text(text_tokens, return_dense=True)
114 | image_features_norm = F.normalize(image_features)
115 | text_features_norm = F.normalize(text_features)
116 | image_features_norm_gathered = self.all_gather(image_features_norm)
117 | text_features_norm_gathered = self.all_gather(text_features_norm)
118 |
119 | # image ss
120 | if self.config.experiment.ss:
121 | aug1_embed = self.simclr_mlp(self.encode_image(input['aug_ss_1'].to(self.config.device)))
122 | aug2_embed = self.simclr_mlp(self.encode_image(input['aug_ss_2'].to(self.config.device)))
123 | q_a = F.normalize(aug1_embed, dim=-1, p=2)
124 | q_b = F.normalize(aug2_embed, dim=-1, p=2)
125 | local_batch_size = q_a.size(0)
126 | labels = local_batch_size * utils.get_rank() + torch.arange(local_batch_size, device=q_a.device)
127 | k_a = self.all_gather(q_a)
128 | k_b = self.all_gather(q_b)
129 | ss_loss = compute_simclr_loss(q_a, q_b, k_a, k_b, labels, self.config.experiment.simclr_temperature)
130 | ret['ss_loss'] = ss_loss * self.config.experiment.ss_ratio
131 |
132 | logit_scale = self.logit_scale.exp()
133 | logit_scale.data = torch.clamp(logit_scale.data, max=100)
134 |
135 | idx = ids.view(-1, 1)
136 | gathered_ids = self.all_gather(ids)
137 | idx_all = gathered_ids.view(1, -1)
138 | pos_idx = torch.eq(idx, idx_all).float()
139 | sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
140 |
141 | with torch.no_grad():
142 | image_features_s = self.encode_image(images).detach()
143 | text_features_s = self.encode_text(text_tokens).detach()
144 | image_features_s_norm = F.normalize(image_features_s)
145 | text_features_s_norm = F.normalize(text_features_s)
146 | image_features_s_norm_gathered = self.all_gather(image_features_s_norm)
147 | text_features_s_norm_gathered = self.all_gather(text_features_s_norm)
148 | nitc_loss = self.calc_contrastive(image_features_norm, text_features_norm, image_features_s_norm,
149 | text_features_s_norm,
150 | image_features_norm_gathered, text_features_norm_gathered,
151 | image_features_s_norm_gathered, text_features_s_norm_gathered,
152 | sim_targets, alpha, logit_scale)
153 |
154 | if self.config.experiment.mvs_image:
155 | image_1_features = self.encode_image(images_1)
156 | image_1_features_norm = F.normalize(image_1_features)
157 | image_1_features_norm_gathered = self.all_gather(image_1_features_norm)
158 | with torch.no_grad():
159 | image_1_features_s = self.encode_image(images_1).detach()
160 | image_1_features_s_norm = F.normalize(image_1_features_s)
161 | image_1_features_s_norm_gathered = self.all_gather(image_1_features_s_norm)
162 | loss_img1_txt0 = self.calc_contrastive(image_1_features_norm, text_features_norm, image_1_features_s_norm,
163 | text_features_s_norm,
164 | image_1_features_norm_gathered, text_features_norm_gathered,
165 | image_1_features_s_norm_gathered, text_features_s_norm_gathered,
166 | sim_targets, alpha, logit_scale)
167 | nitc_loss = (nitc_loss + loss_img1_txt0) / 2
168 |
169 | ret['nitc_loss'] = nitc_loss * self.config.experiment.nitc_ratio
170 |
171 | if self.config.experiment.citc:
172 | logits_image_per_image = logit_scale * image_features_norm_gathered @ image_features_norm_gathered.t()
173 | logits_text_per_text = logit_scale * text_features_norm_gathered @ text_features_norm_gathered.t()
174 | inmodal_cyclic_loss = (logits_image_per_image - logits_text_per_text).square().mean() / (
175 | logit_scale * logit_scale)
176 | logits_text_per_image = logit_scale * image_features_norm_gathered @ text_features_norm_gathered.t()
177 | logits_image_per_text = logit_scale * text_features_norm_gathered @ image_features_norm_gathered.t()
178 | crossmodal_cyclic_loss = (logits_text_per_image - logits_image_per_text).square().mean() / (
179 | logit_scale * logit_scale)
180 | citc_loss = self.config.experiment.citc_lambda1 * inmodal_cyclic_loss + self.config.experiment.citc_lambda2 * crossmodal_cyclic_loss
181 | ret['citc_loss'] = citc_loss * self.config.experiment.citc_ratio
182 |
183 | if self.config.experiment.ritc:
184 | logits_per_image_1 = logit_scale * image_features_norm @ text_features_norm_gathered.t()
185 | logits_per_text_1 = logit_scale * text_features_norm @ image_features_norm_gathered.t()
186 | img_log = F.log_softmax(logits_per_image_1, dim=1)
187 | txt_log = F.log_softmax(logits_per_text_1, dim=1)
188 | target_log = (sim_targets + self.eps).log()
189 | kl_img = F.kl_div(target_log, img_log, log_target=True, reduction='batchmean')
190 | kl_txt = F.kl_div(target_log, txt_log, log_target=True, reduction='batchmean')
191 | ritc_loss = 0.5 * (kl_img + kl_txt)
192 | ret['ritc_loss'] = ritc_loss * self.config.experiment.ritc_ratio
193 |
194 | if self.config.experiment.mlm:
195 | x = self.cross_former(text_seq_embeddings, image_seq_embeddings, image_seq_embeddings)
196 | x = self.mlm_head(x)
197 | scores = x.float().reshape(-1, self.vocab_size)
198 | mlm_labels = mlm_labels.reshape(-1)
199 | mlm_loss = F.cross_entropy(scores, mlm_labels)
200 | ret['mlm_loss'] = mlm_loss * self.config.experiment.mlm_ratio
201 |
202 | if self.config.experiment.id:
203 | image_logits = self.classifier(image_features)
204 | text_logits = self.classifier(text_features)
205 | id_loss = (F.cross_entropy(image_logits, ids) + F.cross_entropy(text_logits, ids)) / 2
206 | ret['id_loss'] = id_loss * self.config.experiment.id_ratio
207 |
208 | return ret
209 |
210 | def cross_former(self, q, k, v):
211 | x = self.cross_attn(
212 | self.ln_pre_t(q),
213 | self.ln_pre_i(k),
214 | self.ln_pre_i(v),
215 | need_weights=False)[0]
216 | x = x.permute(1, 0, 2) # NLD -> LND
217 | x = self.cross_modal_transformer(x)
218 | x = x.permute(1, 0, 2) # LND -> NLD
219 |
220 | x = self.ln_post(x)
221 | return x
222 |
223 | # input features are normed
224 | def calc_contrastive(self, image_features, text_features, image_features_s, text_features_s,
225 | image_features_gathered, text_features_gathered, image_features_s_gathered,
226 | text_features_s_gathered,
227 | sim_targets, alpha, logit_scale):
228 | with torch.no_grad():
229 | sim_i2t_s = logit_scale * image_features_s @ text_features_s_gathered.t()
230 | sim_t2i_s = logit_scale * text_features_s @ image_features_s_gathered.t()
231 | sim_i2t_targets = alpha * F.softmax(sim_i2t_s, dim=1) + (1 - alpha) * sim_targets
232 | sim_t2i_targets = alpha * F.softmax(sim_t2i_s, dim=1) + (1 - alpha) * sim_targets # soft + hard
233 | sim_i2t = logit_scale * image_features @ text_features_gathered.t()
234 | sim_t2i = logit_scale * text_features @ image_features_gathered.t()
235 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean()
236 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean()
237 | loss_ita = (loss_i2t + loss_t2i) / 2
238 | return loss_ita
239 |
240 | def compute_simclr_loss(self, logits_a, logits_b, logits_a_gathered, logits_b_gathered, labels, temperature):
241 | sim_aa = logits_a @ logits_a_gathered.t() / temperature
242 | sim_ab = logits_a @ logits_b_gathered.t() / temperature
243 | sim_ba = logits_b @ logits_a_gathered.t() / temperature
244 | sim_bb = logits_b @ logits_b_gathered.t() / temperature
245 | masks = torch.where(F.one_hot(labels, logits_a_gathered.size(0)) == 0, 0, float('-inf'))
246 | sim_aa += masks
247 | sim_bb += masks
248 | sim_a = torch.cat([sim_ab, sim_aa], 1)
249 | sim_b = torch.cat([sim_ba, sim_bb], 1)
250 | loss_a = F.cross_entropy(sim_a, labels)
251 | loss_b = F.cross_entropy(sim_b, labels)
252 | return (loss_a + loss_b) * 0.5
253 |
254 | def _build_mlp(self, in_dim=512, mlp_dim=512, out_dim=512):
255 | return nn.Sequential(
256 | nn.Linear(in_dim, mlp_dim),
257 | nn.ReLU(inplace=True),
258 | nn.Linear(mlp_dim, out_dim)
259 | )
260 |
261 | @property
262 | def dtype(self):
263 | try:
264 | return self.visual.conv1.weight.dtype
265 | except:
266 | try:
267 | return self.visual.head.weight.dtype
268 | except:
269 | try:
270 | return self.visual.stem[0].weight.dtype
271 | except:
272 | return self.encode_text.text_projection.weight.dtype
273 |
274 | def encode_image(self, image, return_dense=False):
275 | if return_dense:
276 | output = self.visual(image.type(self.dtype), return_dense=return_dense)
277 | return output
278 | output = self.visual(image.type(self.dtype))
279 | return output
280 |
281 | def all_gather(self, input):
282 | if not self.use_gather or not is_using_distributed():
283 | return input
284 | output = AllGather.apply(input)
285 | output = output.view(-1, *(output.shape[2:]))
286 | return output
287 |
288 |
289 | def clip_vitb(config, num_classes=11003):
290 | image_encode = visual_transformer(config)
291 | text_encode = text_transformers(config)
292 | model = CLIP(config, image_encode, text_encode, num_classes, config.experiment.ritc_eps)
293 | return model
294 |
--------------------------------------------------------------------------------
/model/text_transformer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | from .base_transformer import Transformer, LayerNorm
6 |
7 |
8 | class TextTransformer(nn.Module):
9 | def __init__(self, config,
10 | embed_dim: int,
11 | context_length: int,
12 | transformer_width: int,
13 | transformer_heads: int,
14 | transformer_layers: int,
15 | positional_embedding_flag: bool,
16 | checkpoint: bool,
17 | bpe_path=None,
18 | ):
19 | super().__init__()
20 | self.config = config
21 | self.context_length = context_length
22 | self.positional_embedding_flag = positional_embedding_flag
23 |
24 | self.transformer = Transformer(
25 | width=transformer_width,
26 | layers=transformer_layers,
27 | heads=transformer_heads,
28 | attn_mask=self.build_attention_mask(),
29 | checkpoint=checkpoint,
30 | dropout=config.experiment.dropout
31 | )
32 | self.token_embedding = nn.Embedding(49408, transformer_width)
33 | self.positional_embedding = nn.Parameter(
34 | torch.normal(mean=0, std=0.02, size=(self.context_length, transformer_width)))
35 | self.ln_final = LayerNorm(transformer_width)
36 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
37 | self.initialize_parameters()
38 |
39 | def train(self, mode=True):
40 | self.training = mode
41 | for module in self.children():
42 | module.train(mode)
43 | return self
44 |
45 | def initialize_parameters(self):
46 | nn.init.normal_(self.token_embedding.weight, std=0.02)
47 | nn.init.normal_(self.positional_embedding, std=0.01)
48 |
49 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
50 | attn_std = self.transformer.width ** -0.5
51 | fc_std = (2 * self.transformer.width) ** -0.5
52 | for block in self.transformer.resblocks:
53 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
54 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
55 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
56 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
57 | if self.text_projection is not None:
58 | # nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) # todo
59 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
60 |
61 | @property
62 | def dtype(self):
63 | return self.positional_embedding.dtype
64 |
65 | def build_attention_mask(self):
66 | # lazily create causal attention mask, with full attention between the vision tokens
67 | # pytorch uses additive attention mask; fill with -inf
68 | mask = torch.empty(self.context_length, self.context_length)
69 | mask.fill_(float("-inf"))
70 | mask.triu_(1) # zero out the lower diagonal
71 | return mask
72 |
73 | def forward(self, texts, mask_type=None, return_dense=False):
74 | if mask_type is not None:
75 | texts, labels = texts
76 | x = self.token_embedding(texts).type(self.dtype) # [batch_size, n_ctx, d_model]
77 | if self.positional_embedding_flag:
78 | x = x + self.positional_embedding.type(self.dtype) # Fix!!!
79 | x = x.permute(1, 0, 2) # NLD -> LND
80 | x = self.transformer(x)
81 | x = x.permute(1, 0, 2) # LND -> NLD
82 | x = self.ln_final(x).type(self.dtype)
83 |
84 | x = x @ self.text_projection
85 |
86 | if mask_type is not None or return_dense:
87 | words_feat = x
88 |
89 | x = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)]
90 |
91 | if mask_type is not None:
92 | return x, words_feat, labels
93 |
94 | if return_dense:
95 | return x, words_feat
96 |
97 | return x
98 |
99 |
100 | def text_transformers(config):
101 | model_config = config.model
102 | kwargs = {
103 | 'context_length': config.experiment.text_length,
104 | 'transformer_width': 512,
105 | 'transformer_heads': 8,
106 | 'transformer_layers': 12,
107 | 'positional_embedding_flag': True,
108 | 'checkpoint': False,
109 | 'embed_dim': model_config.embed_dim,
110 | }
111 | model = TextTransformer(config, **kwargs)
112 | return model
113 |
--------------------------------------------------------------------------------
/model/visual_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .base_transformer import Transformer, LayerNorm
4 | from typing import Tuple, Union
5 |
6 |
7 | class VisualTransformer(nn.Module):
8 | def __init__(self, input_resolution: Union[int, Tuple[int, int]], patch_size: int, width: int, layers: int, heads: int, embed_dim: int,
9 | checkpoint: bool, dropout: float = 0, emb_dropout: float = 0):
10 | super().__init__()
11 | if isinstance(input_resolution, int):
12 | input_resolution = (input_resolution, input_resolution)
13 | self.input_resolution = input_resolution
14 | self.num_x = (input_resolution[1] - patch_size) // patch_size + 1
15 | self.num_y = (input_resolution[0] - patch_size) // patch_size + 1
16 | num_patches = self.num_x * self.num_y
17 |
18 | output_dim = embed_dim
19 | self.output_dim = output_dim
20 | self.freeze_conv1 = True
21 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width,
22 | kernel_size=patch_size, stride=patch_size, bias=False)
23 |
24 | scale = width ** -0.5
25 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
26 | self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, width))
27 | self.ln_pre = LayerNorm(width)
28 |
29 | self.transformer = Transformer(width, layers, heads, checkpoint=checkpoint, dropout=dropout,
30 | emb_dropout=emb_dropout)
31 |
32 | self.ln_post = LayerNorm(width)
33 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
34 | self.initialize_parameters()
35 |
36 | def initialize_parameters(self):
37 | nn.init.normal_(self.positional_embedding, std=0.01)
38 |
39 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
40 | attn_std = self.transformer.width ** -0.5
41 | fc_std = (2 * self.transformer.width) ** -0.5
42 | for block in self.transformer.resblocks:
43 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
44 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
45 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
46 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
47 |
48 | def train(self, mode=True):
49 | self.training = mode
50 | for module in self.children():
51 | module.train(mode)
52 |
53 | if self.freeze_conv1:
54 | for layer in [self.conv1]:
55 | layer.eval()
56 | for param in layer.parameters():
57 | param.requires_grad = False
58 | return self
59 |
60 | def forward(self, x: torch.Tensor, return_dense=False, return_feature=False):
61 | x = self.conv1(x) # shape = [*, width, grid, grid]
62 | # shape = [*, width, grid ** 2]
63 | x = x.reshape(x.shape[0], x.shape[1], -1)
64 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
65 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
66 | x = x + self.positional_embedding.to(x.dtype)
67 | x = self.ln_pre(x)
68 |
69 | x = x.permute(1, 0, 2) # NLD -> LND
70 | x = self.transformer(x)
71 | x = x.permute(1, 0, 2) # LND -> NLD
72 |
73 | # x = self.ln_post(x[:, 0, :])
74 | x = self.ln_post(x)
75 | dense_feat = x
76 |
77 | if self.proj is not None:
78 | dense_feat = x @ self.proj
79 | x = dense_feat[:, 0, :]
80 |
81 | if return_dense:
82 | return x, dense_feat
83 | if return_feature:
84 | return dense_feat
85 | return x
86 |
87 |
88 | def visual_transformer(config):
89 | vision_width = 768
90 | vision_layers = 12
91 | vision_heads = vision_width // 64
92 |
93 | kwargs = {
94 | 'layers': vision_layers,
95 | 'heads': vision_heads,
96 | 'input_resolution': config.experiment.input_resolution,
97 | 'patch_size': 16,
98 | 'width': vision_width,
99 | 'checkpoint': False,
100 | 'embed_dim': config.model.embed_dim,
101 | }
102 |
103 | model = VisualTransformer(**kwargs)
104 | return model
105 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser(description="IRRA Args")
6 | ######################## mode ########################
7 | parser.add_argument("--simplified", default=False, action='store_true')
8 |
9 | args = parser.parse_args()
10 |
11 | return args
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.0
2 | torchvision==0.14.0
3 | torchaudio==0.13.0
4 |
5 | timm==0.6.11
6 | wandb==0.13.5
7 | ftfy==6.1.1
8 |
9 | regex
10 | easydict
11 | pyyaml
12 | textaugment
13 | ipdb
14 |
15 | torchmetrics
16 | matplotlib
17 | jupyter
18 | ipykernel
19 |
--------------------------------------------------------------------------------
/shell/train.sh:
--------------------------------------------------------------------------------
1 | OMP_NUM_THREADS=1 \
2 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
3 | torchrun --rdzv_id=3 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 --nproc_per_node=4 \
4 | main.py --simplified
--------------------------------------------------------------------------------
/text_utils/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Flame-Chasers/TBPS-CLIP/6160a877af99229bbf39077b1047d96cf7fda64c/text_utils/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/text_utils/mask_tokens.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Tuple, List
3 |
4 |
5 | def mask_tokens(inputs, special_tokens, mask_token, tokenizer_length, mlm_probability=0.15, special_tokens_mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
6 | """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
7 | labels = inputs.clone()
8 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
9 | probability_matrix = torch.full(labels.shape, mlm_probability)
10 | if special_tokens_mask is None:
11 | special_tokens_mask = [1 if val in special_tokens else 0 for val in labels.tolist()]
12 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
13 | # if tokenizer._pad_token is not None:
14 | # padding_mask = labels.eq(tokenizer.pad_token_id)
15 | # probability_matrix.masked_fill_(padding_mask, value=0.0)
16 | masked_indices = torch.bernoulli(probability_matrix).bool()
17 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
18 |
19 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
20 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
21 | inputs[indices_replaced] = mask_token
22 |
23 | # 10% of the time, we replace masked input tokens with random word
24 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
25 | random_words = torch.randint(tokenizer_length, labels.shape, dtype=torch.long)
26 | inputs[indices_random] = random_words[indices_random]
27 |
28 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
29 | return inputs, labels
30 |
31 |
32 | def MaskTokens(tokens, mask_type, mask_token, special_tokens=None, tokenizer_length=None, sepcial_tokens_mask=None, special_tokens_mask=None):
33 | if mask_type == 'MLM':
34 | tokens, labels = mask_tokens(inputs=tokens, special_tokens=special_tokens, mask_token=mask_token, tokenizer_length=tokenizer_length, special_tokens_mask=special_tokens_mask)
35 | else:
36 | raise NotImplementedError(mask_type)
37 | return tokens, labels
38 |
--------------------------------------------------------------------------------
/text_utils/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2 ** 8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2 ** 8 + n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | # Change: Extend <|mask|> tokenizer-size+=1
63 | class SimpleTokenizer(object):
64 | def __init__(self, bpe_path: str = default_bpe()):
65 | self.byte_encoder = bytes_to_unicode()
66 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
67 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
68 | merges = merges[1:49152 - 256 - 2 + 1]
69 | merges = [tuple(merge.split()) for merge in merges]
70 | vocab = list(bytes_to_unicode().values())
71 | vocab = vocab + [v + '' for v in vocab]
72 | for merge in merges:
73 | vocab.append(''.join(merge))
74 |
75 | vocab.pop(-1) # remove last one in vocab(jekyll) to keep vocab_size unchanged
76 | vocab.extend(['<|mask|>', '<|startoftext|>', '<|endoftext|>']) # vocab_size 49408
77 | # vocab.extend(['<|startoftext|>', '<|endoftext|>']) # vocab_size 49408
78 | self.encoder = dict(zip(vocab, range(len(vocab))))
79 | self.decoder = {v: k for k, v in self.encoder.items()}
80 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|mask|>': '<|mask|>', '<|endoftext|>': '<|endoftext|>'}
82 | self.pat = re.compile(
83 | r"""<\|startoftext\|>|<\|mask\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
84 | re.IGNORECASE)
85 |
86 | def bpe(self, token):
87 | if token in self.cache:
88 | return self.cache[token]
89 | word = tuple(token[:-1]) + (token[-1] + '',)
90 | pairs = get_pairs(word)
91 |
92 | if not pairs:
93 | return token + ''
94 |
95 | while True:
96 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
97 | if bigram not in self.bpe_ranks:
98 | break
99 | first, second = bigram
100 | new_word = []
101 | i = 0
102 | while i < len(word):
103 | try:
104 | j = word.index(first, i)
105 | new_word.extend(word[i:j])
106 | i = j
107 | except:
108 | new_word.extend(word[i:])
109 | break
110 |
111 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
112 | new_word.append(first + second)
113 | i += 2
114 | else:
115 | new_word.append(word[i])
116 | i += 1
117 | new_word = tuple(new_word)
118 | word = new_word
119 | if len(word) == 1:
120 | break
121 | else:
122 | pairs = get_pairs(word)
123 | word = ' '.join(word)
124 | self.cache[token] = word
125 | return word
126 |
127 | def encode(self, text):
128 | bpe_tokens = []
129 | text = whitespace_clean(basic_clean(text)).lower()
130 | for token in re.findall(self.pat, text):
131 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
132 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
133 | return bpe_tokens
134 |
135 | def decode(self, tokens):
136 | text = ''.join([self.decoder[token] for token in tokens])
137 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
138 | return text
139 |
--------------------------------------------------------------------------------
/text_utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 |
3 | import torch
4 |
5 | from .mask_tokens import MaskTokens
6 | from text_utils.simple_tokenizer import SimpleTokenizer as _Tokenizer
7 |
8 | _tokenizer = _Tokenizer()
9 |
10 |
11 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, return_length: bool = False,
12 | mask_type=None):
13 | if isinstance(texts, str):
14 | texts = [texts]
15 |
16 | sot_token = _tokenizer.encoder["<|startoftext|>"]
17 | eot_token = _tokenizer.encoder["<|endoftext|>"]
18 |
19 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
20 | for i, tokens in enumerate(all_tokens):
21 | if len(tokens) > context_length:
22 | all_tokens[i] = [tokens[0]] + tokens[1:context_length - 1] + [tokens[-1]]
23 | all_tokens[i] = torch.Tensor(all_tokens[i]).long()
24 |
25 | if mask_type is not None:
26 | mask_token = _tokenizer.encoder["<|mask|>"]
27 | special_tokens = [sot_token, eot_token, mask_token]
28 | masked_tokens = [
29 | MaskTokens(tokens, mask_type=mask_type, mask_token=mask_token, special_tokens=special_tokens,
30 | tokenizer_length=len(_tokenizer.encoder)) for tokens in all_tokens]
31 | all_tokens = [item[0] for item in masked_tokens]
32 | all_labels = [item[1] for item in masked_tokens]
33 |
34 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
35 | labels = torch.ones(len(all_tokens), context_length, dtype=torch.long) * -100
36 | token_lengths = torch.ones(len(all_tokens), dtype=torch.long)
37 |
38 | for i, tokens in enumerate(all_tokens):
39 | result[i, :len(tokens)] = tokens
40 | token_lengths[i] = min(len(tokens), context_length)
41 | if mask_type is not None:
42 | labels[i, :len(tokens)] = all_labels[i]
43 |
44 | if mask_type:
45 | # print(result[0], labels[0], '<< masking', flush=True)
46 | return result, labels
47 | if return_length:
48 | return result, token_lengths
49 | else:
50 | return result
--------------------------------------------------------------------------------