├── .DS_Store ├── .gitattributes ├── .gitignore ├── LICENSE ├── Pretrain.py ├── README.md ├── Retrieval.py ├── assets ├── image-20230814214836481.png └── pipline.png ├── configs ├── Retrieval_rsicd.yaml ├── Retrieval_rsitmd.yaml ├── config_bert.json └── config_swinT_224.json ├── dataset ├── __init__.py ├── dist_dataset.py ├── grounding_dataset.py ├── nlvr_dataset.py ├── pretrain_dataset.py ├── randaugment.py ├── re_dataset.py └── utils.py ├── fix_data └── rsitmd_precomp │ ├── test_caps.txt │ └── test_filename.txt ├── models ├── __init__.py ├── bert.py ├── model_retrieval.py ├── mytools.py ├── pir.py ├── resnet.py ├── swin_transformer.py ├── tokenization_bert.py └── vit.py ├── mytools.py ├── optim.py ├── requirements.txt ├── run.py ├── scheduler.py └── utils ├── .DS_Store ├── __init__.py ├── checkpointer.py ├── cider └── pyciderevalcap │ ├── __init__.py │ ├── cider │ ├── __init__.py │ ├── cider.py │ └── cider_scorer.py │ └── ciderD │ ├── __init__.py │ ├── ciderD.py │ └── ciderD_scorer.py ├── hdfs_io.py └── torch_io.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zjut-MultimediaPlus/PIR-pytorch/aeadf5ff0fba18ceb496495058c30501b2b5bb34/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 kinshingpoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pretrain.py: -------------------------------------------------------------------------------- 1 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 2 | # Github: https://github.com/zengyan-97/X-VLM 3 | # Copyright (c) 2022, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | import argparse 7 | import os 8 | import sys 9 | 10 | import ruamel.yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | import math 17 | 18 | import torch 19 | from torch.utils.data import DataLoader 20 | import torch.backends.cudnn as cudnn 21 | import torch.distributed as dist 22 | from torch.optim import Optimizer 23 | 24 | 25 | import utils 26 | from dataset import create_dataset 27 | from scheduler import create_scheduler 28 | from optim import create_optimizer 29 | 30 | from utils.checkpointer import Checkpointer 31 | from utils.hdfs_io import hmkdir, hcopy 32 | from accelerators.apex_ddp_accelerator import ApexDDPAccelerator 33 | 34 | 35 | def reinit_scheduler_properties_mysched(optimizer: Optimizer, scheduler, cfg) -> None: 36 | """ 37 | with ApexDDP, do re-init to avoid lr_scheduler warning. 38 | issue: https://github.com/pytorch/pytorch/issues/27595 39 | issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/841 40 | """ 41 | args = cfg 42 | 43 | if scheduler.optimizer == optimizer: 44 | # from transformers import get_linear_schedule_with_warmup 45 | def lr_lambda(current_step: int): 46 | if current_step < args.num_warmup_steps: 47 | return float(current_step) / float(max(1, args.num_warmup_steps)) 48 | return max( 49 | 0.0, float(args.num_training_steps - current_step) / float( 50 | max(1, args.num_training_steps - args.num_warmup_steps)) 51 | ) 52 | 53 | scheduler.__init__(optimizer, lr_lambda, last_epoch=-1) 54 | 55 | 56 | def train(model, general_loader, region_loader, optimizer, epoch_info, device, scheduler, config, accelerator, checkpointer): 57 | model.train() 58 | start_epoch, _ = epoch_info 59 | metric_logger = utils.MetricLogger(delimiter=" ") 60 | metric_logger.add_meter('loss_itc', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 61 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 62 | metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 63 | metric_logger.add_meter('loss_bbox', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 64 | metric_logger.add_meter('loss_giou', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 65 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 66 | metric_logger.add_meter('lr_large', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 67 | 68 | header = 'Train step: [{}]'.format(start_epoch) 69 | assert start_epoch == 0 70 | print_freq = 50 71 | 72 | world_size = utils.get_world_size() 73 | step_per_epoch = math.ceil(config['train_dataset_size']/(config['batch_size']*world_size)) 74 | assert step_per_epoch > 1 75 | global_step = 0 # start from 0 76 | 77 | subarea_iter = iter(region_loader) 78 | 79 | for i, batch in enumerate(metric_logger.log_every(general_loader, print_freq, header, step_per_epoch, epoch_info)): 80 | 81 | if random.random() < config['regions']['iter_perc']: 82 | try: 83 | region_batch = next(subarea_iter) 84 | except StopIteration: 85 | subarea_iter = iter(region_loader) 86 | region_batch = next(subarea_iter) 87 | 88 | image, region_batch = region_batch[0].to(device, non_blocking=True), [ 89 | t.to(device) if t is not None else None for t in region_batch[1:]] 90 | 91 | idx_to_group_img, text_ids, text_atts, text_ids_masked, masked_pos, masked_ids, \ 92 | image_atts, target_bbox, is_image = region_batch 93 | 94 | if config['calc_image_bbox_loss']: 95 | is_image = None 96 | 97 | optimizer.zero_grad() 98 | 99 | loss_itc, loss_itm, loss_mlm, loss_bbox, loss_giou = \ 100 | model(image, text_ids, text_atts, text_ids_masked=text_ids_masked, masked_pos=masked_pos, masked_ids=masked_ids, 101 | image_atts=image_atts, idx_to_group_img=idx_to_group_img, target_bbox=target_bbox, is_image=is_image, ret_bbox_loss=True) 102 | 103 | loss = loss_itc + loss_itm + loss_mlm + loss_bbox + loss_giou 104 | accelerator.backward_step(loss, optimizer) 105 | 106 | accelerator_clip_grad_norm = float(config['accelerator']['CLIP_GRAD_NORM']) 107 | if accelerator_clip_grad_norm > 0: 108 | accelerator.optimizer_step(optimizer, model, accelerator_clip_grad_norm) 109 | optimizer.step() 110 | 111 | metric_logger.update(loss_bbox=loss_bbox.item()) 112 | metric_logger.update(loss_giou=loss_giou.item()) 113 | 114 | else: 115 | # fix it 116 | metric_logger.update(loss_bbox=0.5) 117 | metric_logger.update(loss_giou=0.5) 118 | 119 | image, batch = batch[0].to(device, non_blocking=True), [t.to(device) if t is not None else None for t in batch[1:]] 120 | text_ids, text_atts, text_ids_masked, masked_pos, masked_ids = batch 121 | 122 | optimizer.zero_grad() 123 | 124 | loss_itc, loss_itm, loss_mlm = model(image, text_ids, text_atts, text_ids_masked=text_ids_masked, 125 | masked_pos=masked_pos, masked_ids=masked_ids) 126 | 127 | loss = loss_itc + loss_itm + loss_mlm 128 | accelerator.backward_step(loss, optimizer) 129 | 130 | accelerator_clip_grad_norm = float(config['accelerator']['CLIP_GRAD_NORM']) 131 | if accelerator_clip_grad_norm > 0: 132 | accelerator.optimizer_step(optimizer, model, accelerator_clip_grad_norm) 133 | optimizer.step() 134 | scheduler.step() 135 | 136 | metric_logger.update(loss_itc=loss_itc.item()) 137 | metric_logger.update(loss_itm=loss_itm.item()) 138 | metric_logger.update(loss_mlm=loss_mlm.item()) 139 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 140 | metric_logger.update(lr_large=optimizer.param_groups[2]["lr"]) 141 | 142 | if utils.is_main_process(): 143 | current_epoch = global_step // step_per_epoch 144 | train_stats = {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 145 | 146 | if (global_step+1) % step_per_epoch == 0: 147 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 148 | 'epoch': current_epoch, 149 | } 150 | 151 | with open("log.txt", "a") as f: 152 | f.write(json.dumps(log_stats) + "\n") 153 | 154 | if (current_epoch+1) % config['ckpt_frequent'] == 0: 155 | model_without_ddp = model 156 | if hasattr(model, 'module'): 157 | model_without_ddp = model.module 158 | 159 | save_obj = { 160 | 'model': model_without_ddp.state_dict(), 161 | # 'optimizer': optimizer.state_dict(), 162 | # 'lr_scheduler': scheduler.state_dict(), 163 | 'config': config, 164 | # 'epoch': current_epoch, 165 | } 166 | checkpointer.save_checkpoint(model_state=save_obj, 167 | epoch=current_epoch, 168 | training_states=optimizer.state_dict()) 169 | 170 | if (global_step+1) % config['ckpt_frequent_step'] == 0: 171 | model_without_ddp = model 172 | if hasattr(model, 'module'): 173 | model_without_ddp = model.module 174 | 175 | save_obj = { 176 | 'model': model_without_ddp.state_dict(), 177 | # 'optimizer': optimizer.state_dict(), 178 | # 'lr_scheduler': scheduler.state_dict(), 179 | 'config': config, 180 | # 'epoch': current_epoch, 181 | } 182 | 183 | checkpointer.save_checkpoint(model_state=save_obj, 184 | epoch=current_epoch, step=global_step, 185 | training_states=optimizer.state_dict()) 186 | 187 | global_step += 1 188 | 189 | # gather the stats from all processes 190 | metric_logger.synchronize_between_processes() 191 | print("Averaged stats:", metric_logger.global_avg()) 192 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 193 | 194 | 195 | def main(args, config): 196 | utils.init_distributed_mode(args) 197 | device = torch.device(args.device) 198 | 199 | config['train_file'] = ','.join(config['train_file']) 200 | # config['train_file_regions'] = ','.join(config['train_file_regions']) 201 | config['batch_size'] = config['images']['batch_size'] 202 | 203 | seed = args.seed + utils.get_rank() 204 | torch.manual_seed(seed) 205 | np.random.seed(seed) 206 | random.seed(seed) 207 | cudnn.benchmark = True 208 | 209 | print("Creating dataset", flush=True) 210 | general_dataset, region_dataset = create_dataset('pretrain', config) 211 | 212 | if utils.is_main_process(): 213 | print(f"### train_file: {config['train_file']}", flush=True) 214 | print(f"### train_file_regions: {config['train_file_regions']}", flush=True) 215 | print(f"### batch size, {config['batch_size']} x {int(os.environ.get('WORLD_SIZE', 1))}") 216 | 217 | general_loader = torch.utils.data.DataLoader(general_dataset, batch_size=config['images']['batch_size'], 218 | num_workers=config['images']['num_workers'], 219 | pin_memory=True, 220 | drop_last=False, 221 | collate_fn=general_dataset.collate_fn) 222 | 223 | # region_loader = torch.utils.data.DataLoader(region_dataset, batch_size=config['regions']['max_images'], # batch_size = max_images * max_regions 224 | # num_workers=config['regions']['num_workers'], 225 | # pin_memory=True, 226 | # drop_last=False, 227 | # collate_fn=region_dataset.collate_fn) 228 | 229 | print("Creating model", flush=True) 230 | 231 | if args.model_name == 'XVLM': 232 | from models.model_retrieval import XVLM 233 | model = XVLM(config=config) 234 | elif args.model_name == 'ROBIN': 235 | from models.model_retrieval import ROBIN 236 | model = ROBIN(config=config) 237 | else: 238 | raise ValueError 239 | 240 | print(model) 241 | model = model.to(device) 242 | print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad), flush=True) 243 | 244 | rank = int(os.environ.get('RANK', 0)) 245 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 246 | 247 | arg_opt = utils.AttrDict(config['optimizer']) 248 | optimizer = create_optimizer(arg_opt, model) 249 | 250 | arg_sche = utils.AttrDict(config['schedular']) 251 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 252 | arg_sche['step_per_epoch'] = math.ceil(config['train_dataset_size'] / (config['batch_size'] * world_size)) 253 | lr_scheduler = create_scheduler(arg_sche, optimizer) 254 | 255 | arg_acc = utils.AttrDict(config['accelerator']) 256 | accelerator = ApexDDPAccelerator(arg_acc, logger=None) 257 | 258 | model, optimizer, lr_scheduler = accelerator.set_up(model, optimizer, lr_scheduler, local_rank, world_size, rank) 259 | reinit_scheduler_properties_mysched(optimizer, lr_scheduler, arg_sche) 260 | 261 | checkpointer = Checkpointer(args.output_dir) 262 | 263 | print("### output_dir, ", args.output_dir, flush=True) 264 | start_time = time.time() 265 | 266 | start_epoch = 0 267 | max_epoch = config['schedular']['epochs'] 268 | epoch_info = (start_epoch, max_epoch) 269 | 270 | print("Start training", flush=True) 271 | train(model, general_loader, region_loader, optimizer, epoch_info, device, lr_scheduler, config, 272 | accelerator, checkpointer) 273 | dist.barrier() 274 | 275 | if utils.is_main_process(): 276 | os.system("cat log.txt") 277 | hcopy('log.txt', args.output_dir) 278 | 279 | total_time = time.time() - start_time 280 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 281 | print('Training time {}'.format(total_time_str), flush=True) 282 | 283 | print('### Time {}'.format(total_time_str)) 284 | 285 | 286 | if __name__ == '__main__': 287 | parser = argparse.ArgumentParser() 288 | parser.add_argument('--config', type=str, required=True) 289 | parser.add_argument('--output_dir', type=str, default='output/pretrain') 290 | parser.add_argument('--seed', default=42, type=int) 291 | parser.add_argument('--device', default='cuda') 292 | parser.add_argument('--distributed', action='store_false') 293 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 294 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 295 | args = parser.parse_args() 296 | 297 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 298 | 299 | hmkdir(args.output_dir) 300 | 301 | yaml.dump(config, open('config.yaml', 'w')) 302 | hcopy('config.yaml', args.output_dir) 303 | 304 | main(args, config) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Prior Instruction Representation Framework for Remote Sensing Image-text Retrieval (MM'23 Oral) 2 | 3 | By [Jiancheng Pan](https://scholar.google.com/citations?user=nRPD3tAAAAAJ&hl=en&oi=ao), Qing Ma, [Cong Bai](https://scholar.google.com/citations?hl=zh-CN&user=XGZ4UZgAAAAJ&view_op=list_works&sortby=pubdate). 4 | 5 | This repo is the official implementation of "[A Prior Instruction Representation Framework for Remote Sensing Image-text Retrieval](https://dl.acm.org/doi/abs/10.1145/3591106.3592236)"(MM'23 Oral). 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-prior-instruction-representation-framework/cross-modal-retrieval-on-rsicd)](https://paperswithcode.com/sota/cross-modal-retrieval-on-rsicd?p=a-prior-instruction-representation-framework) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-prior-instruction-representation-framework/cross-modal-retrieval-on-rsitmd)](https://paperswithcode.com/sota/cross-modal-retrieval-on-rsitmd?p=a-prior-instruction-representation-framework) 9 | 10 | - [A Prior Instruction Representation Framework for Remote Sensing Image-text Retrieval (MM'23 Oral)](#a-prior-instruction-representation-framework-for-remote-sensing-image-text-retrieval-mm23-oral) 11 | - [ℹ️ Introduction](#ℹ️-introduction) 12 | - [🎯 Implementation](#-implementation) 13 | - [Project Files](#project-files) 14 | - [Environments](#environments) 15 | - [Train](#train) 16 | - [Test](#test) 17 | - [🌎 Datasets](#-datasets) 18 | - [📊 Results](#-results) 19 | - [🙏 Acknowledgement](#-acknowledgement) 20 | - [📝 Citation](#-citation) 21 | 22 | ## ℹ️ Introduction 23 | 24 | This paper presents a prior instruction representation framework (PIR) for remote sensing image-text retrieval, aimed at remote sensing vision-language understanding tasks to solve the semantic noise problem. Our highlight is the proposal of a paradigm that draws on prior knowledge to instruct adaptive learning of vision and text representations. Concretely, two progressive attention encoder (PAE) structures, Spatial-PAE and Temporal-PAE, are proposed to perform long-range dependency modeling to enhance key feature representation. In vision representation, Vision Instruction Representation (VIR) based on Spatial-PAE exploits the prior-guided knowledge of the remote sensing scene recognition by building a belief matrix to select key features for reducing the impact of semantic noise. In text representation, Language Cycle Attention (LCA) based on Temporal-PAE uses the previous time step to cyclically activate the current time step to enhance text representation capability. A cluster-wise affiliation loss is proposed to constrain the inter-classes and to reduce the semantic confusion zones in the common subspace. Comprehensive experiments demonstrate that using prior knowledge instruction could enhance vision and text representations and could outperform the state-of-the-art methods on two benchmark datasets, RSICD and RSITMD. 25 | 26 | ![pipline](assets/pipline.png) 27 | 28 | ## 🎯 Implementation 29 | ### Project Files 30 | The directory hierarchy is shown below, where the **checkpoints** and **data** files can be downloaded from here [[Baidu Disk]](https://pan.baidu.com/s/1aB-aSfD5h_PS6Ak_tt5RGA?pwd=tqv2) . 31 | 32 | ``` 33 | . 34 | ├── checkpoints 35 | │   └── PIR 36 | │   ├── full_rsicd 37 | │   │   ├── checkpoint_49.pth 38 | │   │   ├── checkpoint_best.pth 39 | │   │   ├── config.yaml 40 | │   │   └── log.txt 41 | │   └── full_rsitmd 42 | │   ├── checkpoint_49.pth 43 | │   ├── checkpoint_best.pth 44 | │   ├── config.yaml 45 | │   └── log.txt 46 | ├── configs 47 | │   ├── config_bert.json 48 | │   ├── config_swinT_224.json 49 | │   ├── Retrieval_rsicd.yaml 50 | │   └── Retrieval_rsitmd.yaml 51 | ├── data 52 | ├── dataset 53 | ├── models 54 | ├── utils 55 | ├── mytools.py 56 | ├── optim.py 57 | ├── Pretrain.py 58 | ├── Retrieval.py 59 | ├── run.py 60 | ├── scheduler.py 61 | └── requirements.txt 62 | ``` 63 | ### Environments 64 | 65 | ```bash 66 | pip install -r requirements.txt 67 | ``` 68 | 69 | ### Train 70 | If you encounter environmental problems, you can directly modify the `get_dist_launch` function of `run.py`, for example:(2 card GPU) 71 | ``` 72 | elif args.dist == 'f2': 73 | return "CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9999 --nproc_per_node=2 " \ 74 | "--nnodes=1 " 75 | ``` 76 | For training, run cmd as follow: 77 | ```bash 78 | python run.py --task 'itr_rsitmd' --dist "f2" --config 'configs/Retrieval_rsitmd.yaml' --output_dir './checkpoints/PIR/full_rsitmd' 79 | 80 | python run.py --task 'itr_rsicd' --dist "f2" --config 'configs/Retrieval_rsicd.yaml' --output_dir './checkpoints/PIR/full_rsicd' 81 | ``` 82 | 83 | ### Test 84 | 85 | ```bash 86 | python run.py --task 'itr_rsitmd' --dist "f2" --config 'configs/Retrieval_rsitmd.yaml' --output_dir './checkpoints/PIR/test' --checkpoint './checkpoints/PIR/full_rsitmd/checkpoint_best.pth' --evaluate 87 | 88 | python run.py --task 'itr_rsicd' --dist "f2" --config 'configs/Retrieval_rsicd.yaml' --output_dir './checkpoints/PIR/test' --checkpoint './checkpoints/PIR/full_rsicd/checkpoint_best.pth' --evaluate 89 | ``` 90 | 91 | ## 🌎 Datasets 92 | 93 | All experiments are based on [RSITMD](https://github.com/xiaoyuan1996/AMFMN/tree/master/RSITMD) and [RSICD](https://github.com/201528014227051/RSICD_optimal) datasets. 94 | 95 | you also can download the images form [Baidu Desk](https://pan.baidu.com/s/1mLkQA8InOxKjseGgEVoaew?pwd=c3c5), and correspondingly modify the `yaml` file under configs files as follows: `image_root: './images/datasets_name/'` 96 | 97 | 98 | ## 📊 Results 99 | 100 | ![image-20230814214836481](assets/image-20230814214836481.png) 101 | 102 | ## 🙏 Acknowledgement 103 | 104 | - Basic code to thank [X-VLM](https://github.com/zengyan-97/X-VLM) by Zeng et al. 105 | 106 | ## 📝 Citation 107 | 108 | If you find this code useful for your work or use it in your project, please cite our paper as: 109 | 110 | ``` 111 | @inproceedings{pan2023prior, 112 | title={A Prior Instruction Representation Framework for Remote Sensing Image-text Retrieval}, 113 | author={Pan, Jiancheng and Ma, Qing and Bai, Cong}, 114 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia}, 115 | pages={611--620}, 116 | year={2023} 117 | } 118 | ``` 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /Retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import math 5 | import ruamel.yaml as yaml 6 | import numpy as np 7 | import random 8 | import time 9 | import datetime 10 | import json 11 | from pathlib import Path 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | from models.tokenization_bert import BertTokenizer 17 | import utils 18 | from dataset import create_dataset, create_sampler, create_loader 19 | from scheduler import create_scheduler 20 | from optim import create_optimizer 21 | from models.model_retrieval import PIR 22 | 23 | 24 | def train(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config): 25 | model.train() 26 | 27 | metric_logger = utils.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | 30 | if config['use_affil_loss']: 31 | metric_logger.add_meter('loss_affil', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 32 | metric_logger.add_meter('loss_contr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 33 | elif config['use_triplet_loss']: 34 | metric_logger.add_meter('loss_triplet', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 35 | else: 36 | metric_logger.add_meter('loss_contr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 37 | header = 'Train Epoch: [{}]'.format(epoch) 38 | print_freq = 50 39 | step_size = 100 40 | print('_________________{}__________________'.format(len(data_loader))) 41 | for i, (image, text, idx, label) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 42 | 43 | image = image.to(device, non_blocking=True) 44 | idx = idx.to(device, non_blocking=True) 45 | ## token 长度调整 46 | text_input = tokenizer(text, padding='longest', max_length=config['max_tokens'], return_tensors="pt").to(device) 47 | # mask_text_input = tokenizer(mask_text, padding='longest', max_length=config['max_tokens'], return_tensors="pt").to(device) 48 | ## 损失函数选择 49 | if config['use_affil_loss']: 50 | loss_contr, loss_affil = model(image, text_input.input_ids, idx=idx, label=label) 51 | loss = loss_contr + config['center_factor'] * loss_affil 52 | elif config['use_triplet_loss']: 53 | loss_triplet = model(image, text_input.input_ids) 54 | loss = loss_triplet 55 | else: 56 | loss_contr = model(image, text_input.input_ids, idx=idx, label=label) 57 | loss = loss_contr 58 | 59 | 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | scheduler.step() 64 | 65 | # 检测是否有没有参与反向传播的模块和参数 66 | # for name, param in model.named_parameters(): 67 | # if param.grad is None: 68 | # print('Miss grad module_name is :'.format(name)) 69 | 70 | 71 | if config['use_affil_loss']: 72 | metric_logger.update(loss_affil=loss_affil.item()) 73 | metric_logger.update(loss_contr=loss_contr.item()) 74 | elif config['use_triplet_loss']: 75 | metric_logger.update(loss_triplet=loss_triplet.item()) 76 | else: 77 | metric_logger.update(loss_contr=loss_contr.item()) 78 | 79 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 80 | 81 | # gather the stats from all processes 82 | metric_logger.synchronize_between_processes() 83 | print("Averaged stats:", metric_logger.global_avg()) 84 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 85 | 86 | 87 | @torch.no_grad() 88 | def evaluation(model, data_loader, tokenizer, device, config): 89 | model.eval() 90 | metric_logger = utils.MetricLogger(delimiter=" ") 91 | header = 'Evaluation:' 92 | print('Computing features for evaluation...') 93 | start_time = time.time() 94 | texts = data_loader.dataset.text 95 | # mask_texts = data_loader.dataset.mask_text 96 | num_text = len(texts) 97 | text_bs = config['batch_size_test_text'] # 256 98 | text_embeds = [] 99 | image_embeds = [] 100 | all_ = [] 101 | print('_________________{}__________________'.format(len(data_loader))) 102 | # Inference 图像特征 103 | for image, img_id in data_loader: 104 | image = image.to(device) 105 | if config['is_baseline']: 106 | image_embed = model.get_vision_embeds(image) 107 | else: 108 | # image_embed = model.get_vision_fusion_embeds(image, config) 109 | t1 = time.time() 110 | image_embed = model.get_vision_fusion_embeds(image, config) 111 | t2 = time.time() 112 | all_.append(t2 - t1) 113 | 114 | image_embeds.append(image_embed) 115 | print("infer image time:{:.2f}".format(np.average(all_))) 116 | # Inference 文本特征 117 | for i in range(0, num_text, text_bs): 118 | text = texts[i: min(num_text, i + text_bs)] 119 | text_input = tokenizer(text, padding='longest', truncation=True, max_length=config['max_tokens'], 120 | return_tensors="pt").to(device) 121 | if config['is_baseline']: 122 | text_embed = model.get_text_embeds(text_input.input_ids) 123 | else: 124 | text_embed = model.get_text_fusion_embeds(text_input.input_ids, config) 125 | 126 | text_embeds.append(text_embed) 127 | 128 | image_embeds = torch.cat(image_embeds, dim=0) 129 | text_embeds = torch.cat(text_embeds, dim=0) 130 | 131 | # 计算image_emb和text_emb的相似度矩阵 132 | sims_matrix = image_embeds @ text_embeds.t() 133 | 134 | score_matrix_i2t = sims_matrix 135 | score_matrix_t2i = sims_matrix.t() 136 | 137 | if args.distributed: 138 | dist.barrier() 139 | torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM) 140 | torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM) 141 | if utils.is_main_process(): 142 | total_time = time.time() - start_time 143 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 144 | print('Evaluation time {}'.format(total_time_str)) 145 | 146 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 147 | 148 | 149 | @torch.no_grad() 150 | def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): 151 | # Images->Text 152 | ranks = np.zeros(scores_i2t.shape[0]) 153 | for index, score in enumerate(scores_i2t): 154 | inds = np.argsort(score)[::-1] 155 | # Score 156 | rank = 1e20 157 | for i in img2txt[index]: 158 | tmp = np.where(inds == i)[0][0] 159 | if tmp < rank: 160 | rank = tmp 161 | ranks[index] = rank 162 | 163 | # Compute metrics 164 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 165 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 166 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 167 | 168 | # Text->Images 169 | ranks = np.zeros(scores_t2i.shape[0]) 170 | 171 | for index, score in enumerate(scores_t2i): 172 | inds = np.argsort(score)[::-1] 173 | ranks[index] = np.where(inds == txt2img[index])[0][0] 174 | 175 | # Compute metrics 176 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 177 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 178 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 179 | 180 | tr_mean = (tr1 + tr5 + tr10) / 3 181 | ir_mean = (ir1 + ir5 + ir10) / 3 182 | r_mean = (tr_mean + ir_mean) / 2 183 | 184 | eval_result = {'txt_r1': round(tr1,2), 185 | 'txt_r5': round(tr5,2), 186 | 'txt_r10': round(tr10,2), 187 | 'img_r1': round(ir1,2), 188 | 'img_r5': round(ir5,2), 189 | 'img_r10': round(ir10,2), 190 | 'r_mean': round(r_mean,2)} 191 | return eval_result 192 | 193 | 194 | def main(args, config): 195 | utils.init_distributed_mode(args) 196 | device = torch.device(args.device) 197 | 198 | world_size = utils.get_world_size() 199 | 200 | if args.bs > 0: 201 | config['batch_size_train'] = args.bs // world_size 202 | 203 | seed = args.seed + utils.get_rank() 204 | torch.manual_seed(seed) 205 | np.random.seed(seed) 206 | random.seed(seed) 207 | cudnn.benchmark = True 208 | 209 | print("Creating model", flush=True) 210 | 211 | model = PIR(config=config) 212 | 213 | # load pre-trianed model 214 | # 不加载预训练模型 215 | if args.checkpoint != '-1': 216 | model.load_pretrained(args.checkpoint, config, is_eval=args.evaluate) 217 | model = model.to(device) 218 | print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad)) 219 | 220 | model_without_ddp = model 221 | if args.distributed: 222 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 223 | model_without_ddp = model.module 224 | 225 | 226 | tokenizer = BertTokenizer.from_pretrained(config['text_encoder']) 227 | 228 | print("Creating retrieval dataset", flush=True) 229 | train_dataset, val_dataset, test_dataset = create_dataset('re', config, args.evaluate) 230 | 231 | start_time = time.time() 232 | print("### output_dir, ", args.output_dir, flush=True) 233 | 234 | if args.evaluate: 235 | print("Start evaluating", flush=True) 236 | test_loader = create_loader([test_dataset], [None], 237 | batch_size=[config['batch_size_test']], 238 | num_workers=[4], 239 | is_trains=[False], 240 | collate_fns=[None])[0] 241 | # val and test 242 | # score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, tokenizer, device, config) 243 | score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, tokenizer, device, config) 244 | 245 | if utils.is_main_process(): 246 | # val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt) 247 | # print(val_result) 248 | test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt) 249 | print(test_result) 250 | 251 | dist.barrier() 252 | 253 | else: 254 | print("Start training", flush=True) 255 | 256 | train_dataset_size = len(train_dataset) 257 | 258 | if utils.is_main_process(): 259 | print(f"### data {train_dataset_size}, batch size, {config['batch_size_train']} x {world_size}") 260 | 261 | if args.distributed: 262 | num_tasks = utils.get_world_size() 263 | global_rank = utils.get_rank() 264 | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None] 265 | else: 266 | samplers = [None, None, None] 267 | 268 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset], samplers, 269 | batch_size=[config['batch_size_train']] + [ 270 | config['batch_size_test']] * 2, 271 | num_workers=[4, 4, 4], 272 | is_trains=[True, False, False], 273 | collate_fns=[None, None, None]) 274 | 275 | arg_opt = utils.AttrDict(config['optimizer']) 276 | optimizer = create_optimizer(arg_opt, model) 277 | arg_sche = utils.AttrDict(config['schedular']) 278 | arg_sche['step_per_epoch'] = math.ceil(train_dataset_size/(config['batch_size_train']*world_size)) 279 | lr_scheduler = create_scheduler(arg_sche, optimizer) 280 | 281 | max_epoch = config['schedular']['epochs'] 282 | best = 0 283 | best_epoch = 0 284 | 285 | for epoch in range(0, max_epoch): 286 | if args.distributed: 287 | train_loader.sampler.set_epoch(epoch) 288 | train_stats = train(model, train_loader, optimizer, tokenizer, epoch, device, lr_scheduler, config) 289 | 290 | # score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, tokenizer, device, config) 291 | score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, tokenizer, device, config) 292 | 293 | if utils.is_main_process(): 294 | # val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt) 295 | # print(val_result) 296 | test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt) 297 | print(test_result) 298 | 299 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 300 | # **{f'val_{k}': v for k, v in val_result.items()}, 301 | **{f'test_{k}': v for k, v in test_result.items()}, 302 | 'epoch': epoch} 303 | 304 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 305 | f.write(json.dumps(log_stats) + "\n") 306 | 307 | if test_result['r_mean'] > best: 308 | save_obj = { 309 | 'model': model_without_ddp.state_dict(), 310 | # 'optimizer': optimizer.state_dict(), 311 | # 'lr_scheduler': lr_scheduler.state_dict(), 312 | 'config': config, 313 | # 'epoch': epoch, 314 | } 315 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 316 | best = test_result['r_mean'] 317 | best_epoch = epoch 318 | 319 | elif epoch >= config['schedular']['epochs'] - 1: 320 | save_obj = { 321 | 'model': model_without_ddp.state_dict(), 322 | # 'optimizer': optimizer.state_dict(), 323 | # 'lr_scheduler': lr_scheduler.state_dict(), 324 | 'config': config, 325 | # 'epoch': epoch, 326 | } 327 | torch.save(save_obj, os.path.join(args.output_dir, f'checkpoint_{epoch}.pth')) 328 | 329 | dist.barrier() 330 | torch.cuda.empty_cache() 331 | 332 | if utils.is_main_process(): 333 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 334 | f.write("best epoch: %d" % best_epoch) 335 | 336 | os.system(f"cat {args.output_dir}/log.txt") 337 | 338 | total_time = time.time() - start_time 339 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 340 | print('### Time {}'.format(total_time_str)) 341 | 342 | 343 | if __name__ == '__main__': 344 | parser = argparse.ArgumentParser() 345 | parser.add_argument('--checkpoint', type=str, required=True) 346 | parser.add_argument('--config', type=str, required=True) 347 | parser.add_argument('--output_dir', type=str, required=True) # this script works for both mscoco and flickr30k 348 | parser.add_argument('--device', default='cuda') 349 | parser.add_argument('--seed', default=42, type=int) 350 | parser.add_argument('--world_size', default=2, type=int, help='number of distributed processes') 351 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 352 | parser.add_argument('--distributed', action='store_false') 353 | parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus") 354 | parser.add_argument('--evaluate', action='store_true') 355 | 356 | args = parser.parse_args() 357 | 358 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 359 | 360 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 361 | 362 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 363 | 364 | main(args, config) 365 | -------------------------------------------------------------------------------- /assets/image-20230814214836481.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zjut-MultimediaPlus/PIR-pytorch/aeadf5ff0fba18ceb496495058c30501b2b5bb34/assets/image-20230814214836481.png -------------------------------------------------------------------------------- /assets/pipline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zjut-MultimediaPlus/PIR-pytorch/aeadf5ff0fba18ceb496495058c30501b2b5bb34/assets/pipline.png -------------------------------------------------------------------------------- /configs/Retrieval_rsicd.yaml: -------------------------------------------------------------------------------- 1 | ############## The train & val & test set root ############## 2 | train_file: ['data/finetune/rsicd_train.json'] 3 | val_file: 'data/finetune/rsicd_val.json' 4 | test_file: 'data/finetune/rsicd_test.json' 5 | image_root: '../X-VLM-pytorch/images/rsicd/' 6 | 7 | ############## Vision encoder setting ############## 8 | vision_config: 'configs/config_swinT_224.json' # configs/config_swinT_224.json 'configs/config_swinB_224.json' 9 | resnet_ckpt: 'data/aid_28-rsp-resnet-50-ckpt.pth' # 'data/aid_28-rsp-resnet-50-ckpt.pth' or 'INS/aid_resnet50.pth' 10 | finetune_conv: False # whether fintue the conv encoder 11 | use_swin: True # if use swin, using 'True' 12 | image_res: 224 # no need modify 13 | patch_size: 32 #if use swin, set the patch_size to 32, else 16 14 | 15 | ############## Text encoder setting ############## 16 | text_config: 'configs/config_bert.json' 17 | text_encoder: 'data/bert-base-uncased' 18 | 19 | ################ Training setting ################ 20 | #== no need revise in general 21 | batch_size_train: 128 22 | batch_size_test: 128 23 | batch_size_test_text: 128 24 | max_tokens: 47 25 | embed_dim: 512 26 | temp1: 0.07 27 | temp2: 0.07 28 | k_test: 512 29 | is_baseline: False # whether is baseline 30 | 31 | ############## Other Settings ############## 32 | optimizer: {opt: adamW, lr: 6e-5, weight_decay: 0.01, lr_mult: 2} # 3e-5 ana 6e-5 are all you need 33 | schedular: {sched: linear, lr: 6e-5, epochs: 50, num_warmup_steps: 0.1} # need to set the epoches, if needed, also lr 34 | 35 | ################ Model setting ####################################################################################### 36 | #== 1. Representation Alignment, RA #### 37 | use_affil_loss: True # use affil loss 38 | use_triplet_loss: False 39 | center_factor: 1 # if use affil loss, set the center factor #### 40 | # indistinct_margin: 0.01 #### 41 | #### 42 | #== 2. Vision Instruction Representation, VIR #### 43 | filter_size: 40 # modify the filter size of vision instruction representation #### 44 | instru_num: 2 #=## 45 | #== 3. Language Cycle Attention, LCA #### 46 | cycle_num: 3 # how many times of cycle attention #### 47 | #### 48 | #== 4. the SA & CA parameter (include VIR and LCA Module) #### 49 | dropout_r: 0.2 #=## 50 | head: 8 #=## 51 | ######################################################################################################################## 52 | -------------------------------------------------------------------------------- /configs/Retrieval_rsitmd.yaml: -------------------------------------------------------------------------------- 1 | ############## The train & val & test set root ############## 2 | train_file: ['data/finetune/rsitmd_train.json'] 3 | val_file: 'data/finetune/rsitmd_val.json' 4 | test_file: 'data/finetune/rsitmd_test.json' 5 | image_root: '../X-VLM-pytorch/images/rsitmd/' 6 | 7 | ############## Vision encoder setting ############## 8 | vision_config: 'configs/config_swinT_224.json' # configs/config_swinT_224.json 'configs/config_swinB_224.json' 9 | 10 | resnet_ckpt: 'data/aid_28-rsp-resnet-50-ckpt.pth' # 'data/aid_28-rsp-resnet-50-ckpt.pth' or 'INS/aid_resnet50.pth' 11 | 12 | finetune_conv: False # whether fintue the conv encoder 13 | use_swin: True # if use swin, using 'True' 14 | image_res: 224 # no need modify 15 | patch_size: 32 #if use swin, set the patch_size to 32, else 16 16 | 17 | ############## Text encoder setting ############## 18 | text_config: 'configs/config_bert.json' 19 | text_encoder: 'data/bert-base-uncased' 20 | 21 | ################ Training setting ################ 22 | #== no need revise in general 23 | batch_size_train: 128 24 | batch_size_test: 128 25 | batch_size_test_text: 128 26 | 27 | max_tokens: 47 28 | embed_dim: 512 29 | temp1: 0.07 30 | temp2: 0.07 31 | k_test: 512 32 | is_baseline: False # whether is baseline 33 | 34 | ############## Other Settings ############## 35 | optimizer: {opt: adamW, lr: 6e-5, weight_decay: 0.01, lr_mult: 2} # 3e-5 ana 6e-5 are all you need 36 | schedular: {sched: linear, lr: 6e-5, epochs: 50, num_warmup_steps: 0.1} # need to set the epoches, if needed, also lr 37 | 38 | ################ Model setting ####################################################################################### 39 | #== 1. Representation Alignment, RA #### 40 | use_affil_loss: True # use affil loss 41 | use_triplet_loss: False 42 | center_factor: 1 # if use affil loss, set the center factor #### 43 | # indistinct_margin: 0.01 #### 44 | #### 45 | #== 2. Vision Instruction Representation, VIR #### 46 | filter_size: 40 # modify the filter size of vision instruction representation #### 47 | instru_num: 2 # 2 #=## 48 | #== 3. Language Cycle Attention, LCA #### 49 | cycle_num: 3 # 3 # how many times of cycle attention #### 50 | #### 51 | #== 4. the SA & CA parameter (include VIR and LCA Module) #### 52 | dropout_r: 0.2 #=## 53 | head: 8 #=## 54 | ######################################################################################################################## 55 | 56 | -------------------------------------------------------------------------------- /configs/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "gradient_checkpointing": false, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 3072, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "pad_token_id": 0, 18 | "position_embedding_type": "absolute", 19 | "transformers_version": "4.6.0.dev0", 20 | "type_vocab_size": 2, 21 | "use_cache": true, 22 | "vocab_size": 30522 23 | } 24 | -------------------------------------------------------------------------------- /configs/config_swinT_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "data/swin_tiny_patch4_window7_224_22k.pth", 3 | "vision_width": 768, 4 | "image_res": 224, 5 | "window_size": 7, 6 | "embed_dim": 96, 7 | "depths": [ 2, 2, 6, 2 ], 8 | "num_heads": [ 3, 6, 12, 24 ] 9 | } 10 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import transforms 5 | from PIL import Image 6 | 7 | from dataset.re_dataset import re_train_dataset, re_eval_dataset 8 | from dataset.pretrain_dataset import ImageTextJsonDataset, RegionTextJsonDataset 9 | 10 | 11 | from dataset.randaugment import RandomAugment 12 | from torchvision.transforms import InterpolationMode 13 | 14 | def create_dataset(dataset, config, evaluate=False): 15 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 16 | 17 | pretrain_transform = transforms.Compose([ 18 | transforms.RandomResizedCrop(config['image_res'], scale=(0.2, 1.0), 19 | interpolation=InterpolationMode.BICUBIC), 20 | transforms.RandomHorizontalFlip(), 21 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | 27 | train_transform = transforms.Compose([ 28 | transforms.RandomResizedCrop(config['image_res'], scale=(0.5, 1.0), 29 | interpolation=InterpolationMode.BICUBIC), 30 | transforms.RandomHorizontalFlip(), 31 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', 32 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 33 | transforms.ToTensor(), 34 | normalize, 35 | ]) 36 | 37 | train_transform_wohflip = transforms.Compose([ 38 | transforms.RandomResizedCrop(config['image_res'], scale=(0.5, 1.0), 39 | interpolation=InterpolationMode.BICUBIC), 40 | # transforms.RandomHorizontalFlip(), 41 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', 42 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 43 | transforms.ToTensor(), 44 | normalize, 45 | ]) 46 | 47 | box_transform = transforms.Compose([ 48 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness']), 49 | transforms.ToTensor(), 50 | normalize, 51 | ]) 52 | 53 | test_transform = transforms.Compose([ 54 | transforms.Resize((config['image_res'], config['image_res']), interpolation=InterpolationMode.BICUBIC), 55 | transforms.ToTensor(), 56 | normalize, 57 | ]) 58 | 59 | if dataset == 'pretrain': 60 | general_dataset = ImageTextJsonDataset(config, config['train_file'], rank=int(os.environ.get('RANK') or 0), 61 | world_size=int(os.environ.get('WORLD_SIZE') or 1), shuffle=True, repeat=True, 62 | transform=pretrain_transform) 63 | 64 | region_dataset = RegionTextJsonDataset(config, config['train_file_regions'], rank=int(os.environ.get('RANK') or 0), 65 | world_size=int(os.environ.get('WORLD_SIZE') or 1), shuffle=True, repeat=True, 66 | transform=pretrain_transform, box_transform=box_transform) 67 | 68 | return general_dataset, region_dataset 69 | 70 | elif dataset == 're': 71 | test_dataset = re_eval_dataset(config['test_file'], test_transform, config['image_root']) 72 | if evaluate: 73 | return None, None, test_dataset 74 | 75 | train_dataset = re_train_dataset(config['train_file'], train_transform, config['image_root']) 76 | val_dataset = re_eval_dataset(config['val_file'], test_transform, config['image_root']) 77 | return train_dataset, val_dataset, test_dataset 78 | 79 | elif dataset == 'vqa': 80 | vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], 81 | split='test', answer_list=config['answer_list'], 82 | text_encoder=config['text_encoder'], use_roberta=config['use_roberta']) 83 | if evaluate: 84 | return None, vqa_test_dataset 85 | 86 | train_dataset = vqa_dataset(config['train_file'], train_transform_wohflip, config['vqa_root'], config['vg_root'], 87 | split='train', text_encoder=config['text_encoder'], use_roberta=config['use_roberta']) 88 | return train_dataset, vqa_test_dataset 89 | 90 | elif dataset == 'nlvr_pretrain': 91 | general_dataset = ImageTextJsonDataset(config, config['train_file'], rank=int(os.environ.get('RANK') or 0), 92 | world_size=int(os.environ.get('WORLD_SIZE') or 1), shuffle=True, repeat=True, 93 | transform=pretrain_transform) 94 | 95 | return general_dataset 96 | 97 | elif dataset == 'nlvr': 98 | test_dataset = nlvr_dataset(config['test_file'], test_transform, config['image_root']) 99 | if evaluate: 100 | return None, None, test_dataset 101 | 102 | train_dataset = nlvr_dataset(config['train_file'], train_transform, config['image_root']) 103 | val_dataset = nlvr_dataset(config['val_file'], test_transform, config['image_root']) 104 | return train_dataset, val_dataset, test_dataset 105 | 106 | elif dataset == 'grounding': 107 | test_dataset = grounding_dataset(config['test_file'], test_transform, config['image_root'], mode='test') 108 | if evaluate: 109 | return None, test_dataset 110 | 111 | train_transform = transforms.Compose([ 112 | transforms.Resize((config['image_res'], config['image_res']), interpolation=InterpolationMode.BICUBIC), 113 | transforms.RandomHorizontalFlip(), 114 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', 115 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 116 | transforms.ToTensor(), 117 | normalize, 118 | ]) 119 | train_dataset = grounding_dataset(config['train_file'], train_transform, config['image_root'], mode='train') 120 | return train_dataset, test_dataset 121 | 122 | elif dataset == 'grounding_bbox_pretrain': 123 | region_dataset = RegionTextJsonDataset(config, config['train_file_regions'], rank=int(os.environ.get('RANK') or 0), 124 | world_size=int(os.environ.get('WORLD_SIZE') or 1), shuffle=True, repeat=True, 125 | transform=pretrain_transform, box_transform=box_transform) 126 | 127 | return region_dataset 128 | 129 | elif dataset == 'grounding_bbox': 130 | test_dataset = grounding_dataset_bbox(config['test_file'], test_transform, config['image_root'], mode='test', config=config) 131 | if evaluate: 132 | return None, test_dataset 133 | 134 | train_transform = transforms.Compose([ 135 | RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness']), 136 | transforms.ToTensor(), 137 | normalize, 138 | ]) 139 | train_dataset = grounding_dataset_bbox(config['train_file'], train_transform, config['image_root'], mode='train', config=config) 140 | return train_dataset, test_dataset 141 | 142 | elif dataset == 'captioning_pretrain': 143 | general_dataset = ImageTextJsonDataset(config, config['train_file'], rank=int(os.environ.get('RANK') or 0), 144 | world_size=int(os.environ.get('WORLD_SIZE') or 1), shuffle=True, repeat=True, 145 | transform=pretrain_transform, add_eos=True) 146 | return general_dataset 147 | 148 | elif dataset == 'caption_coco': 149 | train_dataset = coco_karpathy_train(train_transform, config['image_root'], config['train_file'], prompt=config['prompt'], max_words=config['max_tokens']) 150 | val_dataset = coco_karpathy_caption_eval(test_transform, config['image_root'], config['val_file'], 'val') 151 | test_dataset = coco_karpathy_caption_eval(test_transform, config['image_root'], config['test_file'], 'test') 152 | 153 | return train_dataset, val_dataset, test_dataset 154 | 155 | elif dataset == 'caption_coco_scst': 156 | train_dataset = coco_karpathy_train_scst(train_transform, config['image_root'], config['train_file'], 157 | prompt=config['prompt'], max_words=config['max_tokens']) 158 | val_dataset = coco_karpathy_caption_eval(test_transform, config['image_root'], config['val_file'], 'val') 159 | test_dataset = coco_karpathy_caption_eval(test_transform, config['image_root'], config['test_file'], 'test') 160 | 161 | return train_dataset, val_dataset, test_dataset 162 | 163 | else: 164 | raise NotImplementedError(f"dataset == {dataset}") 165 | 166 | 167 | def vqa_collate_fn(batch): 168 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] 169 | for image, question, answer, weights in batch: 170 | image_list.append(image) 171 | question_list.append(question) 172 | weight_list += weights 173 | answer_list += answer 174 | n.append(len(answer)) 175 | return torch.stack(image_list, dim=0), question_list, answer_list, torch.Tensor(weight_list), n 176 | 177 | 178 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 179 | samplers = [] 180 | for dataset, shuffle in zip(datasets, shuffles): 181 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 182 | samplers.append(sampler) 183 | return samplers 184 | 185 | 186 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 187 | loaders = [] 188 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(datasets, samplers, batch_size, num_workers, 189 | is_trains, collate_fns): 190 | if is_train: 191 | shuffle = (sampler is None) 192 | drop_last = True 193 | else: 194 | shuffle = False 195 | drop_last = False 196 | loader = DataLoader( 197 | dataset, 198 | batch_size=bs, 199 | num_workers=n_worker, 200 | pin_memory=True, 201 | sampler=sampler, 202 | shuffle=shuffle, 203 | collate_fn=collate_fn, 204 | drop_last=drop_last 205 | ) 206 | loaders.append(loader) 207 | 208 | if len(loaders) <= 1: 209 | print(f"### be careful: func create_loader returns a list length of {len(loaders)}") 210 | 211 | return loaders 212 | -------------------------------------------------------------------------------- /dataset/dist_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 4 | # Github: https://github.com/zengyan-97/X-VLM 5 | # Copyright (c) 2022, ByteDance Inc. 6 | # All rights reserved. 7 | 8 | import sys 9 | from typing import List, Any 10 | import warnings 11 | import random 12 | from itertools import cycle 13 | import torch 14 | from torch.utils.data import IterableDataset 15 | 16 | from utils.hdfs_io import hopen, hlist_files 17 | 18 | 19 | class DistLineReadingDataset(IterableDataset): # pylint: disable=W0223 20 | """ 21 | iterate a set of folders. 22 | """ 23 | def __init__(self, 24 | data_path: str, 25 | rank: int = 0, 26 | world_size: int = 1, 27 | shuffle: bool = False, 28 | repeat: bool = False): 29 | super().__init__() 30 | self.shuffle = shuffle 31 | self.rank = rank 32 | self.world_size = world_size 33 | 34 | self.files = hlist_files(data_path.split(',')) 35 | self.files = [f for f in self.files if f.find('_SUCCESS') < 0] 36 | self.is_hdfs = data_path.startswith('hdfs') 37 | 38 | self.repeat = repeat 39 | print('[DATA]--all dataset containing {} files.'.format(len(self.files))) 40 | if len(self.files) % self.world_size != 0: 41 | print('[DATA]--Whole dataset file num %s cannot split to worldsize %s ' % 42 | (len(self.files), self.world_size)) 43 | sys.stdout.flush() 44 | 45 | def generate(self): 46 | if self.world_size == 1 or len(self.files) == 1: 47 | cur_dataloader_files = self.files 48 | else: 49 | cur_dataloader_files = split_shard( 50 | self.files, self.rank, self.world_size) 51 | 52 | while True: 53 | if self.shuffle: 54 | random.shuffle(cur_dataloader_files) 55 | worker_info = torch.utils.data.get_worker_info() 56 | 57 | if worker_info is not None: 58 | if len(cur_dataloader_files) % worker_info.num_workers != 0: 59 | print('[DATA]--current dataloader %s file num %s cannot split to worker_num %s ' % 60 | (self.rank, len(cur_dataloader_files), worker_info.num_workers)) 61 | cur_worker_files = split_shard( 62 | cur_dataloader_files, worker_info.id, worker_info.num_workers) 63 | if worker_info.id == 0: 64 | print("[DataLoader] --> Rank:{} Workers:[{} ~ {}][{}] Size of process file:{} ...".format( 65 | self.rank, 0, worker_info.num_workers - 1, worker_info.id, len(cur_dataloader_files))) 66 | else: 67 | cur_worker_files = cur_dataloader_files 68 | 69 | if self.shuffle: 70 | random.shuffle(cur_worker_files) 71 | for filepath in cur_worker_files: 72 | if self.is_hdfs: 73 | with hopen(filepath, 'r') as reader: 74 | for line in reader: 75 | yield line.decode() 76 | continue 77 | with open(filepath, 'r') as reader: 78 | for line in reader: 79 | yield line 80 | 81 | if not self.repeat: 82 | break 83 | 84 | def __iter__(self): 85 | return self.generate() 86 | 87 | 88 | def split_shard(data: List[Any], shard_idx: int, shard_size: int): 89 | num = len(data) 90 | if num < shard_size: 91 | raise RuntimeError("num:{} < shard size:{}".format(num, shard_size)) 92 | start_idx = (num * shard_idx) // shard_size 93 | end_idx = (num * (shard_idx + 1)) // shard_size 94 | return data[start_idx: end_idx] 95 | -------------------------------------------------------------------------------- /dataset/grounding_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import math 4 | import random 5 | from random import random as rand 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | from torchvision.transforms.functional import hflip, resize 11 | 12 | from PIL import Image 13 | from dataset.utils import pre_caption 14 | # from refTools.refer_python3 import REFER 15 | 16 | 17 | class grounding_dataset(Dataset): 18 | def __init__(self, ann_file, transform, image_root, max_words=30, mode='train'): 19 | self.ann = [] 20 | for f in ann_file: 21 | self.ann += json.load(open(f, 'r')) 22 | self.transform = transform 23 | self.image_root = image_root 24 | self.max_words = max_words 25 | self.mode = mode 26 | 27 | if self.mode == 'train': 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.ann: 31 | img_id = ann['image'].split('/')[-1] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.ann) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.ann[index] 42 | 43 | image_path = os.path.join(self.image_root, ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = pre_caption(ann['text'], self.max_words) 48 | 49 | if self.mode == 'train': 50 | img_id = ann['image'].split('/')[-1] 51 | 52 | return image, caption, self.img_ids[img_id] 53 | else: 54 | return image, caption, ann['ref_id'] 55 | -------------------------------------------------------------------------------- /dataset/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from dataset.utils import pre_caption 6 | 7 | 8 | class nlvr_dataset(Dataset): 9 | def __init__(self, ann_file, transform, image_root): 10 | self.ann = [] 11 | for f in ann_file: 12 | self.ann += json.load(open(f, 'r')) 13 | self.transform = transform 14 | self.image_root = image_root 15 | self.max_words = 30 16 | 17 | def __len__(self): 18 | return len(self.ann) 19 | 20 | def __getitem__(self, index): 21 | 22 | ann = self.ann[index] 23 | 24 | image0_path = os.path.join(self.image_root, ann['images'][0]) 25 | image0 = Image.open(image0_path).convert('RGB') 26 | image0 = self.transform(image0) 27 | 28 | image1_path = os.path.join(self.image_root, ann['images'][1]) 29 | image1 = Image.open(image1_path).convert('RGB') 30 | image1 = self.transform(image1) 31 | 32 | sentence = pre_caption(ann['sentence'], self.max_words) 33 | 34 | if ann['label'] == 'True': 35 | label = 1 36 | 37 | elif ann['label'] == 'False': 38 | label = 0 39 | 40 | else: 41 | raise ValueError(f"unsupported label: {ann['label']}") 42 | 43 | return image0, image1, sentence, label 44 | -------------------------------------------------------------------------------- /dataset/pretrain_dataset.py: -------------------------------------------------------------------------------- 1 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 2 | # Github: https://github.com/zengyan-97/X-VLM 3 | # Copyright (c) 2022, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | import json 7 | import copy 8 | import math 9 | import random 10 | import sys 11 | import re 12 | import io 13 | import traceback 14 | from base64 import b64decode 15 | 16 | from random import randint, shuffle 17 | from random import random as rand 18 | 19 | import torch 20 | from torchvision.transforms.functional import hflip, resize 21 | from transformers import BertTokenizer, RobertaTokenizer 22 | 23 | from PIL import Image 24 | from PIL import ImageFile 25 | ImageFile.LOAD_TRUNCATED_IMAGES = True 26 | Image.MAX_IMAGE_PIXELS = None 27 | 28 | from dataset.utils import pre_caption 29 | from dataset.dist_dataset import DistLineReadingDataset 30 | 31 | 32 | class TextMaskingGenerator: 33 | def __init__(self, tokenizer, mask_prob, mask_max, skipgram_prb=0.2, skipgram_size=3, mask_whole_word=True, use_roberta=False): 34 | self.id2token = {i: w for w, i in tokenizer.get_vocab().items()} 35 | print("len(tokenizer.id2token), ", len(self.id2token), flush=True) 36 | 37 | self.use_roberta = use_roberta 38 | 39 | for i in range(len(self.id2token)): 40 | assert i in self.id2token.keys() # check 41 | 42 | self.cls_token = tokenizer.cls_token 43 | self.mask_token = tokenizer.mask_token 44 | print("mask_generator.cls_token, ", self.cls_token, flush=True) 45 | print("mask_generator.mask_token, ", self.mask_token, flush=True) 46 | 47 | self.mask_max = mask_max 48 | self.mask_prob = mask_prob 49 | 50 | self.skipgram_prb = skipgram_prb 51 | self.skipgram_size = skipgram_size 52 | self.mask_whole_word = mask_whole_word 53 | 54 | def get_random_word(self): 55 | i = randint(0, len(self.id2token) - 1) 56 | return self.id2token[i] 57 | 58 | def __call__(self, tokens: list): # tokens: [CLS] + ... 59 | n_pred = min(self.mask_max, max( 60 | 1, int(round(len(tokens) * self.mask_prob)))) 61 | 62 | # candidate positions of masked tokens 63 | assert tokens[0] == self.cls_token 64 | special_pos = set([0]) # will not be masked 65 | cand_pos = list(range(1, len(tokens))) 66 | 67 | shuffle(cand_pos) 68 | masked_pos = set() 69 | max_cand_pos = max(cand_pos) 70 | for pos in cand_pos: 71 | if len(masked_pos) >= n_pred: 72 | break 73 | if pos in masked_pos: 74 | continue 75 | 76 | def _expand_whole_word(st, end): 77 | new_st, new_end = st, end 78 | 79 | if self.use_roberta: 80 | while (new_st > 1) and (tokens[new_st][0] != 'Ġ'): 81 | new_st -= 1 82 | while (new_end < len(tokens)) and (tokens[new_end][0] != 'Ġ'): 83 | new_end += 1 84 | else: 85 | # bert, WordPiece 86 | while (new_st >= 0) and tokens[new_st].startswith('##'): 87 | new_st -= 1 88 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 89 | new_end += 1 90 | 91 | return new_st, new_end 92 | 93 | if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 94 | # ngram 95 | cur_skipgram_size = randint(2, self.skipgram_size) 96 | if self.mask_whole_word: 97 | st_pos, end_pos = _expand_whole_word( 98 | pos, pos + cur_skipgram_size) 99 | else: 100 | st_pos, end_pos = pos, pos + cur_skipgram_size 101 | else: 102 | if self.mask_whole_word: 103 | st_pos, end_pos = _expand_whole_word(pos, pos + 1) 104 | else: 105 | st_pos, end_pos = pos, pos + 1 106 | 107 | for mp in range(st_pos, end_pos): 108 | if (0 < mp <= max_cand_pos) and (mp not in special_pos): 109 | masked_pos.add(mp) 110 | else: 111 | break 112 | 113 | masked_pos = list(masked_pos) 114 | n_real_pred = len(masked_pos) 115 | if n_real_pred > n_pred: 116 | shuffle(masked_pos) 117 | masked_pos = masked_pos[:n_pred] 118 | 119 | for pos in masked_pos: 120 | if rand() < 0.8: # 80% 121 | tokens[pos] = self.mask_token 122 | elif rand() < 0.5: # 10% 123 | tokens[pos] = self.get_random_word() 124 | 125 | return tokens, masked_pos 126 | 127 | 128 | class ImageTextJsonDataset(DistLineReadingDataset): 129 | def __init__(self, config, data_path, rank=0, world_size=1, shuffle=True, repeat=True, transform=None, add_eos=False): 130 | super().__init__(data_path, rank, world_size, shuffle, repeat) 131 | 132 | if 'images' in config.keys(): 133 | self.image_key = config['images']['image_key'] 134 | self.is_image_rpath = config['images']['is_image_rpath'] 135 | self.caption_key = config['images']['caption_key'] 136 | self.batch_size = config['images']['batch_size'] 137 | self.tokenized = config['images']['tokenized'] 138 | 139 | self.use_roberta = config['use_roberta'] 140 | self.tokenizer = RobertaTokenizer.from_pretrained(config['text_encoder']) if self.use_roberta else \ 141 | BertTokenizer.from_pretrained(config['text_encoder']) 142 | 143 | self.add_eos = add_eos 144 | 145 | self.cls_token = self.tokenizer.cls_token 146 | self.eos_token = self.tokenizer.sep_token 147 | self.pad_token_id = self.tokenizer.pad_token_id 148 | self.mask_token_id = self.tokenizer.mask_token_id 149 | 150 | print("dataset.cls_token, ", self.cls_token, flush=True) 151 | print("dataset.eos_token, ", self.eos_token, flush=True) 152 | print("dataset.pad_token_id, ", self.pad_token_id, flush=True) 153 | print("dataset.mask_token_id, ", self.mask_token_id, flush=True) 154 | 155 | self.mask_generator = TextMaskingGenerator(self.tokenizer, config['mask_prob'], 156 | config['max_masks'], config['skipgram_prb'], 157 | config['skipgram_size'], config['mask_whole_word']) 158 | 159 | self.PAD_mask = -100 # loss will ignore this 160 | self.max_words = config['max_words'] 161 | self.max_tokens = config['max_tokens'] 162 | self.max_masks = config['max_masks'] 163 | 164 | self.transform = transform 165 | self.image_res = config['image_res'] 166 | self.patch_size = config['patch_size'] 167 | assert self.image_res % self.patch_size == 0 168 | self.num_patch = int(self.image_res / self.patch_size) 169 | 170 | def __iter__(self): 171 | for example in self.generate(): 172 | try: 173 | ann = json.loads(example) 174 | assert isinstance(ann, dict), "ann is not dict" 175 | 176 | caption = ann[self.caption_key] 177 | if isinstance(caption, list): 178 | caption = random.choice(caption) 179 | 180 | if self.is_image_rpath: # read path or base64 encoding 181 | image = Image.open(ann[self.image_key]).convert('RGB') 182 | else: 183 | # if reading from HDFS, use this: 184 | image = Image.open(io.BytesIO(b64decode(ann[self.image_key]))).convert("RGB") 185 | 186 | image = self.transform(image) 187 | 188 | text_ids, text_atts, text_ids_masked, masked_pos, masked_ids = self.preprocess(caption) 189 | 190 | yield image, text_ids, text_atts, text_ids_masked, masked_pos, masked_ids 191 | 192 | except Exception as e: 193 | print(traceback.format_exc()) 194 | print('encounter broken data: %s' % e) 195 | print('-'*20) 196 | sys.stdout.flush() 197 | 198 | def preprocess(self, text): 199 | if self.tokenized: 200 | tokens = text.strip().split(' ') 201 | else: 202 | text = pre_caption(text, self.max_words) # be careful, if text is '', it will cause error 203 | tokens = self.tokenizer.tokenize(text) 204 | 205 | tokens = [self.cls_token] + tokens[:self.max_tokens - 1] 206 | 207 | if self.add_eos: 208 | tokens = tokens[:self.max_tokens - 1] 209 | tokens += [self.eos_token] 210 | 211 | n_tokens = len(tokens) 212 | assert n_tokens >= 2, "len(word tokens) < 2" 213 | 214 | text_ids = self.tokenizer.convert_tokens_to_ids(tokens) # list of int 215 | 216 | tokens_masked, masked_pos = self.mask_generator(copy.deepcopy(tokens)) 217 | text_ids_masked = self.tokenizer.convert_tokens_to_ids(tokens_masked) # list of int 218 | masked_ids = [text_ids[p] for p in masked_pos] 219 | 220 | # pad 221 | n_pad = self.max_tokens - n_tokens 222 | text_ids = text_ids + [self.pad_token_id] * n_pad 223 | text_atts = [1] * n_tokens + [0] * n_pad 224 | 225 | text_ids_masked = text_ids_masked + [self.pad_token_id] * n_pad 226 | n_pad = self.max_masks - len(masked_ids) 227 | masked_pos = masked_pos + [0] * n_pad 228 | masked_ids = masked_ids + [self.PAD_mask] * n_pad 229 | 230 | return text_ids, text_atts, text_ids_masked, masked_pos, masked_ids 231 | 232 | def collate_fn(self, batch): 233 | batch_tensors = [] 234 | for x in zip(*batch): 235 | if x[0] is None: 236 | batch_tensors.append(None) 237 | elif isinstance(x[0], torch.Tensor): 238 | batch_tensors.append(torch.stack(x)) 239 | else: 240 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 241 | 242 | return batch_tensors 243 | 244 | 245 | class RegionTextJsonDataset(ImageTextJsonDataset): 246 | def __init__(self, config, data_path, rank=0, world_size=1, shuffle=True, repeat=True, transform=None, box_transform=None): 247 | super().__init__(config, data_path, rank=rank, world_size=world_size, shuffle=shuffle, 248 | repeat=repeat, transform=transform) 249 | 250 | self.image_key = config['regions']['image_key'] 251 | self.is_image_rpath = config['regions']['is_image_rpath'] 252 | self.caption_key = config['regions']['caption_key'] 253 | assert self.caption_key == 'caption', "please follow my data format" 254 | self.batch_size = config['regions']['batch_size'] 255 | self.tokenized = config['regions']['tokenized'] 256 | self.careful_hflip = config['regions']['careful_hflip'] if 'careful_hflip' in config['regions'] else False 257 | 258 | self.box_transform = box_transform 259 | self.max_regions = config['regions']['max_regions'] 260 | self.min_perc_in_image = config['regions']['min_perc_in_image'] 261 | 262 | def get_bbox(self, ann): 263 | x, y, w, h = ann['bb'] 264 | return int(x), int(y), int(w), int(h) 265 | 266 | def left_or_right_in_caption(self, ann): 267 | def _in_it(elem): 268 | if isinstance(elem['caption'], list): 269 | for caption in elem['caption']: 270 | if ('left' in caption) or ('right' in caption): 271 | return True 272 | else: 273 | if ('left' in elem['caption']) or ('right' in elem['caption']): 274 | return True 275 | 276 | if 'caption' in ann.keys(): 277 | if _in_it(ann): 278 | return True 279 | 280 | for elem in ann['elems']: 281 | if _in_it(elem): 282 | return True 283 | 284 | return False 285 | 286 | def __iter__(self): 287 | for example in self.generate(): 288 | try: 289 | ann = json.loads(example) 290 | assert isinstance(ann, dict), "ann is not dict" 291 | 292 | try: 293 | image = Image.open(ann[self.image_key]).convert('RGB') if self.is_image_rpath \ 294 | else Image.open(io.BytesIO(b64decode(ann[self.image_key]))).convert("RGB") 295 | except Warning: 296 | raise ValueError("### Warning: RegionTextJsonDataset Image.open") 297 | 298 | W, H = image.size 299 | 300 | # random crop 301 | x, y, w, h = self.get_bbox(random.choice(ann['elems'])) 302 | assert (x >= 0) and (y >= 0) and (x + w <= W) and (y + h <= H) and (w > 0) and (h > 0), "elem invalid" 303 | 304 | x0, y0 = random.randint(0, math.floor(x)), random.randint(0, math.floor(y)) 305 | x1, y1 = random.randint(min(math.ceil(x + w), W), W), random.randint(min(math.ceil(y + h), H), H) 306 | w0, h0 = x1 - x0, y1 - y0 307 | assert (x0 >= 0) and (y0 >= 0) and (x0 + w0 <= W) and (y0 + h0 <= H) and (w0 > 0) and (h0 > 0), "elem randomcrop, invalid" 308 | 309 | image = image.crop((x0, y0, x0 + w0, y0 + h0)) 310 | W, H = image.size 311 | 312 | do_hflip = False 313 | if rand() < 0.5: 314 | if self.careful_hflip and self.left_or_right_in_caption(ann): 315 | pass 316 | else: 317 | image = hflip(image) 318 | do_hflip = True 319 | 320 | image = resize(image, [self.image_res, self.image_res], interpolation=Image.BICUBIC) 321 | image = self.box_transform(image) 322 | 323 | text_ids_list = [] 324 | text_ids_masked_list = [] 325 | text_atts_list = [] 326 | masked_pos_list = [] 327 | masked_ids_list = [] 328 | image_atts_list = [] 329 | 330 | target_bbox_list = [] 331 | is_image_list = [] 332 | 333 | max_elems = self.max_regions 334 | 335 | if 'caption' in ann.keys(): 336 | caption = random.choice(ann['caption']) if isinstance(ann['caption'], list) else ann['caption'] 337 | text_ids, text_atts, text_ids_masked, masked_pos, masked_ids = self.preprocess(caption) 338 | 339 | text_ids_list.append(text_ids) 340 | text_atts_list.append(text_atts) 341 | text_ids_masked_list.append(text_ids_masked) 342 | masked_pos_list.append(masked_pos) 343 | masked_ids_list.append(masked_ids) 344 | 345 | image_atts_list.append([1] * (self.num_patch ** 2 + 1)) 346 | target_bbox_list.append(torch.tensor([0.5, 0.5, 1, 1], dtype=torch.float)) 347 | is_image_list.append(1) 348 | 349 | max_elems -= 1 350 | 351 | elems = random.sample(ann['elems'], len(ann['elems'])) 352 | 353 | for elem in elems: 354 | if max_elems <= 0: 355 | break 356 | 357 | x, y, w, h = self.get_bbox(elem) 358 | 359 | xx, yy = max(x0, x), max(y0, y) 360 | xm, ym = min(x0 + w0, x + w), min(y0 + h0, y + h) 361 | if (xm > xx) and (ym > yy): 362 | if (xm - xx) * (ym - yy) / (w * h) > self.min_perc_in_image: 363 | x, y, w, h = xx, yy, xm - xx, ym - yy # part inside the cropped image 364 | 365 | # axis transform: after crop 366 | x = x - x0 367 | y = y - y0 368 | 369 | if do_hflip: # flipped applied 370 | x = (W - x) - w # W is w0 371 | 372 | # resize applied 373 | x = self.image_res / W * x 374 | w = self.image_res / W * w 375 | y = self.image_res / H * y 376 | h = self.image_res / H * h 377 | 378 | caption = random.choice(elem['caption']) if isinstance(elem['caption'], list) else elem['caption'] 379 | 380 | if 'attributes' in elem.keys(): 381 | elem_attr = random.choice(elem['attributes']) if isinstance(elem['attributes'], list) else elem['attributes'] 382 | caption = elem_attr + ' ' + caption 383 | 384 | text_ids, text_atts, text_ids_masked, masked_pos, masked_ids = self.preprocess(caption) 385 | image_atts = self.get_image_attns(x, y, w, h) 386 | 387 | text_ids_list.append(text_ids) 388 | text_atts_list.append(text_atts) 389 | text_ids_masked_list.append(text_ids_masked) 390 | masked_pos_list.append(masked_pos) 391 | masked_ids_list.append(masked_ids) 392 | image_atts_list.append(image_atts) 393 | 394 | center_x = x + 1 / 2 * w 395 | center_y = y + 1 / 2 * h 396 | 397 | target_bbox_list.append(torch.tensor([center_x / self.image_res, center_y / self.image_res, 398 | w / self.image_res, h / self.image_res], 399 | dtype=torch.float)) 400 | 401 | is_image_list.append(0) 402 | 403 | max_elems -= 1 404 | 405 | image_list = [image] if len(text_ids_list) else [] 406 | 407 | yield image_list, text_ids_list, text_atts_list, text_ids_masked_list, masked_pos_list, \ 408 | masked_ids_list, image_atts_list, target_bbox_list, is_image_list 409 | 410 | except Exception as e: 411 | print(traceback.format_exc()) 412 | print('encounter broken data: %s' % e) 413 | print('-' * 20) 414 | sys.stdout.flush() 415 | 416 | def get_image_attns(self, x, y, w, h): 417 | x_min = min(math.floor(x / self.patch_size), self.num_patch - 1) 418 | x_max = max(x_min+1, min(math.ceil((x+w) / self.patch_size), self.num_patch)) # exclude 419 | 420 | y_min = min(math.floor(y / self.patch_size), self.num_patch - 1) 421 | y_max = max(y_min+1, min(math.ceil((y+h) / self.patch_size), self.num_patch)) # exclude 422 | 423 | image_atts = [0] * (1 + self.num_patch ** 2) 424 | image_atts[0] = 1 # always include [CLS] 425 | for j in range(x_min, x_max): 426 | for i in range(y_min, y_max): 427 | index = self.num_patch * i + j + 1 428 | assert (index > 0) and (index <= self.num_patch ** 2), f"patch index out of range, index: {index}" 429 | image_atts[index] = 1 430 | 431 | return image_atts 432 | 433 | def collate_fn(self, batch_sample): 434 | batch = [] 435 | for x in zip(*batch_sample): 436 | batch.append(x) 437 | 438 | images, batch = batch[0], batch[1:] 439 | 440 | idx_to_group_img = [] 441 | img_idx = -1 442 | for sample in batch[0]: 443 | n_elems = len(sample) 444 | if n_elems > 0: 445 | img_idx += 1 446 | idx_to_group_img.extend([img_idx] * n_elems) # flatten 447 | 448 | batch_size = self.batch_size 449 | n_elems = len(idx_to_group_img) 450 | to_keep = list(range(n_elems)) 451 | if n_elems >= batch_size: 452 | to_keep = random.sample(to_keep, batch_size) 453 | else: 454 | # fixed batch_size is required. otherwise, the process will be blocked. so, i do pad here. 455 | # but pad causes wrong calculation for contrastive learning. 456 | # Set appropriate batch_size, max_images, and max_regions to avoid frequent padding. 457 | try: 458 | to_pad = random.sample(to_keep, batch_size - n_elems) 459 | to_keep += to_pad 460 | print("### warning: pad region_batch by sampling, ", len(to_pad), flush=True) 461 | 462 | except ValueError: 463 | print("### warning: pad region_batch by expanding, ", batch_size-len(to_keep), flush=True) 464 | to_keep = (to_keep * math.ceil(batch_size/len(to_keep)))[:batch_size] 465 | 466 | images = torch.stack(sum(images, [])) # flatten 467 | idx_to_group_img = torch.tensor([idx_to_group_img[index] for index in to_keep], dtype=torch.long) 468 | 469 | batch_tensors = [images, idx_to_group_img] 470 | for x in [sum(x, []) for x in batch]: 471 | 472 | x = [x[index] for index in to_keep] 473 | 474 | if x[0] is None: 475 | batch_tensors.append(None) 476 | elif isinstance(x[0], torch.Tensor): 477 | batch_tensors.append(torch.stack(x)) 478 | else: 479 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 480 | 481 | return batch_tensors 482 | -------------------------------------------------------------------------------- /dataset/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /dataset/re_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | from torch.utils.data import Dataset 5 | from jieba import analyse 6 | from PIL import Image 7 | from PIL import ImageFile 8 | 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | Image.MAX_IMAGE_PIXELS = None 11 | 12 | from dataset.utils import pre_caption 13 | 14 | 15 | class re_train_dataset(Dataset): 16 | def __init__(self, ann_file, transform, image_root, max_words=30): 17 | self.ann = [] 18 | for f in ann_file: 19 | self.ann += json.load(open(f, 'r')) 20 | self.transform = transform 21 | self.image_root = image_root 22 | self.max_words = max_words 23 | self.img_ids = {} 24 | 25 | n = 0 26 | for ann in self.ann: 27 | img_id = ann['image_id'] 28 | if img_id not in self.img_ids.keys(): 29 | self.img_ids[img_id] = n 30 | n += 1 31 | 32 | def __len__(self): 33 | return len(self.ann) 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.ann[index] 38 | 39 | image_path = os.path.join(self.image_root, ann['image']) 40 | image = Image.open(image_path).convert('RGB') 41 | image = self.transform(image) 42 | 43 | caption = pre_caption(ann['caption'], self.max_words) 44 | 45 | # t = analyse.extract_tags(caption, topK=4, withWeight=False) 46 | # ii = caption.split(' ') 47 | # k = "" 48 | # fl = 0 49 | # for j in range(len(ii)): 50 | # if fl == 1: 51 | # k += " " 52 | # fl = 1 53 | # if ii[j] not in t: 54 | # k += "[MASK]" 55 | # else: 56 | # k += ii[j] 57 | # 58 | # mask_text = pre_caption(k, self.max_words) 59 | # print('caption: {}'.format(caption)) 60 | # print('mask_texts: {}'.format(mask_texts)) 61 | 62 | label = torch.tensor(ann['label']) 63 | 64 | ## if no need label, set value to zero or others: 65 | # label = 0 66 | # return image, caption, mask_text, self.img_ids[ann['image_id']], label 67 | return image, caption, self.img_ids[ann['image_id']], label 68 | 69 | 70 | class re_eval_dataset(Dataset): 71 | def __init__(self, ann_file, transform, image_root, max_words=30): 72 | self.ann = json.load(open(ann_file, 'r')) 73 | self.transform = transform 74 | self.image_root = image_root 75 | self.max_words = max_words 76 | 77 | self.text = [] 78 | # self.mask_text = [] 79 | self.image = [] 80 | # self.image_data = [] 81 | self.txt2img = {} 82 | self.img2txt = {} 83 | 84 | txt_id = 0 85 | for img_id, ann in enumerate(self.ann): 86 | self.image.append(ann['image']) 87 | self.img2txt[img_id] = [] 88 | for i, caption in enumerate(ann['caption']): 89 | self.text.append(pre_caption(caption, self.max_words)) 90 | self.img2txt[img_id].append(txt_id) 91 | self.txt2img[txt_id] = img_id 92 | txt_id += 1 93 | 94 | # t = analyse.extract_tags(caption, topK=4, withWeight=False) 95 | # ii = caption.split(' ') 96 | # k = "" 97 | # fl = 0 98 | # for j in range(len(ii)): 99 | # if fl == 1: 100 | # k += " " 101 | # fl = 1 102 | # if ii[j] not in t: 103 | # k += "[MASK]" 104 | # else: 105 | # k += ii[j] 106 | # self.mask_text.append(pre_caption(k, self.max_words)) 107 | 108 | # image_path = os.path.join(self.image_root, ann['image']) 109 | # image = Image.open(image_path).convert('RGB') 110 | # image = self.transform(image) 111 | # self.image_data.append(image.unsqueeze(dim=0)) 112 | 113 | def __len__(self): 114 | return len(self.image) 115 | 116 | def __getitem__(self, index): 117 | 118 | image_path = os.path.join(self.image_root, self.ann[index]['image']) 119 | image = Image.open(image_path).convert('RGB') 120 | image = self.transform(image) 121 | 122 | return image, index 123 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn.functional as F 8 | 9 | import utils 10 | from tqdm import tqdm 11 | 12 | # from utils.hdfs_io import hexists, hcopy, hopen 13 | # from vqaTools.vqaEval import VQAEval 14 | # from refTools.evaluation.refEvaluation import RefEvaluation 15 | 16 | 17 | def pre_question(question, max_ques_words): 18 | question = re.sub( 19 | r"([,.'!?\"()*#:;~])", 20 | '', 21 | question.lower(), 22 | ).replace('-', ' ').replace('/', ' ') 23 | question = question.rstrip(' ') 24 | 25 | # truncate question 26 | question_words = question.split(' ') 27 | if len(question_words) > max_ques_words: 28 | question = ' '.join(question_words[:max_ques_words]) 29 | 30 | return question 31 | 32 | 33 | def pre_caption(caption, max_words): 34 | caption = re.sub( 35 | r"([,.'!?\"()*#:;~])", 36 | '', 37 | caption.lower(), 38 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 39 | 40 | caption = re.sub( 41 | r"\s{2,}", 42 | ' ', 43 | caption, 44 | ) 45 | caption = caption.rstrip('\n') 46 | caption = caption.strip(' ') 47 | 48 | # truncate caption 49 | caption_words = caption.split(' ') 50 | if len(caption_words) > max_words: 51 | caption = ' '.join(caption_words[:max_words]) 52 | # print(caption) 53 | if not len(caption): 54 | # print('=========') 55 | # print(caption) 56 | # print('=========') 57 | raise ValueError("pre_caption yields invalid text") 58 | 59 | return caption 60 | 61 | 62 | def write_json(result: list, wpath: str): 63 | if wpath.startswith('hdfs'): 64 | with hopen(wpath, 'w') as f: 65 | for res in result: 66 | to_write = json.dumps(res) + '\n' 67 | f.write(to_write.encode()) 68 | else: 69 | with open(wpath, 'wt') as f: 70 | for res in result: 71 | f.write(json.dumps(res) + '\n') 72 | 73 | 74 | def read_json(rpath: str): 75 | result = [] 76 | if rpath.startswith('hdfs'): 77 | with hopen(rpath, 'r') as f: 78 | for line in f: 79 | result.append(json.loads(line.decode().strip())) 80 | else: 81 | with open(rpath, 'rt') as f: 82 | for line in f: 83 | result.append(json.loads(line.strip())) 84 | 85 | return result 86 | 87 | 88 | def collect_result(result, filename, local_wdir, hdfs_wdir, write_to_hdfs=False, save_result=False, remove_duplicate='', do_not_collect=False): 89 | assert isinstance(result, list) 90 | write_json(result, os.path.join(hdfs_wdir if write_to_hdfs else local_wdir, 91 | '%s_rank%d.json' % (filename, utils.get_rank()))) 92 | dist.barrier() 93 | 94 | if do_not_collect: 95 | return None 96 | 97 | result = [] 98 | final_result_file = '' 99 | if utils.is_main_process(): 100 | # combine results from all processes 101 | for rank in range(utils.get_world_size()): 102 | result += read_json(os.path.join(hdfs_wdir if write_to_hdfs else local_wdir, 103 | '%s_rank%d.json' % (filename, rank))) 104 | 105 | if remove_duplicate: # for evaluating captioning tasks 106 | result_new = [] 107 | id_list = set() 108 | for res in result: 109 | if res[remove_duplicate] not in id_list: 110 | id_list.add(res[remove_duplicate]) 111 | result_new.append(res) 112 | result = result_new 113 | 114 | if save_result: 115 | final_result_file = os.path.join(local_wdir, '%s.json' % filename) 116 | json.dump(result, open(final_result_file, 'w'), indent=4) 117 | print('result file saved to %s' % final_result_file) 118 | if write_to_hdfs: 119 | hcopy(final_result_file, os.path.join(hdfs_wdir, '%s.json' % filename)) 120 | print('result file saved to %s' % os.path.join(hdfs_wdir, '%s.json' % filename)) 121 | 122 | dist.barrier() 123 | 124 | return final_result_file if save_result else result 125 | 126 | 127 | def collect_tensor_result(result, filename, local_wdir, hdfs_wdir, write_to_hdfs=False): 128 | wpath = os.path.join(local_wdir, '%s_rank%d.pth' % (filename, utils.get_rank())) 129 | torch.save(result, wpath) 130 | if write_to_hdfs: 131 | hcopy(wpath, hdfs_wdir) 132 | 133 | dist.barrier() 134 | 135 | result = [] 136 | if utils.is_main_process(): 137 | # combine results from all processes 138 | for rank in range(utils.get_world_size()): 139 | rpath = os.path.join(local_wdir, '%s_rank%d.pth' % (filename, rank)) 140 | if write_to_hdfs: 141 | hcopy(os.path.join(hdfs_wdir, '%s_rank%d.pth' % (filename, rank)), rpath) 142 | 143 | result += torch.load(rpath) 144 | 145 | dist.barrier() 146 | 147 | return result 148 | 149 | 150 | def grounding_eval(results, dets, cocos, refer, alpha, mask_size=24): 151 | correct_A_d, correct_B_d, correct_val_d = 0, 0, 0 152 | correct_A, correct_B, correct_val = 0, 0, 0 153 | num_A, num_B, num_val = 0, 0, 0 154 | 155 | for res in tqdm(results): 156 | 157 | ref_id = res['ref_id'] 158 | ref = refer.Refs[ref_id] 159 | ref_box = refer.refToAnn[ref_id]['bbox'] 160 | image = refer.Imgs[ref['image_id']] 161 | 162 | mask = res['pred'].cuda().view(1, 1, mask_size, mask_size) 163 | mask = F.interpolate(mask, size=(image['height'], image['width']), mode='bicubic').squeeze() 164 | 165 | # rank detection boxes 166 | max_score = 0 167 | for det in dets[str(ref['image_id'])]: 168 | score = mask[int(det[1]):int(det[1] + det[3]), int(det[0]):int(det[0] + det[2])] 169 | area = det[2] * det[3] 170 | score = score.sum() / area ** alpha 171 | if score > max_score: 172 | pred_box = det[:4] 173 | max_score = score 174 | 175 | IoU_det = computeIoU(ref_box, pred_box) 176 | 177 | if ref['split'] == 'testA': 178 | num_A += 1 179 | if IoU_det >= 0.5: 180 | correct_A_d += 1 181 | elif ref['split'] == 'testB': 182 | num_B += 1 183 | if IoU_det >= 0.5: 184 | correct_B_d += 1 185 | elif ref['split'] == 'val': 186 | num_val += 1 187 | if IoU_det >= 0.5: 188 | correct_val_d += 1 189 | 190 | eval_result = {'val_d': correct_val_d / num_val, 'testA_d': correct_A_d / num_A, 'testB_d': correct_B_d / num_B} 191 | 192 | for metric, acc in eval_result.items(): 193 | print(f'{metric}: {acc:.3f}') 194 | 195 | return eval_result 196 | 197 | 198 | def grounding_eval_vlue(results, test_json, alpha, mask_size=24): 199 | correct_val_d = 0 200 | num_val = 0 201 | 202 | ref_id_map = {} 203 | with open(test_json, 'r') as f: 204 | for sample in json.load(f): 205 | ref_id_map[sample['ref_id']] = sample 206 | 207 | for res in tqdm(results): 208 | 209 | ref_id = res['ref_id'] 210 | 211 | ref_box = ref_id_map[ref_id]['bbox'] 212 | height = ref_id_map[ref_id]['height'] 213 | width = ref_id_map[ref_id]['width'] 214 | dets = ref_id_map[ref_id]['dets'] # (x, y, w, h) 215 | 216 | mask = res['pred'].cuda().view(1, 1, mask_size, mask_size) 217 | mask = F.interpolate(mask, size=(height, width), mode='bicubic').squeeze() 218 | 219 | # rank detection boxes 220 | max_score = 0 221 | for det in dets: 222 | score = mask[int(det[1]):int(det[1] + det[3]), int(det[0]):int(det[0] + det[2])] 223 | area = det[2] * det[3] 224 | score = score.sum() / area ** alpha 225 | if score > max_score: 226 | pred_box = det[:4] 227 | max_score = score 228 | 229 | IoU_det = computeIoU(ref_box, pred_box) 230 | 231 | num_val += 1 232 | if IoU_det >= 0.5: 233 | correct_val_d += 1 234 | 235 | eval_result = {'score': correct_val_d / num_val} 236 | 237 | for metric, acc in eval_result.items(): 238 | print(f'{metric}: {acc:.3f}') 239 | 240 | return eval_result 241 | 242 | 243 | def grounding_eval_bbox(results, refer): 244 | correct_A_d, correct_B_d, correct_val_d = 0, 0, 0 245 | num_A, num_B, num_val = 0, 0, 0 246 | 247 | for res in tqdm(results): 248 | ref_id = res['ref_id'] 249 | ref = refer.Refs[ref_id] 250 | ref_box = refer.refToAnn[ref_id]['bbox'] 251 | image = refer.Imgs[ref['image_id']] 252 | 253 | coord = res['pred'].cuda() 254 | coord[0::2] *= image['width'] 255 | coord[1::2] *= image['height'] 256 | 257 | coord[0] -= coord[2] / 2 258 | coord[1] -= coord[3] / 2 259 | 260 | IoU_det = computeIoU(ref_box, coord) 261 | 262 | if ref['split'] == 'testA': 263 | num_A += 1 264 | if IoU_det >= 0.5: 265 | correct_A_d += 1 266 | elif ref['split'] == 'testB': 267 | num_B += 1 268 | if IoU_det >= 0.5: 269 | correct_B_d += 1 270 | elif ref['split'] == 'val': 271 | num_val += 1 272 | if IoU_det >= 0.5: 273 | correct_val_d += 1 274 | 275 | eval_result = {'val_d': correct_val_d / num_val, 'testA_d': correct_A_d / num_A, 'testB_d': correct_B_d / num_B} 276 | 277 | for metric, acc in eval_result.items(): 278 | print(f'{metric}: {acc:.3f}') 279 | 280 | return eval_result 281 | 282 | 283 | def grounding_eval_bbox_vlue(results, test_json): 284 | correct_val_d = 0 285 | num_val = 0 286 | 287 | ref_id_map = {} 288 | with open(test_json, 'r') as f: 289 | for sample in json.load(f): 290 | ref_id_map[sample['ref_id']] = sample 291 | 292 | for res in tqdm(results): 293 | ref_id = res['ref_id'] 294 | 295 | ref_box = ref_id_map[ref_id]['bbox'] 296 | height = ref_id_map[ref_id]['height'] 297 | width = ref_id_map[ref_id]['width'] 298 | 299 | coord = res['pred'].cuda() 300 | coord[0::2] *= width 301 | coord[1::2] *= height 302 | 303 | coord[0] -= coord[2] / 2 304 | coord[1] -= coord[3] / 2 305 | 306 | IoU_det = computeIoU(ref_box, coord) 307 | 308 | num_val += 1 309 | if IoU_det >= 0.5: 310 | correct_val_d += 1 311 | 312 | eval_result = {'score': correct_val_d / num_val} 313 | 314 | for metric, acc in eval_result.items(): 315 | print(f'{metric}: {acc:.3f}') 316 | 317 | return eval_result 318 | 319 | 320 | # IoU function 321 | def computeIoU(box1, box2): 322 | # each box is of [x1, y1, w, h] 323 | inter_x1 = max(box1[0], box2[0]) 324 | inter_y1 = max(box1[1], box2[1]) 325 | inter_x2 = min(box1[0] + box1[2] - 1, box2[0] + box2[2] - 1) 326 | inter_y2 = min(box1[1] + box1[3] - 1, box2[1] + box2[3] - 1) 327 | 328 | if inter_x1 < inter_x2 and inter_y1 < inter_y2: 329 | inter = (inter_x2 - inter_x1 + 1) * (inter_y2 - inter_y1 + 1) 330 | else: 331 | inter = 0 332 | union = box1[2] * box1[3] + box2[2] * box2[3] - inter 333 | return float(inter) / union 334 | 335 | 336 | from pycocotools.coco import COCO 337 | from pycocoevalcap.eval import COCOEvalCap 338 | 339 | 340 | def coco_caption_eval(annotation_file, results_file): 341 | assert os.path.exists(annotation_file) 342 | 343 | # create coco object and coco_result object 344 | coco = COCO(annotation_file) 345 | coco_result = coco.loadRes(results_file) 346 | 347 | # create coco_eval object by taking coco and coco_result 348 | coco_eval = COCOEvalCap(coco, coco_result) 349 | 350 | # evaluate on a subset of images by setting 351 | # coco_eval.params['image_id'] = coco_result.getImgIds() 352 | # please remove this line when evaluating the full validation set 353 | # coco_eval.params['image_id'] = coco_result.getImgIds() 354 | 355 | # evaluate results 356 | # SPICE will take a few minutes the first time, but speeds up due to caching 357 | coco_eval.evaluate() 358 | 359 | # print output evaluation scores 360 | for metric, score in coco_eval.eval.items(): 361 | print(f'{metric}: {score:.3f}', flush=True) 362 | 363 | return coco_eval -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.pir import build_mlp 2 | from models.pir import PIRBase 3 | from models.pir import load_pretrained_pir 4 | -------------------------------------------------------------------------------- /models/model_retrieval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import PIRBase, load_pretrained_pir 3 | import torch.nn.functional as F 4 | 5 | 6 | class PIR(PIRBase): 7 | def __init__(self, config): 8 | super().__init__(config, load_vision_params=True, load_text_params=True, use_contrastive_loss=True, \ 9 | use_affil_loss=False) 10 | self.config = config 11 | self.use_affil_loss = config['use_affil_loss'] 12 | self.use_triplet_loss = config['use_triplet_loss'] 13 | 14 | def load_pretrained(self, ckpt_rpath, config, is_eval=False): 15 | state_dict = load_pretrained_pir(ckpt_rpath, config, is_eval=is_eval, load_text=True) 16 | msg = self.load_state_dict(state_dict, strict=False) 17 | print('load checkpoint from %s' % ckpt_rpath) 18 | print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p]) 19 | print("unexpected_keys: ", msg.unexpected_keys) 20 | 21 | def forward(self, image, text_ids, idx=None, label=None): 22 | ## Baseline(Swin-T+Bert-B) 23 | if self.config['is_baseline']: 24 | img_emb = self.get_vision_embeds(image) 25 | txt_emb = self.get_text_embeds(text_ids) 26 | else: 27 | img_emb= self.get_vision_fusion_embeds(image, self.config) 28 | txt_emb = self.get_text_fusion_embeds(text_ids, self.config) 29 | 30 | if self.use_affil_loss: 31 | loss_contr = self.get_contr_loss(img_emb, txt_emb, idx=idx, label=label, config=self.config) 32 | loss_affil = self.get_affil_loss(img_emb, txt_emb, idx=idx, label=label, config=self.config) 33 | return loss_contr, loss_affil 34 | elif self.use_triplet_loss: 35 | loss_triplet = self.get_triplet_loss(img_emb, txt_emb) 36 | return loss_triplet 37 | else: 38 | loss_contr = self.get_contr_loss(img_emb, txt_emb, idx=idx, label=label, config=self.config) 39 | return loss_contr -------------------------------------------------------------------------------- /models/mytools.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | # Author:Zhiqiang Yuan 3 | """导入一些包""" 4 | import os 5 | import time, random 6 | import json 7 | import numpy as np 8 | from sklearn.decomposition import PCA 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d import Axes3D 11 | 12 | """ 打印一些东西 """ 13 | """----------------------------------------------------------------------""" 14 | 15 | 16 | # 打印列表按照竖行的形式 17 | def print_list(list): 18 | print("++++++++++++++++++++++++++++++++++++++++++++") 19 | for l in list: 20 | print(l) 21 | print("++++++++++++++++++++++++++++++++++++++++++++") 22 | 23 | 24 | # 打印字典按照竖行的形式 25 | def print_dict(dict): 26 | print("++++++++++++++++++++++++++++++++++++++++++++") 27 | for k, v in dict.items(): 28 | print("key:", k, " value:", v) 29 | print("++++++++++++++++++++++++++++++++++++++++++++") 30 | 31 | 32 | # 打印一些东西,加入标识符 33 | def print_with_log(info): 34 | print("++++++++++++++++++++++++++++++++++++++++++++") 35 | print(info) 36 | print("++++++++++++++++++++++++++++++++++++++++++++") 37 | 38 | 39 | # 打印标识符 40 | def print_log(): 41 | print("++++++++++++++++++++++++++++++++++++++++++++") 42 | 43 | 44 | """ 文件存储 """ 45 | """----------------------------------------------------------------------""" 46 | 47 | 48 | # 保存结果到json文件 49 | def save_to_json(info, filename, encoding='UTF-8'): 50 | with open(filename, "w", encoding=encoding) as f: 51 | json.dump(info, f, indent=2, separators=(',', ':')) 52 | 53 | 54 | # 从json文件中读取 55 | def load_from_json(filename): 56 | with open(filename, encoding='utf-8') as f: 57 | info = json.load(f) 58 | return info 59 | 60 | 61 | # 储存为npy文件 62 | def save_to_npy(info, filename): 63 | np.save(filename, info, allow_pickle=True) 64 | 65 | 66 | # 从npy中读取 67 | def load_from_npy(filename): 68 | info = np.load(filename, allow_pickle=True) 69 | return info 70 | 71 | 72 | # 保存结果到txt文件 73 | def log_to_txt(contexts=None, filename="save.txt", mark=False, encoding='UTF-8', add_n=False): 74 | f = open(filename, "a", encoding=encoding) 75 | if mark: 76 | sig = "------------------------------------------------\n" 77 | f.write(sig) 78 | elif isinstance(contexts, dict): 79 | tmp = "" 80 | for c in contexts.keys(): 81 | tmp += str(c) + " | " + str(contexts[c]) + "\n" 82 | contexts = tmp 83 | f.write(contexts) 84 | else: 85 | if isinstance(contexts, list): 86 | tmp = "" 87 | for c in contexts: 88 | if add_n: 89 | tmp += str(c) + "\n" 90 | else: 91 | tmp += str(c) 92 | contexts = tmp 93 | else: 94 | contexts = contexts + "\n" 95 | f.write(contexts) 96 | 97 | f.close() 98 | 99 | 100 | # 从txt中读取行 101 | def load_from_txt(filename, encoding="utf-8"): 102 | f = open(filename, 'r', encoding=encoding) 103 | contexts = f.readlines() 104 | return contexts 105 | 106 | 107 | """ 字典变换 """ 108 | """----------------------------------------------------------------------""" 109 | 110 | 111 | # 键值互换 112 | def dict_k_v_exchange(dict): 113 | tmp = {} 114 | for key, value in dict.items(): 115 | tmp[value] = key 116 | return tmp 117 | 118 | 119 | # 2维数组转字典 120 | def d2array_to_dict(d2array): 121 | # Input: N x 2 list 122 | # Output: dict 123 | dict = {} 124 | for item in d2array: 125 | if item[0] not in dict.keys(): 126 | dict[item[0]] = [item[1]] 127 | else: 128 | dict[item[0]].append(item[1]) 129 | return dict 130 | 131 | 132 | """ 绘图 """ 133 | """----------------------------------------------------------------------""" 134 | 135 | 136 | # 绘制3D图像 137 | def visual_3d_points(list, color=True): 138 | """ 139 | :param list: N x (dim +1) 140 | N 为点的数量 141 | dim 为 输入数据的维度 142 | 1 为类别, 即可视化的颜色 当且仅当color为True时 143 | """ 144 | list = np.array(list) 145 | if color: 146 | data = list[:, :4] 147 | label = list[:, -1] 148 | else: 149 | data = list 150 | label = None 151 | 152 | # PCA降维 153 | pca = PCA(n_components=3, whiten=True).fit(data) 154 | data = pca.transform(data) 155 | 156 | # 定义坐标轴 157 | fig = plt.figure() 158 | ax1 = plt.axes(projection='3d') 159 | if label is not None: 160 | color = label 161 | else: 162 | color = "blue" 163 | ax1.scatter3D(np.transpose(data)[0], np.transpose(data)[1], np.transpose(data)[2], c=color) # 绘制散点图 164 | 165 | plt.show() 166 | 167 | 168 | """ 实用工具 """ 169 | """----------------------------------------------------------------------""" 170 | 171 | 172 | # 计算数组中元素出现的个数 173 | def count_list(lens): 174 | dict = {} 175 | for key in lens: 176 | dict[key] = dict.get(key, 0) + 1 177 | dict = sorted(dict.items(), key=lambda x: x[1], reverse=True) 178 | 179 | print_list(dict) 180 | return dict 181 | 182 | 183 | # list 加法 w1、w2为权重 184 | def list_add(list1, list2, w1=1, w2=1): 185 | return [l1 * w1 + l2 * w2 for (l1, l2) in zip(list1, list2)] 186 | -------------------------------------------------------------------------------- /models/pir.py: -------------------------------------------------------------------------------- 1 | import os import copy import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist import math from utils import read_json from functools import partial from models.swin_transformer import SwinTransformer, interpolate_relative_pos_embed from models.vit import VisionTransformer, interpolate_pos_embed from models.bert import BertModel, BertConfig from models.resnet import resnet50, resnet101 from torchvision.models import vgg16, vgg19_bn from torchvision import models from torch.autograd import Variable class AllGather(torch.autograd.Function): """An autograd function that performs allgather on a tensor.""" @staticmethod def forward(ctx, tensor, rank, world_size): output = [torch.empty_like(tensor) for _ in range(world_size)] dist.all_gather(output, tensor) ctx.rank = rank ctx.batch_size = tensor.shape[0] return torch.cat(output, 0) @staticmethod def backward(ctx, grad_output): return ( grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], None, None ) allgather = AllGather.apply def build_vision_encoder(config, load_vision_params=False): """ Args: load_params: False when building fine-tuning models """ num_patches = (config['image_res'] // config['patch_size']) ** 2 if config['use_swin']: vision_config = read_json(config['vision_config']) assert config['image_res'] == vision_config['image_res'] assert config['patch_size'] == 32 vision_width = vision_config['vision_width'] vision_encoder = SwinTransformer(img_size=vision_config['image_res'], patch_size=4, in_chans=3, embed_dim=vision_config['embed_dim'], depths=vision_config['depths'], num_heads=vision_config['num_heads'], window_size=vision_config['window_size'], mlp_ratio=4., qkv_bias=True, drop_rate=0.0, drop_path_rate=0.1, ape=False, patch_norm=True, use_checkpoint=False) if load_vision_params: # download from https://github.com/microsoft/Swin-Transformer state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] for k in list(state_dict.keys()): if 'relative_position_bias_table' in k: dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) elif ('relative_position_index' in k) or ('attn_mask' in k): del state_dict[k] else: assert config['patch_size'] == 16 vision_width = 384 vision_encoder = VisionTransformer( img_size=config['image_res'], patch_size=config['patch_size'], embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), local_attn_depth=4) if load_vision_params: # download from https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth state_dict = torch.load("data/deit_small_patch16_224-cd65a155.pth", map_location="cpu")["model"] pos_embed_reshaped = interpolate_pos_embed(state_dict['pos_embed'], num_patches=num_patches, num_extra_tokens=1) state_dict['pos_embed'] = pos_embed_reshaped if load_vision_params: if config['use_swin']: print("### Load Trans-Encoder[SWin-T]: ", flush=True) else: print("### Load Trans-Encoder[ViT]: ", flush=True) msg = vision_encoder.load_state_dict(state_dict, strict=False) # print("missing_keys: ", msg.missing_keys) # print("unexpected_keys: ", msg.unexpected_keys) return vision_encoder, vision_width def build_conv_encoder(config, load_vision_params=False, ins='resnet'): resnet_ckpt = config['resnet_ckpt'] finetune_conv = config['finetune_conv'] # # 加载resnet前n-1层 if ins == 'resnet': ## resnet as ins-encoder resnet_with_last = nn.Sequential(*list(resnet50(num_classes=30).children())[:-1]) # resnet_with_last = nn.Sequential(*list(resnet101(num_classes=30).children())[:-1]) conv_width = 2048 # 特征维度大小为2048 elif ins == 'vgg': ## vgg as ins-encoder # vgg = vgg16(num_classes=30) vgg = vgg19_bn(num_classes=30) vgg.classifier[6] = nn.Linear(4096, 2048) resnet_with_last = vgg conv_width = 2048 # 特征维度大小为2048 else: raise ValueError if load_vision_params: print("### Load Conv-Encoder[ResNet-50]: ", flush=True) state_dict = torch.load(resnet_ckpt, map_location="cpu") if len(state_dict) < 10: state_dict = state_dict['model'] if ins == 'vgg': state_dict.pop('classifier.6.weight') state_dict.pop('classifier.6.bias') resnet_with_last.load_state_dict(state_dict, strict=False) # 更新参数 for child in resnet_with_last.children(): for param in child.parameters(): param.requires_grad = finetune_conv return resnet_with_last, conv_width def build_text_encoder(config, load_text_params=False): # 加载text config text_config = read_json(config['text_config']) text_width = text_config['hidden_size'] # 建立bert模型 bert_config = BertConfig.from_json_file(config['text_config']) text_encoder = BertModel(bert_config) if load_text_params: # 加载预训练参数 print("### Load Trans-Encoder[Bert-B]: ", flush=True) init_checkpoint = config['text_encoder'] + '/pytorch_model.bin' state_dict = torch.load(init_checkpoint, map_location='cpu') text_encoder.load_state_dict(state_dict, strict=False) # 更新参数 for child in text_encoder.children(): for param in child.parameters(): param.requires_grad = True return text_encoder, text_width def build_mlp(input_dim, output_dim): return nn.Sequential( nn.Linear(input_dim, input_dim * 2), nn.LayerNorm(input_dim * 2), nn.GELU(), nn.Linear(input_dim * 2, output_dim)) def build_self_attention(config, model='image'): embed_dim = config['embed_dim'] dropout_r = config['dropout_r'] head = config['head'] if model == 'cross': return CroSA(embed_dim, dropout_r, head) elif model == 'image': return MHSA(embed_dim, dropout_r, head) elif model == 'text': return MHSA(embed_dim, dropout_r, head) else: raise ValueError def clones(module, N): """Produce N identical layers. """ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) def load_pretrained_pir(ckpt_rpath, config, is_eval=False, load_text=False): checkpoint = torch.load(ckpt_rpath, map_location='cpu') state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint if is_eval: return state_dict num_patches = (config['image_res'] // config['patch_size']) ** 2 print("### Loading pretrained vision encoder", flush=True) window_size = read_json(config['vision_config'])['window_size'] for k in list(state_dict.keys()): if 'relative_position_bias_table' in k: dst_num_pos = (2 * window_size - 1) ** 2 state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) elif ('relative_position_index' in k) or ('attn_mask' in k): del state_dict[k] if load_text: print("### Loading pretrained text encoder", flush=True) for key in list(state_dict.keys()): if 'text_encoder.' in key: if 'bert.' in key: encoder_key = key.replace('bert.', '') state_dict[encoder_key] = state_dict[key] del state_dict[key] return state_dict class CroSA(nn.Module): def __init__(self,embed_dim, dropout_r=0.2, head=8): super(CroSA, self).__init__() self.embed_dim = embed_dim self.dropout_r = dropout_r self.head = head self.mhatt_cross = MHAtt(self.embed_dim, self.dropout_r, self.head) self.ffn = build_mlp(self.embed_dim, self.embed_dim) self.dropout1 = nn.Dropout(self.dropout_r) self.norm1 = nn.LayerNorm(self.embed_dim) self.dropout2 = nn.Dropout(self.dropout_r) self.norm2 = nn.LayerNorm(self.embed_dim) def forward(self, x, y, x_mask=None, y_mask=None): y = self.norm1(y + self.dropout1(self.mhatt_cross(y, y, x, y_mask))) y = self.norm2(y + self.dropout2(self.ffn(y))) return y class MHSA(nn.Module): def __init__(self, embed_dim, dropout_r=0.2, head=8): super(MHSA, self).__init__() self.embed_dim = embed_dim self.dropout_r = dropout_r self.head = head self.mhatt = MHAtt(self.embed_dim, self.dropout_r, self.head) self.ffn = build_mlp(self.embed_dim, self.embed_dim) self.dropout1 = nn.Dropout(self.dropout_r) self.norm1 = nn.LayerNorm(self.embed_dim) self.dropout2 = nn.Dropout(self.dropout_r) self.norm2 = nn.LayerNorm(self.embed_dim) def forward(self, x, x_mask=None): x = self.norm1(x + self.dropout1(self.mhatt(x, x, x, x_mask))) x = self.norm2(x + self.dropout2(self.ffn(x))) return x class MHAtt(nn.Module): def __init__(self, embed_dim, dropout_r=0.2, head=8): super(MHAtt, self).__init__() self.embed_dim = embed_dim self.dropout_r = dropout_r self.head = head self.linear_v = nn.Linear(self.embed_dim, self.embed_dim) self.linear_k = nn.Linear(self.embed_dim, self.embed_dim) self.linear_q = nn.Linear(self.embed_dim, self.embed_dim) self.linear_merge = nn.Linear(self.embed_dim, self.embed_dim) self.dropout = nn.Dropout(self.dropout_r) def forward(self, v, k, q, mask=None): # print('q: {}'.format(q.shape)) bs = q.size(0) v = self.linear_v(v).view(bs, -1, self.head, self.embed_dim // self.head).transpose(1, 2) k = self.linear_k(k).view(bs, -1, self.head, self.embed_dim // self.head).transpose(1, 2) q = self.linear_q(q).view(bs, -1, self.head, self.embed_dim // self.head).transpose(1, 2) # print('q_linear: {}'.format(q.shape)) atted = self.att(v, k, q, mask) # print('atted: {}'.format(atted.shape)) atted = atted.transpose(1, 2).contiguous().view(bs, -1, self.embed_dim) atted = self.linear_merge(atted) return atted def att(self, v, k, q, mask=None): d_k = q.shape[-1] scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask, -1e9) # print('scores: {}'.format(scores.shape)) att_map = torch.softmax(scores, dim=-1) att_map = self.dropout(att_map) return torch.matmul(att_map, v) class PIRBase(nn.Module): def __init__(self, config=None, load_vision_params=False, load_text_params=True, use_contrastive_loss=False, use_affil_loss=False): super().__init__() if config['is_baseline']: self.vision_encoder, vision_width = build_vision_encoder(config, load_vision_params=load_vision_params) self.text_encoder, text_width = build_text_encoder(config, load_text_params=load_text_params) self.vision_width = vision_width self.text_width = text_width self.embed_dim = config['embed_dim'] self.max_tokens = config['max_tokens'] if config['use_triplet_loss'] == False: self.temp = nn.Parameter(torch.ones([]) * config['temp1']) if config['use_affil_loss']: self.temp2 = nn.Parameter(torch.ones([]) * config['temp2']) # without AFLoss self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) self.text_proj = nn.Linear(self.text_width, self.embed_dim) else: self.vision_encoder, vision_width = build_vision_encoder(config, load_vision_params=load_vision_params) self.text_encoder, text_width = build_text_encoder(config, load_text_params=load_text_params) self.conv_encoder, conv_width = build_conv_encoder(config, load_vision_params=load_vision_params) # without VIR self.vision_width = vision_width self.text_width = text_width self.conv_width = conv_width # without VIR self.embed_dim = config['embed_dim'] self.max_tokens = config['max_tokens'] self.temp = nn.Parameter(torch.ones([]) * config['temp1']) if config['use_affil_loss']: self.temp2 = nn.Parameter(torch.ones([]) * config['temp2']) # without AFLoss self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) self.text_proj = nn.Linear(self.text_width, self.embed_dim) self.conv_proj = nn.Linear(self.conv_width, self.embed_dim) # without VIR self.mapping_img = clones(nn.Linear(self.embed_dim, self.embed_dim), config['instru_num']) # without VIR self.mapping_txt = clones(nn.Linear(self.embed_dim, self.embed_dim), config['cycle_num']) # without LCA self.head_img = nn.Linear(self.embed_dim, self.embed_dim) # without VIR self.head_txt = nn.Linear(self.embed_dim, self.embed_dim) # without LCA self.img_sa = clones(build_self_attention(config, model='image'), config['instru_num']) # without VIR self.img_ca = clones(build_self_attention(config, model='cross'), config['instru_num']) # without VIR self.txt_sa = clones(build_self_attention(config, model='text'), config['cycle_num'])# without LCA self.txt_ca = clones(build_self_attention(config, model='cross'),config['cycle_num'])# without LCA def load_pretrained_pir(self, ckpt_rpath, config, is_eval=False): state_dict = load_pretrained_pir(ckpt_rpath, config, is_eval=is_eval, load_text=True) msg = self.load_state_dict(state_dict, strict=False) print('load checkpoint from %s' % ckpt_rpath) print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p]) print("unexpected_keys: ", msg.unexpected_keys) def get_vision_embeds(self, image): """ vision_embeds: cls + patch embeds """ return F.normalize(self.vision_proj(self.vision_encoder(image))[:, 0, :]) def get_text_embeds(self, text_ids): """ text_embeds: cls + sequence embeds """ return F.normalize(self.text_proj(self.text_encoder(text_ids))[:, 0, :]) def get_vision_fusion_embeds(self, image, config): """ Vision Instruction Representation-VLR """ filter_size = config['filter_size'] swin_feat = self.vision_proj(self.vision_encoder(image)) # ResNet and VGG conv_feat = self.conv_proj(self.conv_encoder(image).squeeze()).unsqueeze(dim=1) # ## Swin and ViT # conv_feat = self.conv_proj(self.conv_encoder(image)[:, 0, :].squeeze()).unsqueeze(dim=1) image_g_emb = swin_feat[:, 0, :] swin_feat_loc = swin_feat[:, 0:50,:] # Rank & Filter score_feat = F.softmax(torch.matmul(swin_feat_loc, conv_feat.transpose(-2, -1)).squeeze(), dim=-1) sorted_fscore, sorted_ind = torch.sort(score_feat, dim=1, descending=True) swin_feat_fi_ = [] for i in range(swin_feat_loc.shape[0]): swin_feat_fi_.append(torch.index_select(swin_feat_loc[i,:,:], 0, sorted_ind[i]).unsqueeze(dim=0)) swin_feat_fi = torch.cat(swin_feat_fi_, dim=0)[:, :filter_size, :] conv_feats = conv_feat.expand(swin_feat_fi.shape) # print('conv_feats{}'.format(conv_feats.shape)) for i in range(config['instru_num']): swin_att = self.img_sa[i](swin_feat_fi) # swin_inst = self.img_ca[i](swin_att, self.mapping_img[i](conv_feat)) swin_inst = self.img_ca[i](self.mapping_img[i](conv_feats), swin_att) # print('swin_inst: {}'.format(swin_inst.shape)) swin_feat_fi = swin_inst img_l_adj = swin_inst[:, 0, :] return F.normalize(image_g_emb + self.head_img(img_l_adj)) def get_text_fusion_embeds(self, text_ids, config): """ Language Cycle Attention--LCA """ text_feat = self.text_proj(self.text_encoder(text_ids)) # mask_text_feat = self.text_proj(self.text_encoder(mask_text_ids)) nu_text = text_feat.shape[1] text_g_emb = text_feat[:, 0, :] text_l_embs = text_feat[:, 0:nu_text, :] for i in range(config['cycle_num']): text_att = self.txt_sa[i](text_l_embs) # text_cros = self.txt_ca[i](text_att, self.mapping_txt[i](text_l_embs)) text_cros = self.txt_ca[i](self.mapping_txt[i](text_l_embs), text_att) text_l_embs = text_cros text_l_adj = text_l_embs[:, 0, :] return F.normalize(text_g_emb + self.head_txt(text_l_adj)) ### 等参SA实验 text_feat = self.text_proj(self.text_encoder(text_ids)) nu_text = text_feat.shape[1] text_g_emb = text_feat[:, 0, :] text_l_embs = text_feat[:, 0:nu_text, :] for i in range(config['cycle_num']): text_att = self.txt_sa[i](text_l_embs) text_l_embs = text_att text_l_adj = text_l_embs[:, 0, :] return F.normalize(self.head_txt(text_l_adj)) def get_contr_loss(self, image_feat, text_feat, idx=None, label=None, config=None): """ Args: image_feat, text_feat: normalized Returns: contrastive loss """ assert image_feat.size(-1) == self.embed_dim assert text_feat.size(-1) == self.embed_dim image_feat_all = allgather(image_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) text_feat_all = allgather(text_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) logits = image_feat_all @ text_feat_all.t() / self.temp # print(logits) bsz = image_feat_all.shape[0] if idx is None: labels = torch.arange(bsz, device=image_feat.device) loss_i2t = F.cross_entropy(logits, labels) loss_t2i = F.cross_entropy(logits.t(), labels) else: idx = idx.view(-1, 1) assert idx.size(0) == image_feat.size(0) ## 生成对角阵 idx_all = allgather(idx, torch.distributed.get_rank(), torch.distributed.get_world_size()) pos_idx = torch.eq(idx_all, idx_all.t()).float() labels = pos_idx / pos_idx.sum(dim=1, keepdim=True) loss_i2t = -torch.sum(F.log_softmax(logits, dim=1) * labels, dim=1).mean() loss_t2i = -torch.sum(F.log_softmax(logits.t(), dim=1) * labels, dim=1).mean() return (loss_i2t + loss_t2i) / 2 def get_affil_loss(self, image_feat, text_feat, idx=None, label=None, config=None): assert image_feat.size(-1) == self.embed_dim assert text_feat.size(-1) == self.embed_dim # logits = image_feat @ text_feat.t() la_idx = torch.eq(label.unsqueeze(dim=1), label.unsqueeze(dim=1).t()).float() # 然后计算他们的聚类中心 img_centers = [] txt_centers = [] for i in range(image_feat.shape[0]): # # 计算均值聚类中心 mod = la_idx[i].unsqueeze(dim=1) mask = mod.repeat(1, 512) non_zero_num = torch.sum(mod, dim=0) # print(non_zero_num) img_center = (image_feat * mask).sum(dim=0, keepdim=True) / non_zero_num txt_center = (text_feat * mask).sum(dim=0, keepdim=True) / non_zero_num img_centers.append(img_center) txt_centers.append(txt_center) img_centers = torch.cat(img_centers, dim=0) txt_centers = torch.cat(txt_centers, dim=0) img_centers_all = allgather(img_centers, torch.distributed.get_rank(), torch.distributed.get_world_size()) txt_centers_all = allgather(txt_centers, torch.distributed.get_rank(), torch.distributed.get_world_size()) image_feat_all = allgather(image_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) text_feat_all = allgather(text_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) img2txt_center = image_feat_all @ txt_centers_all.t() / self.temp2 txt2img_center = text_feat_all @ img_centers_all.t() / self.temp2 bsz = img2txt_center.shape[0] labels = torch.eye(bsz, device=image_feat.device) loss_i2t = -torch.sum(F.log_softmax(img2txt_center, dim=1) * labels, dim=1).mean() loss_t2i = -torch.sum(F.log_softmax(txt2img_center.t(), dim=1) * labels, dim=1).mean() return (loss_i2t + loss_t2i) / 2 def get_triplet_loss(self, image_feat, text_feat, margin=0.2, max_violation=False): assert image_feat.size(-1) == self.embed_dim assert text_feat.size(-1) == self.embed_dim image_feat_all = allgather(image_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) text_feat_all = allgather(text_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) scores = image_feat_all @ text_feat_all.t() # print(logits) bsz = image_feat_all.shape[0] diagonal = scores.diag().view(bsz, 1) d1 = diagonal.expand_as(scores) d2 = diagonal.t().expand_as(scores) # compare every diagonal score to scores in its column # caption retrieval cost_s = (margin + scores - d1).clamp(min=0) # compare every diagonal score to scores in its row # image retrieval cost_im = (margin + scores - d2).clamp(min=0) mask = torch.eye(scores.size(0)) > .5 I = Variable(mask) if torch.cuda.is_available(): I = I.cuda(device=image_feat.device) cost_s = cost_s.masked_fill_(I, 0) cost_im = cost_im.masked_fill_(I, 0) if max_violation: cost_s = cost_s.max(1)[0] cost_im = cost_im.max(0)[0] sum_cost_s = cost_s.sum() sum_cost_im = cost_im.sum() return sum_cost_s + sum_cost_im -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | import os 6 | 7 | import torchvision 8 | 9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=dilation, groups=groups, bias=False, dilation=dilation) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 24 | base_width=64, dilation=1, norm_layer=None): 25 | super(BasicBlock, self).__init__() 26 | if norm_layer is None: 27 | norm_layer = nn.BatchNorm2d 28 | if groups != 1 or base_width != 64: 29 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 30 | if dilation > 1: 31 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 32 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = norm_layer(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = norm_layer(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 64 | base_width=64, dilation=1, norm_layer=None): 65 | super(Bottleneck, self).__init__() 66 | if norm_layer is None: 67 | norm_layer = nn.BatchNorm2d 68 | width = int(planes * (base_width / 64.)) * groups 69 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 70 | self.conv1 = conv1x1(inplanes, width) 71 | self.bn1 = norm_layer(width) 72 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 73 | self.bn2 = norm_layer(width) 74 | self.conv3 = conv1x1(width, planes * self.expansion) 75 | self.bn3 = norm_layer(planes * self.expansion) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | identity = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | identity = self.downsample(x) 96 | 97 | out += identity 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, block, layers, num_classes=51, zero_init_residual=False, 106 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 107 | norm_layer=None): 108 | super(ResNet, self).__init__() 109 | if norm_layer is None: 110 | norm_layer = nn.BatchNorm2d 111 | self._norm_layer = norm_layer 112 | 113 | self.inplanes = 64 114 | self.dilation = 1 115 | if replace_stride_with_dilation is None: 116 | # each element in the tuple indicates if we should replace 117 | # the 2x2 stride with a dilated convolution instead 118 | replace_stride_with_dilation = [False, False, False] 119 | if len(replace_stride_with_dilation) != 3: 120 | raise ValueError("replace_stride_with_dilation should be None " 121 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 122 | self.groups = groups 123 | self.base_width = width_per_group 124 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 125 | bias=False) 126 | self.bn1 = norm_layer(self.inplanes) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 129 | self.layer1 = self._make_layer(block, 64, layers[0]) 130 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 131 | dilate=replace_stride_with_dilation[0]) 132 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 133 | dilate=replace_stride_with_dilation[1]) 134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 135 | dilate=replace_stride_with_dilation[2]) 136 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 137 | self.fc = nn.Linear(512 * block.expansion, num_classes) 138 | 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv2d): 141 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 142 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 143 | nn.init.constant_(m.weight, 1) 144 | nn.init.constant_(m.bias, 0) 145 | 146 | # Zero-initialize the last BN in each residual branch, 147 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 148 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 149 | if zero_init_residual: 150 | for m in self.modules(): 151 | if isinstance(m, Bottleneck): 152 | nn.init.constant_(m.bn3.weight, 0) 153 | elif isinstance(m, BasicBlock): 154 | nn.init.constant_(m.bn2.weight, 0) 155 | 156 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 157 | norm_layer = self._norm_layer 158 | downsample = None 159 | previous_dilation = self.dilation 160 | if dilate: 161 | self.dilation *= stride 162 | stride = 1 163 | if stride != 1 or self.inplanes != planes * block.expansion: 164 | downsample = nn.Sequential( 165 | conv1x1(self.inplanes, planes * block.expansion, stride), 166 | norm_layer(planes * block.expansion), 167 | ) 168 | 169 | layers = [] 170 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 171 | self.base_width, previous_dilation, norm_layer)) 172 | self.inplanes = planes * block.expansion 173 | for _ in range(1, blocks): 174 | layers.append(block(self.inplanes, planes, groups=self.groups, 175 | base_width=self.base_width, dilation=self.dilation, 176 | norm_layer=norm_layer)) 177 | 178 | return nn.Sequential(*layers) 179 | 180 | def forward(self, x): 181 | x = self.conv1(x) 182 | x = self.bn1(x) 183 | x = self.relu(x) 184 | x = self.maxpool(x) 185 | 186 | x = self.layer1(x) 187 | x = self.layer2(x) 188 | x = self.layer3(x) 189 | x = self.layer4(x) 190 | 191 | x = self.avgpool(x) 192 | x = torch.flatten(x, 1) 193 | x = self.fc(x) 194 | 195 | return x 196 | 197 | 198 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 199 | model = ResNet(block, layers, **kwargs) 200 | return model 201 | 202 | 203 | 204 | def resnet50(pretrained=False, progress=True, **kwargs): 205 | r"""ResNet-50 model from 206 | `"Deep Residual Learning for Image Recognition" `_ 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | progress (bool): If True, displays a progress bar of the download to stderr 211 | """ 212 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 213 | **kwargs) 214 | 215 | def resnet101(pretrained=False, progress=True, **kwargs): 216 | r"""ResNet-101 model from 217 | `"Deep Residual Learning for Image Recognition" `_ 218 | 219 | Args: 220 | pretrained (bool): If True, returns a model pre-trained on ImageNet 221 | progress (bool): If True, displays a progress bar of the download to stderr 222 | """ 223 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 224 | **kwargs) 225 | 226 | 227 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from functools import partial 7 | 8 | from timm.models.vision_transformer import _cfg, PatchEmbed 9 | from timm.models.registry import register_model 10 | from timm.models.layers import trunc_normal_, DropPath 11 | 12 | 13 | class Mlp(nn.Module): 14 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 15 | """ 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 36 | super().__init__() 37 | self.num_heads = num_heads 38 | head_dim = dim // num_heads 39 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 40 | self.scale = qk_scale or head_dim ** -0.5 41 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | self.attn_gradients = None 46 | self.attention_map = None 47 | 48 | def save_attn_gradients(self, attn_gradients): 49 | self.attn_gradients = attn_gradients 50 | 51 | def get_attn_gradients(self): 52 | return self.attn_gradients 53 | 54 | def save_attention_map(self, attention_map): 55 | self.attention_map = attention_map 56 | 57 | def get_attention_map(self): 58 | return self.attention_map 59 | 60 | def forward(self, x, register_hook=False, image_atts=None): 61 | B, N, C = x.shape 62 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 63 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 64 | 65 | attn = (q @ k.transpose(-2, -1)) * self.scale 66 | 67 | if image_atts is not None: 68 | attn += image_atts 69 | 70 | attn = attn.softmax(dim=-1) 71 | attn = self.attn_drop(attn) 72 | 73 | if register_hook: 74 | self.save_attention_map(attn) 75 | attn.register_hook(self.save_attn_gradients) 76 | 77 | # attn: (bs, num_heads, num_patches, num_patches) 78 | # v: (bs, num_heads, num_patches, d) 79 | # attn @ v: (bs, num_heads, num_patches, d) 80 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 81 | x = self.proj(x) 82 | x = self.proj_drop(x) 83 | return x 84 | 85 | 86 | class Block(nn.Module): 87 | 88 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 89 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 90 | super().__init__() 91 | self.norm1 = norm_layer(dim) 92 | self.attn = Attention( 93 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 94 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 95 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 96 | self.norm2 = norm_layer(dim) 97 | mlp_hidden_dim = int(dim * mlp_ratio) 98 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 99 | 100 | def forward(self, x, register_hook=False, image_atts=None): 101 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook, image_atts=image_atts)) 102 | x = x + self.drop_path(self.mlp(self.norm2(x))) 103 | return x 104 | 105 | 106 | class VisionTransformer(nn.Module): 107 | """ Vision Transformer 108 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 109 | https://arxiv.org/abs/2010.11929 110 | """ 111 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 112 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 113 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, local_attn_depth=0): 114 | """ 115 | Args: 116 | img_size (int, tuple): input image size 117 | patch_size (int, tuple): patch size 118 | in_chans (int): number of input channels 119 | num_classes (int): number of classes for classification head 120 | embed_dim (int): embedding dimension 121 | depth (int): depth of transformer 122 | num_heads (int): number of attention heads 123 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 124 | qkv_bias (bool): enable bias for qkv if True 125 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 126 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 127 | drop_rate (float): dropout rate 128 | attn_drop_rate (float): attention dropout rate 129 | drop_path_rate (float): stochastic depth rate 130 | norm_layer: (nn.Module): normalization layer 131 | """ 132 | super().__init__() 133 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 134 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 135 | 136 | self.patch_embed = PatchEmbed( 137 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 138 | 139 | self.num_patch_embed = self.patch_embed.num_patches 140 | 141 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 142 | 143 | self.num_pos_embed = self.num_patch_embed + 1 144 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_pos_embed, embed_dim)) 145 | 146 | self.pos_drop = nn.Dropout(p=drop_rate) 147 | 148 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 149 | self.blocks = nn.ModuleList([ 150 | Block( 151 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 152 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 153 | for i in range(depth)]) 154 | 155 | self.depth = depth 156 | self.local_attn_depth = local_attn_depth # do local attn from index=(depth - local_attn_depth) 157 | 158 | self.norm = norm_layer(embed_dim) 159 | 160 | trunc_normal_(self.pos_embed, std=.02) 161 | trunc_normal_(self.cls_token, std=.02) 162 | self.apply(self._init_weights) 163 | 164 | def _init_weights(self, m): 165 | if isinstance(m, nn.Linear): 166 | trunc_normal_(m.weight, std=.02) 167 | if isinstance(m, nn.Linear) and m.bias is not None: 168 | nn.init.constant_(m.bias, 0) 169 | elif isinstance(m, nn.LayerNorm): 170 | nn.init.constant_(m.bias, 0) 171 | nn.init.constant_(m.weight, 1.0) 172 | 173 | @torch.jit.ignore 174 | def no_weight_decay(self): 175 | return {'pos_embed', 'cls_token'} 176 | 177 | def forward(self, x, register_blk=-1, idx_to_group_img=None, image_atts=None): 178 | 179 | B = x.shape[0] 180 | x = self.patch_embed(x) 181 | 182 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 183 | x = torch.cat((cls_tokens, x), dim=1) 184 | 185 | x = x + self.pos_embed[:,:x.size(1),:] 186 | x = self.pos_drop(x) 187 | 188 | do_gather = True if idx_to_group_img is not None else False 189 | 190 | if do_gather and (image_atts is not None): 191 | full_atts = torch.ones(x.shape[:2], dtype=x.dtype).to(x.device) 192 | image_atts_blk = torch.cat([image_atts, full_atts], dim=0) 193 | 194 | image_atts_blk = image_atts_blk.unsqueeze(1).unsqueeze(2) 195 | image_atts_blk = (1.0 - image_atts_blk) * -10000.0 196 | else: 197 | image_atts_blk = None 198 | 199 | for i, blk in enumerate(self.blocks): 200 | if (self.local_attn_depth > 0) and (i >= self.depth-self.local_attn_depth): 201 | if do_gather: 202 | do_gather = False 203 | 204 | x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2])) 205 | x = torch.cat([x_bs, x], dim=0) 206 | 207 | x = blk(x, register_blk == i, image_atts=image_atts_blk) 208 | 209 | else: 210 | x = blk(x, register_blk==i, image_atts=None) 211 | 212 | x = self.norm(x) 213 | 214 | if idx_to_group_img is not None: 215 | bs = len(idx_to_group_img) 216 | x_bs, x_fullatts = torch.split(x, [bs, x.size(0)-bs]) 217 | return x_bs, x_fullatts 218 | 219 | return x 220 | 221 | 222 | def interpolate_pos_embed(pos_embed_checkpoint, num_patches, num_extra_tokens=1): 223 | # num_patches = visual_encoder.num_patch_embed 224 | # num_extra_tokens = visual_encoder.num_pos_embed - visual_encoder.num_patch_embed 225 | 226 | # interpolate position embedding 227 | embedding_size = pos_embed_checkpoint.shape[-1] 228 | # height (== width) for the checkpoint position embedding 229 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 230 | # height (== width) for the new position embedding 231 | new_size = int(num_patches ** 0.5) 232 | 233 | if orig_size != new_size: 234 | # class_token and dist_token are kept unchanged 235 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 236 | # only the position tokens are interpolated 237 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 238 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 239 | pos_tokens = torch.nn.functional.interpolate( 240 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 241 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 242 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 243 | print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2)) 244 | 245 | return new_pos_embed 246 | else: 247 | return pos_embed_checkpoint 248 | -------------------------------------------------------------------------------- /mytools.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | # Author:Zhiqiang Yuan 3 | """导入一些包""" 4 | import os 5 | import time, random 6 | import json 7 | import numpy as np 8 | from sklearn.decomposition import PCA 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d import Axes3D 11 | 12 | """ 打印一些东西 """ 13 | """----------------------------------------------------------------------""" 14 | 15 | 16 | # 打印列表按照竖行的形式 17 | def print_list(list): 18 | print("++++++++++++++++++++++++++++++++++++++++++++") 19 | for l in list: 20 | print(l) 21 | print("++++++++++++++++++++++++++++++++++++++++++++") 22 | 23 | 24 | # 打印字典按照竖行的形式 25 | def print_dict(dict): 26 | print("++++++++++++++++++++++++++++++++++++++++++++") 27 | for k, v in dict.items(): 28 | print("key:", k, " value:", v) 29 | print("++++++++++++++++++++++++++++++++++++++++++++") 30 | 31 | 32 | # 打印一些东西,加入标识符 33 | def print_with_log(info): 34 | print("++++++++++++++++++++++++++++++++++++++++++++") 35 | print(info) 36 | print("++++++++++++++++++++++++++++++++++++++++++++") 37 | 38 | 39 | # 打印标识符 40 | def print_log(): 41 | print("++++++++++++++++++++++++++++++++++++++++++++") 42 | 43 | 44 | """ 文件存储 """ 45 | """----------------------------------------------------------------------""" 46 | 47 | 48 | # 保存结果到json文件 49 | def save_to_json(info, filename, encoding='UTF-8'): 50 | with open(filename, "w", encoding=encoding) as f: 51 | json.dump(info, f, indent=2, separators=(',', ':')) 52 | 53 | 54 | # 从json文件中读取 55 | def load_from_json(filename): 56 | with open(filename, encoding='utf-8') as f: 57 | info = json.load(f) 58 | return info 59 | 60 | 61 | # 储存为npy文件 62 | def save_to_npy(info, filename): 63 | np.save(filename, info, allow_pickle=True) 64 | 65 | 66 | # 从npy中读取 67 | def load_from_npy(filename): 68 | info = np.load(filename, allow_pickle=True) 69 | return info 70 | 71 | 72 | # 保存结果到txt文件 73 | def log_to_txt(contexts=None, filename="save.txt", mark=False, encoding='UTF-8', add_n=False): 74 | f = open(filename, "a", encoding=encoding) 75 | if mark: 76 | sig = "------------------------------------------------\n" 77 | f.write(sig) 78 | elif isinstance(contexts, dict): 79 | tmp = "" 80 | for c in contexts.keys(): 81 | tmp += str(c) + " | " + str(contexts[c]) + "\n" 82 | contexts = tmp 83 | f.write(contexts) 84 | else: 85 | if isinstance(contexts, list): 86 | tmp = "" 87 | for c in contexts: 88 | if add_n: 89 | tmp += str(c) + "\n" 90 | else: 91 | tmp += str(c) 92 | contexts = tmp 93 | else: 94 | contexts = contexts + "\n" 95 | f.write(contexts) 96 | 97 | f.close() 98 | 99 | 100 | # 从txt中读取行 101 | def load_from_txt(filename, encoding="utf-8"): 102 | f = open(filename, 'r', encoding=encoding) 103 | contexts = f.readlines() 104 | return contexts 105 | 106 | 107 | """ 字典变换 """ 108 | """----------------------------------------------------------------------""" 109 | 110 | 111 | # 键值互换 112 | def dict_k_v_exchange(dict): 113 | tmp = {} 114 | for key, value in dict.items(): 115 | tmp[value] = key 116 | return tmp 117 | 118 | 119 | # 2维数组转字典 120 | def d2array_to_dict(d2array): 121 | # Input: N x 2 list 122 | # Output: dict 123 | dict = {} 124 | for item in d2array: 125 | if item[0] not in dict.keys(): 126 | dict[item[0]] = [item[1]] 127 | else: 128 | dict[item[0]].append(item[1]) 129 | return dict 130 | 131 | 132 | """ 绘图 """ 133 | """----------------------------------------------------------------------""" 134 | 135 | 136 | # 绘制3D图像 137 | def visual_3d_points(list, color=True): 138 | """ 139 | :param list: N x (dim +1) 140 | N 为点的数量 141 | dim 为 输入数据的维度 142 | 1 为类别, 即可视化的颜色 当且仅当color为True时 143 | """ 144 | list = np.array(list) 145 | if color: 146 | data = list[:, :4] 147 | label = list[:, -1] 148 | else: 149 | data = list 150 | label = None 151 | 152 | # PCA降维 153 | pca = PCA(n_components=3, whiten=True).fit(data) 154 | data = pca.transform(data) 155 | 156 | # 定义坐标轴 157 | fig = plt.figure() 158 | ax1 = plt.axes(projection='3d') 159 | if label is not None: 160 | color = label 161 | else: 162 | color = "blue" 163 | ax1.scatter3D(np.transpose(data)[0], np.transpose(data)[1], np.transpose(data)[2], c=color) # 绘制散点图 164 | 165 | plt.show() 166 | 167 | 168 | """ 实用工具 """ 169 | """----------------------------------------------------------------------""" 170 | 171 | 172 | # 计算数组中元素出现的个数 173 | def count_list(lens): 174 | dict = {} 175 | for key in lens: 176 | dict[key] = dict.get(key, 0) + 1 177 | dict = sorted(dict.items(), key=lambda x: x[1], reverse=True) 178 | 179 | print_list(dict) 180 | return dict 181 | 182 | 183 | # list 加法 w1、w2为权重 184 | def list_add(list1, list2, w1=1, w2=1): 185 | return [l1 * w1 + l2 * w2 for (l1, l2) in zip(list1, list2)] 186 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | from transformers.optimization import AdamW 2 | from torch.optim import Adam 3 | 4 | def create_optimizer(args, model): 5 | lr = args.lr 6 | wd = args.weight_decay 7 | lr_mult = getattr(args, 'lr_mult', 1) 8 | print("### lr_mult, ", lr_mult) 9 | 10 | optimizer_grouped_parameters = [ 11 | {"params": [], "weight_decay": wd, "lr": lr}, 12 | {"params": [], "weight_decay": 0.0, "lr": lr}, 13 | {"params": [], "weight_decay": wd, "lr": lr * lr_mult}, 14 | {"params": [], "weight_decay": 0.0, "lr": lr * lr_mult} 15 | ] 16 | 17 | no_decay = {"bias", 18 | "LayerNorm.bias", 19 | "LayerNorm.weight", 20 | "norm.bias", 21 | "norm.weight", 22 | "norm1.bias", 23 | "norm1.weight", 24 | "norm2.bias", 25 | "norm2.weight"} 26 | 27 | if hasattr(model, 'init_params'): 28 | large_lr = model.init_params 29 | print("### model has 'init_params', ", len(large_lr)) 30 | else: 31 | large_lr = {} 32 | 33 | for n, p in model.named_parameters(): 34 | if not p.requires_grad: 35 | continue # frozen weights 36 | 37 | if any(nd in n for nd in no_decay): 38 | if n in large_lr: 39 | optimizer_grouped_parameters[3]['params'].append(p) 40 | else: 41 | optimizer_grouped_parameters[1]['params'].append(p) 42 | else: # decay 43 | if n in large_lr: 44 | optimizer_grouped_parameters[2]['params'].append(p) 45 | else: 46 | optimizer_grouped_parameters[0]['params'].append(p) 47 | 48 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)) 49 | # 使用Adam优化器 50 | # optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) 51 | return optimizer 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Package Version 2 | --------------------- ------------ 3 | attrs 22.2.0 4 | backcall 0.2.0 5 | certifi 2022.12.7 6 | charset-normalizer 3.0.1 7 | click 8.1.3 8 | coloredlogs 15.0.1 9 | cycler 0.11.0 10 | debugpy 1.6.6 11 | decorator 5.1.1 12 | entrypoints 0.4 13 | filelock 3.9.0 14 | flatbuffers 23.3.3 15 | fonttools 4.38.0 16 | htmlmin 0.1.12 17 | huggingface-hub 0.12.1 18 | humanfriendly 10.0 19 | idna 3.4 20 | ImageHash 4.3.1 21 | imageio 2.26.0 22 | importlib-metadata 6.0.0 23 | install 1.3.5 24 | ipykernel 6.16.2 25 | ipython 7.34.0 26 | ipywidgets 8.0.4 27 | jedi 0.18.2 28 | jieba 0.42.1 29 | Jinja2 3.1.2 30 | joblib 1.2.0 31 | jupyter_client 7.4.9 32 | jupyter_core 4.12.0 33 | jupyterlab-widgets 3.0.5 34 | kiwisolver 1.4.4 35 | MarkupSafe 2.1.2 36 | matplotlib 3.5.3 37 | matplotlib-inline 0.1.6 38 | mpmath 1.3.0 39 | multimethod 1.9.1 40 | nest-asyncio 1.5.6 41 | networkx 2.6.3 42 | numpy 1.21.6 43 | onnx 1.14.0 44 | onnxruntime 1.14.1 45 | opencv-python 4.7.0.72 46 | packaging 23.0 47 | pandas 1.3.5 48 | pandas-profiling 3.6.6 49 | parso 0.8.3 50 | patsy 0.5.3 51 | pexpect 4.8.0 52 | phik 0.12.3 53 | pickleshare 0.7.5 54 | Pillow 9.4.0 55 | pip 22.3.1 56 | prompt-toolkit 3.0.38 57 | protobuf 4.22.3 58 | psutil 5.9.4 59 | ptyprocess 0.7.0 60 | pycocoevalcap 1.2 61 | pycocotools 2.0.6 62 | pydantic 1.10.6 63 | Pygments 2.14.0 64 | pyparsing 3.0.9 65 | python-dateutil 2.8.2 66 | pytz 2022.7.1 67 | PyWavelets 1.3.0 68 | PyYAML 6.0 69 | pyzmq 25.0.1 70 | regex 2022.10.31 71 | requests 2.28.2 72 | ruamel.yaml 0.17.21 73 | ruamel.yaml.clib 0.2.7 74 | sacremoses 0.0.53 75 | scikit-image 0.19.3 76 | scikit-learn 1.0.2 77 | scipy 1.7.3 78 | seaborn 0.12.2 79 | setuptools 65.6.3 80 | six 1.16.0 81 | statsmodels 0.13.5 82 | sympy 1.10.1 83 | tangled-up-in-unicode 0.2.0 84 | tensorboard-logger 0.1.0 85 | termcolor 2.2.0 86 | threadpoolctl 3.1.0 87 | tifffile 2021.11.2 88 | timm 0.4.9 89 | tokenizers 0.10.3 90 | torch 1.10.0+cu113 91 | torch-cluster 1.5.9 92 | torch-geometric 2.2.0 93 | torch-scatter 2.0.9 94 | torch-sparse 0.6.12 95 | torch-spline-conv 1.2.1 96 | torchaudio 0.10.0+cu113 97 | torchvision 0.11.1+cu113 98 | tornado 6.2 99 | tqdm 4.64.1 100 | traitlets 5.9.0 101 | transformers 4.12.5 102 | typeguard 2.13.3 103 | typing_extensions 4.5.0 104 | urllib3 1.26.14 105 | visions 0.7.5 106 | wcwidth 0.2.6 107 | wheel 0.38.4 108 | widgetsnbextension 4.0.5 109 | yacs 0.1.8 110 | ydata-profiling 4.1.1 111 | zipp 3.15.0 112 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import argparse 6 | 7 | from utils.hdfs_io import HADOOP_BIN, hexists, hmkdir, hcopy 8 | 9 | def get_dist_launch(args): # some examples 10 | 11 | if args.dist == 'f4': 12 | return "CUDA_VISIBLE_DEVICES=0,1,2,3 WORLD_SIZE=4 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9999 --nproc_per_node=4 " \ 13 | "--nnodes=1 " 14 | 15 | elif args.dist == 'f2': 16 | return "CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9999 --nproc_per_node=2 " \ 17 | "--nnodes=1 " 18 | 19 | elif args.dist == 'f3': 20 | return "CUDA_VISIBLE_DEVICES=0,1,2 WORLD_SIZE=3 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9999 --nproc_per_node=3 " \ 21 | "--nnodes=1 " 22 | 23 | elif args.dist == 'f12': 24 | return "CUDA_VISIBLE_DEVICES=1,2 WORLD_SIZE=2 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9999 --nproc_per_node=2 " \ 25 | "--nnodes=1 " 26 | 27 | elif args.dist == 'f02': 28 | return "CUDA_VISIBLE_DEVICES=0,2 WORLD_SIZE=2 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9999 --nproc_per_node=2 " \ 29 | "--nnodes=1 " 30 | 31 | elif args.dist == 'f03': 32 | return "CUDA_VISIBLE_DEVICES=0,3 WORLD_SIZE=2 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9999 --nproc_per_node=2 " \ 33 | "--nnodes=1 " 34 | 35 | elif args.dist == 'l2': 36 | return "CUDA_VISIBLE_DEVICES=2,3 WORLD_SIZE=2 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9998 --nproc_per_node=2 " \ 37 | "--nnodes=1 " 38 | 39 | elif args.dist.startswith('gpu'): # use one gpu, --dist "gpu0" 40 | num = int(args.dist[3:]) 41 | assert 0 <= num <= 8 42 | return "CUDA_VISIBLE_DEVICES={:} WORLD_SIZE=1 /home/pjc/.conda/envs/xlvm/bin/python -W ignore -m torch.distributed.launch --master_port 9999 --nproc_per_node=1 " \ 43 | "--nnodes=1 ".format(num) 44 | 45 | else: 46 | raise ValueError 47 | 48 | 49 | def get_from_hdfs(file_hdfs): 50 | """ 51 | compatible to HDFS path or local path 52 | """ 53 | if file_hdfs.startswith('hdfs'): 54 | file_local = os.path.split(file_hdfs)[-1] 55 | 56 | if os.path.exists(file_local): 57 | print(f"rm existing {file_local}") 58 | os.system(f"rm {file_local}") 59 | 60 | hcopy(file_hdfs, file_local) 61 | 62 | else: 63 | file_local = file_hdfs 64 | assert os.path.exists(file_local) 65 | 66 | return file_local 67 | 68 | 69 | def run_retrieval(args): 70 | dist_launch = get_dist_launch(args) 71 | 72 | os.system(f"{dist_launch} " 73 | f"--use_env Retrieval.py --config {args.config} " 74 | f"--output_dir {args.output_dir} --bs {args.bs} --checkpoint {args.checkpoint} {'--evaluate' if args.evaluate else ''}") 75 | 76 | 77 | def run(args): 78 | if args.task == 'itr_rsicd': 79 | assert os.path.exists("../X-VLM-pytorch/images/rsicd") 80 | args.config = 'configs/Retrieval_rsicd.yaml' 81 | run_retrieval(args) 82 | 83 | elif args.task == 'itr_rsitmd': 84 | assert os.path.exists("../X-VLM-pytorch/images/rsitmd") 85 | args.config = 'configs/Retrieval_rsitmd.yaml' 86 | run_retrieval(args) 87 | 88 | elif args.task == 'itr_coco': 89 | assert os.path.exists("../X-VLM-pytorch/images/coco") 90 | args.config = 'configs/Retrieval_coco.yaml' 91 | run_retrieval(args) 92 | 93 | elif args.task == 'itr_nwpu': 94 | assert os.path.exists("../X-VLM-pytorch/images/NWPU") 95 | args.config = 'configs/Retrieval_nwpu.yaml' 96 | run_retrieval(args) 97 | else: 98 | raise NotImplementedError(f"task == {args.task}") 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--task', type=str, default='itr_rsitmd') 104 | parser.add_argument('--dist', type=str, default='f2', help="see func get_dist_launch for details") 105 | parser.add_argument('--config', default='configs/Retrieval_rsitmd.yaml', type=str, help="if not given, use default") 106 | parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus; " 107 | "this option only works for fine-tuning scripts.") 108 | parser.add_argument('--seed', default=42, type=int) 109 | parser.add_argument('--checkpoint', default='-1', type=str, help="for fine-tuning") 110 | parser.add_argument('--load_ckpt_from', default=' ', type=str, help="load domain pre-trained params") 111 | # write path: local or HDFS 112 | parser.add_argument('--output_dir', type=str, default='./outputs/test', help='for fine-tuning, local path; ' 113 | 'for pre-training, local and HDFS are both allowed.') 114 | parser.add_argument('--evaluate', action='store_true', default=False, help="evaluation on downstream tasks") 115 | args = parser.parse_args() 116 | assert hexists(os.path.dirname(args.output_dir)) 117 | hmkdir(args.output_dir) 118 | run(args) 119 | 120 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | 3 | 4 | def create_scheduler(args, optimizer): 5 | if 'num_training_steps' not in args: 6 | args['num_training_steps'] = args['epochs'] * args['step_per_epoch'] 7 | print("### num_training_steps, ", args['num_training_steps'], flush=True) 8 | 9 | if isinstance(args['num_warmup_steps'], float): 10 | assert 0 <= args['num_warmup_steps'] < 1 11 | args['num_warmup_steps'] = int(args['num_training_steps'] * args['num_warmup_steps']) 12 | print("### num_warmup_steps, ", args['num_warmup_steps'], flush=True) 13 | 14 | if args.sched == 'linear': 15 | def lr_lambda(current_step: int): 16 | if current_step < args.num_warmup_steps: 17 | return float(current_step) / float(max(1, args.num_warmup_steps)) 18 | return max( 19 | 0.0, float(args.num_training_steps - current_step) / float( 20 | max(1, args.num_training_steps - args.num_warmup_steps)) 21 | ) 22 | 23 | lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) 24 | 25 | else: 26 | raise NotImplementedError(f"args.sched == {args.sched}") 27 | 28 | return lr_scheduler 29 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zjut-MultimediaPlus/PIR-pytorch/aeadf5ff0fba18ceb496495058c30501b2b5bb34/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from collections import defaultdict, deque, OrderedDict 5 | 6 | import datetime 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD 14 | 15 | 16 | class ScstRewardCriterion(torch.nn.Module): 17 | CIDER_REWARD_WEIGHT = 1 18 | 19 | def __init__(self, cider_cached_tokens='corpus', baseline_type='greedy'): 20 | self.CiderD_scorer = CiderD(df=cider_cached_tokens) 21 | assert baseline_type in ['greedy', 'sample'] 22 | self.baseline_type = baseline_type 23 | self._cur_score = None 24 | super().__init__() 25 | 26 | def forward(self, gt_res, greedy_res, sample_res, sample_logprobs): 27 | batch_size = len(gt_res) 28 | sample_res_size = len(sample_res) 29 | seq_per_img = sample_res_size // batch_size 30 | 31 | gen_res = [] 32 | gen_res.extend(sample_res) 33 | gt_idx = [i // seq_per_img for i in range(sample_res_size)] 34 | if self.baseline_type == 'greedy': 35 | assert len(greedy_res) == batch_size 36 | gen_res.extend(greedy_res) 37 | gt_idx.extend([i for i in range(batch_size)]) 38 | 39 | scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res) 40 | 41 | if self.baseline_type == 'greedy': 42 | baseline = scores[-batch_size:][:, np.newaxis] 43 | else: 44 | sc_ = scores.reshape(batch_size, seq_per_img) 45 | baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1) 46 | 47 | # sample - baseline 48 | reward = scores[:sample_res_size].reshape(batch_size, seq_per_img) 49 | self._cur_score = reward.mean() 50 | 51 | reward = reward - baseline 52 | reward = reward.reshape(sample_res_size) 53 | 54 | reward = torch.as_tensor(reward, device=sample_logprobs.device, dtype=torch.float) 55 | loss = - sample_logprobs * reward 56 | loss = loss.mean() 57 | return loss 58 | 59 | def get_score(self): 60 | return self._cur_score 61 | 62 | def _calculate_eval_scores(self, gen_res, gt_idx, gt_res): 63 | ''' 64 | gen_res: generated captions, list of str 65 | gt_idx: list of int, of the same length as gen_res 66 | gt_res: ground truth captions, list of list of str. 67 | gen_res[i] corresponds to gt_res[gt_idx[i]] 68 | Each image can have multiple ground truth captions 69 | ''' 70 | gen_res_size = len(gen_res) 71 | 72 | res = OrderedDict() 73 | for i in range(gen_res_size): 74 | res[i] = [self._wrap_sentence(gen_res[i])] 75 | 76 | gts = OrderedDict() 77 | gt_res_ = [ 78 | [self._wrap_sentence(gt_res[i][j]) for j in range(len(gt_res[i]))] 79 | for i in range(len(gt_res)) 80 | ] 81 | for i in range(gen_res_size): 82 | gts[i] = gt_res_[gt_idx[i]] 83 | 84 | res_ = [{'image_id': i, 'caption': res[i]} for i in range(len(res))] 85 | _, batch_cider_scores = self.CiderD_scorer.compute_score(gts, res_) 86 | scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores 87 | return scores 88 | 89 | @classmethod 90 | def _wrap_sentence(self, s): 91 | # ensure the sentence ends with token 92 | # in order to keep consisitent with cider_cached_tokens 93 | r = s.strip() 94 | if r.endswith('.'): 95 | r = r[:-1] 96 | r += ' ' 97 | return r 98 | 99 | 100 | class SmoothedValue(object): 101 | """Track a series of values and provide access to smoothed values over a 102 | window or the global series average. 103 | """ 104 | 105 | def __init__(self, window_size=20, fmt=None): 106 | if fmt is None: 107 | fmt = "{median:.4f} ({global_avg:.4f})" 108 | self.deque = deque(maxlen=window_size) 109 | self.total = 0.0 110 | self.count = 0 111 | self.fmt = fmt 112 | 113 | def update(self, value, n=1): 114 | self.deque.append(value) 115 | self.count += n 116 | self.total += value * n 117 | 118 | def synchronize_between_processes(self): 119 | """ 120 | Warning: does not synchronize the deque! 121 | """ 122 | if not is_dist_avail_and_initialized(): 123 | return 124 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 125 | dist.barrier() 126 | dist.all_reduce(t) 127 | t = t.tolist() 128 | self.count = int(t[0]) 129 | self.total = t[1] 130 | 131 | @property 132 | def median(self): 133 | d = torch.tensor(list(self.deque)) 134 | return d.median().item() 135 | 136 | @property 137 | def avg(self): 138 | d = torch.tensor(list(self.deque), dtype=torch.float32) 139 | return d.mean().item() 140 | 141 | @property 142 | def global_avg(self): 143 | return self.total / self.count 144 | 145 | @property 146 | def max(self): 147 | return max(self.deque) 148 | 149 | @property 150 | def value(self): 151 | return self.deque[-1] 152 | 153 | def __str__(self): 154 | return self.fmt.format( 155 | median=self.median, 156 | avg=self.avg, 157 | global_avg=self.global_avg, 158 | max=self.max, 159 | value=self.value) 160 | 161 | 162 | class MetricLogger(object): 163 | def __init__(self, delimiter="\t"): 164 | self.meters = defaultdict(SmoothedValue) 165 | self.delimiter = delimiter 166 | 167 | def update(self, **kwargs): 168 | for k, v in kwargs.items(): 169 | if isinstance(v, torch.Tensor): 170 | v = v.item() 171 | assert isinstance(v, (float, int)) 172 | self.meters[k].update(v) 173 | 174 | def __getattr__(self, attr): 175 | if attr in self.meters: 176 | return self.meters[attr] 177 | if attr in self.__dict__: 178 | return self.__dict__[attr] 179 | raise AttributeError("'{}' object has no attribute '{}'".format( 180 | type(self).__name__, attr)) 181 | 182 | def __str__(self): 183 | loss_str = [] 184 | for name, meter in self.meters.items(): 185 | loss_str.append( 186 | "{}: {}".format(name, str(meter)) 187 | ) 188 | return self.delimiter.join(loss_str) 189 | 190 | def global_avg(self): 191 | loss_str = [] 192 | for name, meter in self.meters.items(): 193 | loss_str.append( 194 | "{}: {:.4f}".format(name, meter.global_avg) 195 | ) 196 | return self.delimiter.join(loss_str) 197 | 198 | def synchronize_between_processes(self): 199 | for meter in self.meters.values(): 200 | meter.synchronize_between_processes() 201 | 202 | def add_meter(self, name, meter): 203 | self.meters[name] = meter 204 | 205 | def log_every(self, iterable, print_freq, header=None, dataset_len=None, epoch_info=None): 206 | if not header: 207 | header = '' 208 | if not dataset_len: 209 | dataset_len = len(iterable) 210 | start_time = time.time() 211 | end = time.time() 212 | iter_time = SmoothedValue(fmt='{avg:.4f}') 213 | data_time = SmoothedValue(fmt='{avg:.4f}') 214 | space_fmt = ':' + str(len(str(dataset_len))) + 'd' 215 | 216 | _msg = [ 217 | '[{0' + space_fmt + '}/{1}]', 218 | 'eta: {eta}', 219 | '{meters}', 220 | 'time: {time}', 221 | 'data: {data}' 222 | ] 223 | if torch.cuda.is_available(): 224 | _msg.append('max mem: {memory:.0f}') 225 | _msg = self.delimiter.join(_msg) 226 | MB = 1024.0 * 1024.0 227 | iterable = iter(iterable) 228 | train_steps = dataset_len 229 | if epoch_info: 230 | start_epoch, end_epoch = epoch_info 231 | train_steps = (end_epoch - start_epoch) * dataset_len 232 | for i in range(train_steps): 233 | obj = next(iterable) 234 | data_time.update(time.time() - end) 235 | yield obj 236 | iter_time.update(time.time() - end) 237 | if epoch_info: 238 | header = int(i / dataset_len) + start_epoch 239 | header = 'Train step: [{}]'.format(header) 240 | log_msg = header + " " + _msg 241 | if (i % dataset_len) % print_freq == 0 or i == dataset_len - 1: 242 | eta_seconds = iter_time.global_avg * (dataset_len - i % dataset_len) 243 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 244 | if torch.cuda.is_available(): 245 | print(log_msg.format( 246 | i % dataset_len, dataset_len, eta=eta_string, 247 | meters=str(self), 248 | time=str(iter_time), data=str(data_time), 249 | memory=torch.cuda.max_memory_allocated() / MB)) 250 | else: 251 | print(log_msg.format( 252 | i % dataset_len, dataset_len, eta=eta_string, 253 | meters=str(self), 254 | time=str(iter_time), data=str(data_time))) 255 | 256 | end = time.time() 257 | total_time = time.time() - start_time 258 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 259 | print('{} Total time: {} ({:.4f} s / it)'.format( 260 | header, total_time_str, total_time / dataset_len)) 261 | 262 | 263 | class AttrDict(dict): 264 | def __init__(self, *args, **kwargs): 265 | super(AttrDict, self).__init__(*args, **kwargs) 266 | self.__dict__ = self 267 | 268 | 269 | def compute_acc(logits, label, reduction='mean'): 270 | ret = (torch.argmax(logits, dim=1) == label).float() 271 | if reduction == 'none': 272 | return ret.detach() 273 | elif reduction == 'mean': 274 | return ret.mean().item() 275 | 276 | 277 | def compute_n_params(model, return_str=True): 278 | tot = 0 279 | for p in model.parameters(): 280 | w = 1 281 | for x in p.shape: 282 | w *= x 283 | tot += w 284 | if return_str: 285 | if tot >= 1e6: 286 | return '{:.1f}M'.format(tot / 1e6) 287 | else: 288 | return '{:.1f}K'.format(tot / 1e3) 289 | else: 290 | return tot 291 | 292 | 293 | def setup_for_distributed(is_master): 294 | """ 295 | This function disables printing when not in master process 296 | """ 297 | import builtins as __builtin__ 298 | builtin_print = __builtin__.print 299 | 300 | def print(*args, **kwargs): 301 | force = kwargs.pop('force', False) 302 | if is_master or force: 303 | builtin_print(*args, **kwargs) 304 | 305 | __builtin__.print = print 306 | 307 | 308 | def is_dist_avail_and_initialized(): 309 | if not dist.is_available(): 310 | return False 311 | if not dist.is_initialized(): 312 | return False 313 | return True 314 | 315 | 316 | def get_world_size(): 317 | if not is_dist_avail_and_initialized(): 318 | return 1 319 | return dist.get_world_size() 320 | 321 | 322 | def get_rank(): 323 | if not is_dist_avail_and_initialized(): 324 | return 0 325 | return dist.get_rank() 326 | 327 | 328 | def is_main_process(): 329 | return get_rank() == 0 330 | 331 | 332 | def save_on_master(*args, **kwargs): 333 | if is_main_process(): 334 | torch.save(*args, **kwargs) 335 | 336 | 337 | def init_distributed_mode(args): 338 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 339 | args.rank = int(os.environ["RANK"]) 340 | args.world_size = int(os.environ['WORLD_SIZE']) 341 | args.gpu = int(os.environ['LOCAL_RANK']) 342 | elif 'SLURM_PROCID' in os.environ: 343 | args.rank = int(os.environ['SLURM_PROCID']) 344 | args.gpu = args.rank % torch.cuda.device_count() 345 | else: 346 | print('Not using distributed mode') 347 | args.distributed = False 348 | return 349 | 350 | args.distributed = True 351 | 352 | torch.cuda.set_device(args.gpu) 353 | args.dist_backend = 'nccl' 354 | print('| distributed init (rank {}): {}'.format( 355 | args.rank, args.dist_url), flush=True) 356 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 357 | world_size=args.world_size, rank=args.rank) 358 | torch.distributed.barrier() 359 | setup_for_distributed(args.rank == 0) 360 | 361 | 362 | def read_json(rpath): 363 | with open(rpath, 'r') as f: 364 | return json.load(f) -------------------------------------------------------------------------------- /utils/checkpointer.py: -------------------------------------------------------------------------------- 1 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 2 | # Github: https://github.com/zengyan-97/X-VLM 3 | # Copyright (c) 2022, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | from typing import Union, Dict, List, Tuple, Any, Callable 7 | import logging 8 | import os 9 | import re 10 | import time 11 | 12 | import torch 13 | 14 | from utils.hdfs_io import hexists, hmkdir, hcopy 15 | from utils.torch_io import save as hdfs_torch_save 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class Checkpointer: 20 | def __init__(self, 21 | serialization_dir: str = ".output") -> None: 22 | self._serialization_dir = serialization_dir 23 | if not hexists(self._serialization_dir): 24 | hmkdir(self._serialization_dir) 25 | 26 | def save_checkpoint(self, 27 | epoch: Union[int, str], 28 | model_state: Dict[str, Any], 29 | training_states: Dict[str, Any], 30 | step: int = -1) -> None: 31 | """ 32 | Save ckpt to local or HDFS 33 | """ 34 | if step > 0: 35 | model_path = os.path.join( 36 | self._serialization_dir, "model_state_step_{}.th".format(step)) 37 | hdfs_torch_save(model_state, model_path) 38 | 39 | else: 40 | model_path = os.path.join( 41 | self._serialization_dir, "model_state_epoch_{}.th".format(epoch)) 42 | 43 | training_path = os.path.join(self._serialization_dir, 44 | "training_state_latest.th") 45 | hdfs_torch_save(model_state, model_path) 46 | hdfs_torch_save({**training_states, "epoch": epoch}, training_path) 47 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # 4 | # Description: Describes the class to compute the CIDEr 5 | # (Consensus-Based Image Description Evaluation) Metric 6 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 7 | # 8 | # Creation Date: Sun Feb 8 14:16:54 2015 9 | # 10 | # Authors: Ramakrishna Vedantam and 11 | # Tsung-Yi Lin 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | from .cider_scorer import CiderScorer 17 | 18 | 19 | class Cider: 20 | """ 21 | Main Class to compute the CIDEr metric 22 | 23 | """ 24 | def __init__(self, n=4, df="corpus"): 25 | """ 26 | Initialize the CIDEr scoring function 27 | : param n (int): n-gram size 28 | : param df (string): specifies where to get the IDF values from 29 | takes values 'corpus', 'coco-train' 30 | : return: None 31 | """ 32 | # set cider to sum over 1 to 4-grams 33 | self._n = n 34 | self._df = df 35 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) 36 | 37 | def compute_score(self, gts, res): 38 | """ 39 | Main function to compute CIDEr score 40 | : param gts (dict) : {image:tokenized reference sentence} 41 | : param res (dict) : {image:tokenized candidate sentence} 42 | : return: cider (float) : computed CIDEr score for the corpus 43 | """ 44 | 45 | # clear all the previous hypos and refs 46 | self.cider_scorer.clear() 47 | 48 | for res_id in res: 49 | 50 | hypo = res_id['caption'] 51 | ref = gts[res_id['image_id']] 52 | 53 | # Sanity check. 54 | assert(type(hypo) is list) 55 | assert(len(hypo) == 1) 56 | assert(type(ref) is list) 57 | assert(len(ref) > 0) 58 | self.cider_scorer += (hypo[0], ref) 59 | 60 | (score, scores) = self.cider_scorer.compute_score() 61 | 62 | return score, scores 63 | 64 | def method(self): 65 | return "CIDEr" 66 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import copy 9 | import six 10 | from six.moves import cPickle 11 | from collections import defaultdict 12 | import numpy as np 13 | import math 14 | import os 15 | 16 | def precook(s, n=4, out=False): 17 | """ 18 | Takes a string as input and returns an object that can be given to 19 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 20 | can take string arguments as well. 21 | :param s: string : sentence to be converted into ngrams 22 | :param n: int : number of ngrams for which representation is calculated 23 | :return: term frequency vector for occuring ngrams 24 | """ 25 | words = s.split() 26 | counts = defaultdict(int) 27 | for k in range(1,n+1): 28 | for i in range(len(words)-k+1): 29 | ngram = tuple(words[i:i+k]) 30 | counts[ngram] += 1 31 | return counts 32 | 33 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 34 | '''Takes a list of reference sentences for a single segment 35 | and returns an object that encapsulates everything that BLEU 36 | needs to know about them. 37 | :param refs: list of string : reference sentences for some image 38 | :param n: int : number of ngrams for which (ngram) representation is calculated 39 | :return: result (list of dict) 40 | ''' 41 | return [precook(ref, n) for ref in refs] 42 | 43 | def cook_test(test, n=4): 44 | '''Takes a test sentence and returns an object that 45 | encapsulates everything that BLEU needs to know about it. 46 | :param test: list of string : hypothesis sentence for some image 47 | :param n: int : number of ngrams for which (ngram) representation is calculated 48 | :return: result (dict) 49 | ''' 50 | return precook(test, n, True) 51 | 52 | class CiderScorer(object): 53 | """CIDEr scorer. 54 | """ 55 | 56 | def copy(self): 57 | ''' copy the refs.''' 58 | new = CiderScorer(n=self.n) 59 | new.ctest = copy.copy(self.ctest) 60 | new.crefs = copy.copy(self.crefs) 61 | return new 62 | 63 | def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0): 64 | ''' singular instance ''' 65 | self.n = n 66 | self.sigma = sigma 67 | self.crefs = [] 68 | self.ctest = [] 69 | self.df_mode = df_mode 70 | self.ref_len = None 71 | if self.df_mode != "corpus": 72 | pkl_file = cPickle.load(open(os.path.join('data', df_mode + '.p'),'rb'), **(dict(encoding='latin1') if six.PY3 else {})) 73 | self.ref_len = np.log(float(pkl_file['ref_len'])) 74 | self.document_frequency = pkl_file['document_frequency'] 75 | self.cook_append(test, refs) 76 | 77 | def clear(self): 78 | self.crefs = [] 79 | self.ctest = [] 80 | 81 | def cook_append(self, test, refs): 82 | '''called by constructor and __iadd__ to avoid creating new instances.''' 83 | 84 | if refs is not None: 85 | self.crefs.append(cook_refs(refs)) 86 | if test is not None: 87 | self.ctest.append(cook_test(test)) ## N.B.: -1 88 | else: 89 | self.ctest.append(None) # lens of crefs and ctest have to match 90 | 91 | def size(self): 92 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 93 | return len(self.crefs) 94 | 95 | def __iadd__(self, other): 96 | '''add an instance (e.g., from another sentence).''' 97 | 98 | if type(other) is tuple: 99 | ## avoid creating new CiderScorer instances 100 | self.cook_append(other[0], other[1]) 101 | else: 102 | self.ctest.extend(other.ctest) 103 | self.crefs.extend(other.crefs) 104 | 105 | return self 106 | def compute_doc_freq(self): 107 | ''' 108 | Compute term frequency for reference data. 109 | This will be used to compute idf (inverse document frequency later) 110 | The term frequency is stored in the object 111 | :return: None 112 | ''' 113 | for refs in self.crefs: 114 | # refs, k ref captions of one image 115 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 116 | self.document_frequency[ngram] += 1 117 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 118 | 119 | def compute_cider(self): 120 | def counts2vec(cnts): 121 | """ 122 | Function maps counts of ngram to vector of tfidf weights. 123 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 124 | The n-th entry of array denotes length of n-grams. 125 | :param cnts: 126 | :return: vec (array of dict), norm (array of float), length (int) 127 | """ 128 | vec = [defaultdict(float) for _ in range(self.n)] 129 | length = 0 130 | norm = [0.0 for _ in range(self.n)] 131 | for (ngram,term_freq) in cnts.items(): 132 | # give word count 1 if it doesn't appear in reference corpus 133 | df = np.log(max(1.0, self.document_frequency[ngram])) 134 | # ngram index 135 | n = len(ngram)-1 136 | # tf (term_freq) * idf (precomputed idf) for n-grams 137 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 138 | # compute norm for the vector. the norm will be used for 139 | # computing similarity 140 | norm[n] += pow(vec[n][ngram], 2) 141 | 142 | if n == 1: 143 | length += term_freq 144 | norm = [np.sqrt(n) for n in norm] 145 | return vec, norm, length 146 | 147 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 148 | ''' 149 | Compute the cosine similarity of two vectors. 150 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 151 | :param vec_ref: array of dictionary for vector corresponding to reference 152 | :param norm_hyp: array of float for vector corresponding to hypothesis 153 | :param norm_ref: array of float for vector corresponding to reference 154 | :param length_hyp: int containing length of hypothesis 155 | :param length_ref: int containing length of reference 156 | :return: array of score for each n-grams cosine similarity 157 | ''' 158 | delta = float(length_hyp - length_ref) 159 | # measure consine similarity 160 | val = np.array([0.0 for _ in range(self.n)]) 161 | for n in range(self.n): 162 | # ngram 163 | for (ngram,count) in vec_hyp[n].items(): 164 | val[n] += vec_hyp[n][ngram] * vec_ref[n][ngram] 165 | 166 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 167 | val[n] /= (norm_hyp[n]*norm_ref[n]) 168 | 169 | assert(not math.isnan(val[n])) 170 | return val 171 | 172 | # compute log reference length 173 | if self.df_mode == "corpus": 174 | self.ref_len = np.log(float(len(self.crefs))) 175 | 176 | scores = [] 177 | for test, refs in zip(self.ctest, self.crefs): 178 | # compute vector for test captions 179 | vec, norm, length = counts2vec(test) 180 | # compute vector for ref captions 181 | score = np.array([0.0 for _ in range(self.n)]) 182 | for ref in refs: 183 | vec_ref, norm_ref, length_ref = counts2vec(ref) 184 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 185 | # change by vrama91 - mean of ngram scores, instead of sum 186 | score_avg = np.mean(score) 187 | # divide by number of references 188 | score_avg /= len(refs) 189 | # multiply score by 10 190 | score_avg *= 10.0 191 | # append score of an image to the score list 192 | scores.append(score_avg) 193 | return scores 194 | 195 | def compute_score(self, option=None, verbose=0): 196 | # compute idf 197 | if self.df_mode == "corpus": 198 | self.document_frequency = defaultdict(float) 199 | self.compute_doc_freq() 200 | # assert to check document frequency 201 | assert(len(self.ctest) >= max(self.document_frequency.values())) 202 | # import json for now and write the corresponding files 203 | # compute cider score 204 | score = self.compute_cider() 205 | # debug 206 | # print score 207 | return np.mean(np.array(score)), np.array(score) 208 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/ciderD/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/ciderD/ciderD.py: -------------------------------------------------------------------------------- 1 | # Filename: ciderD.py 2 | # 3 | # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | from .ciderD_scorer import CiderScorer 14 | import pdb 15 | 16 | class CiderD: 17 | """ 18 | Main Class to compute the CIDEr metric 19 | 20 | """ 21 | def __init__(self, n=4, sigma=6.0, df="corpus"): 22 | # set cider to sum over 1 to 4-grams 23 | self._n = n 24 | # set the standard deviation parameter for gaussian penalty 25 | self._sigma = sigma 26 | # set which where to compute document frequencies from 27 | self._df = df 28 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) 29 | 30 | def compute_score(self, gts, res): 31 | """ 32 | Main function to compute CIDEr score 33 | :param hypo_for_image (dict) : dictionary with key and value 34 | ref_for_image (dict) : dictionary with key and value 35 | :return: cider (float) : computed CIDEr score for the corpus 36 | """ 37 | 38 | # clear all the previous hypos and refs 39 | tmp_cider_scorer = self.cider_scorer.copy_empty() 40 | tmp_cider_scorer.clear() 41 | for res_id in res: 42 | 43 | hypo = res_id['caption'] 44 | ref = gts[res_id['image_id']] 45 | 46 | # Sanity check. 47 | assert(type(hypo) is list) 48 | assert(len(hypo) == 1) 49 | assert(type(ref) is list) 50 | assert(len(ref) > 0) 51 | tmp_cider_scorer += (hypo[0], ref) 52 | 53 | (score, scores) = tmp_cider_scorer.compute_score() 54 | 55 | return score, scores 56 | 57 | def method(self): 58 | return "CIDEr-D" 59 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/ciderD/ciderD_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import copy 9 | from collections import defaultdict 10 | import numpy as np 11 | import pdb 12 | import math 13 | import six 14 | from six.moves import cPickle 15 | import os 16 | 17 | def precook(s, n=4, out=False): 18 | """ 19 | Takes a string as input and returns an object that can be given to 20 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 21 | can take string arguments as well. 22 | :param s: string : sentence to be converted into ngrams 23 | :param n: int : number of ngrams for which representation is calculated 24 | :return: term frequency vector for occuring ngrams 25 | """ 26 | words = s.split() 27 | counts = defaultdict(int) 28 | for k in range(1,n+1): 29 | for i in range(len(words)-k+1): 30 | ngram = tuple(words[i:i+k]) 31 | counts[ngram] += 1 32 | return counts 33 | 34 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 35 | '''Takes a list of reference sentences for a single segment 36 | and returns an object that encapsulates everything that BLEU 37 | needs to know about them. 38 | :param refs: list of string : reference sentences for some image 39 | :param n: int : number of ngrams for which (ngram) representation is calculated 40 | :return: result (list of dict) 41 | ''' 42 | return [precook(ref, n) for ref in refs] 43 | 44 | def cook_test(test, n=4): 45 | '''Takes a test sentence and returns an object that 46 | encapsulates everything that BLEU needs to know about it. 47 | :param test: list of string : hypothesis sentence for some image 48 | :param n: int : number of ngrams for which (ngram) representation is calculated 49 | :return: result (dict) 50 | ''' 51 | return precook(test, n, True) 52 | 53 | class CiderScorer(object): 54 | """CIDEr scorer. 55 | """ 56 | 57 | def copy(self): 58 | ''' copy the refs.''' 59 | new = CiderScorer(n=self.n) 60 | new.ctest = copy.copy(self.ctest) 61 | new.crefs = copy.copy(self.crefs) 62 | return new 63 | 64 | def copy_empty(self): 65 | new = CiderScorer(df_mode="corpus", n=self.n, sigma=self.sigma) 66 | new.df_mode = self.df_mode 67 | new.ref_len = self.ref_len 68 | new.document_frequency = self.document_frequency 69 | return new 70 | 71 | def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0): 72 | ''' singular instance ''' 73 | self.n = n 74 | self.sigma = sigma 75 | self.crefs = [] 76 | self.ctest = [] 77 | self.df_mode = df_mode 78 | self.ref_len = None 79 | if self.df_mode != "corpus": 80 | pkl_file = cPickle.load(open(df_mode,'rb'), **(dict(encoding='latin1') if six.PY3 else {})) 81 | self.ref_len = np.log(float(pkl_file['ref_len'])) 82 | self.document_frequency = pkl_file['document_frequency'] 83 | else: 84 | self.document_frequency = None 85 | self.cook_append(test, refs) 86 | 87 | def clear(self): 88 | self.crefs = [] 89 | self.ctest = [] 90 | 91 | def cook_append(self, test, refs): 92 | '''called by constructor and __iadd__ to avoid creating new instances.''' 93 | 94 | if refs is not None: 95 | self.crefs.append(cook_refs(refs)) 96 | if test is not None: 97 | self.ctest.append(cook_test(test)) ## N.B.: -1 98 | else: 99 | self.ctest.append(None) # lens of crefs and ctest have to match 100 | 101 | def size(self): 102 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 103 | return len(self.crefs) 104 | 105 | def __iadd__(self, other): 106 | '''add an instance (e.g., from another sentence).''' 107 | 108 | if type(other) is tuple: 109 | ## avoid creating new CiderScorer instances 110 | self.cook_append(other[0], other[1]) 111 | else: 112 | self.ctest.extend(other.ctest) 113 | self.crefs.extend(other.crefs) 114 | 115 | return self 116 | def compute_doc_freq(self): 117 | ''' 118 | Compute term frequency for reference data. 119 | This will be used to compute idf (inverse document frequency later) 120 | The term frequency is stored in the object 121 | :return: None 122 | ''' 123 | for refs in self.crefs: 124 | # refs, k ref captions of one image 125 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 126 | self.document_frequency[ngram] += 1 127 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 128 | 129 | def compute_cider(self): 130 | def counts2vec(cnts): 131 | """ 132 | Function maps counts of ngram to vector of tfidf weights. 133 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 134 | The n-th entry of array denotes length of n-grams. 135 | :param cnts: 136 | :return: vec (array of dict), norm (array of float), length (int) 137 | """ 138 | vec = [defaultdict(float) for _ in range(self.n)] 139 | length = 0 140 | norm = [0.0 for _ in range(self.n)] 141 | for (ngram,term_freq) in cnts.items(): 142 | # give word count 1 if it doesn't appear in reference corpus 143 | df = np.log(max(1.0, self.document_frequency[ngram])) 144 | # ngram index 145 | n = len(ngram)-1 146 | # tf (term_freq) * idf (precomputed idf) for n-grams 147 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 148 | # compute norm for the vector. the norm will be used for computing similarity 149 | norm[n] += pow(vec[n][ngram], 2) 150 | 151 | if n == 1: 152 | length += term_freq 153 | norm = [np.sqrt(n) for n in norm] 154 | return vec, norm, length 155 | 156 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 157 | ''' 158 | Compute the cosine similarity of two vectors. 159 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 160 | :param vec_ref: array of dictionary for vector corresponding to reference 161 | :param norm_hyp: array of float for vector corresponding to hypothesis 162 | :param norm_ref: array of float for vector corresponding to reference 163 | :param length_hyp: int containing length of hypothesis 164 | :param length_ref: int containing length of reference 165 | :return: array of score for each n-grams cosine similarity 166 | ''' 167 | delta = float(length_hyp - length_ref) 168 | # measure consine similarity 169 | val = np.array([0.0 for _ in range(self.n)]) 170 | for n in range(self.n): 171 | # ngram 172 | for (ngram,count) in vec_hyp[n].items(): 173 | # vrama91 : added clipping 174 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 175 | 176 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 177 | val[n] /= (norm_hyp[n]*norm_ref[n]) 178 | 179 | assert(not math.isnan(val[n])) 180 | # vrama91: added a length based gaussian penalty 181 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 182 | return val 183 | 184 | # compute log reference length 185 | if self.df_mode == "corpus": 186 | self.ref_len = np.log(float(len(self.crefs))) 187 | #elif self.df_mode == "coco-val-df": 188 | # if coco option selected, use length of coco-val set 189 | # self.ref_len = np.log(float(40504)) 190 | 191 | scores = [] 192 | for test, refs in zip(self.ctest, self.crefs): 193 | # compute vector for test captions 194 | vec, norm, length = counts2vec(test) 195 | # compute vector for ref captions 196 | score = np.array([0.0 for _ in range(self.n)]) 197 | for ref in refs: 198 | vec_ref, norm_ref, length_ref = counts2vec(ref) 199 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 200 | # change by vrama91 - mean of ngram scores, instead of sum 201 | score_avg = np.mean(score) 202 | # divide by number of references 203 | score_avg /= len(refs) 204 | # multiply score by 10 205 | score_avg *= 10.0 206 | # append score of an image to the score list 207 | scores.append(score_avg) 208 | return scores 209 | 210 | def compute_score(self, option=None, verbose=0): 211 | # compute idf 212 | if self.df_mode == "corpus": 213 | self.document_frequency = defaultdict(float) 214 | self.compute_doc_freq() 215 | # assert to check document frequency 216 | assert(len(self.ctest) >= max(self.document_frequency.values())) 217 | # import json for now and write the corresponding files 218 | # compute cider score 219 | score = self.compute_cider() 220 | # debug 221 | # print score 222 | return np.mean(np.array(score)), np.array(score) 223 | -------------------------------------------------------------------------------- /utils/hdfs_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 4 | # Github: https://github.com/zengyan-97/X-VLM 5 | # Copyright (c) 2022, ByteDance Inc. 6 | # All rights reserved. 7 | 8 | import sys 9 | from typing import IO, Any, List 10 | 11 | import shutil 12 | import subprocess 13 | from contextlib import contextmanager 14 | import os 15 | import glob 16 | import threading 17 | 18 | HADOOP_BIN = 'HADOOP_ROOT_LOGGER=ERROR,console /SET/PATH/TO/hadoop/bin/hdfs' 19 | 20 | __all__ = ['hlist_files', 'hopen', 'hexists', 'hmkdir'] 21 | 22 | 23 | @contextmanager # type: ignore 24 | def hopen(hdfs_path: str, mode: str = "r") -> IO[Any]: 25 | """ 26 | open a file on hdfs with contextmanager. 27 | 28 | Args: 29 | mode (str): supports ["r", "w", "wa"] 30 | """ 31 | pipe = None 32 | if mode.startswith("r"): 33 | pipe = subprocess.Popen( 34 | "{} dfs -text {}".format(HADOOP_BIN, hdfs_path), shell=True, stdout=subprocess.PIPE) 35 | yield pipe.stdout 36 | pipe.stdout.close() # type: ignore 37 | pipe.wait() 38 | return 39 | if mode == "wa" or mode == "a": 40 | pipe = subprocess.Popen( 41 | "{} dfs -appendToFile - {}".format(HADOOP_BIN, hdfs_path), shell=True, stdin=subprocess.PIPE) 42 | yield pipe.stdin 43 | pipe.stdin.close() # type: ignore 44 | pipe.wait() 45 | return 46 | if mode.startswith("w"): 47 | pipe = subprocess.Popen( 48 | "{} dfs -put -f - {}".format(HADOOP_BIN, hdfs_path), shell=True, stdin=subprocess.PIPE) 49 | yield pipe.stdin 50 | pipe.stdin.close() # type: ignore 51 | pipe.wait() 52 | return 53 | raise RuntimeError("unsupported io mode: {}".format(mode)) 54 | 55 | 56 | def hlist_files(folders: List[str]) -> List[str]: 57 | files = [] 58 | for folder in folders: 59 | if folder.startswith('hdfs'): 60 | pipe = subprocess.Popen("{} dfs -ls {}".format(HADOOP_BIN, folder), shell=True, 61 | stdout=subprocess.PIPE) 62 | # output, _ = pipe.communicate() 63 | for line in pipe.stdout: # type: ignore 64 | line = line.strip() 65 | # drwxr-xr-x - user group 4 file 66 | if len(line.split()) < 5: 67 | continue 68 | files.append(line.split()[-1].decode("utf8")) 69 | pipe.stdout.close() # type: ignore 70 | pipe.wait() 71 | else: 72 | if os.path.isdir(folder): 73 | files.extend([os.path.join(folder, d) for d in os.listdir(folder)]) 74 | elif os.path.isfile(folder): 75 | files.append(folder) 76 | else: 77 | print('Path {} is invalid'.format(folder)) 78 | sys.stdout.flush() 79 | 80 | return files 81 | 82 | 83 | def hexists(file_path: str) -> bool: 84 | """ hdfs capable to check whether a file_path is exists """ 85 | if file_path.startswith('hdfs'): 86 | return os.system("{} dfs -test -e {}".format(HADOOP_BIN, file_path)) == 0 87 | return os.path.exists(file_path) 88 | 89 | 90 | def hmkdir(file_path: str) -> bool: 91 | """ hdfs mkdir """ 92 | if file_path.startswith('hdfs'): 93 | os.system("{} dfs -mkdir -p {}".format(HADOOP_BIN, file_path)) # exist ok 94 | else: 95 | if not os.path.exists(file_path): 96 | os.mkdir(file_path) 97 | return True 98 | 99 | 100 | def hcopy(from_path: str, to_path: str) -> bool: 101 | """ hdfs copy """ 102 | if to_path.startswith("hdfs"): 103 | if from_path.startswith("hdfs"): 104 | os.system("{} dfs -cp -f {} {}".format(HADOOP_BIN, from_path, to_path)) 105 | else: 106 | os.system("{} dfs -copyFromLocal -f {} {}".format(HADOOP_BIN, from_path, to_path)) 107 | else: 108 | if from_path.startswith("hdfs"): 109 | os.system("{} dfs -text {} > {}".format(HADOOP_BIN, from_path, to_path)) 110 | else: 111 | shutil.copy(from_path, to_path) 112 | return True 113 | 114 | -------------------------------------------------------------------------------- /utils/torch_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 4 | # Github: https://github.com/zengyan-97/X-VLM 5 | # Copyright (c) 2022, ByteDance Inc. 6 | # All rights reserved. 7 | 8 | import io 9 | import torch 10 | 11 | from .hdfs_io import hopen 12 | 13 | 14 | def load(filepath: str, **kwargs): 15 | """ load model """ 16 | if not filepath.startswith("hdfs://"): 17 | return torch.load(filepath, **kwargs) 18 | with hopen(filepath, "rb") as reader: 19 | accessor = io.BytesIO(reader.read()) 20 | state_dict = torch.load(accessor, **kwargs) 21 | del accessor 22 | return state_dict 23 | 24 | 25 | def save(obj, filepath: str, **kwargs): 26 | """ save model """ 27 | if filepath.startswith("hdfs://"): 28 | with hopen(filepath, "wb") as writer: 29 | torch.save(obj, writer, **kwargs) 30 | else: 31 | torch.save(obj, filepath, **kwargs) 32 | --------------------------------------------------------------------------------