├── README.md ├── data.py ├── evaluation.py ├── meter ├── __init__.py ├── config.py ├── datamodules │ ├── __init__.py │ ├── coco_caption_karpathy_datamodule.py │ ├── conceptual_caption_datamodule.py │ ├── datamodule_base.py │ ├── f30k_caption_karpathy_datamodule.py │ ├── multitask_datamodule.py │ ├── nlvr2_datamodule.py │ ├── sbu_datamodule.py │ ├── snli_datamodule.py │ ├── vg_caption_datamodule.py │ └── vqav2_datamodule.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── coco_caption_karpathy_dataset.py │ ├── conceptual_caption_dataset.py │ ├── f30k_caption_karpathy_dataset.py │ ├── nlvr2_dataset.py │ ├── sbu_caption_dataset.py │ ├── snli_dataset.py │ ├── vg_caption_dataset.py │ └── vqav2_dataset.py ├── gadgets │ ├── __init__.py │ └── my_metrics.py ├── modules │ ├── __init__.py │ ├── bert.py │ ├── convit.py │ ├── convit_models.py │ ├── convnext.py │ ├── dist_utils.py │ ├── eff_model.py │ ├── ensemble_models.py │ ├── eval_gl.py │ ├── file_utils.py │ ├── fusion_model.py │ ├── meter_module.py │ ├── meter_utils.py │ ├── objectives.py │ ├── swin_helpers.py │ ├── swin_transformer.py │ ├── textual_encoder.py │ ├── utils.py │ └── visual_encoder.py ├── transforms │ ├── __init__.py │ ├── randaug.py │ ├── transform.py │ └── utils.py └── utils │ ├── __init__.py │ ├── glossary.py │ ├── write_coco_karpathy.py │ ├── write_conceptual_caption.py │ ├── write_f30k_karpathy.py │ ├── write_nlvr2.py │ ├── write_sbu.py │ ├── write_snli.py │ ├── write_vg.py │ └── write_vqa.py ├── requirements.txt └── run.py /README.md: -------------------------------------------------------------------------------- 1 | # HACAN 2 | The codes for our paper "**HACAN: Hybrid Attention-Driven Cross-Layer Alignment Network for Image-Text Retrieval**". 3 | 4 | 5 | ## Introduction 6 | In the field of image-text matching and cross-modal retrieval, while there have been advancements in fine-grained retrieval techniques, current methods often focus solely on the direct connections between visual elements in images and textual keywords. This focus overlooks the complex semantic interactions between modalities, at both local and global levels, leading to semantic ambiguity. We introduce a **H**ybrid **A**ttention-Driven **C**ross-layer **A**lignment **N**etwork (**HACAN**), leveraging BERT and ConvNeXt to merge global and local strategies effectively, addressing semantic ambiguity and alignment issues. By proposing a global contrastive divergence loss, HACAN boosts the complementarity between vision and language, thereby enhancing the model's capability to distinguish between positive and negative samples. By incorporating hierarchical inference strategies, HACAN significantly improves retrieval efficiency. On the Flickr30K and MS-COCO datasets, HACAN surpasses state-of-the-art image-to-text retrieval methods by a margin of 5% to 8% in the Rsum metric. 7 | 8 | 9 | ## Preparation 10 | ### Dependencies 11 | We recommended to use Anaconda for the following packages. 12 | - python >= 3.8 13 | - [torch](http://pytorch.org/) (>=1.8.1) 14 | - [lightning](https://lightning.ai/) (1.8.0) 15 | - [transformers](https://huggingface.co/docs/transformers) (4.24.0) 16 | - torchvision 17 | - opencv-python 18 | 19 | 20 | ### Data 21 | The experimental dataset can be downloaded from [Flickr30K](http://shannon.cs.illinois.edu/DenotationGraph/) and [MSCOCO](http://mscoco.org/). We will subsequently release the experimental pre-trained model for public access. We refer to the path of extracted files as `$DATASET_PATH` and the storage location of the pre-trained model as `$MODEL_PATH`. 22 | 23 | 24 | ## Evaluation 25 | Run `run.py` to evaluate the trained models on Flickr30K or MSCOCO. 26 | ```bash 27 | Test on Flickr30K: 28 | python run.py with data_root=`$DATASET_PATH` test_only=True checkpoint=`$MODEL_PATH` 29 | 30 | Test on MSCOCO: 31 | python run.py with coco_config data_root=`$DATASET_PATH` test_only=True checkpoint=`$MODEL_PATH` 32 | ``` 33 | 34 | 35 | ## Training 36 | Run `run.py` to train the model on Flickr30K or MSCOCO. 37 | ```bash 38 | Train on Flickr30K: 39 | python run.py with data_root=`$DATASET_PATH` loss="GCD" 40 | 41 | Train on MSCOCO: 42 | python run.py with coco_config data_root=`$DATASET_PATH` loss="GCD" 43 | ``` 44 | 45 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy 3 | 4 | 5 | def i2t_SCAN(sims, npts=None, return_ranks=False, fold5=False): 6 | """ 7 | Images->Text (Image Annotation) 8 | Images: (N, n_region, d) matrix of images 9 | Captions: (5N, max_n_word, d) matrix of captions 10 | CapLens: (5N) array of caption lengths 11 | sims: (N, 5N) matrix of similarity im-cap 12 | """ 13 | npts = sims.shape[0] 14 | ranks = np.zeros(npts) 15 | top1 = np.zeros(npts) 16 | for index in range(npts): 17 | inds = np.argsort(sims[index])[::-1] 18 | # Score 19 | rank = 1e20 20 | for i in range(5 * index, 5 * index + 5, 1): 21 | tmp = np.where(inds == i)[0][0] 22 | if tmp < rank: 23 | rank = tmp 24 | ranks[index] = rank 25 | top1[index] = inds[0] 26 | 27 | # Compute metrics 28 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 29 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 30 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 31 | 32 | r20 = 100.0 * len(np.where(ranks < 20)[0]) / len(ranks) 33 | r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks) 34 | r70 = 100.0 * len(np.where(ranks < 70)[0]) / len(ranks) 35 | r100 = 100.0 * len(np.where(ranks < 100)[0]) / len(ranks) 36 | 37 | medr = np.floor(np.median(ranks)) + 1 38 | meanr = ranks.mean() + 1 39 | if return_ranks: 40 | if fold5: 41 | return (r1, r5, r10, medr, meanr), (ranks, top1) 42 | return (r1, r5, r10, r20, r50, r70, r100, medr, meanr), (ranks, top1) 43 | else: 44 | if fold5: 45 | return (r1, r5, r10, medr, meanr) 46 | return (r1, r5, r10, r20, r50, r70, r100, medr, meanr) 47 | 48 | def t2i_SCAN(sims, npts=None, return_ranks=False, fold5=False): 49 | """ 50 | Text->Images (Image Search) 51 | Images: (N, n_region, d) matrix of images 52 | Captions: (5N, max_n_word, d) matrix of captions 53 | CapLens: (5N) array of caption lengths 54 | sims: (N, 5N) matrix of similarity im-cap 55 | """ 56 | npts = sims.shape[0] 57 | ranks = np.zeros(5 * npts) 58 | top1 = np.zeros(5 * npts) 59 | 60 | # --> (5N(caption), N(image)) 61 | sims = sims.T 62 | 63 | for index in range(npts): 64 | for i in range(5): 65 | inds = np.argsort(sims[5 * index + i])[::-1] 66 | ranks[5 * index + i] = np.where(inds == index)[0][0] 67 | top1[5 * index + i] = inds[0] 68 | 69 | # Compute metrics 70 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 71 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 72 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 73 | 74 | r20 = 100.0 * len(np.where(ranks < 20)[0]) / len(ranks) 75 | r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks) 76 | r70 = 100.0 * len(np.where(ranks < 70)[0]) / len(ranks) 77 | r100 = 100.0 * len(np.where(ranks < 100)[0]) / len(ranks) 78 | 79 | medr = np.floor(np.median(ranks)) + 1 80 | meanr = ranks.mean() + 1 81 | if return_ranks: 82 | if fold5: 83 | return (r1, r5, r10, medr, meanr), (ranks, top1) 84 | return (r1, r5, r10, r20, r50, r70, r100, medr, meanr), (ranks, top1) 85 | else: 86 | if fold5: 87 | return (r1, r5, r10, medr, meanr) 88 | return (r1, r5, r10, r20, r50, r70, r100, medr, meanr) 89 | 90 | '''def i2t(images, captions, npts=None, measure='cosine', return_ranks=False): 91 | """ 92 | Images->Text (Image Annotation) 93 | Images: (5N, K) matrix of images 94 | Captions: (5N, K) matrix of captions 95 | """ 96 | if npts is None: 97 | npts = int(images.shape[0] / 5) 98 | index_list = [] 99 | 100 | ranks = numpy.zeros(npts) 101 | top1 = numpy.zeros(npts) 102 | for index in range(npts): 103 | 104 | # Get query image 105 | im = images[5 * index].reshape(1, images.shape[1]) 106 | 107 | # Compute scores 108 | d = numpy.dot(im, captions.T).flatten() 109 | inds = numpy.argsort(d)[::-1] 110 | index_list.append(inds[0]) 111 | 112 | # Score 113 | rank = 1e20 114 | for i in range(5 * index, 5 * index + 5, 1): 115 | tmp = numpy.where(inds == i)[0][0] 116 | if tmp < rank: 117 | rank = tmp 118 | ranks[index] = rank 119 | top1[index] = inds[0] 120 | 121 | # Compute metrics 122 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 123 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 124 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 125 | 126 | r20 = 100.0 * len(numpy.where(ranks < 20)[0]) / len(ranks) 127 | r50 = 100.0 * len(numpy.where(ranks < 50)[0]) / len(ranks) 128 | r70 = 100.0 * len(numpy.where(ranks < 70)[0]) / len(ranks) 129 | r100 = 100.0 * len(numpy.where(ranks < 100)[0]) / len(ranks) 130 | 131 | medr = numpy.floor(numpy.median(ranks)) + 1 132 | meanr = ranks.mean() + 1 133 | if return_ranks: 134 | return (r1, r5, r10, r20, r50, r70, r100, medr, meanr), (ranks, top1) 135 | else: 136 | return (r1, r5, r10, r20, r50, r70, r100, medr, meanr) 137 | 138 | 139 | def t2i(images, captions, npts=None, measure='cosine', return_ranks=False): 140 | """ 141 | Text->Images (Image Search) 142 | Images: (5N, K) matrix of images 143 | Captions: (5N, K) matrix of captions 144 | """ 145 | if npts is None: 146 | npts = int(images.shape[0] / 5) 147 | ims = numpy.array([images[i] for i in range(0, len(images), 5)]) 148 | 149 | ranks = numpy.zeros(5 * npts) 150 | top1 = numpy.zeros(5 * npts) 151 | for index in range(npts): 152 | 153 | # Get query captions 154 | queries = captions[5 * index:5 * index + 5] 155 | 156 | # Compute scores 157 | 158 | d = numpy.dot(queries, ims.T) 159 | inds = numpy.zeros(d.shape) 160 | for i in range(len(inds)): 161 | inds[i] = numpy.argsort(d[i])[::-1] 162 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0] 163 | top1[5 * index + i] = inds[i][0] 164 | 165 | # Compute metrics 166 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 167 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 168 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 169 | 170 | r20 = 100.0 * len(numpy.where(ranks < 20)[0]) / len(ranks) 171 | r50 = 100.0 * len(numpy.where(ranks < 50)[0]) / len(ranks) 172 | r70 = 100.0 * len(numpy.where(ranks < 70)[0]) / len(ranks) 173 | r100 = 100.0 * len(numpy.where(ranks < 100)[0]) / len(ranks) 174 | 175 | medr = numpy.floor(numpy.median(ranks)) + 1 176 | meanr = ranks.mean() + 1 177 | if return_ranks: 178 | return (r1, r5, r10, r20, r50, r70, r100, medr, meanr), (ranks, top1) 179 | else: 180 | return (r1, r5, r10, r20, r50, r70, r100, medr, meanr)''' 181 | -------------------------------------------------------------------------------- /meter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLyu0110/HACAN/35bdbc5cb2a9a62870fa9dce180c03c4c9d54206/meter/__init__.py -------------------------------------------------------------------------------- /meter/config.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | # 3 | ex = Experiment("METER") 4 | 5 | 6 | def _loss_names(d): 7 | ret = { 8 | "itm": 0, 9 | "mlm": 0, 10 | "mpp": 0, 11 | "vqa": 0, 12 | "vcr": 0, 13 | "vcr_qar": 0, 14 | "nlvr2": 0, 15 | "irtr": 0, 16 | "contras": 0, 17 | "snli": 0, 18 | } 19 | ret.update(d) 20 | return ret 21 | 22 | 23 | @ex.config 24 | def config(): 25 | exp_name = "finetune_irtr_f30k" 26 | seed = 0 27 | datasets = "f30k" 28 | loss_names = _loss_names({"irtr": 1}) 29 | batch_size = 64 # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller. 30 | margin = 0.2 31 | 32 | # Image setting 33 | #train_transform_keys = ["clip"] 34 | #val_transform_keys = ["clip"] 35 | image_size = 224 36 | patch_size = 32 37 | #draw_false_image = 1 38 | image_only = False 39 | 40 | # Text Setting 41 | #vqav2_label_size = 3129 42 | max_text_len = 32 43 | tokenizer = "/home/ls/bert-base-uncased" 44 | vocab_size = 30522 45 | whole_word_masking = False # note that whole_word_masking does not work for RoBERTa 46 | mlm_prob = 0.15 47 | #draw_false_text = 0 48 | 49 | # Transformer Setting 50 | num_top_layer = 6 51 | input_image_embed_size = 1024 52 | input_text_embed_size = 768 53 | vit = "swin_base_patch4_window7_224_in22k" 54 | hidden_size = 768 55 | num_heads = 12 56 | num_layers = 6 57 | mlp_ratio = 4 58 | drop_rate = 0.1 59 | 60 | # Optimizer Setting 61 | optim_type = "adamw" 62 | learning_rate = 1e-4 63 | lr_update = 10 64 | weight_decay = 0.01 65 | decay_power = 1 66 | max_epoch = 30 67 | max_steps = None 68 | warmup_steps = 10000 69 | end_lr = 0 70 | lr_mult_head = 5 # multiply lr for downstream heads 71 | lr_mult_cross_modal = 5 # multiply lr for the cross-modal module 72 | 73 | # Downstream Setting 74 | get_recall_metric = False 75 | 76 | # PL Trainer Setting 77 | resume_from = None 78 | fast_dev_run = False 79 | val_check_interval = 1.0 80 | test_only = False 81 | checkpoint = '/data3/lihaoxuan/New_Time/TKDE/github/runs/i2t_freeze/epoch=68-step=172499-v1.ckpt' 82 | 83 | # below params varies with the environment 84 | data_root = '/data1/lihaoxuan/orignal-datasets/' 85 | log_dir = "result" 86 | per_gpu_batchsize = 64 # you should define this manually with per_gpu_batch_size=# 87 | num_gpus = 1 88 | num_nodes = 1 89 | load_path = "" 90 | num_workers = 8 91 | precision = 16 92 | 93 | #SCAN 94 | direction = 't2i' 95 | lambda_softmax = 9 96 | embed_dim = 768 97 | 98 | # add 99 | margin = 0.2 100 | loss = "GCD" # F_HN / M_HN / GCD 101 | activation = 'leaky_relu' # tanh / relu / leaky_relu 102 | experiment_name = '' 103 | fold5 = False 104 | save_path='' 105 | # save_path='runs/i2t_freeze_f30k_last_but_one_t2i' 106 | focal_type = 'prob' 107 | 108 | 109 | 110 | @ex.named_config 111 | def coco_config(): 112 | exp_name = "finetune_irtr_coco" 113 | seed = 0 114 | datasets = "coco" 115 | loss_names = _loss_names({"irtr": 1}) 116 | batch_size = 64 # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller. 117 | margin = 0.2 118 | 119 | # Image setting 120 | #train_transform_keys = ["clip"] 121 | #val_transform_keys = ["clip"] 122 | image_size = 224 123 | patch_size = 32 124 | #draw_false_image = 1 125 | image_only = False 126 | 127 | # Text Setting 128 | #vqav2_label_size = 3129 129 | max_text_len = 32 130 | tokenizer = "/home/ls/bert-base-uncased" 131 | vocab_size = 30522 132 | whole_word_masking = False # note that whole_word_masking does not work for RoBERTa 133 | mlm_prob = 0.15 134 | #draw_false_text = 0 135 | 136 | # Transformer Setting 137 | num_top_layer = 6 138 | input_image_embed_size = 1024 139 | input_text_embed_size = 768 140 | vit = "swin_base_patch4_window7_224_in22k" 141 | hidden_size = 768 142 | num_heads = 12 143 | num_layers = 6 144 | mlp_ratio = 4 145 | drop_rate = 0.1 146 | 147 | # Optimizer Setting 148 | optim_type = "adamw" 149 | learning_rate = 1e-4 150 | lr_update = 10 151 | weight_decay = 0.01 152 | decay_power = 1 153 | max_epoch = 30 154 | max_steps = None 155 | warmup_steps = 10000 156 | end_lr = 0 157 | lr_mult_head = 5 # multiply lr for downstream heads 158 | lr_mult_cross_modal = 5 # multiply lr for the cross-modal module 159 | 160 | # Downstream Setting 161 | get_recall_metric = False 162 | 163 | # PL Trainer Setting 164 | resume_from = None 165 | fast_dev_run = False 166 | val_check_interval = 1.0 167 | test_only = False 168 | checkpoint = '/data3/lihaoxuan/New_Time/TKDE/github/runs/i2t_freeze/last.ckpt' 169 | 170 | # below params varies with the environment 171 | data_root = '/data1/lihaoxuan/orignal-datasets/' 172 | log_dir = "result" 173 | per_gpu_batchsize = 64 # you should define this manually with per_gpu_batch_size=# 174 | num_gpus = 1 175 | num_nodes = 1 176 | load_path = "" 177 | num_workers = 8 178 | precision = 16 179 | 180 | #SCAN 181 | direction = 'i2t' 182 | lambda_softmax = 9 183 | 184 | # add 185 | margin = 0.2 186 | loss = "GCD" 187 | activation = 'leaky_relu' 188 | experiment_name = '' 189 | fold5 = False 190 | save_path='runs/i2t_freeze_coco_last_layers' 191 | focal_type = 'prob' 192 | 193 | -------------------------------------------------------------------------------- /meter/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .vg_caption_datamodule import VisualGenomeCaptionDataModule 2 | from .f30k_caption_karpathy_datamodule import F30KCaptionKarpathyDataModule 3 | from .coco_caption_karpathy_datamodule import CocoCaptionKarpathyDataModule 4 | from .conceptual_caption_datamodule import ConceptualCaptionDataModule 5 | from .sbu_datamodule import SBUCaptionDataModule 6 | from .vqav2_datamodule import VQAv2DataModule 7 | from .nlvr2_datamodule import NLVR2DataModule 8 | from .snli_datamodule import SNLIDataModule 9 | 10 | _datamodules = { 11 | "vg": VisualGenomeCaptionDataModule, 12 | "f30k": F30KCaptionKarpathyDataModule, 13 | "coco": CocoCaptionKarpathyDataModule, 14 | "gcc": ConceptualCaptionDataModule, 15 | "sbu": SBUCaptionDataModule, 16 | "vqa": VQAv2DataModule, 17 | "nlvr2": NLVR2DataModule, 18 | "snli": SNLIDataModule, 19 | } 20 | -------------------------------------------------------------------------------- /meter/datamodules/coco_caption_karpathy_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import CocoCaptionKarpathyDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class CocoCaptionKarpathyDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return CocoCaptionKarpathyDataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return CocoCaptionKarpathyDataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "coco" 20 | -------------------------------------------------------------------------------- /meter/datamodules/conceptual_caption_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import ConceptualCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class ConceptualCaptionDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return ConceptualCaptionDataset 12 | 13 | @property 14 | def dataset_name(self): 15 | return "gcc" 16 | -------------------------------------------------------------------------------- /meter/datamodules/datamodule_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from transformers import ( 6 | DataCollatorForLanguageModeling, 7 | DataCollatorForWholeWordMask, 8 | BertTokenizer, 9 | RobertaTokenizer, 10 | ) 11 | 12 | 13 | def get_pretrained_tokenizer(from_pretrained): 14 | if torch.distributed.is_initialized(): 15 | if torch.distributed.get_rank() == 0: 16 | if 'roberta' in from_pretrained: 17 | RobertaTokenizer.from_pretrained(from_pretrained) 18 | else: 19 | BertTokenizer.from_pretrained( 20 | from_pretrained, do_lower_case="uncased" in from_pretrained 21 | ) 22 | torch.distributed.barrier() 23 | 24 | if 'roberta' in from_pretrained: 25 | return RobertaTokenizer.from_pretrained(from_pretrained) 26 | return BertTokenizer.from_pretrained( 27 | from_pretrained, do_lower_case="uncased" in from_pretrained 28 | ) 29 | 30 | 31 | class BaseDataModule(LightningDataModule): 32 | def __init__(self, _config): 33 | super().__init__() 34 | 35 | self.data_dir = _config["data_root"] 36 | 37 | self.num_workers = _config["num_workers"] 38 | self.batch_size = _config["per_gpu_batchsize"] 39 | self.eval_batch_size = self.batch_size 40 | 41 | self.image_size = _config["image_size"] 42 | self.max_text_len = _config["max_text_len"] 43 | self.draw_false_image = _config["draw_false_image"] 44 | self.draw_false_text = _config["draw_false_text"] 45 | self.image_only = _config["image_only"] 46 | 47 | self.train_transform_keys = ( 48 | ["default_train"] 49 | if len(_config["train_transform_keys"]) == 0 50 | else _config["train_transform_keys"] 51 | ) 52 | 53 | self.val_transform_keys = ( 54 | ["default_val"] 55 | if len(_config["val_transform_keys"]) == 0 56 | else _config["val_transform_keys"] 57 | ) 58 | 59 | tokenizer = _config["tokenizer"] 60 | self.tokenizer = get_pretrained_tokenizer(tokenizer) 61 | self.vocab_size = self.tokenizer.vocab_size 62 | 63 | collator = ( 64 | DataCollatorForWholeWordMask 65 | if _config["whole_word_masking"] 66 | else DataCollatorForLanguageModeling 67 | ) 68 | 69 | self.mlm_collator = collator( 70 | tokenizer=self.tokenizer, mlm=True, mlm_probability=_config["mlm_prob"] 71 | ) 72 | self.setup_flag = False 73 | 74 | @property 75 | def dataset_cls(self): 76 | raise NotImplementedError("return tuple of dataset class") 77 | 78 | @property 79 | def dataset_name(self): 80 | raise NotImplementedError("return name of dataset") 81 | 82 | def set_train_dataset(self): 83 | self.train_dataset = self.dataset_cls( 84 | self.data_dir, 85 | self.train_transform_keys, 86 | split="train", 87 | image_size=self.image_size, 88 | max_text_len=self.max_text_len, 89 | draw_false_image=self.draw_false_image, 90 | draw_false_text=self.draw_false_text, 91 | image_only=self.image_only, 92 | tokenizer=self.tokenizer, 93 | ) 94 | 95 | def set_val_dataset(self): 96 | self.val_dataset = self.dataset_cls( 97 | self.data_dir, 98 | self.val_transform_keys, 99 | split="val", 100 | image_size=self.image_size, 101 | max_text_len=self.max_text_len, 102 | draw_false_image=self.draw_false_image, 103 | draw_false_text=self.draw_false_text, 104 | image_only=self.image_only, 105 | tokenizer=self.tokenizer, 106 | ) 107 | 108 | if hasattr(self, "dataset_cls_no_false"): 109 | self.val_dataset_no_false = self.dataset_cls_no_false( 110 | self.data_dir, 111 | self.val_transform_keys, 112 | split="val", 113 | image_size=self.image_size, 114 | max_text_len=self.max_text_len, 115 | draw_false_image=0, 116 | draw_false_text=0, 117 | image_only=self.image_only, 118 | tokenizer=self.tokenizer, 119 | ) 120 | 121 | def make_no_false_val_dset(self, image_only=False): 122 | return self.dataset_cls_no_false( 123 | self.data_dir, 124 | self.val_transform_keys, 125 | split="val", 126 | image_size=self.image_size, 127 | max_text_len=self.max_text_len, 128 | draw_false_image=0, 129 | draw_false_text=0, 130 | image_only=image_only, 131 | tokenizer=self.tokenizer, 132 | ) 133 | 134 | def set_test_dataset(self): 135 | self.test_dataset = self.dataset_cls( 136 | self.data_dir, 137 | self.val_transform_keys, 138 | split="test", 139 | image_size=self.image_size, 140 | max_text_len=self.max_text_len, 141 | draw_false_image=self.draw_false_image, 142 | draw_false_text=self.draw_false_text, 143 | image_only=self.image_only, 144 | tokenizer=self.tokenizer, 145 | ) 146 | 147 | def setup(self, stage): 148 | if not self.setup_flag: 149 | self.set_train_dataset() 150 | self.set_val_dataset() 151 | self.set_test_dataset() 152 | 153 | self.train_dataset.tokenizer = self.tokenizer 154 | self.val_dataset.tokenizer = self.tokenizer 155 | self.test_dataset.tokenizer = self.tokenizer 156 | 157 | self.setup_flag = True 158 | 159 | def train_dataloader(self): 160 | loader = DataLoader( 161 | self.train_dataset, 162 | batch_size=self.batch_size, 163 | shuffle=True, 164 | num_workers=self.num_workers, 165 | pin_memory=True, 166 | collate_fn=self.train_dataset.collate, 167 | ) 168 | return loader 169 | 170 | def val_dataloader(self): 171 | loader = DataLoader( 172 | self.val_dataset, 173 | batch_size=self.eval_batch_size, 174 | shuffle=False, 175 | num_workers=self.num_workers, 176 | pin_memory=True, 177 | collate_fn=self.val_dataset.collate, 178 | ) 179 | return loader 180 | 181 | def test_dataloader(self): 182 | loader = DataLoader( 183 | self.test_dataset, 184 | batch_size=self.eval_batch_size, 185 | shuffle=False, 186 | num_workers=self.num_workers, 187 | pin_memory=True, 188 | collate_fn=self.test_dataset.collate, 189 | ) 190 | return loader 191 | -------------------------------------------------------------------------------- /meter/datamodules/f30k_caption_karpathy_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import F30KCaptionKarpathyDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class F30KCaptionKarpathyDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return F30KCaptionKarpathyDataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return F30KCaptionKarpathyDataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "f30k" 20 | 21 | def train_dataloader(self): 22 | loader = DataLoader( 23 | self.train_dataset, 24 | batch_size=self.batch_size, 25 | shuffle=True, 26 | num_workers=0, 27 | pin_memory=True, 28 | collate_fn=self.train_dataset.collate, 29 | ) 30 | return loader 31 | 32 | def val_dataloader(self): 33 | loader = DataLoader( 34 | self.val_dataset, 35 | batch_size=self.eval_batch_size, 36 | shuffle=False, 37 | num_workers=0, 38 | pin_memory=True, 39 | collate_fn=self.val_dataset.collate, 40 | ) 41 | return loader 42 | 43 | def test_dataloader(self): 44 | loader = DataLoader( 45 | self.test_dataset, 46 | batch_size=self.eval_batch_size, 47 | shuffle=False, 48 | num_workers=0, 49 | pin_memory=True, 50 | collate_fn=self.test_dataset.collate, 51 | ) 52 | return loader 53 | -------------------------------------------------------------------------------- /meter/datamodules/multitask_datamodule.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataset import ConcatDataset 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from . import _datamodules 9 | 10 | 11 | class MTDataModule(LightningDataModule): 12 | def __init__(self, _config, dist=False): 13 | datamodule_keys = _config["datasets"] 14 | assert len(datamodule_keys) > 0 15 | 16 | super().__init__() 17 | 18 | self.dm_keys = datamodule_keys 19 | self.dm_dicts = {key: _datamodules[key](_config) for key in datamodule_keys} 20 | self.dms = [v for k, v in self.dm_dicts.items()] 21 | 22 | self.batch_size = self.dms[0].batch_size 23 | self.vocab_size = self.dms[0].vocab_size 24 | self.num_workers = self.dms[0].num_workers 25 | 26 | self.dist = dist 27 | 28 | def prepare_data(self): 29 | for dm in self.dms: 30 | dm.prepare_data() 31 | 32 | def setup(self, stage): 33 | for dm in self.dms: 34 | dm.setup(stage) 35 | 36 | self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms]) 37 | self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms]) 38 | self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms]) 39 | self.tokenizer = self.dms[0].tokenizer 40 | 41 | self.collate = functools.partial( 42 | self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 43 | ) 44 | 45 | if self.dist: 46 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 47 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=True) 48 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False) 49 | else: 50 | self.train_sampler = None 51 | self.val_sampler = None 52 | self.test_sampler = None 53 | 54 | def train_dataloader(self): 55 | loader = DataLoader( 56 | self.train_dataset, 57 | batch_size=self.batch_size, 58 | sampler=self.train_sampler, 59 | num_workers=self.num_workers, 60 | collate_fn=self.collate, 61 | ) 62 | return loader 63 | 64 | def val_dataloader(self, batch_size=None): 65 | loader = DataLoader( 66 | self.val_dataset, 67 | batch_size=batch_size if batch_size is not None else self.batch_size, 68 | sampler=self.val_sampler, 69 | num_workers=self.num_workers, 70 | collate_fn=self.collate, 71 | ) 72 | return loader 73 | 74 | def test_dataloader(self): 75 | loader = DataLoader( 76 | self.test_dataset, 77 | batch_size=self.batch_size, 78 | sampler=self.test_sampler, 79 | num_workers=self.num_workers, 80 | collate_fn=self.collate, 81 | ) 82 | return loader 83 | -------------------------------------------------------------------------------- /meter/datamodules/nlvr2_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import NLVR2Dataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class NLVR2DataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return NLVR2Dataset 12 | 13 | @property 14 | def dataset_name(self): 15 | return "nlvr2" 16 | -------------------------------------------------------------------------------- /meter/datamodules/sbu_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import SBUCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class SBUCaptionDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return SBUCaptionDataset 12 | 13 | @property 14 | def dataset_name(self): 15 | return "sbu" 16 | -------------------------------------------------------------------------------- /meter/datamodules/snli_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import SNLIDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class SNLIDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return SNLIDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "snli" 17 | -------------------------------------------------------------------------------- /meter/datamodules/vg_caption_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import VisualGenomeCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class VisualGenomeCaptionDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return VisualGenomeCaptionDataset 12 | 13 | @property 14 | def dataset_name(self): 15 | return "vg" 16 | -------------------------------------------------------------------------------- /meter/datamodules/vqav2_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import VQAv2Dataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class VQAv2DataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return VQAv2Dataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "vqa" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | 21 | train_answers = self.train_dataset.table["answers"].to_pandas().tolist() 22 | val_answers = self.val_dataset.table["answers"].to_pandas().tolist() 23 | train_labels = self.train_dataset.table["answer_labels"].to_pandas().tolist() 24 | val_labels = self.val_dataset.table["answer_labels"].to_pandas().tolist() 25 | 26 | all_answers = [c for c in train_answers + val_answers if c is not None] 27 | all_answers = [l for lll in all_answers for ll in lll for l in ll] 28 | all_labels = [c for c in train_labels + val_labels if c is not None] 29 | all_labels = [l for lll in all_labels for ll in lll for l in ll] 30 | 31 | self.answer2id = {k: v for k, v in zip(all_answers, all_labels)} 32 | sorted_a2i = sorted(self.answer2id.items(), key=lambda x: x[1]) 33 | self.num_class = max(self.answer2id.values()) + 1 34 | 35 | self.id2answer = defaultdict(lambda: "unknown") 36 | for k, v in sorted_a2i: 37 | self.id2answer[v] = k 38 | -------------------------------------------------------------------------------- /meter/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .vg_caption_dataset import VisualGenomeCaptionDataset 2 | from .coco_caption_karpathy_dataset import CocoCaptionKarpathyDataset 3 | from .f30k_caption_karpathy_dataset import F30KCaptionKarpathyDataset 4 | from .conceptual_caption_dataset import ConceptualCaptionDataset 5 | from .sbu_caption_dataset import SBUCaptionDataset 6 | from .vqav2_dataset import VQAv2Dataset 7 | from .nlvr2_dataset import NLVR2Dataset 8 | from .snli_dataset import SNLIDataset 9 | -------------------------------------------------------------------------------- /meter/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import io 4 | import pyarrow as pa 5 | import os 6 | 7 | from PIL import Image 8 | from ..transforms import keys_to_transforms 9 | 10 | 11 | class BaseDataset(torch.utils.data.Dataset): 12 | def __init__( 13 | self, 14 | data_dir: str, 15 | transform_keys: list, 16 | image_size: int, 17 | names: list, 18 | text_column_name: str = "", 19 | remove_duplicate=True, 20 | max_text_len=40, 21 | draw_false_image=0, 22 | draw_false_text=0, 23 | image_only=False, 24 | tokenizer=None, 25 | ): 26 | """ 27 | data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data 28 | transform_keys : keys for generating augmented views of images 29 | text_column_name : pyarrow table column name that has list of strings as elements 30 | """ 31 | assert len(transform_keys) >= 1 32 | super().__init__() 33 | 34 | self.transforms = keys_to_transforms(transform_keys, size=image_size) 35 | self.clip_transform = False 36 | for transform_key in transform_keys: 37 | if 'clip' in transform_key: 38 | self.clip_transform = True 39 | break 40 | self.text_column_name = text_column_name 41 | self.names = names 42 | self.max_text_len = max_text_len 43 | self.draw_false_image = draw_false_image 44 | self.draw_false_text = draw_false_text 45 | self.image_only = image_only 46 | self.data_dir = data_dir 47 | 48 | if len(names) != 0: 49 | tables = [ 50 | pa.ipc.RecordBatchFileReader( 51 | pa.memory_map(f"{data_dir}/{name}.arrow", "r") 52 | ).read_all() 53 | for name in names 54 | if os.path.isfile(f"{data_dir}/{name}.arrow") 55 | ] 56 | 57 | self.table_names = list() 58 | for i, name in enumerate(names): 59 | self.table_names += [name] * len(tables[i]) 60 | 61 | self.table = pa.concat_tables(tables, promote=True) 62 | if text_column_name != "": 63 | self.text_column_name = text_column_name 64 | self.all_texts = self.table[text_column_name].to_pandas().tolist() 65 | if type(self.all_texts[0][0]) == str: 66 | self.all_texts = ( 67 | [list(set(texts)) for texts in self.all_texts] 68 | if remove_duplicate 69 | else self.all_texts 70 | ) 71 | else: #snli 72 | self.all_texts = ( 73 | [[t[1].strip() for t in texts] for texts in self.all_texts] 74 | ) 75 | else: 76 | self.all_texts = list() 77 | else: 78 | self.all_texts = list() 79 | 80 | self.index_mapper = dict() 81 | 82 | if text_column_name != "" and not self.image_only: 83 | j = 0 84 | for i, texts in enumerate(self.all_texts): 85 | for _j in range(len(texts)): 86 | self.index_mapper[j] = (i, _j) 87 | j += 1 88 | else: 89 | for i in range(len(self.table)): 90 | self.index_mapper[i] = (i, None) 91 | 92 | @property 93 | def corpus(self): 94 | return [text for texts in self.all_texts for text in texts] 95 | 96 | def __len__(self): 97 | return len(self.index_mapper) 98 | 99 | def get_raw_image(self, index, image_key="image"): 100 | index, caption_index = self.index_mapper[index] 101 | image_bytes = io.BytesIO(self.table[image_key][index].as_py()) 102 | image_bytes.seek(0) 103 | if self.clip_transform: 104 | return Image.open(image_bytes).convert("RGBA") 105 | else: 106 | return Image.open(image_bytes).convert("RGB") 107 | 108 | def get_image(self, index, image_key="image"): 109 | image = self.get_raw_image(index, image_key=image_key) 110 | image_tensor = [tr(image) for tr in self.transforms] 111 | return { 112 | "image": image_tensor, 113 | "img_index": self.index_mapper[index][0], 114 | "cap_index": self.index_mapper[index][1], 115 | "raw_index": index, 116 | } 117 | 118 | def get_false_image(self, rep, image_key="image"): 119 | random_index = random.randint(0, len(self.index_mapper) - 1) 120 | image = self.get_raw_image(random_index, image_key=image_key) 121 | image_tensor = [tr(image) for tr in self.transforms] 122 | return {f"false_image_{rep}": image_tensor} 123 | 124 | def get_text(self, raw_index): 125 | index, caption_index = self.index_mapper[raw_index] 126 | 127 | text = self.all_texts[index][caption_index] 128 | encoding = self.tokenizer( 129 | text, 130 | padding="max_length", 131 | truncation=True, 132 | max_length=self.max_text_len, 133 | return_special_tokens_mask=True, 134 | ) 135 | return { 136 | "text": (text, encoding), 137 | "img_index": index, 138 | "cap_index": caption_index, 139 | "raw_index": raw_index, 140 | } 141 | 142 | def get_false_text(self, rep): 143 | random_index = random.randint(0, len(self.index_mapper) - 1) 144 | 145 | index, caption_index = self.index_mapper[random_index] 146 | text = self.all_texts[index][caption_index] 147 | encoding = self.tokenizer( 148 | text, 149 | truncation=True, 150 | max_length=self.max_text_len, 151 | return_special_tokens_mask=True, 152 | ) 153 | return {f"false_text_{rep}": (text, encoding)} 154 | 155 | def get_suite(self, index): 156 | result = None 157 | while result is None: 158 | try: 159 | ret = dict() 160 | ret.update(self.get_image(index)) 161 | if not self.image_only: 162 | txt = self.get_text(index) 163 | ret.update({"replica": True if txt["cap_index"] > 0 else False}) 164 | ret.update(txt) 165 | 166 | for i in range(self.draw_false_image): 167 | ret.update(self.get_false_image(i)) 168 | for i in range(self.draw_false_text): 169 | ret.update(self.get_false_text(i)) 170 | result = True 171 | except Exception as e: 172 | print(f"Error while read file idx {index} in {self.names[0]} -> {e}") 173 | index = random.randint(0, len(self.index_mapper) - 1) 174 | return ret 175 | 176 | def collate(self, batch, mlm_collator): 177 | batch_size = len(batch) 178 | keys = set([key for b in batch for key in b.keys()]) 179 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 180 | 181 | img_keys = [k for k in list(dict_batch.keys()) if "image" in k] 182 | img_sizes = list() 183 | 184 | for img_key in img_keys: 185 | img = dict_batch[img_key] 186 | img_sizes += [ii.shape for i in img if i is not None for ii in i] 187 | 188 | for size in img_sizes: 189 | assert ( 190 | len(size) == 3 191 | ), f"Collate error, an image should be in shape of (3, H, W), instead of given {size}" 192 | 193 | if len(img_keys) != 0: 194 | max_height = max([i[1] for i in img_sizes]) 195 | max_width = max([i[2] for i in img_sizes]) 196 | 197 | for img_key in img_keys: 198 | img = dict_batch[img_key] 199 | view_size = len(img[0]) 200 | 201 | new_images = [ 202 | torch.zeros(batch_size, 3, max_height, max_width) 203 | for _ in range(view_size) 204 | ] 205 | 206 | for bi in range(batch_size): 207 | orig_batch = img[bi] 208 | for vi in range(view_size): 209 | if orig_batch is None: 210 | new_images[vi][bi] = None 211 | else: 212 | orig = img[bi][vi] 213 | new_images[vi][bi, :, : orig.shape[1], : orig.shape[2]] = orig 214 | 215 | dict_batch[img_key] = new_images 216 | 217 | txt_keys = [k for k in list(dict_batch.keys()) if "text" in k] 218 | 219 | if len(txt_keys) != 0: 220 | texts = [[d[0] for d in dict_batch[txt_key]] for txt_key in txt_keys] 221 | encodings = [[d[1] for d in dict_batch[txt_key]] for txt_key in txt_keys] 222 | draw_text_len = len(encodings) 223 | flatten_encodings = [e for encoding in encodings for e in encoding] 224 | flatten_mlms = mlm_collator(flatten_encodings) 225 | 226 | for i, txt_key in enumerate(txt_keys): 227 | texts, encodings = ( 228 | [d[0] for d in dict_batch[txt_key]], 229 | [d[1] for d in dict_batch[txt_key]], 230 | ) 231 | 232 | mlm_ids, mlm_labels = ( 233 | flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)], 234 | flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)], 235 | ) 236 | 237 | input_ids = torch.zeros_like(mlm_ids) 238 | attention_mask = torch.zeros_like(mlm_ids) 239 | for _i, encoding in enumerate(encodings): 240 | _input_ids, _attention_mask = ( 241 | torch.tensor(encoding["input_ids"]), 242 | torch.tensor(encoding["attention_mask"]), 243 | ) 244 | input_ids[_i, : len(_input_ids)] = _input_ids 245 | attention_mask[_i, : len(_attention_mask)] = _attention_mask 246 | 247 | dict_batch[txt_key] = texts 248 | dict_batch[f"{txt_key}_ids"] = input_ids 249 | dict_batch[f"{txt_key}_labels"] = torch.full_like(input_ids, -100) 250 | dict_batch[f"{txt_key}_ids_mlm"] = mlm_ids 251 | dict_batch[f"{txt_key}_labels_mlm"] = mlm_labels 252 | dict_batch[f"{txt_key}_masks"] = attention_mask 253 | 254 | return dict_batch 255 | -------------------------------------------------------------------------------- /meter/datasets/coco_caption_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import io 3 | from PIL import Image 4 | 5 | class CocoCaptionKarpathyDataset(BaseDataset): 6 | def __init__(self, *args, split="", **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | names = ["coco_caption_karpathy_train", "coco_caption_karpathy_val"] 12 | elif split == "val": 13 | names = ["coco_caption_karpathy_val"] 14 | elif split == "test": 15 | names = ["coco_caption_karpathy_test"] 16 | 17 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 18 | 19 | 20 | def __getitem__(self, index): 21 | suite = self.get_suite(index) 22 | 23 | if "test" in self.split: 24 | _index, _question_index = self.index_mapper[index] 25 | iid = self.table["image_id"][_index].as_py() 26 | iid = int(iid.split(".")[0].split("_")[-1]) 27 | suite.update({"iid": iid}) 28 | 29 | return suite 30 | -------------------------------------------------------------------------------- /meter/datasets/conceptual_caption_dataset.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from .base_dataset import BaseDataset 3 | import io 4 | from PIL import Image 5 | 6 | 7 | class ConceptualCaptionDataset(BaseDataset): 8 | def __init__(self, *args, split="", **kwargs): 9 | assert split in ["train", "val", "test"] 10 | if split == "test": 11 | split = "val" 12 | 13 | if split == "train": 14 | names = [f"conceptual_caption_train_{i}" for i in range(31)] 15 | elif split == "val": 16 | names = [] 17 | 18 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 19 | 20 | 21 | def __getitem__(self, index): 22 | return self.get_suite(index) 23 | 24 | def get_text(self, raw_index): 25 | index, caption_index = self.index_mapper[raw_index] 26 | 27 | text = self.all_texts[index][caption_index] 28 | encoding = self.tokenizer( 29 | text, 30 | padding="max_length", 31 | truncation=True, 32 | max_length=self.max_text_len, 33 | return_special_tokens_mask=True, 34 | ) 35 | return { 36 | "text": (text, encoding), 37 | "img_index": index, 38 | "cap_index": caption_index, 39 | "raw_index": raw_index, 40 | } 41 | -------------------------------------------------------------------------------- /meter/datasets/f30k_caption_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class F30KCaptionKarpathyDataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | 8 | if split == "train": 9 | names = ["f30k_caption_karpathy_train", "f30k_caption_karpathy_val"] 10 | elif split == "val": 11 | names = ["f30k_caption_karpathy_test"] 12 | elif split == "test": 13 | names = ["f30k_caption_karpathy_test"] 14 | 15 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 16 | 17 | def __getitem__(self, index): 18 | return self.get_suite(index) 19 | -------------------------------------------------------------------------------- /meter/datasets/nlvr2_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import sys 3 | import random 4 | 5 | 6 | class NLVR2Dataset(BaseDataset): 7 | def __init__(self, *args, split="", **kwargs): 8 | assert split in ["train", "val", "test"] 9 | self.split = split 10 | 11 | if split == "train": 12 | names = ["nlvr2_train"] 13 | elif split == "val": 14 | names = ["nlvr2_dev", "nlvr2_test1"] 15 | elif split == "test": 16 | names = ["nlvr2_dev", "nlvr2_test1"] 17 | 18 | super().__init__( 19 | *args, 20 | **kwargs, 21 | names=names, 22 | text_column_name="questions", 23 | remove_duplicate=False, 24 | ) 25 | 26 | def __getitem__(self, index): 27 | result = None 28 | while result is None: 29 | try: 30 | image_tensor_0 = self.get_image(index, image_key="image_0")["image"] 31 | image_tensor_1 = self.get_image(index, image_key="image_1")["image"] 32 | text = self.get_text(index)["text"] 33 | result = True 34 | except: 35 | print( 36 | f"error while read file idx {index} in {self.names[0]}", 37 | file=sys.stderr, 38 | ) 39 | index = random.randint(0, len(self.index_mapper) - 1) 40 | 41 | index, question_index = self.index_mapper[index] 42 | answers = self.table["answers"][index][question_index].as_py() 43 | answers = answers == "True" 44 | 45 | return { 46 | "image_0": image_tensor_0, 47 | "image_1": image_tensor_1, 48 | "text": text, 49 | "answers": answers, 50 | "table_name": self.table_names[index], 51 | } 52 | -------------------------------------------------------------------------------- /meter/datasets/sbu_caption_dataset.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from .base_dataset import BaseDataset 3 | import io 4 | from PIL import Image 5 | 6 | 7 | class SBUCaptionDataset(BaseDataset): 8 | def __init__(self, *args, split="", **kwargs): 9 | assert split in ["train", "val", "test"] 10 | if split == "test": 11 | split = "val" 12 | 13 | if split == "train": 14 | names = [f"sbu_{i}" for i in range(9)] 15 | elif split == "val": 16 | names = [] 17 | 18 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 19 | 20 | def __getitem__(self, index): 21 | return self.get_suite(index) 22 | 23 | def get_text(self, raw_index): 24 | index, caption_index = self.index_mapper[raw_index] 25 | 26 | text = self.all_texts[index][caption_index] 27 | encoding = self.tokenizer( 28 | text, 29 | padding="max_length", 30 | truncation=True, 31 | max_length=self.max_text_len, 32 | return_special_tokens_mask=True, 33 | ) 34 | return { 35 | "text": (text, encoding), 36 | "img_index": index, 37 | "cap_index": caption_index, 38 | "raw_index": raw_index, 39 | } 40 | -------------------------------------------------------------------------------- /meter/datasets/snli_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class SNLIDataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["snli_train"] 11 | elif split == "val": 12 | names = ["snli_dev", "snli_test"] 13 | elif split == "test": 14 | names = ["snli_dev", "snli_test"] 15 | 16 | super().__init__( 17 | *args, 18 | **kwargs, 19 | names=names, 20 | text_column_name="sentences", 21 | remove_duplicate=False, 22 | ) 23 | 24 | def __getitem__(self, index): 25 | image_tensor = self.get_image(index)["image"] 26 | text = self.get_text(index)["text"] 27 | 28 | index, question_index = self.index_mapper[index] 29 | 30 | labels = self.table["labels"][index][question_index].as_py() 31 | 32 | return { 33 | "image": image_tensor, 34 | "text": text, 35 | "labels": labels, 36 | "table_name": self.table_names[index], 37 | } 38 | -------------------------------------------------------------------------------- /meter/datasets/vg_caption_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import io 3 | from PIL import Image 4 | 5 | 6 | class VisualGenomeCaptionDataset(BaseDataset): 7 | def __init__(self, *args, split="", **kwargs): 8 | assert split in ["train", "val", "test"] 9 | if split == "test": 10 | split = "val" 11 | 12 | if split == "train": 13 | names = ["vg"] 14 | elif split == "val": 15 | names = [] 16 | 17 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 18 | 19 | def __getitem__(self, index): 20 | return self.get_suite(index) 21 | 22 | def get_text(self, raw_index): 23 | index, caption_index = self.index_mapper[raw_index] 24 | 25 | text = self.all_texts[index][caption_index] 26 | encoding = self.tokenizer( 27 | text, 28 | padding="max_length", 29 | truncation=True, 30 | max_length=self.max_text_len, 31 | return_special_tokens_mask=True, 32 | ) 33 | return { 34 | "text": (text, encoding), 35 | "img_index": index, 36 | "cap_index": caption_index, 37 | "raw_index": raw_index, 38 | } 39 | -------------------------------------------------------------------------------- /meter/datasets/vqav2_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class VQAv2Dataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["vqav2_train", "vqav2_val"] 11 | elif split == "val": 12 | names = ["vqav2_val"] 13 | elif split == "test": 14 | names = ["vqav2_test"] 15 | 16 | super().__init__( 17 | *args, 18 | **kwargs, 19 | names=names, 20 | text_column_name="questions", 21 | remove_duplicate=False, 22 | ) 23 | 24 | def __getitem__(self, index): 25 | image_tensor = self.get_image(index)["image"] 26 | text = self.get_text(index)["text"] 27 | 28 | index, question_index = self.index_mapper[index] 29 | qid = self.table["question_id"][index][question_index].as_py() 30 | 31 | if self.split != "test": 32 | answers = self.table["answers"][index][question_index].as_py() 33 | labels = self.table["answer_labels"][index][question_index].as_py() 34 | scores = self.table["answer_scores"][index][question_index].as_py() 35 | else: 36 | answers = list() 37 | labels = list() 38 | scores = list() 39 | 40 | return { 41 | "image": image_tensor, 42 | "text": text, 43 | "vqa_answer": answers, 44 | "vqa_labels": labels, 45 | "vqa_scores": scores, 46 | "qid": qid, 47 | } 48 | -------------------------------------------------------------------------------- /meter/gadgets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLyu0110/HACAN/35bdbc5cb2a9a62870fa9dce180c03c4c9d54206/meter/gadgets/__init__.py -------------------------------------------------------------------------------- /meter/gadgets/my_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | 5 | class Accuracy(Metric): 6 | def __init__(self, dist_sync_on_step=False): 7 | super().__init__(dist_sync_on_step=dist_sync_on_step) 8 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 9 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 10 | 11 | def update(self, logits, target): 12 | logits, target = ( 13 | logits.detach().to(self.correct.device), 14 | target.detach().to(self.correct.device), 15 | ) 16 | preds = logits.argmax(dim=-1) 17 | preds = preds[target != -100] 18 | target = target[target != -100] 19 | if target.numel() == 0: 20 | return 1 21 | 22 | assert preds.shape == target.shape 23 | 24 | self.correct += torch.sum(preds == target) 25 | self.total += target.numel() 26 | 27 | def compute(self): 28 | return self.correct / self.total 29 | 30 | 31 | class Scalar(Metric): 32 | def __init__(self, dist_sync_on_step=False): 33 | super().__init__(dist_sync_on_step=dist_sync_on_step) 34 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 35 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 36 | 37 | def update(self, scalar): 38 | if isinstance(scalar, torch.Tensor): 39 | scalar = scalar.detach().to(self.scalar.device) 40 | else: 41 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 42 | self.scalar += scalar 43 | self.total += 1 44 | 45 | def compute(self): 46 | return self.scalar / self.total 47 | 48 | 49 | class VQAScore(Metric): 50 | def __init__(self, dist_sync_on_step=False): 51 | super().__init__(dist_sync_on_step=dist_sync_on_step) 52 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 53 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 54 | 55 | def update(self, logits, target): 56 | logits, target = ( 57 | logits.detach().float().to(self.score.device), 58 | target.detach().float().to(self.score.device), 59 | ) 60 | logits = torch.max(logits, 1)[1] 61 | one_hots = torch.zeros(*target.size()).to(target) 62 | one_hots.scatter_(1, logits.view(-1, 1), 1) 63 | scores = one_hots * target 64 | 65 | self.score += scores.sum() 66 | self.total += len(logits) 67 | 68 | def compute(self): 69 | return self.score / self.total 70 | -------------------------------------------------------------------------------- /meter/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .meter_module import METERTransformerSS 2 | -------------------------------------------------------------------------------- /meter/modules/convit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | '''These modules are adapted from those of timm, see 9 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 10 | ''' 11 | 12 | import torch 13 | import torch.nn as nn 14 | from functools import partial 15 | import torch.nn.functional as F 16 | from timm.models.helpers import load_pretrained 17 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 18 | from timm.models.registry import register_model 19 | 20 | import torch 21 | import torch.nn as nn 22 | import matplotlib.pyplot as plt 23 | 24 | 25 | class Mlp(nn.Module): 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | self.apply(self._init_weights) 35 | 36 | def _init_weights(self, m): 37 | if isinstance(m, nn.Linear): 38 | trunc_normal_(m.weight, std=.02) 39 | if isinstance(m, nn.Linear) and m.bias is not None: 40 | nn.init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.LayerNorm): 42 | nn.init.constant_(m.bias, 0) 43 | nn.init.constant_(m.weight, 1.0) 44 | 45 | def forward(self, x): 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | 54 | class GPSA(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 56 | locality_strength=1., use_local_init=True): 57 | super().__init__() 58 | self.num_heads = num_heads 59 | self.dim = dim 60 | head_dim = dim // num_heads 61 | self.scale = qk_scale or head_dim ** -0.5 62 | 63 | self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) 64 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 65 | 66 | self.attn_drop = nn.Dropout(attn_drop) 67 | self.proj = nn.Linear(dim, dim) 68 | self.pos_proj = nn.Linear(3, num_heads) 69 | self.proj_drop = nn.Dropout(proj_drop) 70 | self.locality_strength = locality_strength 71 | self.gating_param = nn.Parameter(torch.ones(self.num_heads)) 72 | self.apply(self._init_weights) 73 | if use_local_init: 74 | self.local_init(locality_strength=locality_strength) 75 | 76 | def _init_weights(self, m): 77 | if isinstance(m, nn.Linear): 78 | trunc_normal_(m.weight, std=.02) 79 | if isinstance(m, nn.Linear) and m.bias is not None: 80 | nn.init.constant_(m.bias, 0) 81 | elif isinstance(m, nn.LayerNorm): 82 | nn.init.constant_(m.bias, 0) 83 | nn.init.constant_(m.weight, 1.0) 84 | 85 | def forward(self, x): 86 | B, N, C = x.shape 87 | if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N: 88 | self.get_rel_indices(N) 89 | 90 | attn = self.get_attention(x) 91 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 92 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 93 | x = self.proj(x) 94 | x = self.proj_drop(x) 95 | return x 96 | 97 | def get_attention(self, x): 98 | B, N, C = x.shape 99 | qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 100 | q, k = qk[0], qk[1] 101 | pos_score = self.rel_indices.expand(B, -1, -1,-1) 102 | pos_score = self.pos_proj(pos_score).permute(0,3,1,2) 103 | patch_score = (q @ k.transpose(-2, -1)) * self.scale 104 | patch_score = patch_score.softmax(dim=-1) 105 | pos_score = pos_score.softmax(dim=-1) 106 | 107 | gating = self.gating_param.view(1,-1,1,1) 108 | attn = (1.-torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score 109 | attn /= attn.sum(dim=-1).unsqueeze(-1) 110 | attn = self.attn_drop(attn) 111 | return attn 112 | 113 | def get_attention_map(self, x, return_map = False): 114 | 115 | attn_map = self.get_attention(x).mean(0) # average over batch 116 | distances = self.rel_indices.squeeze()[:,:,-1]**.5 117 | dist = torch.einsum('nm,hnm->h', (distances, attn_map)) 118 | dist /= distances.size(0) 119 | if return_map: 120 | return dist, attn_map 121 | else: 122 | return dist 123 | 124 | def local_init(self, locality_strength=1.): 125 | 126 | self.v.weight.data.copy_(torch.eye(self.dim)) 127 | locality_distance = 1 #max(1,1/locality_strength**.5) 128 | 129 | kernel_size = int(self.num_heads**.5) 130 | center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2 131 | for h1 in range(kernel_size): 132 | for h2 in range(kernel_size): 133 | position = h1+kernel_size*h2 134 | self.pos_proj.weight.data[position,2] = -1 135 | self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance 136 | self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance 137 | self.pos_proj.weight.data *= locality_strength 138 | 139 | def get_rel_indices(self, num_patches): 140 | img_size = int(num_patches**.5) 141 | rel_indices = torch.zeros(1, num_patches, num_patches, 3) 142 | ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) 143 | indx = ind.repeat(img_size,img_size) 144 | indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) 145 | indd = indx**2 + indy**2 146 | rel_indices[:,:,:,2] = indd.unsqueeze(0) 147 | rel_indices[:,:,:,1] = indy.unsqueeze(0) 148 | rel_indices[:,:,:,0] = indx.unsqueeze(0) 149 | device = self.qk.weight.device 150 | self.rel_indices = rel_indices.to(device) 151 | 152 | 153 | class MHSA(nn.Module): 154 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 155 | super().__init__() 156 | self.num_heads = num_heads 157 | head_dim = dim // num_heads 158 | self.scale = qk_scale or head_dim ** -0.5 159 | 160 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 161 | self.attn_drop = nn.Dropout(attn_drop) 162 | self.proj = nn.Linear(dim, dim) 163 | self.proj_drop = nn.Dropout(proj_drop) 164 | self.apply(self._init_weights) 165 | 166 | def _init_weights(self, m): 167 | if isinstance(m, nn.Linear): 168 | trunc_normal_(m.weight, std=.02) 169 | if isinstance(m, nn.Linear) and m.bias is not None: 170 | nn.init.constant_(m.bias, 0) 171 | elif isinstance(m, nn.LayerNorm): 172 | nn.init.constant_(m.bias, 0) 173 | nn.init.constant_(m.weight, 1.0) 174 | 175 | def get_attention_map(self, x, return_map = False): 176 | B, N, C = x.shape 177 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 178 | q, k, v = qkv[0], qkv[1], qkv[2] 179 | attn_map = (q @ k.transpose(-2, -1)) * self.scale 180 | attn_map = attn_map.softmax(dim=-1).mean(0) 181 | 182 | img_size = int(N**.5) 183 | ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) 184 | indx = ind.repeat(img_size,img_size) 185 | indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) 186 | indd = indx**2 + indy**2 187 | distances = indd**.5 188 | distances = distances.to('cuda') 189 | 190 | dist = torch.einsum('nm,hnm->h', (distances, attn_map)) 191 | dist /= N 192 | 193 | if return_map: 194 | return dist, attn_map 195 | else: 196 | return dist 197 | 198 | 199 | def forward(self, x): 200 | B, N, C = x.shape 201 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 202 | q, k, v = qkv[0], qkv[1], qkv[2] 203 | 204 | attn = (q @ k.transpose(-2, -1)) * self.scale 205 | attn = attn.softmax(dim=-1) 206 | attn = self.attn_drop(attn) 207 | 208 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 209 | x = self.proj(x) 210 | x = self.proj_drop(x) 211 | return x 212 | 213 | class Block(nn.Module): 214 | 215 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 216 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): 217 | super().__init__() 218 | self.norm1 = norm_layer(dim) 219 | self.use_gpsa = use_gpsa 220 | if self.use_gpsa: 221 | self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) 222 | else: 223 | self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) 224 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 225 | self.norm2 = norm_layer(dim) 226 | mlp_hidden_dim = int(dim * mlp_ratio) 227 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 228 | 229 | def forward(self, x): 230 | x = x + self.drop_path(self.attn(self.norm1(x))) 231 | x = x + self.drop_path(self.mlp(self.norm2(x))) 232 | return x 233 | 234 | 235 | class PatchEmbed(nn.Module): 236 | """ Image to Patch Embedding, from timm 237 | """ 238 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 239 | super().__init__() 240 | img_size = to_2tuple(img_size) 241 | patch_size = to_2tuple(patch_size) 242 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 243 | self.img_size = img_size 244 | self.patch_size = patch_size 245 | self.num_patches = num_patches 246 | 247 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 248 | self.apply(self._init_weights) 249 | def forward(self, x): 250 | B, C, H, W = x.shape 251 | assert H == self.img_size[0] and W == self.img_size[1], \ 252 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 253 | x = self.proj(x).flatten(2).transpose(1, 2) 254 | return x 255 | def _init_weights(self, m): 256 | if isinstance(m, nn.Linear): 257 | trunc_normal_(m.weight, std=.02) 258 | if isinstance(m, nn.Linear) and m.bias is not None: 259 | nn.init.constant_(m.bias, 0) 260 | elif isinstance(m, nn.LayerNorm): 261 | nn.init.constant_(m.bias, 0) 262 | nn.init.constant_(m.weight, 1.0) 263 | 264 | 265 | class HybridEmbed(nn.Module): 266 | """ CNN Feature Map Embedding, from timm 267 | """ 268 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 269 | super().__init__() 270 | assert isinstance(backbone, nn.Module) 271 | img_size = to_2tuple(img_size) 272 | self.img_size = img_size 273 | self.backbone = backbone 274 | if feature_size is None: 275 | with torch.no_grad(): 276 | training = backbone.training 277 | if training: 278 | backbone.eval() 279 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 280 | feature_size = o.shape[-2:] 281 | feature_dim = o.shape[1] 282 | backbone.train(training) 283 | else: 284 | feature_size = to_2tuple(feature_size) 285 | feature_dim = self.backbone.feature_info.channels()[-1] 286 | self.num_patches = feature_size[0] * feature_size[1] 287 | self.proj = nn.Linear(feature_dim, embed_dim) 288 | self.apply(self._init_weights) 289 | 290 | def forward(self, x): 291 | x = self.backbone(x)[-1] 292 | x = x.flatten(2).transpose(1, 2) 293 | x = self.proj(x) 294 | return x 295 | 296 | 297 | class VisionTransformer(nn.Module): 298 | """ Vision Transformer with support for patch or hybrid CNN input stage 299 | """ 300 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=48, depth=12, 301 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 302 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, 303 | local_up_to_layer=10, locality_strength=1., use_pos_embed=True): 304 | super().__init__() 305 | self.num_classes = num_classes 306 | self.local_up_to_layer = local_up_to_layer 307 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 308 | self.locality_strength = locality_strength 309 | self.use_pos_embed = use_pos_embed 310 | 311 | if hybrid_backbone is not None: 312 | self.patch_embed = HybridEmbed( 313 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 314 | else: 315 | self.patch_embed = PatchEmbed( 316 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 317 | num_patches = self.patch_embed.num_patches 318 | self.num_patches = num_patches 319 | 320 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 321 | self.pos_drop = nn.Dropout(p=drop_rate) 322 | 323 | if self.use_pos_embed: 324 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 325 | trunc_normal_(self.pos_embed, std=.02) 326 | 327 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 328 | self.blocks = nn.ModuleList([ 329 | Block( 330 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 331 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 332 | use_gpsa=True, 333 | locality_strength=locality_strength) 334 | if i 0 else nn.Identity() 345 | 346 | trunc_normal_(self.cls_token, std=.02) 347 | self.head.apply(self._init_weights) 348 | 349 | def _init_weights(self, m): 350 | if isinstance(m, nn.Linear): 351 | trunc_normal_(m.weight, std=.02) 352 | if isinstance(m, nn.Linear) and m.bias is not None: 353 | nn.init.constant_(m.bias, 0) 354 | elif isinstance(m, nn.LayerNorm): 355 | nn.init.constant_(m.bias, 0) 356 | nn.init.constant_(m.weight, 1.0) 357 | 358 | @torch.jit.ignore 359 | def no_weight_decay(self): 360 | return {'pos_embed', 'cls_token'} 361 | 362 | def get_classifier(self): 363 | return self.head 364 | 365 | def reset_classifier(self, num_classes, global_pool=''): 366 | self.num_classes = num_classes 367 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 368 | 369 | def forward_features(self, x): 370 | B = x.shape[0] 371 | x = self.patch_embed(x) 372 | 373 | cls_tokens = self.cls_token.expand(B, -1, -1) 374 | 375 | if self.use_pos_embed: 376 | x = x + self.pos_embed 377 | x = self.pos_drop(x) 378 | 379 | x_stage = [] 380 | for u,blk in enumerate(self.blocks): 381 | if u == self.local_up_to_layer : 382 | x = torch.cat((cls_tokens, x), dim=1) 383 | x = blk(x) 384 | x_stage.append(x) 385 | 386 | x_stage = [self.norm(x) for x in x_stage] 387 | x = self.norm(x) 388 | return x[:, 0], x_stage 389 | 390 | def forward(self, x): 391 | x, x_stage = self.forward_features(x) 392 | x = self.head(x) 393 | return x, x_stage 394 | 395 | -------------------------------------------------------------------------------- /meter/modules/convit_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import torch.nn as nn 10 | from functools import partial 11 | 12 | from .convit import VisionTransformer 13 | from timm.models.efficientnet import EfficientNet 14 | from timm.models.vision_transformer import _cfg 15 | from timm.models.registry import register_model 16 | 17 | @register_model 18 | def convit_tiny(pretrained=False, **kwargs): 19 | num_heads = 4 20 | kwargs['embed_dim'] *= num_heads 21 | model = VisionTransformer( 22 | num_heads=num_heads, 23 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 24 | model.default_cfg = _cfg() 25 | if pretrained: 26 | checkpoint = torch.hub.load_state_dict_from_url( 27 | url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth", 28 | map_location="cpu", check_hash=True 29 | ) 30 | model.load_state_dict(checkpoint) 31 | return model 32 | 33 | @register_model 34 | def convit_small(pretrained=False, **kwargs): 35 | num_heads = 9 36 | kwargs['embed_dim'] *= num_heads 37 | model = VisionTransformer( 38 | num_heads=num_heads, 39 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 40 | model.default_cfg = _cfg() 41 | if pretrained: 42 | checkpoint = torch.hub.load_state_dict_from_url( 43 | url="https://dl.fbaipublicfiles.com/convit/convit_small.pth", 44 | map_location="cpu", check_hash=True 45 | ) 46 | model.load_state_dict(checkpoint) 47 | return model 48 | 49 | @register_model 50 | def convit_base(pretrained=False, **kwargs): 51 | num_heads = 16 52 | kwargs['embed_dim'] *= num_heads 53 | model = VisionTransformer( 54 | num_heads=num_heads, 55 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 56 | model.default_cfg = _cfg() 57 | if pretrained: 58 | checkpoint = torch.hub.load_state_dict_from_url( 59 | url="https://dl.fbaipublicfiles.com/convit/convit_base.pth", 60 | map_location="cpu", check_hash=True 61 | ) 62 | model.load_state_dict(checkpoint) 63 | return model -------------------------------------------------------------------------------- /meter/modules/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_, DropPath 13 | from timm.models.registry import register_model 14 | 15 | 16 | class Block(nn.Module): 17 | r""" ConvNeXt Block. There are two equivalent implementations: 18 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 19 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 20 | We use (2) as we find it slightly faster in PyTorch 21 | 22 | Args: 23 | dim (int): Number of input channels. 24 | drop_path (float): Stochastic depth rate. Default: 0.0 25 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 26 | """ 27 | 28 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 29 | super().__init__() 30 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 31 | self.norm = LayerNorm(dim, eps=1e-6) 32 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 33 | self.act = nn.GELU() 34 | self.pwconv2 = nn.Linear(4 * dim, dim) 35 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 36 | requires_grad=True) if layer_scale_init_value > 0 else None 37 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 38 | 39 | def forward(self, x): 40 | input = x 41 | x = self.dwconv(x) 42 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 43 | x = self.norm(x) 44 | x = self.pwconv1(x) 45 | x = self.act(x) 46 | x = self.pwconv2(x) 47 | if self.gamma is not None: 48 | x = self.gamma * x 49 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 50 | 51 | x = input + self.drop_path(x) 52 | return x 53 | 54 | 55 | class ConvNeXt(nn.Module): 56 | r""" ConvNeXt 57 | A PyTorch impl of : `A ConvNet for the 2020s` - 58 | https://arxiv.org/pdf/2201.03545.pdf 59 | 60 | Args: 61 | in_chans (int): Number of input image channels. Default: 3.txt 62 | num_classes (int): Number of classes for classification head. Default: 1000 63 | depths (tuple(int)): Number of blocks at each stage. Default: [3.txt, 3.txt, 9, 3.txt] 64 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 65 | drop_path_rate (float): Stochastic depth rate. Default: 0. 66 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 67 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 68 | """ 69 | 70 | def __init__(self, in_chans=3, num_classes=1000, 71 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 72 | layer_scale_init_value=1e-6, head_init_scale=1., 73 | ): 74 | super().__init__() 75 | 76 | self.downsample_layers = nn.ModuleList() # stem and 3.txt intermediate downsampling conv layers 77 | stem = nn.Sequential( 78 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 79 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 80 | ) 81 | self.downsample_layers.append(stem) 82 | for i in range(3): 83 | downsample_layer = nn.Sequential( 84 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 85 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), 86 | ) 87 | self.downsample_layers.append(downsample_layer) 88 | 89 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 90 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 91 | cur = 0 92 | for i in range(4): 93 | stage = nn.Sequential( 94 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 95 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 96 | ) 97 | self.stages.append(stage) 98 | cur += depths[i] 99 | 100 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 101 | self.head = nn.Linear(dims[-1], num_classes) 102 | 103 | self.apply(self._init_weights) 104 | self.head.weight.data.mul_(head_init_scale) 105 | self.head.bias.data.mul_(head_init_scale) 106 | 107 | def _init_weights(self, m): 108 | if isinstance(m, (nn.Conv2d, nn.Linear)): 109 | trunc_normal_(m.weight, std=.02) 110 | nn.init.constant_(m.bias, 0) 111 | 112 | def forward_features(self, x): 113 | x_stage = [] 114 | for i in range(4): 115 | x = self.downsample_layers[i](x) 116 | x = self.stages[i](x) 117 | x_stage.append(x) 118 | return self.norm(x.mean([-2, -1])), x_stage # global average pooling, (N, C, H, W) -> (N, C) 119 | 120 | def forward(self, x): 121 | x, x_stage = self.forward_features(x) 122 | # x = self.head(x) 123 | return x, x_stage 124 | 125 | 126 | class LayerNorm(nn.Module): 127 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 128 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 129 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 130 | with shape (batch_size, channels, height, width). 131 | """ 132 | 133 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 134 | super().__init__() 135 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 136 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 137 | self.eps = eps 138 | self.data_format = data_format 139 | if self.data_format not in ["channels_last", "channels_first"]: 140 | raise NotImplementedError 141 | self.normalized_shape = (normalized_shape,) 142 | 143 | def forward(self, x): 144 | if self.data_format == "channels_last": 145 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 146 | elif self.data_format == "channels_first": 147 | u = x.mean(1, keepdim=True) 148 | s = (x - u).pow(2).mean(1, keepdim=True) 149 | x = (x - u) / torch.sqrt(s + self.eps) 150 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 151 | return x 152 | 153 | 154 | model_urls = { 155 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 156 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 157 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 158 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 159 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 160 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 161 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 162 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 163 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 164 | } 165 | 166 | 167 | @register_model 168 | def convnext_tiny(pretrained=False, in_22k=False, **kwargs): 169 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 170 | if pretrained: 171 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 172 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 173 | model.load_state_dict(checkpoint["model"]) 174 | return model 175 | 176 | 177 | @register_model 178 | def convnext_small(pretrained=False, in_22k=False, **kwargs): 179 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 180 | if pretrained: 181 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 182 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 183 | model.load_state_dict(checkpoint["model"]) 184 | return model 185 | 186 | 187 | @register_model 188 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 189 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 190 | if pretrained: 191 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 192 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 193 | model.load_state_dict(checkpoint["model"]) 194 | return model 195 | 196 | 197 | @register_model 198 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 199 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 200 | if pretrained: 201 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 202 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 203 | model.load_state_dict(checkpoint["model"]) 204 | return model 205 | 206 | 207 | @register_model 208 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 209 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 210 | if pretrained: 211 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 212 | url = model_urls['convnext_xlarge_22k'] 213 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 214 | model.load_state_dict(checkpoint["model"]) 215 | return model -------------------------------------------------------------------------------- /meter/modules/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import functools 8 | import logging 9 | import numpy as np 10 | import pickle 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import torch 15 | 16 | _LOCAL_PROCESS_GROUP = None 17 | """ 18 | A torch process group which only includes processes that on the same machine as the current process. 19 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 20 | """ 21 | 22 | 23 | def get_world_size() -> int: 24 | if not dist.is_available(): 25 | return 1 26 | if not dist.is_initialized(): 27 | return 1 28 | return dist.get_world_size() 29 | 30 | 31 | def get_rank() -> int: 32 | if not dist.is_available(): 33 | return 0 34 | if not dist.is_initialized(): 35 | return 0 36 | return dist.get_rank() 37 | 38 | 39 | def get_local_rank() -> int: 40 | """ 41 | Returns: 42 | The rank of the current process within the local (per-machine) process group. 43 | """ 44 | if not dist.is_available(): 45 | return 0 46 | if not dist.is_initialized(): 47 | return 0 48 | assert _LOCAL_PROCESS_GROUP is not None 49 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 50 | 51 | 52 | def get_local_size() -> int: 53 | """ 54 | Returns: 55 | The size of the per-machine process group, 56 | i.e. the number of processes per machine. 57 | """ 58 | if not dist.is_available(): 59 | return 1 60 | if not dist.is_initialized(): 61 | return 1 62 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 63 | 64 | 65 | def is_main_process() -> bool: 66 | return get_rank() == 0 67 | 68 | 69 | def synchronize(): 70 | """ 71 | Helper function to synchronize (barrier) among all processes when 72 | using distributed training 73 | """ 74 | if not dist.is_available(): 75 | return 76 | if not dist.is_initialized(): 77 | return 78 | world_size = dist.get_world_size() 79 | if world_size == 1: 80 | return 81 | dist.barrier() 82 | 83 | 84 | @functools.lru_cache() 85 | def _get_global_gloo_group(): 86 | """ 87 | Return a process group based on gloo backend, containing all the ranks 88 | The result is cached. 89 | """ 90 | if dist.get_backend() == "nccl": 91 | return dist.new_group(backend="gloo") 92 | else: 93 | return dist.group.WORLD 94 | 95 | 96 | def _serialize_to_tensor(data, group): 97 | backend = dist.get_backend(group) 98 | assert backend in ["gloo", "nccl"] 99 | device = torch.device("cpu" if backend == "gloo" else "cuda") 100 | 101 | buffer = pickle.dumps(data) 102 | if len(buffer) > 1024 ** 3: 103 | logger = logging.getLogger(__name__) 104 | logger.warning( 105 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 106 | get_rank(), len(buffer) / (1024 ** 3), device 107 | ) 108 | ) 109 | storage = torch.ByteStorage.from_buffer(buffer) 110 | tensor = torch.ByteTensor(storage).to(device=device) 111 | return tensor 112 | 113 | 114 | def _pad_to_largest_tensor(tensor, group): 115 | """ 116 | Returns: 117 | list[int]: size of the tensor, on each rank 118 | Tensor: padded tensor that has the max size 119 | """ 120 | world_size = dist.get_world_size(group=group) 121 | assert ( 122 | world_size >= 1 123 | ), "comm.gather/all_gather must be called from ranks within the given group!" 124 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 125 | size_list = [ 126 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 127 | for _ in range(world_size) 128 | ] 129 | dist.all_gather(size_list, local_size, group=group) 130 | size_list = [int(size.item()) for size in size_list] 131 | 132 | max_size = max(size_list) 133 | 134 | # we pad the tensor because torch all_gather does not support 135 | # gathering tensors of different shapes 136 | if local_size != max_size: 137 | padding = torch.zeros( 138 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 139 | ) 140 | tensor = torch.cat((tensor, padding), dim=0) 141 | return size_list, tensor 142 | 143 | 144 | def all_gather(data, group=None): 145 | """ 146 | Run all_gather on arbitrary picklable data (not necessarily tensors). 147 | 148 | Args: 149 | data: any picklable object 150 | group: a torch process group. By default, will use a group which 151 | contains all ranks on gloo backend. 152 | 153 | Returns: 154 | list[data]: list of data gathered from each rank 155 | """ 156 | if get_world_size() == 1: 157 | return [data] 158 | if group is None: 159 | group = _get_global_gloo_group() 160 | if dist.get_world_size(group) == 1: 161 | return [data] 162 | 163 | tensor = _serialize_to_tensor(data, group) 164 | 165 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 166 | max_size = max(size_list) 167 | 168 | # receiving Tensor from all ranks 169 | tensor_list = [ 170 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 171 | for _ in size_list 172 | ] 173 | dist.all_gather(tensor_list, tensor, group=group) 174 | 175 | data_list = [] 176 | for size, tensor in zip(size_list, tensor_list): 177 | buffer = tensor.cpu().numpy().tobytes()[:size] 178 | data_list.append(pickle.loads(buffer)) 179 | 180 | return data_list 181 | 182 | 183 | def gather(data, dst=0, group=None): 184 | """ 185 | Run gather on arbitrary picklable data (not necessarily tensors). 186 | 187 | Args: 188 | data: any picklable object 189 | dst (int): destination rank 190 | group: a torch process group. By default, will use a group which 191 | contains all ranks on gloo backend. 192 | 193 | Returns: 194 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 195 | an empty list. 196 | """ 197 | if get_world_size() == 1: 198 | return [data] 199 | if group is None: 200 | group = _get_global_gloo_group() 201 | if dist.get_world_size(group=group) == 1: 202 | return [data] 203 | rank = dist.get_rank(group=group) 204 | 205 | tensor = _serialize_to_tensor(data, group) 206 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 207 | 208 | # receiving Tensor from all ranks 209 | if rank == dst: 210 | max_size = max(size_list) 211 | tensor_list = [ 212 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 213 | for _ in size_list 214 | ] 215 | dist.gather(tensor, tensor_list, dst=dst, group=group) 216 | 217 | data_list = [] 218 | for size, tensor in zip(size_list, tensor_list): 219 | buffer = tensor.cpu().numpy().tobytes()[:size] 220 | data_list.append(pickle.loads(buffer)) 221 | return data_list 222 | else: 223 | dist.gather(tensor, [], dst=dst, group=group) 224 | return [] 225 | 226 | 227 | def shared_random_seed(): 228 | """ 229 | Returns: 230 | int: a random number that is the same across all workers. 231 | If workers need a shared RNG, they can use this shared seed to 232 | create one. 233 | 234 | All workers must call this function, otherwise it will deadlock. 235 | """ 236 | ints = np.random.randint(2 ** 31) 237 | all_ints = all_gather(ints) 238 | return all_ints[0] 239 | 240 | 241 | def reduce_dict(input_dict, average=True): 242 | """ 243 | Reduce the values in the dictionary from all processes so that process with rank 244 | 0 has the reduced results. 245 | 246 | Args: 247 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 248 | average (bool): whether to do average or sum 249 | 250 | Returns: 251 | a dict with the same keys as input_dict, after reduction. 252 | """ 253 | world_size = get_world_size() 254 | if world_size < 2: 255 | return input_dict 256 | with torch.no_grad(): 257 | names = [] 258 | values = [] 259 | # sort the keys so that they are consistent across processes 260 | for k in sorted(input_dict.keys()): 261 | names.append(k) 262 | values.append(input_dict[k]) 263 | values = torch.stack(values, dim=0) 264 | dist.reduce(values, dst=0) 265 | if dist.get_rank() == 0 and average: 266 | # only main process gets accumulated, so only divide by 267 | # world_size in this case 268 | values /= world_size 269 | reduced_dict = {k: v for k, v in zip(names, values)} 270 | return reduced_dict 271 | -------------------------------------------------------------------------------- /meter/modules/ensemble_models.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Evaluate two model ensemble 3 | # Usage: 4 | # ```python 5 | # from vocab import Vocabulary 6 | # import evaluation_ensemble 7 | # evaluation_ensemble.evalrank("$RUN_PATH/coco_scan/t-i_AVG/model_best.pth.tar", "$RUN_PATH/coco_scan/i-t_LSE/model_best.pth.tar", data_path="$DATA_PATH", split="testall",fold5=True) 8 | # ``` 9 | # --------------------------------------------------------------- 10 | """Evaluation Ensemble""" 11 | 12 | from __future__ import print_function 13 | import os 14 | 15 | import sys 16 | from data import get_test_loader 17 | import time 18 | import numpy as np 19 | from vocab import Vocabulary, deserialize_vocab # NOQA 20 | import torch 21 | from model import SCAN, xattn_score_t2i, xattn_score_i2t 22 | from collections import OrderedDict 23 | import time 24 | from torch.autograd import Variable 25 | 26 | class AverageMeter(object): 27 | """Computes and stores the average and current value""" 28 | 29 | def __init__(self): 30 | self.reset() 31 | 32 | def reset(self): 33 | self.val = 0 34 | self.avg = 0 35 | self.sum = 0 36 | self.count = 0 37 | 38 | def update(self, val, n=0): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = self.sum / (.0001 + self.count) 43 | 44 | def __str__(self): 45 | """String representation for logging 46 | """ 47 | # for values that should be recorded exactly e.g. iteration number 48 | if self.count == 0: 49 | return str(self.val) 50 | # for stats 51 | return '%.4f (%.4f)' % (self.val, self.avg) 52 | 53 | 54 | class LogCollector(object): 55 | """A collection of logging objects that can change from train to val""" 56 | 57 | def __init__(self): 58 | # to keep the order of logged variables deterministic 59 | self.meters = OrderedDict() 60 | 61 | def update(self, k, v, n=0): 62 | # create a new meter if previously not recorded 63 | if k not in self.meters: 64 | self.meters[k] = AverageMeter() 65 | self.meters[k].update(v, n) 66 | 67 | def __str__(self): 68 | """Concatenate the meters in one log line 69 | """ 70 | s = '' 71 | for i, (k, v) in enumerate(self.meters.iteritems()): 72 | if i > 0: 73 | s += ' ' 74 | s += k + ' ' + str(v) 75 | return s 76 | 77 | def tb_log(self, tb_logger, prefix='', step=None): 78 | """Log using tensorboard 79 | """ 80 | for k, v in self.meters.iteritems(): 81 | tb_logger.log_value(prefix + k, v.val, step=step) 82 | 83 | 84 | def encode_data(model, data_loader, log_step=10, logging=print): 85 | """Encode all images and captions loadable by `data_loader` 86 | """ 87 | batch_time = AverageMeter() 88 | val_logger = LogCollector() 89 | 90 | # switch to evaluate mode 91 | model.val_start() 92 | 93 | end = time.time() 94 | 95 | # np array to keep all the embeddings 96 | img_embs = None 97 | cap_embs = None 98 | cap_lens = None 99 | 100 | max_n_word = 0 101 | for i, (images, captions, lengths, ids) in enumerate(data_loader): 102 | max_n_word = max(max_n_word, max(lengths)) 103 | 104 | for i, (images, captions, lengths, ids) in enumerate(data_loader): 105 | # make sure val logger is used 106 | model.logger = val_logger 107 | 108 | # compute the embeddings 109 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths, volatile=True) 110 | #print(img_emb) 111 | if img_embs is None: 112 | if img_emb.dim() == 3: 113 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2))) 114 | else: 115 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1))) 116 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2))) 117 | cap_lens = [0] * len(data_loader.dataset) 118 | # cache embeddings 119 | img_embs[ids] = img_emb.data.cpu().numpy().copy() 120 | cap_embs[ids,:max(lengths),:] = cap_emb.data.cpu().numpy().copy() 121 | for j, nid in enumerate(ids): 122 | cap_lens[nid] = cap_len[j] 123 | 124 | # measure accuracy and record loss 125 | model.forward_loss(img_emb, cap_emb, cap_len) 126 | 127 | # measure elapsed time 128 | batch_time.update(time.time() - end) 129 | end = time.time() 130 | 131 | if i % log_step == 0: 132 | logging('Test: [{0}/{1}]\t' 133 | '{e_log}\t' 134 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 135 | .format( 136 | i, len(data_loader), batch_time=batch_time, 137 | e_log=str(model.logger))) 138 | del images, captions 139 | return img_embs, cap_embs, cap_lens 140 | 141 | 142 | def evalrank(model_path, model_path2, data_path=None, split='dev', fold5=False): 143 | """ 144 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold 145 | cross-validation is done (only for MSCOCO). Otherwise, the full data is 146 | used for evaluation. 147 | """ 148 | # load model and options 149 | checkpoint = torch.load(model_path) 150 | opt = checkpoint['opt'] 151 | 152 | checkpoint2 = torch.load(model_path2) 153 | opt2 = checkpoint2['opt'] 154 | 155 | print(opt) 156 | if data_path is not None: 157 | opt.data_path = data_path 158 | 159 | # load vocabulary used by the model 160 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) 161 | opt.vocab_size = len(vocab) 162 | 163 | # construct model 164 | model = SCAN(opt) 165 | model2 = SCAN(opt2) 166 | 167 | # load model state 168 | model.load_state_dict(checkpoint['model']) 169 | 170 | model2.load_state_dict(checkpoint['model']) 171 | 172 | print('Loading dataset') 173 | data_loader = get_test_loader(split, opt.data_name, vocab, 174 | opt.batch_size, opt.workers, opt) 175 | 176 | 177 | start_total=time.time() 178 | print('Computing results...') 179 | img_embs, cap_embs, cap_lens = encode_data(model, data_loader) 180 | img_embs2, cap_embs2, cap_lens2 = encode_data(model2, data_loader) 181 | 182 | print('Images: %d, Captions: %d' % 183 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 184 | 185 | 186 | if not fold5: 187 | # no cross-validation, full evaluation 188 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) 189 | img_embs2 = np.array([img_embs2[i] for i in range(0, len(img_embs2), 5)]) 190 | 191 | start = time.time() 192 | if opt.cross_attn == 't2i': 193 | sims = shard_xattn_t2i(img_embs, cap_embs, cap_lens, opt, shard_size=128) 194 | sims2 = shard_xattn_t2i(img_embs2, cap_embs2, cap_lens2, opt2, shard_size=128) 195 | elif opt.cross_attn == 'i2t': 196 | sims = shard_xattn_i2t(img_embs, cap_embs, cap_lens, opt, shard_size=128) 197 | sims2 = shard_xattn_i2t(img_embs2, cap_embs2, cap_lens2, opt2, shard_size=128) 198 | else: 199 | raise NotImplementedError 200 | end = time.time() 201 | print("calculate similarity time:", end-start) 202 | 203 | sims = (sims + sims2)/2 204 | 205 | r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True) 206 | ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True) 207 | ar = (r[0] + r[1] + r[2]) / 3 208 | ari = (ri[0] + ri[1] + ri[2]) / 3 209 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 210 | print("rsum: %.1f" % rsum) 211 | print("Average i2t Recall: %.1f" % ar) 212 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) 213 | print("Average t2i Recall: %.1f" % ari) 214 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) 215 | else: 216 | # 5fold cross-validation, only for MSCOCO 217 | results = [] 218 | for i in range(5): 219 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 220 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 221 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] 222 | 223 | img_embs_shard2 = img_embs2[i * 5000:(i + 1) * 5000:5] 224 | cap_embs_shard2 = cap_embs2[i * 5000:(i + 1) * 5000] 225 | cap_lens_shard2 = cap_lens2[i * 5000:(i + 1) * 5000] 226 | 227 | start = time.time() 228 | if opt.cross_attn == 't2i': 229 | sims = shard_xattn_t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) 230 | sims2 = shard_xattn_t2i(img_embs_shard2, cap_embs_shard2, cap_lens_shard2, opt2, shard_size=128) 231 | elif opt.cross_attn == 'i2t': 232 | sims = shard_xattn_i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) 233 | sims2 = shard_xattn_i2t(img_embs_shard2, cap_embs_shard2, cap_lens_shard2, opt2, shard_size=128) 234 | else: 235 | raise NotImplementedError 236 | end = time.time() 237 | print("calculate similarity time:", end-start) 238 | 239 | sims = (sims + sims2)/2 240 | 241 | 242 | r, rt0 = i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) 243 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 244 | ri, rti0 = t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) 245 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 246 | 247 | if i == 0: 248 | rt, rti = rt0, rti0 249 | ar = (r[0] + r[1] + r[2]) / 3 250 | ari = (ri[0] + ri[1] + ri[2]) / 3 251 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 252 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 253 | results += [list(r) + list(ri) + [ar, ari, rsum]] 254 | 255 | print("-----------------------------------") 256 | print("Mean metrics: ") 257 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 258 | print("rsum: %.1f" % (mean_metrics[10] * 6)) 259 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 260 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 261 | mean_metrics[:5]) 262 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 263 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 264 | mean_metrics[5:10]) 265 | 266 | torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar') 267 | 268 | end_total=time.time() 269 | print('test time (S): ' + str(end_total-start_total)) 270 | 271 | def softmax(X, axis): 272 | """ 273 | Compute the softmax of each element along an axis of X. 274 | """ 275 | y = np.atleast_2d(X) 276 | # subtract the max for numerical stability 277 | y = y - np.expand_dims(np.max(y, axis = axis), axis) 278 | # exponentiate y 279 | y = np.exp(y) 280 | # take the sum along the specified axis 281 | ax_sum = np.expand_dims(np.sum(y, axis = axis), axis) 282 | # finally: divide elementwise 283 | p = y / ax_sum 284 | return p 285 | 286 | 287 | def shard_xattn_t2i(images, captions, caplens, opt, shard_size=128): 288 | """ 289 | Computer pairwise t2i image-caption distance with locality sharding 290 | """ 291 | n_im_shard = (len(images)-1)/shard_size + 1 292 | n_cap_shard = (len(captions)-1)/shard_size + 1 293 | 294 | d = np.zeros((len(images), len(captions))) 295 | for i in range(n_im_shard): 296 | im_start, im_end = shard_size*i, min(shard_size*(i+1), len(images)) 297 | for j in range(n_cap_shard): 298 | sys.stdout.write('\r>> shard_xattn_t2i batch (%d,%d)' % (i,j)) 299 | cap_start, cap_end = shard_size*j, min(shard_size*(j+1), len(captions)) 300 | im = Variable(torch.from_numpy(images[im_start:im_end]), volatile=True).cuda() 301 | s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda() 302 | l = caplens[cap_start:cap_end] 303 | sim = xattn_score_t2i(im, s, l, opt) 304 | d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy() 305 | sys.stdout.write('\n') 306 | return d 307 | 308 | 309 | def shard_xattn_i2t(images, captions, caplens, opt, shard_size=128): 310 | """ 311 | Computer pairwise i2t image-caption distance with locality sharding 312 | """ 313 | n_im_shard = (len(images)-1)/shard_size + 1 314 | n_cap_shard = (len(captions)-1)/shard_size + 1 315 | 316 | d = np.zeros((len(images), len(captions))) 317 | for i in range(n_im_shard): 318 | im_start, im_end = shard_size*i, min(shard_size*(i+1), len(images)) 319 | for j in range(n_cap_shard): 320 | sys.stdout.write('\r>> shard_xattn_i2t batch (%d,%d)' % (i,j)) 321 | cap_start, cap_end = shard_size*j, min(shard_size*(j+1), len(captions)) 322 | im = Variable(torch.from_numpy(images[im_start:im_end]), volatile=True).cuda() 323 | s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda() 324 | l = caplens[cap_start:cap_end] 325 | sim = xattn_score_i2t(im, s, l, opt) 326 | d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy() 327 | sys.stdout.write('\n') 328 | return d 329 | 330 | 331 | def i2t(images, captions, caplens, sims, npts=None, return_ranks=False): 332 | """ 333 | Images->Text (Image Annotation) 334 | Images: (N, n_region, d) matrix of images 335 | Captions: (5N, max_n_word, d) matrix of captions 336 | CapLens: (5N) array of caption lengths 337 | sims: (N, 5N) matrix of similarity im-cap 338 | """ 339 | npts = images.shape[0] 340 | ranks = np.zeros(npts) 341 | top1 = np.zeros(npts) 342 | for index in range(npts): 343 | inds = np.argsort(sims[index])[::-1] 344 | # Score 345 | rank = 1e20 346 | for i in range(5 * index, 5 * index + 5, 1): 347 | tmp = np.where(inds == i)[0][0] 348 | if tmp < rank: 349 | rank = tmp 350 | ranks[index] = rank 351 | top1[index] = inds[0] 352 | 353 | # Compute metrics 354 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 355 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 356 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 357 | medr = np.floor(np.median(ranks)) + 1 358 | meanr = ranks.mean() + 1 359 | if return_ranks: 360 | return (r1, r5, r10, medr, meanr), (ranks, top1) 361 | else: 362 | return (r1, r5, r10, medr, meanr) 363 | 364 | 365 | def t2i(images, captions, caplens, sims, npts=None, return_ranks=False): 366 | """ 367 | Text->Images (Image Search) 368 | Images: (N, n_region, d) matrix of images 369 | Captions: (5N, max_n_word, d) matrix of captions 370 | CapLens: (5N) array of caption lengths 371 | sims: (N, 5N) matrix of similarity im-cap 372 | """ 373 | npts = images.shape[0] 374 | ranks = np.zeros(5 * npts) 375 | top1 = np.zeros(5 * npts) 376 | 377 | # --> (5N(caption), N(image)) 378 | sims = sims.T 379 | 380 | for index in range(npts): 381 | for i in range(5): 382 | inds = np.argsort(sims[5 * index + i])[::-1] 383 | ranks[5 * index + i] = np.where(inds == index)[0][0] 384 | top1[5 * index + i] = inds[0] 385 | 386 | # Compute metrics 387 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 388 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 389 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 390 | medr = np.floor(np.median(ranks)) + 1 391 | meanr = ranks.mean() + 1 392 | if return_ranks: 393 | return (r1, r5, r10, medr, meanr), (ranks, top1) 394 | else: 395 | return (r1, r5, r10, medr, meanr) -------------------------------------------------------------------------------- /meter/modules/eval_gl.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import numpy as np 3 | 4 | 5 | def i2t_gl(full_img_emb_aggrs, full_cap_emb_aggrs, img_embs, cap_embs, img_emb_fusions, cab_emb_fusions, img_lenghts, 6 | cap_lenghts, npts=None, return_ranks=True, ndcg_scorer=None, fold_index=0, measure='dot', sim_function=None, 7 | sim_function_new=None, cap_batches=1, pl_module=None, fold5=-1, topk=None, weight=None): 8 | # global_sims = np.matmul(full_img_emb_aggrs, full_cap_emb_aggrs.T) # 9 | 10 | if npts is None: 11 | npts = img_embs.shape[0] // 5 12 | 13 | index_list = [] 14 | ranks = numpy.zeros(npts) 15 | top1 = numpy.zeros(npts) 16 | 17 | full_img_emb_aggrs = np.array([full_img_emb_aggrs[i] for i in range(0, len(full_img_emb_aggrs), 5)]) 18 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) 19 | img_emb_fusions = np.array([img_emb_fusions[i] for i in range(0, len(img_emb_fusions), 5)]) 20 | 21 | global_sims = np.matmul(full_img_emb_aggrs, full_cap_emb_aggrs.T) # (5N, 5N) 22 | sims = sim_function_new(img_embs, cap_embs, cap_lenghts, pl_module) 23 | final_sims = global_sims * weight + sims * (1 - weight) # (N, 5N) 24 | 25 | for index in range(npts): 26 | 27 | d_r = final_sims[index] 28 | g_inds = numpy.argsort(d_r)[::-1] 29 | top_g_inds = list(g_inds[0:topk]) 30 | 31 | im = img_embs[index].reshape(1, img_embs.shape[1], img_embs.shape[2]) # (1, N, dim) 32 | im_f = img_emb_fusions[index].reshape(1, img_emb_fusions.shape[1], img_emb_fusions.shape[2]) # (1, N, dim) 33 | 34 | cap_now = cap_embs[top_g_inds] 35 | cap_f_now = cab_emb_fusions[top_g_inds] 36 | cap_lenghts_now = list(cap_lenghts[top_g_inds]) # (150, ) 37 | 38 | # d: (1, 150) 39 | # d_l_1 = final_sims[5 * index][top_g_inds] 40 | d_l_1 = d_r[top_g_inds] 41 | d_l_2 = sim_function(im_f, cap_now, cap_lenghts_now, 9) 42 | d_l_3 = sim_function(im, cap_f_now, cap_lenghts_now, 9) 43 | d_f = (d_l_1 + d_l_2 + d_l_3).flatten() 44 | 45 | l_inds = numpy.argsort(d_f)[::-1] 46 | inds = g_inds[l_inds] 47 | index_list.append(inds[0]) 48 | 49 | # Score 50 | rank = 1e20 51 | tmp_inss = [] 52 | for i in range(5 * index, 5 * index + 5, 1): 53 | tmp_ins = list(numpy.where(inds == i)) 54 | if len(tmp_ins[0]) <= 0: 55 | continue 56 | tmp_inss.append(tmp_ins[0][0]) 57 | if len(tmp_inss) <= 0: 58 | tmp_inss.append(0) 59 | tmp = min(tmp_inss) 60 | if tmp < rank: 61 | rank = tmp 62 | ranks[index] = rank 63 | top1[index] = inds[0] 64 | 65 | # Compute metrics 66 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 67 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 68 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 69 | medr = numpy.floor(numpy.median(ranks)) + 1 70 | meanr = ranks.mean() + 1 71 | if return_ranks: 72 | return (r1, r5, r10, medr, meanr, 0., 0.), (ranks, top1), sims 73 | else: 74 | return (r1, r5, r10, medr, meanr, 0., 0.), sims 75 | 76 | 77 | def t2i_gl(full_img_emb_aggrs, full_cap_emb_aggrs, img_embs, cap_embs, img_emb_fusions, cab_emb_fusions, img_lenghts, 78 | cap_lenghts, npts=None, return_ranks=True, ndcg_scorer=None, fold_index=0, measure='dot', sim_function=None, 79 | sim_function_new=None, cap_batches=1, pl_module=None, sims=None, topk=None, weight=None): 80 | # global_sims = np.matmul(full_img_emb_aggrs, full_cap_emb_aggrs.T) # 81 | 82 | if npts is None: 83 | npts = img_embs.shape[0] // 5 84 | 85 | index_list = [] 86 | ims = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) 87 | f_ims = np.array([img_emb_fusions[i] for i in range(0, len(img_emb_fusions), 5)]) 88 | full_img_emb_aggrs = np.array([full_img_emb_aggrs[i] for i in range(0, len(full_img_emb_aggrs), 5)]) 89 | 90 | ranks = numpy.zeros(5 * npts) 91 | top50 = numpy.zeros((5 * npts, 5)) 92 | 93 | global_sims = np.matmul(full_img_emb_aggrs, full_cap_emb_aggrs.T).T # (N, 5N) -> (5N, N) 94 | # sims = sim_function_new(ims, cap_embs, cap_lenghts, pl_module).T 95 | # sims = (np.array([sims[i] for i in range(0, len(sims), 5)])).T 96 | sims = (sims).T 97 | 98 | final_sims = sims * (1 - weight) + global_sims * weight 99 | 100 | for index in range(npts): 101 | 102 | # Get query captions 103 | queries_1 = cap_embs[5 * index:5 * index + 5] 104 | queries_2 = cab_emb_fusions[5 * index:5 * index + 5] 105 | queries_len = cap_lenghts[5 * index:5 * index + 5] 106 | 107 | d_r = final_sims[5 * index: 5 * index + 5] 108 | inds = numpy.zeros((len(d_r), topk)) # (5, 150) 109 | 110 | for i in range(len(d_r)): 111 | di_g = d_r[i] 112 | g_inds = numpy.argsort(di_g)[::-1] 113 | cap_inds = list(g_inds[:topk]) 114 | 115 | quer = queries_1[i].reshape(1, queries_1[i].shape[0], queries_1[i].shape[1]) # (1, N, dim) 116 | f_quer = queries_2[i].reshape(1, queries_2[i].shape[0], queries_2[i].shape[1]) # (1, N, dim) 117 | 118 | quer_len = [queries_len[i]] 119 | 120 | ims_now = ims[cap_inds] 121 | f_ims_now = f_ims[cap_inds] 122 | 123 | # (1, 150) 124 | d_l_1 = d_r[i][cap_inds] 125 | # d_l_1 = sim_function_new(ims_now, quer, quer_len, pl_module).T 126 | d_l_2 = sim_function(f_ims_now, quer, quer_len, 9).T 127 | d_l_3 = sim_function(ims_now, f_quer, quer_len, 9).T 128 | d_f = (d_l_1 + d_l_2 + d_l_3).flatten() 129 | 130 | l_inds = numpy.argsort(d_f)[::-1] 131 | inds[i] = g_inds[l_inds] 132 | r_r = numpy.where(inds[i] == index)[0] 133 | if len(r_r) <= 0: 134 | ranks[5 * index + i] = 0 135 | else: 136 | ranks[5 * index + i] = r_r[0] 137 | top50[5 * index + i] = inds[i][0:5] 138 | 139 | # Compute metrics 140 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 141 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 142 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 143 | medr = numpy.floor(numpy.median(ranks)) + 1 144 | meanr = ranks.mean() + 1 145 | 146 | if return_ranks: 147 | return (r1, r5, r10, medr, meanr, 0., 0.), (ranks, top50) 148 | else: 149 | return (r1, r5, r10, medr, meanr, 0., 0.) 150 | -------------------------------------------------------------------------------- /meter/modules/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from urllib.parse import urlparse 27 | except ImportError: 28 | from urlparse import urlparse 29 | 30 | try: 31 | from pathlib import Path 32 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 33 | Path.home() / '.pytorch_pretrained_bert')) 34 | except (AttributeError, ImportError): 35 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 36 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 37 | 38 | CONFIG_NAME = "config.json" 39 | WEIGHTS_NAME = "pytorch_model.bin" 40 | 41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def url_to_filename(url, etag=None): 45 | """ 46 | Convert `url` into a hashed filename in a repeatable way. 47 | If `etag` is specified, append its hash to the url's, delimited 48 | by a period. 49 | """ 50 | url_bytes = url.encode('utf-8') 51 | url_hash = sha256(url_bytes) 52 | filename = url_hash.hexdigest() 53 | 54 | if etag: 55 | etag_bytes = etag.encode('utf-8') 56 | etag_hash = sha256(etag_bytes) 57 | filename += '.' + etag_hash.hexdigest() 58 | 59 | return filename 60 | 61 | 62 | def filename_to_url(filename, cache_dir=None): 63 | """ 64 | Return the url and etag (which may be ``None``) stored for `filename`. 65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 66 | """ 67 | if cache_dir is None: 68 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 70 | cache_dir = str(cache_dir) 71 | 72 | cache_path = os.path.join(cache_dir, filename) 73 | if not os.path.exists(cache_path): 74 | raise EnvironmentError("file {} not found".format(cache_path)) 75 | 76 | meta_path = cache_path + '.json' 77 | if not os.path.exists(meta_path): 78 | raise EnvironmentError("file {} not found".format(meta_path)) 79 | 80 | with open(meta_path, encoding="utf-8") as meta_file: 81 | metadata = json.load(meta_file) 82 | url = metadata['url'] 83 | etag = metadata['etag'] 84 | 85 | return url, etag 86 | 87 | 88 | def cached_path(url_or_filename, cache_dir=None): 89 | """ 90 | Given something that might be a URL (or might be a local path), 91 | determine which. If it's a URL, download the file and cache it, and 92 | return the path to the cached file. If it's already a local path, 93 | make sure the file exists and then return the path. 94 | """ 95 | if cache_dir is None: 96 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 98 | url_or_filename = str(url_or_filename) 99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 100 | cache_dir = str(cache_dir) 101 | 102 | parsed = urlparse(url_or_filename) 103 | 104 | if parsed.scheme in ('http', 'https', 's3'): 105 | # URL, so get it from the cache (downloading if necessary) 106 | return get_from_cache(url_or_filename, cache_dir) 107 | elif os.path.exists(url_or_filename): 108 | # File, and it exists. 109 | return url_or_filename 110 | elif parsed.scheme == '': 111 | # File, but it doesn't exist. 112 | raise EnvironmentError("file {} not found".format(url_or_filename)) 113 | else: 114 | # Something unknown 115 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 116 | 117 | 118 | def split_s3_path(url): 119 | """Split a full s3 path into the bucket name and path.""" 120 | parsed = urlparse(url) 121 | if not parsed.netloc or not parsed.path: 122 | raise ValueError("bad s3 path {}".format(url)) 123 | bucket_name = parsed.netloc 124 | s3_path = parsed.path 125 | # Remove '/' at beginning of path. 126 | if s3_path.startswith("/"): 127 | s3_path = s3_path[1:] 128 | return bucket_name, s3_path 129 | 130 | 131 | def s3_request(func): 132 | """ 133 | Wrapper function for s3 requests in order to create more helpful error 134 | messages. 135 | """ 136 | 137 | @wraps(func) 138 | def wrapper(url, *args, **kwargs): 139 | try: 140 | return func(url, *args, **kwargs) 141 | except ClientError as exc: 142 | if int(exc.response["Error"]["Code"]) == 404: 143 | raise EnvironmentError("file {} not found".format(url)) 144 | else: 145 | raise 146 | 147 | return wrapper 148 | 149 | 150 | @s3_request 151 | def s3_etag(url): 152 | """Check ETag on S3 object.""" 153 | s3_resource = boto3.resource("s3") 154 | bucket_name, s3_path = split_s3_path(url) 155 | s3_object = s3_resource.Object(bucket_name, s3_path) 156 | return s3_object.e_tag 157 | 158 | 159 | @s3_request 160 | def s3_get(url, temp_file): 161 | """Pull a file directly from S3.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 165 | 166 | 167 | def http_get(url, temp_file): 168 | req = requests.get(url, stream=True) 169 | content_length = req.headers.get('Content-Length') 170 | total = int(content_length) if content_length is not None else None 171 | progress = tqdm(unit="B", total=total) 172 | for chunk in req.iter_content(chunk_size=1024): 173 | if chunk: # filter out keep-alive new chunks 174 | progress.update(len(chunk)) 175 | temp_file.write(chunk) 176 | progress.close() 177 | 178 | 179 | def get_from_cache(url, cache_dir=None): 180 | """ 181 | Given a URL, look for the corresponding dataset in the local cache. 182 | If it's not there, download it. Then return the path to the cached file. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | if not os.path.exists(cache_dir): 190 | os.makedirs(cache_dir) 191 | 192 | # Get eTag to add to filename, if it exists. 193 | if url.startswith("s3://"): 194 | etag = s3_etag(url) 195 | else: 196 | try: 197 | response = requests.head(url, allow_redirects=True) 198 | if response.status_code != 200: 199 | etag = None 200 | else: 201 | etag = response.headers.get("ETag") 202 | except EnvironmentError: 203 | etag = None 204 | 205 | if sys.version_info[0] == 2 and etag is not None: 206 | etag = etag.decode('utf-8') 207 | filename = url_to_filename(url, etag) 208 | 209 | # get cache path to put the file 210 | cache_path = os.path.join(cache_dir, filename) 211 | 212 | # If we don't have a connection (etag is None) and can't identify the file 213 | # try to get the last downloaded one 214 | if not os.path.exists(cache_path) and etag is None: 215 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 216 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 217 | if matching_files: 218 | cache_path = os.path.join(cache_dir, matching_files[-1]) 219 | 220 | if not os.path.exists(cache_path): 221 | # Download to temporary file, then copy to cache dir once finished. 222 | # Otherwise you get corrupt cache entries if the download gets interrupted. 223 | with tempfile.NamedTemporaryFile() as temp_file: 224 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 225 | 226 | # GET file object 227 | if url.startswith("s3://"): 228 | s3_get(url, temp_file) 229 | else: 230 | http_get(url, temp_file) 231 | 232 | # we are copying the file before closing it, so flush to avoid truncation 233 | temp_file.flush() 234 | # shutil.copyfileobj() starts at the current position, so go to the start 235 | temp_file.seek(0) 236 | 237 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 238 | with open(cache_path, 'wb') as cache_file: 239 | shutil.copyfileobj(temp_file, cache_file) 240 | 241 | logger.info("creating metadata file for %s", cache_path) 242 | meta = {'url': url, 'etag': etag} 243 | meta_path = cache_path + '.json' 244 | with open(meta_path, 'w') as meta_file: 245 | output_string = json.dumps(meta) 246 | if sys.version_info[0] == 2 and isinstance(output_string, str): 247 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 248 | meta_file.write(output_string) 249 | 250 | logger.info("removing temp file %s", temp_file.name) 251 | 252 | return cache_path 253 | 254 | 255 | def read_set_from_file(filename): 256 | ''' 257 | Extract a de-duped collection (set) of text from a file. 258 | Expected file format is one item per line. 259 | ''' 260 | collection = set() 261 | with open(filename, 'r', encoding='utf-8') as file_: 262 | for line in file_: 263 | collection.add(line.rstrip()) 264 | return collection 265 | 266 | 267 | def get_file_extension(path, dot=True, lower=True): 268 | ext = os.path.splitext(path)[1] 269 | ext = ext if dot else ext[1:] 270 | return ext.lower() if lower else ext 271 | -------------------------------------------------------------------------------- /meter/modules/meter_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from .bert import BertModel 5 | import pytorch_lightning as pl 6 | import torch.nn.functional as F 7 | from . import objectives, meter_utils 8 | import meter.modules.convnext as convnext 9 | from meter.modules.visual_encoder import Visual_Enconder 10 | from meter.modules.textual_encoder import Textual_Enconder 11 | 12 | 13 | def freeze_layers(model, bool): 14 | for child in model.children(): 15 | for param in child.parameters(): 16 | param.requires_grad = bool 17 | 18 | 19 | class METERTransformerSS(pl.LightningModule): 20 | def __init__(self, config): 21 | super().__init__() 22 | self.save_hyperparameters() 23 | 24 | self.eval_bool = False 25 | 26 | self.cross_modal_text_transform = nn.Linear(config['input_text_embed_size'], config['hidden_size']) 27 | self.cross_modal_text_transform.apply(objectives.init_weights) 28 | self.cross_modal_image_transform = nn.Linear(config['input_image_embed_size'], config['hidden_size']) 29 | self.cross_modal_image_transform.apply(objectives.init_weights) 30 | 31 | # 32 | self.cross_modal_text_transform_1 = nn.Linear(config['input_text_embed_size'], config['hidden_size']) 33 | self.cross_modal_text_transform_1.apply(objectives.init_weights) 34 | self.cross_modal_text_transform_2 = nn.Linear(config['input_text_embed_size'], config['hidden_size']) 35 | self.cross_modal_text_transform_2.apply(objectives.init_weights) 36 | 37 | self.cross_modal_image_transform_1 = nn.Linear(256, config['hidden_size']) 38 | self.cross_modal_image_transform_1.apply(objectives.init_weights) 39 | self.cross_modal_image_transform_2 = nn.Linear(512, config['hidden_size']) 40 | self.cross_modal_image_transform_2.apply(objectives.init_weights) 41 | 42 | # convnext 43 | self.convnexts = getattr(convnext, 'convnext_base')(pretrained=True, in_22k=True, num_classes=21841) 44 | 45 | self.fc1 = nn.Linear(256, 768) 46 | self.fc2 = nn.Linear(512, 768) 47 | self.fc3 = nn.Linear(1024, 768) 48 | 49 | act = config['activation'] 50 | # add Textual + Visual Encoder 51 | self.txt_enc = Textual_Enconder(textual_dim=768, factor=768, act=act) 52 | self.img_enc = Visual_Enconder(visual_dim=1024, factor=768, act=act) 53 | 54 | self.text_transformer = BertModel.from_pretrained(config['tokenizer']) 55 | 56 | freeze_layers(self.text_transformer.encoder, False) 57 | freeze_layers(self.text_transformer.embeddings, False) 58 | freeze_layers(self.text_transformer.pooler, False) 59 | freeze_layers(self.convnexts, False) 60 | 61 | def adjust_k(self): 62 | """ 63 | Update loss hyper-parameter k 64 | linearly from intial_k to 1 according to 65 | the number of epochs 66 | """ 67 | self.iteration += 1 68 | 69 | if self.max_violation: 70 | self.k = 1 71 | return 1. 72 | 73 | self.k = (1.-self.beta**np.float(self.iteration)) 74 | return self.k 75 | 76 | def infer( 77 | self, 78 | batch, 79 | mask_text=False, 80 | mask_image=False, 81 | image_token_type_idx=1, 82 | img=None, 83 | ): 84 | if img is None: 85 | if f"image_{image_token_type_idx - 1}" in batch: 86 | imgkey = f"image_{image_token_type_idx - 1}" 87 | else: 88 | imgkey = "image" 89 | img = batch[imgkey][0] 90 | 91 | do_mlm = "_mlm" if mask_text else "" 92 | text_ids = batch[f"text_ids{do_mlm}"] 93 | text_types = batch[f"text_types"] 94 | text_masks = batch[f"text_masks"] 95 | text_lengths = batch[f'lengths'] 96 | 97 | # pooled_output: (64, 768) 98 | all_encoder_layers, pooled_output = self.text_transformer(text_ids, token_type_ids=text_types, attention_mask=text_masks) 99 | 100 | text_embeds = all_encoder_layers[-1] 101 | text_embeds = self.cross_modal_text_transform(text_embeds) 102 | text_embeds_2 = all_encoder_layers[9] 103 | text_embeds_2 = self.cross_modal_text_transform_2(text_embeds_2) 104 | 105 | # v4: convnexts 106 | image_embeds_global, image_embeds_all = self.convnexts(img) # (B, 1024) 107 | image_embeds_2 = image_embeds_all[2] # (B, 512, 14, 14) 108 | image_embeds_2 = image_embeds_2.reshape(image_embeds_2.size(0), image_embeds_2.size(1), -1).permute(0, 2, 1) # (B, 196, 512) -> (B, 196, 768) 109 | image_embeds_2 = self.fc2(image_embeds_2) 110 | 111 | image_embeds = image_embeds_all[3] # (B, 1024, 7, 7) 112 | image_embeds = image_embeds.reshape(image_embeds.size(0), image_embeds.size(1), -1).permute(0, 2, 1) # (B, 49, 1024) -> (B, 49, 768) 113 | image_embeds = self.fc3(image_embeds) 114 | 115 | full_cap_emb_aggr = self.txt_enc(pooled_output) # (B, 768) - > (B, 768) 116 | full_img_emb_aggr = self.img_enc(image_embeds_global) # (64, 1024) - > (B, 768) 117 | 118 | ret = { 119 | "text_feats_h": text_embeds, 120 | "image_feats_h": image_embeds, 121 | "text_feats_m": text_embeds_2, 122 | 'image_feats_m': image_embeds_2, 123 | "text_ids": text_ids, 124 | "text_masks": text_masks, 125 | 'text_lengths': text_lengths, 126 | 'full_img_emb_aggr': full_img_emb_aggr, 127 | 'full_cap_emb_aggr': full_cap_emb_aggr, 128 | } 129 | return ret 130 | 131 | 132 | def collate_function_both(self, batch): 133 | cap_node_albef = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 134 | cap_node_dot = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 135 | list_cap_edge_index = [] 136 | list_cap_edge_attr = [] 137 | cap_cls_albef = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 138 | cap_cls_dot = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 139 | cap_cls_albef_ori = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 140 | cap_cls_dot_ori = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 141 | 142 | img_node_albef = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 143 | img_node_dot = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 144 | list_img_edge_index = [] 145 | list_img_edge_attr = [] 146 | img_cls_albef = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 147 | img_cls_dot = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 148 | img_cls_albef_ori = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 149 | img_cls_dot_ori = torch.tensor(()).to(batch[0]['img']['node_albef'].device) 150 | 151 | list_id = [] 152 | list_n_img_node = [] 153 | list_n_img_node_albef = [] 154 | list_n_img_node_dot = [] 155 | list_n_cap_node = [] 156 | list_n_cap_node_albef = [] 157 | list_n_cap_node_dot = [] 158 | for x in batch: 159 | img_cls_albef = torch.cat((img_cls_albef, x['img']['cls_albef']), dim=0) 160 | img_cls_dot = torch.cat((img_cls_dot, x['img']['cls_dot']), dim=0) 161 | img_node_albef = torch.cat((img_node_albef, x['img']['node_albef']), dim=0) 162 | img_node_dot = torch.cat((img_node_dot, x['img']['node_dot']), dim=0) 163 | list_img_edge_index.append(x['img']['edge_index']) 164 | list_img_edge_attr.append(x['img']['edge_attr']) 165 | list_n_img_node.append(x['img']['node_albef'].shape[0] + x['img']['node_dot'].shape[0]) 166 | list_n_img_node_albef.append(x['img']['node_albef'].shape[0]) 167 | list_n_img_node_dot.append(x['img']['node_dot'].shape[0]) 168 | img_cls_albef_ori = torch.cat((img_cls_albef_ori, x['img']['cls_albef_ori']), dim=0) 169 | img_cls_dot_ori = torch.cat((img_cls_dot_ori, x['img']['cls_dot_ori']), dim=0) 170 | 171 | cap_cls_albef = torch.cat((cap_cls_albef, x['cap']['cls_albef']), dim=0) 172 | cap_cls_dot = torch.cat((cap_cls_dot, x['cap']['cls_dot']), dim=0) 173 | cap_node_albef = torch.cat((cap_node_albef, x['cap']['node_albef']), dim=0) 174 | cap_node_dot = torch.cat((cap_node_dot, x['cap']['node_dot']), dim=0) 175 | list_cap_edge_index.append(x['cap']['edge_index']) 176 | list_cap_edge_attr.append(x['cap']['edge_attr']) 177 | list_n_cap_node.append(x['cap']['node_albef'].shape[0] + x['cap']['node_dot'].shape[0]) 178 | list_n_cap_node_albef.append(x['cap']['node_albef'].shape[0]) 179 | list_n_cap_node_dot.append(x['cap']['node_dot'].shape[0]) 180 | cap_cls_albef_ori = torch.cat((cap_cls_albef_ori, x['cap']['cls_albef_ori']), dim=0) 181 | cap_cls_dot_ori = torch.cat((cap_cls_dot_ori, x['cap']['cls_dot_ori']), dim=0) 182 | list_id.append(x['id']) 183 | 184 | bs = len(list_id) 185 | img_edge_attr = torch.cat(list_img_edge_attr).to(batch[0]['img']['node_albef'].device) 186 | cap_edge_attr = torch.cat(list_cap_edge_attr).to(batch[0]['img']['node_albef'].device) 187 | del list_img_edge_attr, list_cap_edge_attr 188 | img_batch_index = torch.tensor(np.repeat([x for x in range(bs)], list_n_img_node)).to(batch[0]['img']['node_albef'].device) 189 | cap_batch_index = torch.tensor(np.repeat([x for x in range(bs)], list_n_cap_node)).to(batch[0]['img']['node_albef'].device) 190 | count_img = 0 191 | count_cap = 0 192 | for idx in range(bs): 193 | list_img_edge_index[idx] = list_img_edge_index[idx] + count_img 194 | list_cap_edge_index[idx] = list_cap_edge_index[idx] + count_cap 195 | count_img += list_n_img_node[idx] 196 | count_cap += list_n_cap_node[idx] 197 | img_edge_index = torch.cat(list_img_edge_index, dim=1).to(batch[0]['img']['node_albef'].device) 198 | cap_edge_index = torch.cat(list_cap_edge_index, dim=1).to(batch[0]['img']['node_albef'].device) 199 | del list_img_edge_index, list_cap_edge_index 200 | n_img_node_albef = torch.tensor(list_n_img_node_albef).to(batch[0]['img']['node_albef'].device) 201 | n_img_node_dot = torch.tensor(list_n_img_node_dot).to(batch[0]['img']['node_albef'].device) 202 | n_cap_node_albef = torch.tensor(list_n_cap_node_albef).to(batch[0]['img']['node_albef'].device) 203 | n_cap_node_dot = torch.tensor(list_n_cap_node_dot).to(batch[0]['img']['node_albef'].device) 204 | del list_n_img_node_albef, list_n_img_node_dot, list_n_cap_node_albef, list_n_cap_node_dot 205 | img_dict = {'cls_albef': img_cls_albef, 'cls_dot': img_cls_dot, 'batch_index': img_batch_index, 206 | 'node_albef': img_node_albef, 'node_dot': img_node_dot, 207 | 'n_node_albef': n_img_node_albef, 'n_node_dot': n_img_node_dot, 208 | 'edge_index': img_edge_index, 'edge_attr': img_edge_attr, 209 | 'cls_albef_ori': img_cls_albef_ori, 'cls_dot_ori': img_cls_dot_ori} 210 | cap_dict = {'cls_albef': cap_cls_albef, 'cls_dot': cap_cls_dot, 'batch_index': cap_batch_index, 211 | 'node_albef': cap_node_albef, 'node_dot': cap_node_dot, 212 | 'n_node_albef': n_cap_node_albef, 'n_node_dot': n_cap_node_dot, 213 | 'edge_index': cap_edge_index, 'edge_attr': cap_edge_attr, 214 | 'cls_albef_ori': cap_cls_albef_ori, 'cls_dot_ori': cap_cls_dot_ori} 215 | list_id = torch.tensor([[int(x.split('_')[0]) for x in list_id]]).reshape(-1, 1).to(batch[0]['img']['node_albef'].device) 216 | return img_dict, cap_dict, list_id 217 | 218 | def create_index_from_2_list(self, list_1, list_2, dual_index=False, self_loop=False): 219 | first = np.repeat(list_1, len(list_2)) 220 | second = np.tile(list_2, len(list_1)) 221 | result = np.asarray([first, second]) 222 | if dual_index: 223 | first = np.repeat(list_2, len(list_1)) 224 | second = np.tile(list_1, len(list_2)) 225 | result = np.concatenate((result, np.asarray([first, second])), axis=1) 226 | if self_loop: 227 | list_all = list_1 + list_2 228 | result = np.concatenate((result, np.asarray([list_all, list_all])), axis=1) 229 | return result 230 | 231 | def forward(self, batch): 232 | ret = dict() 233 | if len(self.current_tasks) == 0: 234 | ret.update(self.infer(batch)) 235 | return ret 236 | 237 | # Masked Language Modeling 238 | if "mlm" in self.current_tasks: 239 | ret.update(objectives.compute_mlm(self, batch)) 240 | 241 | # Image Text Matching 242 | if "itm" in self.current_tasks: 243 | ret.update(objectives.compute_itm(self, batch)) 244 | 245 | # Visual Question Answering 246 | if "vqa" in self.current_tasks: 247 | ret.update(objectives.compute_vqa(self, batch)) 248 | 249 | # Natural Language for Visual Reasoning 2 250 | if "nlvr2" in self.current_tasks: 251 | ret.update(objectives.compute_nlvr2(self, batch)) 252 | 253 | # SNLI Visual Entailment 254 | if "snli" in self.current_tasks: 255 | ret.update(objectives.compute_snli(self, batch)) 256 | 257 | # Image Retrieval and Text Retrieval 258 | if "irtr" in self.current_tasks: 259 | ret.update(objectives.compute_irtr_my(self, batch)) 260 | 261 | return ret 262 | 263 | def training_step(self, batch, batch_idx): 264 | self.eval_bool = True 265 | 266 | meter_utils.set_task(self) 267 | output = self(batch) 268 | total_loss = sum([v for k, v in output.items() if "loss" in k]) 269 | 270 | return total_loss 271 | 272 | def training_epoch_end(self, outs): 273 | pass 274 | 275 | def validation_step(self, batch, batch_idx): 276 | pass 277 | '''meter_utils.set_task(self) 278 | output = self(batch)''' 279 | 280 | def validation_epoch_end(self, outs): 281 | if self.current_epoch!=0: 282 | meter_utils.epoch_eval_irtr(self) 283 | 284 | if self.current_epoch >= 10: 285 | freeze_layers(self.convnexts, True) 286 | freeze_layers(self.text_transformer.encoder, True) 287 | freeze_layers(self.text_transformer.embeddings, True) 288 | freeze_layers(self.text_transformer.pooler, True) 289 | 290 | def test_step(self, batch, batch_idx): 291 | pass 292 | 293 | def test_epoch_end(self, outs): 294 | #meter_utils.epoch_eval_irtr(self) 295 | meter_utils.epoch_eval_irtr(self, is_test=True) 296 | 297 | def configure_optimizers(self): 298 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.config['learning_rate']) 299 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 300 | 301 | return { 302 | "optimizer": optimizer, 303 | "lr_scheduler": { 304 | "scheduler": scheduler, 305 | "interval": "epoch", 306 | "frequency": 1, 307 | }, 308 | } 309 | -------------------------------------------------------------------------------- /meter/modules/meter_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | from transformers.optimization import AdamW 5 | from transformers import ( 6 | get_polynomial_decay_schedule_with_warmup, 7 | get_cosine_schedule_with_warmup, 8 | ) 9 | from .objectives import compute_irtr_recall, compute_irtr_val, compute_irtr_test_gl, compute_irtr_testv2 10 | from ..gadgets.my_metrics import Accuracy, VQAScore, Scalar 11 | 12 | 13 | def set_metrics(pl_module): 14 | for split in ["train", "val"]: 15 | for k, v in pl_module.hparams.config["loss_names"].items(): 16 | if v < 1: 17 | continue 18 | if k == "vqa": 19 | setattr(pl_module, f"{split}_vqa_score", VQAScore()) 20 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 21 | elif k == "nlvr2": 22 | if split == "train": 23 | setattr(pl_module, f"train_{k}_accuracy", Accuracy()) 24 | setattr(pl_module, f"train_{k}_loss", Scalar()) 25 | else: 26 | setattr(pl_module, f"dev_{k}_accuracy", Accuracy()) 27 | setattr(pl_module, f"dev_{k}_loss", Scalar()) 28 | setattr(pl_module, f"test_{k}_accuracy", Accuracy()) 29 | setattr(pl_module, f"test_{k}_loss", Scalar()) 30 | elif k == "snli": 31 | if split == "train": 32 | setattr(pl_module, f"train_{k}_accuracy", Accuracy()) 33 | setattr(pl_module, f"train_{k}_loss", Scalar()) 34 | else: 35 | setattr(pl_module, f"dev_{k}_accuracy", Accuracy()) 36 | setattr(pl_module, f"dev_{k}_loss", Scalar()) 37 | setattr(pl_module, f"test_{k}_accuracy", Accuracy()) 38 | setattr(pl_module, f"test_{k}_loss", Scalar()) 39 | elif k == "irtr": 40 | setattr(pl_module, f"{split}_irtr_loss", Scalar()) 41 | elif k == "mppd" or k == "mpfr": 42 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 43 | elif k == "itm": 44 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy()) 45 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 46 | else: 47 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy()) 48 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 49 | 50 | 51 | def epoch_wrapup(pl_module): 52 | phase = "train" if pl_module.training else "val" 53 | the_metric = 0 54 | 55 | if pl_module.hparams.config["get_recall_metric"] and not pl_module.training: 56 | (ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10) = compute_irtr_recall(pl_module) 57 | print((ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10), pl_module.global_step) 58 | pl_module.logger.experiment.add_scalar( 59 | "recalls/ir_r1", ir_r1, pl_module.global_step 60 | ) 61 | pl_module.logger.experiment.add_scalar( 62 | "recalls/ir_r5", ir_r5, pl_module.global_step 63 | ) 64 | pl_module.logger.experiment.add_scalar( 65 | "recalls/ir_r10", ir_r10, pl_module.global_step 66 | ) 67 | pl_module.logger.experiment.add_scalar( 68 | "recalls/tr_r1", tr_r1, pl_module.global_step 69 | ) 70 | pl_module.logger.experiment.add_scalar( 71 | "recalls/tr_r5", tr_r5, pl_module.global_step 72 | ) 73 | pl_module.logger.experiment.add_scalar( 74 | "recalls/tr_r10", tr_r10, pl_module.global_step 75 | ) 76 | the_metric += ir_r1.item() + tr_r1.item() 77 | 78 | for loss_name, v in pl_module.hparams.config["loss_names"].items(): 79 | if v < 1: 80 | continue 81 | 82 | value = 0 83 | 84 | if loss_name == "vqa": 85 | value = getattr(pl_module, f"{phase}_{loss_name}_score").compute() 86 | pl_module.log(f"{loss_name}/{phase}/score_epoch", value) 87 | getattr(pl_module, f"{phase}_{loss_name}_score").reset() 88 | pl_module.log( 89 | f"{loss_name}/{phase}/loss_epoch", 90 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 91 | ) 92 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 93 | elif loss_name == "nlvr2" or loss_name == 'snli': 94 | if phase == "train": 95 | value = getattr(pl_module, f"train_{loss_name}_accuracy").compute() 96 | pl_module.log(f"{loss_name}/train/accuracy_epoch", value) 97 | getattr(pl_module, f"train_{loss_name}_accuracy").reset() 98 | pl_module.log( 99 | f"{loss_name}/train/loss_epoch", 100 | getattr(pl_module, f"train_{loss_name}_loss").compute(), 101 | ) 102 | getattr(pl_module, f"train_{loss_name}_loss").reset() 103 | else: 104 | value = getattr(pl_module, f"test_{loss_name}_accuracy").compute() 105 | pl_module.log(f"{loss_name}/test/accuracy_epoch", value) 106 | getattr(pl_module, f"test_{loss_name}_accuracy").reset() 107 | pl_module.log( 108 | f"{loss_name}/test/loss_epoch", 109 | getattr(pl_module, f"test_{loss_name}_loss").compute(), 110 | ) 111 | getattr(pl_module, f"test_{loss_name}_loss").reset() 112 | 113 | value = getattr(pl_module, f"dev_{loss_name}_accuracy").compute() 114 | pl_module.log(f"{loss_name}/dev/accuracy_epoch", value) 115 | getattr(pl_module, f"dev_{loss_name}_accuracy").reset() 116 | pl_module.log( 117 | f"{loss_name}/dev/loss_epoch", 118 | getattr(pl_module, f"dev_{loss_name}_loss").compute(), 119 | ) 120 | getattr(pl_module, f"dev_{loss_name}_loss").reset() 121 | elif loss_name == "irtr": 122 | pl_module.log( 123 | f"{loss_name}/{phase}/irtr_loss_epoch", 124 | getattr(pl_module, f"{phase}_irtr_loss").compute(), 125 | ) 126 | getattr(pl_module, f"{phase}_irtr_loss").reset() 127 | elif loss_name == "mppd" or loss_name == "mpfr": 128 | pl_module.log( 129 | f"{loss_name}/{phase}/loss_epoch", 130 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 131 | ) 132 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 133 | elif loss_name == "itm": 134 | value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute() 135 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value) 136 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset() 137 | pl_module.log( 138 | f"{loss_name}/{phase}/loss_epoch", 139 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 140 | ) 141 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 142 | else: 143 | value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute() 144 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value) 145 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset() 146 | pl_module.log( 147 | f"{loss_name}/{phase}/loss_epoch", 148 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 149 | ) 150 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 151 | 152 | the_metric += value 153 | 154 | pl_module.log(f"{phase}/the_metric", the_metric) 155 | 156 | 157 | def epoch_eval_irtr(pl_module, is_test=False): 158 | phase = "train" if pl_module.training else "val" 159 | 160 | if (not pl_module.training) and (not is_test): 161 | (ir_r1, ir_r5, ir_r10, ir_r20, ir_r50, ir_r70, ir_r100, tr_r1, tr_r5, tr_r10, tr_r20, tr_r50, tr_r70, tr_r100) = compute_irtr_val(pl_module) 162 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f" % (ir_r1, ir_r5, ir_r10, ir_r20, ir_r50, ir_r70, ir_r100)) 163 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f" % (tr_r1, tr_r5, tr_r10, tr_r20, tr_r50, tr_r70, tr_r100)) 164 | pl_module.logger.experiment.add_scalar( 165 | "recalls/ir_r1", ir_r1, pl_module.global_step 166 | ) 167 | pl_module.logger.experiment.add_scalar( 168 | "recalls/ir_r5", ir_r5, pl_module.global_step 169 | ) 170 | pl_module.logger.experiment.add_scalar( 171 | "recalls/ir_r10", ir_r10, pl_module.global_step 172 | ) 173 | pl_module.logger.experiment.add_scalar( 174 | "recalls/tr_r1", tr_r1, pl_module.global_step 175 | ) 176 | pl_module.logger.experiment.add_scalar( 177 | "recalls/tr_r5", tr_r5, pl_module.global_step 178 | ) 179 | pl_module.logger.experiment.add_scalar( 180 | "recalls/tr_r10", tr_r10, pl_module.global_step 181 | ) 182 | else: 183 | compute_irtr_test_gl(pl_module, fold5=False) 184 | # compute_irtr_testv2(pl_module) 185 | 186 | 187 | 188 | def check_non_acc_grad(pl_module): 189 | if pl_module.token_type_embeddings.weight.grad is None: 190 | return True 191 | else: 192 | grad = pl_module.token_type_embeddings.weight.grad 193 | return (grad.sum() == 0).item() 194 | 195 | 196 | def set_task(pl_module): 197 | pl_module.current_tasks = [ 198 | k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1 199 | ] 200 | return 201 | 202 | def set_schedule(pl_module): 203 | lr = pl_module.hparams.config["learning_rate"] 204 | wd = pl_module.hparams.config["weight_decay"] 205 | 206 | no_decay = [ 207 | "bias", 208 | "LayerNorm.bias", 209 | "LayerNorm.weight", 210 | "norm.bias", 211 | "norm.weight", 212 | "norm1.bias", 213 | "norm1.weight", 214 | "norm2.bias", 215 | "norm2.weight", 216 | ] 217 | head_names = ["vqa_classifier", "nlvr2_classifier", "mlm_score", "itm_score", "snli_classifier"] 218 | cross_modal_names = ['cross_modal'] 219 | lr_mult_head = pl_module.hparams.config["lr_mult_head"] 220 | lr_mult_cross_modal = pl_module.hparams.config["lr_mult_cross_modal"] 221 | end_lr = pl_module.hparams.config["end_lr"] 222 | decay_power = pl_module.hparams.config["decay_power"] 223 | optim_type = pl_module.hparams.config["optim_type"] 224 | optimizer_grouped_parameters = [ 225 | { 226 | "params": [ 227 | p 228 | for n, p in pl_module.named_parameters() 229 | if not any(nd in n for nd in no_decay) 230 | and not any(bb in n for bb in head_names) 231 | and not any(ht in n for ht in cross_modal_names) 232 | ], 233 | "weight_decay": wd, 234 | "lr": lr, 235 | }, 236 | { 237 | "params": [ 238 | p 239 | for n, p in pl_module.named_parameters() 240 | if any(nd in n for nd in no_decay) 241 | and not any(bb in n for bb in head_names) 242 | and not any(ht in n for ht in cross_modal_names) 243 | ], 244 | "weight_decay": 0.0, 245 | "lr": lr, 246 | }, 247 | { 248 | "params": [ 249 | p 250 | for n, p in pl_module.named_parameters() 251 | if not any(nd in n for nd in no_decay) 252 | and any(bb in n for bb in head_names) 253 | and not any(ht in n for ht in cross_modal_names) 254 | ], 255 | "weight_decay": wd, 256 | "lr": lr * lr_mult_head, 257 | }, 258 | { 259 | "params": [ 260 | p 261 | for n, p in pl_module.named_parameters() 262 | if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names) 263 | and not any(ht in n for ht in cross_modal_names) 264 | ], 265 | "weight_decay": 0.0, 266 | "lr": lr * lr_mult_head, 267 | }, 268 | { 269 | "params": [ 270 | p 271 | for n, p in pl_module.named_parameters() 272 | if not any(nd in n for nd in no_decay) 273 | and not any(bb in n for bb in head_names) 274 | and any(ht in n for ht in cross_modal_names) 275 | ], 276 | "weight_decay": wd, 277 | "lr": lr * lr_mult_cross_modal, 278 | }, 279 | { 280 | "params": [ 281 | p 282 | for n, p in pl_module.named_parameters() 283 | if any(nd in n for nd in no_decay) 284 | and not any(bb in n for bb in head_names) 285 | and any(ht in n for ht in cross_modal_names) 286 | ], 287 | "weight_decay": 0.0, 288 | "lr": lr * lr_mult_cross_modal, 289 | }, 290 | ] 291 | 292 | if optim_type == "adamw": 293 | optimizer = AdamW( 294 | optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98) 295 | ) 296 | elif optim_type == "adam": 297 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr) 298 | elif optim_type == "sgd": 299 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9) 300 | 301 | if pl_module.trainer.max_steps is None: 302 | max_steps = ( 303 | len(pl_module.trainer.datamodule.train_dataloader()) 304 | * pl_module.trainer.max_epochs 305 | // pl_module.trainer.accumulate_grad_batches 306 | ) 307 | else: 308 | max_steps = pl_module.trainer.max_steps 309 | 310 | warmup_steps = pl_module.hparams.config["warmup_steps"] 311 | if isinstance(pl_module.hparams.config["warmup_steps"], float): 312 | warmup_steps = int(max_steps * warmup_steps) 313 | 314 | if decay_power == "cosine": 315 | scheduler = get_cosine_schedule_with_warmup( 316 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps, 317 | ) 318 | else: 319 | scheduler = get_polynomial_decay_schedule_with_warmup( 320 | optimizer, 321 | num_warmup_steps=warmup_steps, 322 | num_training_steps=max_steps, 323 | lr_end=end_lr, 324 | power=decay_power, 325 | ) 326 | 327 | sched = {"scheduler": scheduler, "interval": "step"} 328 | 329 | return ( 330 | [optimizer], 331 | [sched], 332 | ) 333 | -------------------------------------------------------------------------------- /meter/modules/textual_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | def l2norm(X, dim=1, eps=1e-8): 9 | """L2-normalize columns of X """ 10 | 11 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 12 | X = torch.div(X, norm) 13 | return X 14 | 15 | 16 | class Textual_Enconder(nn.Module): 17 | def __init__(self, textual_dim, factor, act='leaky_relu'): 18 | super(Textual_Enconder, self).__init__() 19 | self.v_fc1 = nn.Linear(textual_dim, textual_dim) 20 | self.v_drop = nn.Dropout(p=0.2, inplace=False) 21 | self.v_fc2 = nn.Linear(textual_dim, (textual_dim + factor) // 2) 22 | self.v_fc3 = nn.Linear((textual_dim + factor) // 2, factor) 23 | 24 | self.act = act 25 | 26 | def forward(self, x): 27 | if self.act == 'leaky_relu': 28 | x = F.leaky_relu(self.v_fc1(x)) 29 | x = self.v_drop(x) 30 | x = F.leaky_relu(self.v_fc2(x)) 31 | x = F.leaky_relu(self.v_fc3(x)) 32 | x = l2norm(x, dim=1) 33 | 34 | elif self.act == 'tanh': 35 | x = F.tanh(self.v_fc1(x)) 36 | x = self.v_drop(x) 37 | x = F.tanh(self.v_fc2(x)) 38 | x = F.tanh(self.v_fc3(x)) 39 | x = l2norm(x, dim=1) 40 | 41 | else: 42 | x = self.v_fc1(x) 43 | x = self.v_drop(x) 44 | x = self.v_fc2(x) 45 | x = self.v_fc3(x) 46 | x = l2norm(x, dim=1) 47 | 48 | return x 49 | -------------------------------------------------------------------------------- /meter/modules/visual_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Rs_GCN(nn.Module): 8 | 9 | def __init__(self, in_channels, inter_channels, bn_layer=True): 10 | super(Rs_GCN, self).__init__() 11 | 12 | self.in_channels = in_channels 13 | self.inter_channels = inter_channels 14 | 15 | if self.inter_channels is None: 16 | self.inter_channels = in_channels // 2 17 | if self.inter_channels == 0: 18 | self.inter_channels = 1 19 | 20 | conv_nd = nn.Conv1d 21 | max_pool = nn.MaxPool1d 22 | bn = nn.BatchNorm1d 23 | 24 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 25 | kernel_size=1, stride=1, padding=0) 26 | 27 | if bn_layer: 28 | self.W = nn.Sequential( 29 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 30 | kernel_size=1, stride=1, padding=0), 31 | bn(self.in_channels) 32 | ) 33 | nn.init.constant_(self.W[1].weight, 0) 34 | nn.init.constant_(self.W[1].bias, 0) 35 | else: 36 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | nn.init.constant_(self.W.weight, 0) 39 | nn.init.constant_(self.W.bias, 0) 40 | 41 | self.theta = None 42 | self.phi = None 43 | 44 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 47 | kernel_size=1, stride=1, padding=0) 48 | 49 | def forward(self, v): 50 | ''' 51 | :param v: (B, D, N) 52 | :return: 53 | ''' 54 | batch_size = v.size(0) 55 | 56 | g_v = self.g(v).view(batch_size, self.inter_channels, -1) 57 | g_v = g_v.permute(0, 2, 1) 58 | 59 | theta_v = self.theta(v).view(batch_size, self.inter_channels, -1) 60 | theta_v = theta_v.permute(0, 2, 1) 61 | phi_v = self.phi(v).view(batch_size, self.inter_channels, -1) 62 | R = torch.matmul(theta_v, phi_v) 63 | N = R.size(-1) 64 | R_div_C = R / N 65 | 66 | y = torch.matmul(R_div_C, g_v) 67 | y = y.permute(0, 2, 1).contiguous() 68 | y = y.view(batch_size, self.inter_channels, *v.size()[2:]) 69 | W_y = self.W(y) 70 | v_star = W_y + v 71 | 72 | return v_star 73 | 74 | 75 | class EncoderImagePrecompAttn(nn.Module): 76 | def __init__(self, img_dim, embed_size, data_name, use_abs=False, no_imgnorm=False): 77 | super(EncoderImagePrecompAttn, self).__init__() 78 | self.embed_size = embed_size 79 | self.no_imgnorm = no_imgnorm 80 | self.use_abs = use_abs 81 | self.data_name = data_name 82 | 83 | self.fc = nn.Linear(img_dim, embed_size) 84 | self.init_weights() 85 | 86 | # GSR 87 | self.img_rnn = nn.GRU(embed_size, embed_size, 1, batch_first=True) 88 | 89 | # GCN reasoning 90 | self.Rs_GCN_1 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size) 91 | self.Rs_GCN_2 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size) 92 | self.Rs_GCN_3 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size) 93 | self.Rs_GCN_4 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size) 94 | 95 | if self.data_name == 'f30k_precomp': 96 | self.bn = nn.BatchNorm1d(embed_size) 97 | 98 | def init_weights(self): 99 | """Xavier initialization for the fully connected layer 100 | """ 101 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features) 102 | self.fc.weight.data.uniform_(-r, r) 103 | self.fc.bias.data.fill_(0) 104 | 105 | def forward(self, images): 106 | """Extract image feature vectors.""" 107 | 108 | fc_img_emd = self.fc(images) 109 | if self.data_name != 'f30k_precomp': 110 | fc_img_emd = l2norm(fc_img_emd) 111 | 112 | # GCN reasoning 113 | # -> B,D,N 114 | GCN_img_emd = fc_img_emd.permute(0, 2, 1) 115 | GCN_img_emd = self.Rs_GCN_1(GCN_img_emd) 116 | GCN_img_emd = self.Rs_GCN_2(GCN_img_emd) 117 | GCN_img_emd = self.Rs_GCN_3(GCN_img_emd) 118 | GCN_img_emd = self.Rs_GCN_4(GCN_img_emd) 119 | # -> B,N,D 120 | GCN_img_emd = GCN_img_emd.permute(0, 2, 1) 121 | 122 | GCN_img_emd = l2norm(GCN_img_emd) 123 | 124 | rnn_img, hidden_state = self.img_rnn(GCN_img_emd) 125 | 126 | # features = torch.mean(rnn_img,dim=1) 127 | features = hidden_state[0] 128 | 129 | if self.data_name == 'f30k_precomp': 130 | features = self.bn(features) 131 | 132 | # normalize in the joint embedding space 133 | if not self.no_imgnorm: 134 | features = l2norm(features) 135 | 136 | # take the absolute value of embedding (used in order embeddings) 137 | if self.use_abs: 138 | features = torch.abs(features) 139 | 140 | return features, GCN_img_emd 141 | 142 | 143 | def l2norm(X, dim=1, eps=1e-8): 144 | """L2-normalize columns of X """ 145 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 146 | X = torch.div(X, norm) 147 | return X 148 | 149 | 150 | class Visual_Enconder(nn.Module): 151 | def __init__(self, visual_dim, factor, act='leaky_relu'): 152 | super(Visual_Enconder, self).__init__() 153 | self.v_fc1 = nn.Linear(visual_dim, visual_dim) 154 | self.v_drop = nn.Dropout(p=0.2, inplace=False) 155 | self.v_fc2 = nn.Linear(visual_dim, (visual_dim + factor) // 2) 156 | self.v_fc3 = nn.Linear((visual_dim + factor) // 2, factor) 157 | 158 | self.act = act 159 | 160 | def forward(self, x): 161 | if self.act == 'leaky_relu': 162 | x = F.leaky_relu(self.v_fc1(x)) 163 | x = self.v_drop(x) 164 | x = F.leaky_relu(self.v_fc2(x)) 165 | x = F.leaky_relu(self.v_fc3(x)) 166 | x = l2norm(x, dim=1) 167 | 168 | elif self.act == 'tanh': 169 | x = F.tanh(self.v_fc1(x)) 170 | x = self.v_drop(x) 171 | x = F.tanh(self.v_fc2(x)) 172 | x = F.tanh(self.v_fc3(x)) 173 | x = l2norm(x, dim=1) 174 | 175 | else: 176 | x = self.v_fc1(x) 177 | x = self.v_drop(x) 178 | x = self.v_fc2(x) 179 | x = self.v_fc3(x) 180 | x = l2norm(x, dim=1) 181 | 182 | return x 183 | -------------------------------------------------------------------------------- /meter/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | pixelbert_transform, 3 | pixelbert_transform_randaug, 4 | vit_transform, 5 | vit_transform_randaug, 6 | imagenet_transform, 7 | imagenet_transform_randaug, 8 | clip_transform, 9 | clip_transform_randaug, 10 | ) 11 | 12 | _transforms = { 13 | "pixelbert": pixelbert_transform, 14 | "pixelbert_randaug": pixelbert_transform_randaug, 15 | "vit": vit_transform, 16 | "vit_randaug": vit_transform_randaug, 17 | "imagenet": imagenet_transform, 18 | "imagenet_randaug": imagenet_transform_randaug, 19 | "clip": clip_transform, 20 | "clip_randaug": clip_transform_randaug, 21 | } 22 | 23 | def keys_to_transforms(keys: list, size=224): 24 | return [_transforms[key](size=size) for key in keys] 25 | -------------------------------------------------------------------------------- /meter/transforms/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3.txt, 0.3.txt] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3.txt, 0.3.txt] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.0: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.0)) 136 | y0 = int(max(0, y0 - v / 2.0)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3.txt), # 0 166 | # (ShearY, 0., 0.3.txt), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3.txt 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | # (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0.0, 0.3), 197 | (ShearY, 0.0, 0.3), 198 | # (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0.0, 100), 200 | (TranslateYabs, 0.0, 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = ( 220 | self.eigvec.type_as(img) 221 | .clone() 222 | .mul(alpha.view(1, 3).expand(3, 3)) 223 | .mul(self.eigval.view(1, 3).expand(3, 3)) 224 | .sum(1) 225 | .squeeze() 226 | ) 227 | 228 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 229 | 230 | 231 | class CutoutDefault(object): 232 | """ 233 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 234 | """ 235 | 236 | def __init__(self, length): 237 | self.length = length 238 | 239 | def __call__(self, img): 240 | h, w = img.size(1), img.size(2) 241 | mask = np.ones((h, w), np.float32) 242 | y = np.random.randint(h) 243 | x = np.random.randint(w) 244 | 245 | y1 = np.clip(y - self.length // 2, 0, h) 246 | y2 = np.clip(y + self.length // 2, 0, h) 247 | x1 = np.clip(x - self.length // 2, 0, w) 248 | x2 = np.clip(x + self.length // 2, 0, w) 249 | 250 | mask[y1:y2, x1:x2] = 0.0 251 | mask = torch.from_numpy(mask) 252 | mask = mask.expand_as(img) 253 | img *= mask 254 | return img 255 | 256 | 257 | class RandAugment: 258 | def __init__(self, n, m): 259 | self.n = n 260 | self.m = m # [0, 30] 261 | self.augment_list = augment_list() 262 | 263 | def __call__(self, img): 264 | ops = random.choices(self.augment_list, k=self.n) 265 | for op, minval, maxval in ops: 266 | val = (float(self.m) / 30) * float(maxval - minval) + minval 267 | img = op(img, val) 268 | 269 | return img 270 | -------------------------------------------------------------------------------- /meter/transforms/transform.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | inception_normalize, 3 | imagenet_normalize, 4 | MinMaxResize, 5 | ) 6 | from PIL import Image 7 | from torchvision import transforms 8 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 9 | from .randaug import RandAugment 10 | 11 | 12 | def pixelbert_transform(size=800): 13 | longer = int((1333 / 800) * size) 14 | return transforms.Compose( 15 | [ 16 | MinMaxResize(shorter=size, longer=longer), 17 | transforms.ToTensor(), 18 | inception_normalize, 19 | ] 20 | ) 21 | 22 | def pixelbert_transform_randaug(size=800): 23 | longer = int((1333 / 800) * size) 24 | trs = transforms.Compose( 25 | [ 26 | MinMaxResize(shorter=size, longer=longer), 27 | transforms.ToTensor(), 28 | inception_normalize, 29 | ] 30 | ) 31 | trs.transforms.insert(0, RandAugment(2, 9)) 32 | return trs 33 | 34 | def imagenet_transform(size=800): 35 | return transforms.Compose( 36 | [ 37 | Resize(size, interpolation=Image.BICUBIC), 38 | CenterCrop(size), 39 | transforms.ToTensor(), 40 | imagenet_normalize, 41 | ] 42 | ) 43 | 44 | def imagenet_transform_randaug(size=800): 45 | trs = transforms.Compose( 46 | [ 47 | Resize(size, interpolation=Image.BICUBIC), 48 | CenterCrop(size), 49 | transforms.ToTensor(), 50 | imagenet_normalize, 51 | ] 52 | ) 53 | trs.transforms.insert(0, RandAugment(2, 9)) 54 | return trs 55 | 56 | def vit_transform(size=800): 57 | return transforms.Compose( 58 | [ 59 | Resize(size, interpolation=Image.BICUBIC), 60 | CenterCrop(size), 61 | transforms.ToTensor(), 62 | inception_normalize, 63 | ] 64 | ) 65 | 66 | def vit_transform_randaug(size=800): 67 | trs = transforms.Compose( 68 | [ 69 | Resize(size, interpolation=Image.BICUBIC), 70 | CenterCrop(size), 71 | transforms.ToTensor(), 72 | inception_normalize, 73 | ] 74 | ) 75 | trs.transforms.insert(0, RandAugment(2, 9)) 76 | return trs 77 | 78 | def clip_transform(size): 79 | return Compose([ 80 | Resize(size, interpolation=Image.BICUBIC), 81 | CenterCrop(size), 82 | lambda image: image.convert("RGB"), 83 | ToTensor(), 84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 85 | ]) 86 | 87 | def clip_transform_randaug(size): 88 | trs = Compose([ 89 | Resize(size, interpolation=Image.BICUBIC), 90 | CenterCrop(size), 91 | lambda image: image.convert("RGB"), 92 | ToTensor(), 93 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 94 | ]) 95 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 96 | trs.transforms.insert(0, RandAugment(2, 9)) 97 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 98 | return trs 99 | 100 | -------------------------------------------------------------------------------- /meter/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | 5 | class MinMaxResize: 6 | def __init__(self, shorter=800, longer=1333): 7 | self.min = shorter 8 | self.max = longer 9 | 10 | def __call__(self, x): 11 | w, h = x.size 12 | scale = self.min / min(w, h) 13 | if h < w: 14 | newh, neww = self.min, scale * w 15 | else: 16 | newh, neww = scale * h, self.min 17 | 18 | if max(newh, neww) > self.max: 19 | scale = self.max / max(newh, neww) 20 | newh = newh * scale 21 | neww = neww * scale 22 | 23 | newh, neww = int(newh + 0.5), int(neww + 0.5) 24 | newh, neww = newh // 32 * 32, neww // 32 * 32 25 | 26 | return x.resize((neww, newh), resample=Image.BICUBIC) 27 | 28 | 29 | class UnNormalize(object): 30 | def __init__(self, mean, std): 31 | self.mean = mean 32 | self.std = std 33 | 34 | def __call__(self, tensor): 35 | """ 36 | Args: 37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 38 | Returns: 39 | Tensor: Normalized image. 40 | """ 41 | for t, m, s in zip(tensor, self.mean, self.std): 42 | t.mul_(s).add_(m) 43 | # The normalize code -> t.sub_(m).div_(s) 44 | return tensor 45 | 46 | 47 | # This is simple maximum entropy normalization performed in Inception paper 48 | inception_normalize = transforms.Compose( 49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 50 | ) 51 | 52 | # ViT uses simple non-biased inception normalization 53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 54 | inception_unnormalize = transforms.Compose( 55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 56 | ) 57 | 58 | # ImageNet normalize 59 | imagenet_normalize = transforms.Compose( 60 | [transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 61 | ) 62 | -------------------------------------------------------------------------------- /meter/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLyu0110/HACAN/35bdbc5cb2a9a62870fa9dce180c03c4c9d54206/meter/utils/__init__.py -------------------------------------------------------------------------------- /meter/utils/glossary.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", 5 | "arent": "aren't", 6 | "cant": "can't", 7 | "couldve": "could've", 8 | "couldnt": "couldn't", 9 | "couldn'tve": "couldn't've", 10 | "couldnt've": "couldn't've", 11 | "didnt": "didn't", 12 | "doesnt": "doesn't", 13 | "dont": "don't", 14 | "hadnt": "hadn't", 15 | "hadnt've": "hadn't've", 16 | "hadn'tve": "hadn't've", 17 | "hasnt": "hasn't", 18 | "havent": "haven't", 19 | "hed": "he'd", 20 | "hed've": "he'd've", 21 | "he'dve": "he'd've", 22 | "hes": "he's", 23 | "howd": "how'd", 24 | "howll": "how'll", 25 | "hows": "how's", 26 | "Id've": "I'd've", 27 | "I'dve": "I'd've", 28 | "Im": "I'm", 29 | "Ive": "I've", 30 | "isnt": "isn't", 31 | "itd": "it'd", 32 | "itd've": "it'd've", 33 | "it'dve": "it'd've", 34 | "itll": "it'll", 35 | "let's": "let's", 36 | "maam": "ma'am", 37 | "mightnt": "mightn't", 38 | "mightnt've": "mightn't've", 39 | "mightn'tve": "mightn't've", 40 | "mightve": "might've", 41 | "mustnt": "mustn't", 42 | "mustve": "must've", 43 | "neednt": "needn't", 44 | "notve": "not've", 45 | "oclock": "o'clock", 46 | "oughtnt": "oughtn't", 47 | "ow's'at": "'ow's'at", 48 | "'ows'at": "'ow's'at", 49 | "'ow'sat": "'ow's'at", 50 | "shant": "shan't", 51 | "shed've": "she'd've", 52 | "she'dve": "she'd've", 53 | "she's": "she's", 54 | "shouldve": "should've", 55 | "shouldnt": "shouldn't", 56 | "shouldnt've": "shouldn't've", 57 | "shouldn'tve": "shouldn't've", 58 | "somebody'd": "somebodyd", 59 | "somebodyd've": "somebody'd've", 60 | "somebody'dve": "somebody'd've", 61 | "somebodyll": "somebody'll", 62 | "somebodys": "somebody's", 63 | "someoned": "someone'd", 64 | "someoned've": "someone'd've", 65 | "someone'dve": "someone'd've", 66 | "someonell": "someone'll", 67 | "someones": "someone's", 68 | "somethingd": "something'd", 69 | "somethingd've": "something'd've", 70 | "something'dve": "something'd've", 71 | "somethingll": "something'll", 72 | "thats": "that's", 73 | "thered": "there'd", 74 | "thered've": "there'd've", 75 | "there'dve": "there'd've", 76 | "therere": "there're", 77 | "theres": "there's", 78 | "theyd": "they'd", 79 | "theyd've": "they'd've", 80 | "they'dve": "they'd've", 81 | "theyll": "they'll", 82 | "theyre": "they're", 83 | "theyve": "they've", 84 | "twas": "'twas", 85 | "wasnt": "wasn't", 86 | "wed've": "we'd've", 87 | "we'dve": "we'd've", 88 | "weve": "we've", 89 | "werent": "weren't", 90 | "whatll": "what'll", 91 | "whatre": "what're", 92 | "whats": "what's", 93 | "whatve": "what've", 94 | "whens": "when's", 95 | "whered": "where'd", 96 | "wheres": "where's", 97 | "whereve": "where've", 98 | "whod": "who'd", 99 | "whod've": "who'd've", 100 | "who'dve": "who'd've", 101 | "wholl": "who'll", 102 | "whos": "who's", 103 | "whove": "who've", 104 | "whyll": "why'll", 105 | "whyre": "why're", 106 | "whys": "why's", 107 | "wont": "won't", 108 | "wouldve": "would've", 109 | "wouldnt": "wouldn't", 110 | "wouldnt've": "wouldn't've", 111 | "wouldn'tve": "wouldn't've", 112 | "yall": "y'all", 113 | "yall'll": "y'all'll", 114 | "y'allll": "y'all'll", 115 | "yall'd've": "y'all'd've", 116 | "y'alld've": "y'all'd've", 117 | "y'all'dve": "y'all'd've", 118 | "youd": "you'd", 119 | "youd've": "you'd've", 120 | "you'dve": "you'd've", 121 | "youll": "you'll", 122 | "youre": "you're", 123 | "youve": "you've", 124 | } 125 | 126 | manual_map = { 127 | "none": "0", 128 | "zero": "0", 129 | "one": "1", 130 | "two": "2", 131 | "three": "3.txt", 132 | "four": "4", 133 | "five": "5", 134 | "six": "6", 135 | "seven": "7", 136 | "eight": "8", 137 | "nine": "9", 138 | "ten": "10", 139 | } 140 | articles = ["a", "an", "the"] 141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 142 | comma_strip = re.compile("(\d)(\,)(\d)") 143 | punct = [ 144 | ";", 145 | r"/", 146 | "[", 147 | "]", 148 | '"', 149 | "{", 150 | "}", 151 | "(", 152 | ")", 153 | "=", 154 | "+", 155 | "\\", 156 | "_", 157 | "-", 158 | ">", 159 | "<", 160 | "@", 161 | "`", 162 | ",", 163 | "?", 164 | "!", 165 | ] 166 | 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | -------------------------------------------------------------------------------- /meter/utils/write_coco_karpathy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import pyarrow as pa 5 | import random 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions, iid2split): 13 | name = path.split("/")[-1] 14 | with open(path, "rb") as fp: 15 | binary = fp.read() 16 | captions = iid2captions[name] 17 | split = iid2split[name] 18 | return [binary, captions, name, split] 19 | 20 | 21 | def make_arrow(root, dataset_root): 22 | with open(f"{root}/karpathy/dataset_coco.json", "r") as fp: 23 | captions = json.load(fp) 24 | 25 | captions = captions["images"] 26 | 27 | iid2captions = defaultdict(list) 28 | iid2split = dict() 29 | 30 | for cap in tqdm(captions): 31 | filename = cap["filename"] 32 | iid2split[filename] = cap["split"] 33 | for c in cap["sentences"]: 34 | iid2captions[filename].append(c["raw"]) 35 | 36 | paths = list(glob(f"{root}/train2014/*.jpg")) + list(glob(f"{root}/val2014/*.jpg")) 37 | random.shuffle(paths) 38 | caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 39 | 40 | if len(paths) == len(caption_paths): 41 | print("all images have caption annotations") 42 | else: 43 | print("not all images have caption annotations") 44 | print( 45 | len(paths), len(caption_paths), len(iid2captions), 46 | ) 47 | 48 | bs = [path2rest(path, iid2captions, iid2split) for path in tqdm(caption_paths)] 49 | 50 | for split in ["train", "val", "restval", "test"]: 51 | batches = [b for b in bs if b[-1] == split] 52 | 53 | dataframe = pd.DataFrame( 54 | batches, columns=["image", "caption", "image_id", "split"], 55 | ) 56 | 57 | table = pa.Table.from_pandas(dataframe) 58 | os.makedirs(dataset_root, exist_ok=True) 59 | with pa.OSFile( 60 | f"{dataset_root}/coco_caption_karpathy_{split}.arrow", "wb" 61 | ) as sink: 62 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 63 | writer.write_table(table) 64 | -------------------------------------------------------------------------------- /meter/utils/write_conceptual_caption.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import gc 5 | import random 6 | import os 7 | 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | 12 | def path2rest(path, iid2captions): 13 | split, _, name = path.split("/")[-3:] 14 | split = split.split("_")[-1] 15 | iid = name 16 | 17 | with open(path, "rb") as fp: 18 | binary = fp.read() 19 | 20 | captions = iid2captions[iid] 21 | 22 | return [ 23 | binary, 24 | captions, 25 | iid, 26 | split, 27 | ] 28 | 29 | 30 | def make_arrow(root, dataset_root): 31 | for split in ["val", "train"]: 32 | with open(f"{root}/{split}_annot.json", "r") as fp: 33 | captions = json.load(fp) 34 | 35 | iid2captions = dict() 36 | for cap in tqdm(captions): 37 | iid = cap[0].split("/")[-1] 38 | iid2captions[iid] = [cap[1]] 39 | 40 | paths = list(glob(f"{root}/images_{split}/*/*")) 41 | random.shuffle(paths) 42 | caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 43 | if len(paths) == len(caption_paths): 44 | print("all images have caption annotations") 45 | else: 46 | print("not all images have caption annotations") 47 | print( 48 | len(paths), len(caption_paths), len(iid2captions), 49 | ) 50 | 51 | sub_len = int(len(caption_paths) // 100000) 52 | subs = list(range(sub_len + 1)) 53 | for sub in subs: 54 | sub_paths = caption_paths[sub * 100000 : (sub + 1) * 100000] 55 | bs = [path2rest(path, iid2captions) for path in tqdm(sub_paths)] 56 | dataframe = pd.DataFrame( 57 | bs, columns=["image", "caption", "image_id", "split"], 58 | ) 59 | 60 | table = pa.Table.from_pandas(dataframe) 61 | 62 | os.makedirs(dataset_root, exist_ok=True) 63 | with pa.OSFile( 64 | f"{dataset_root}/conceptual_caption_{split}_{sub}.arrow", "wb" 65 | ) as sink: 66 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 67 | writer.write_table(table) 68 | del dataframe 69 | del table 70 | del bs 71 | gc.collect() 72 | -------------------------------------------------------------------------------- /meter/utils/write_f30k_karpathy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions, iid2split): 13 | name = path.split("/")[-1] 14 | 15 | with open(path, "rb") as fp: 16 | binary = fp.read() 17 | 18 | captions = iid2captions[name] 19 | split = iid2split[name] 20 | 21 | return [binary, captions, name, split] 22 | 23 | 24 | def make_arrow(root, dataset_root): 25 | with open(f"{root}/karpathy/dataset_flickr30k.json", "r") as fp: 26 | captions = json.load(fp) 27 | 28 | captions = captions["images"] 29 | 30 | iid2captions = defaultdict(list) 31 | iid2split = dict() 32 | 33 | for cap in tqdm(captions): 34 | filename = cap["filename"] 35 | iid2split[filename] = cap["split"] 36 | for c in cap["sentences"]: 37 | iid2captions[filename].append(c["raw"]) 38 | 39 | paths = list(glob(f"{root}/flickr30k-images/*.jpg")) 40 | random.shuffle(paths) 41 | caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 42 | 43 | if len(paths) == len(caption_paths): 44 | print("all images have caption annotations") 45 | else: 46 | print("not all images have caption annotations") 47 | print( 48 | len(paths), len(caption_paths), len(iid2captions), 49 | ) 50 | 51 | bs = [path2rest(path, iid2captions, iid2split) for path in tqdm(caption_paths)] 52 | 53 | for split in ["train", "val", "test"]: 54 | batches = [b for b in bs if b[-1] == split] 55 | 56 | dataframe = pd.DataFrame( 57 | batches, columns=["image", "caption", "image_id", "split"], 58 | ) 59 | 60 | table = pa.Table.from_pandas(dataframe) 61 | 62 | os.makedirs(dataset_root, exist_ok=True) 63 | with pa.OSFile( 64 | f"{dataset_root}/f30k_caption_karpathy_{split}.arrow", "wb" 65 | ) as sink: 66 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 67 | writer.write_table(table) 68 | -------------------------------------------------------------------------------- /meter/utils/write_nlvr2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import os 5 | 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | 10 | def process(root, iden, row): 11 | texts = [r["sentence"] for r in row] 12 | labels = [r["label"] for r in row] 13 | 14 | split = iden.split("-")[0] 15 | 16 | if iden.startswith("train"): 17 | directory = row[0]["directory"] 18 | path = f"{root}/images/train/{directory}/{iden}" 19 | else: 20 | path = f"{root}/{split}/{iden}" 21 | 22 | with open(f"{path}-img0.png", "rb") as fp: 23 | img0 = fp.read() 24 | with open(f"{path}-img1.png", "rb") as fp: 25 | img1 = fp.read() 26 | 27 | return [img0, img1, texts, labels, iden] 28 | 29 | 30 | def make_arrow(root, dataset_root): 31 | train_data = list( 32 | map(json.loads, open(f"{root}/nlvr2/data/train.json").readlines()) 33 | ) 34 | test1_data = list( 35 | map(json.loads, open(f"{root}/nlvr2/data/test1.json").readlines()) 36 | ) 37 | dev_data = list(map(json.loads, open(f"{root}/nlvr2/data/dev.json").readlines())) 38 | 39 | balanced_test1_data = list( 40 | map( 41 | json.loads, 42 | open(f"{root}/nlvr2/data/balanced/balanced_test1.json").readlines(), 43 | ) 44 | ) 45 | balanced_dev_data = list( 46 | map( 47 | json.loads, 48 | open(f"{root}/nlvr2/data/balanced/balanced_dev.json").readlines(), 49 | ) 50 | ) 51 | 52 | unbalanced_test1_data = list( 53 | map( 54 | json.loads, 55 | open(f"{root}/nlvr2/data/unbalanced/unbalanced_test1.json").readlines(), 56 | ) 57 | ) 58 | unbalanced_dev_data = list( 59 | map( 60 | json.loads, 61 | open(f"{root}/nlvr2/data/unbalanced/unbalanced_dev.json").readlines(), 62 | ) 63 | ) 64 | 65 | splits = [ 66 | "train", 67 | "dev", 68 | "test1", 69 | "balanced_dev", 70 | "balanced_test1", 71 | "unbalanced_dev", 72 | "unbalanced_test1", 73 | ] 74 | 75 | datas = [ 76 | train_data, 77 | dev_data, 78 | test1_data, 79 | balanced_dev_data, 80 | balanced_test1_data, 81 | unbalanced_dev_data, 82 | unbalanced_test1_data, 83 | ] 84 | 85 | annotations = dict() 86 | 87 | for split, data in zip(splits, datas): 88 | _annot = defaultdict(list) 89 | for row in tqdm(data): 90 | _annot["-".join(row["identifier"].split("-")[:-1])].append(row) 91 | annotations[split] = _annot 92 | 93 | for split in splits: 94 | bs = [ 95 | process(root, iden, row) for iden, row in tqdm(annotations[split].items()) 96 | ] 97 | 98 | dataframe = pd.DataFrame( 99 | bs, columns=["image_0", "image_1", "questions", "answers", "identifier"], 100 | ) 101 | 102 | table = pa.Table.from_pandas(dataframe) 103 | 104 | os.makedirs(dataset_root, exist_ok=True) 105 | with pa.OSFile(f"{dataset_root}/nlvr2_{split}.arrow", "wb") as sink: 106 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 107 | writer.write_table(table) 108 | -------------------------------------------------------------------------------- /meter/utils/write_sbu.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import gc 5 | import random 6 | import os 7 | 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | 12 | def path2rest(path, iid2captions): 13 | split, _, name = path.split("/")[-3:] 14 | split = split.split("_")[-1] 15 | iid = name 16 | 17 | with open(path, "rb") as fp: 18 | binary = fp.read() 19 | 20 | captions = iid2captions[iid] 21 | 22 | return [ 23 | binary, 24 | captions, 25 | iid, 26 | split, 27 | ] 28 | 29 | 30 | def make_arrow(root, dataset_root): 31 | with open(f"{root}/annot.json", "r") as fp: 32 | captions = json.load(fp) 33 | 34 | iid2captions = dict() 35 | for cap in tqdm(captions): 36 | iid = cap[0].split("/")[-1] 37 | iid2captions[iid] = [cap[1]] 38 | 39 | paths = list(glob(f"{root}/images_train/*/*")) 40 | random.shuffle(paths) 41 | caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 42 | if len(paths) == len(caption_paths): 43 | print("all images have caption annotations") 44 | else: 45 | print("not all images have caption annotations") 46 | print( 47 | len(paths), len(caption_paths), len(iid2captions), 48 | ) 49 | 50 | sub_len = int(len(caption_paths) // 100000) 51 | subs = list(range(sub_len + 1)) 52 | for sub in subs: 53 | sub_paths = caption_paths[sub * 100000 : (sub + 1) * 100000] 54 | bs = [path2rest(path, iid2captions) for path in tqdm(sub_paths)] 55 | dataframe = pd.DataFrame(bs, columns=["image", "caption", "image_id", "split"],) 56 | 57 | table = pa.Table.from_pandas(dataframe) 58 | 59 | os.makedirs(dataset_root, exist_ok=True) 60 | with pa.OSFile(f"{dataset_root}/sbu_{sub}.arrow", "wb") as sink: 61 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 62 | writer.write_table(table) 63 | del dataframe 64 | del table 65 | del bs 66 | gc.collect() 67 | -------------------------------------------------------------------------------- /meter/utils/write_snli.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import os 5 | 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | 10 | label2id = {'contradiction': 0, 'neutral': 1, 'entailment': 2} 11 | def process(root, imgid, ann): 12 | with open(f"{root}/Flickr30K/images/{imgid}.jpg", "rb") as fp: 13 | img = fp.read() 14 | 15 | sentences = ann['sentences'] 16 | 17 | labels = ann['labels'] 18 | 19 | return [img, sentences, labels] 20 | 21 | 22 | 23 | def make_arrow(root, dataset_root): 24 | train_data = list( 25 | map(json.loads, open(f"{root}/snli_ve_train.jsonl").readlines()) 26 | ) 27 | test_data = list( 28 | map(json.loads, open(f"{root}/snli_ve_test.jsonl").readlines()) 29 | ) 30 | dev_data = list( 31 | map(json.loads, open(f"{root}/snli_ve_dev.jsonl").readlines()) 32 | ) 33 | 34 | 35 | splits = [ 36 | "train", 37 | "dev", 38 | "test", 39 | ] 40 | 41 | 42 | annotations = dict() 43 | annotations['train'] = train_data 44 | annotations['dev'] = dev_data 45 | annotations['test'] = test_data 46 | annots = dict() 47 | for split in splits: 48 | annots[split] = {} 49 | for line in annotations[split]: 50 | imgid = line['Flickr30K_ID'] 51 | if not imgid in annots[split]: 52 | annots[split][imgid] = {} 53 | annots[split][imgid]['sentences'] = [] 54 | annots[split][imgid]['labels'] = [] 55 | annots[split][imgid]['sentences'].append( [line['sentence1'], line['sentence2']] ) 56 | annots[split][imgid]['labels'].append( label2id[line['gold_label']] ) 57 | 58 | 59 | 60 | for split in splits: 61 | bs = [process(root, imgid, annots[split][imgid]) for imgid in tqdm(annots[split])] 62 | 63 | dataframe = pd.DataFrame( 64 | bs, columns=["image", "sentences", "labels"] 65 | ) 66 | 67 | table = pa.Table.from_pandas(dataframe) 68 | 69 | os.makedirs(dataset_root, exist_ok=True) 70 | with pa.OSFile(f"{dataset_root}/snli_{split}.arrow", "wb") as sink: 71 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 72 | writer.write_table(table) 73 | -------------------------------------------------------------------------------- /meter/utils/write_vg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions): 13 | name = path.split("/")[-1] 14 | iid = int(name[:-4]) 15 | 16 | with open(path, "rb") as fp: 17 | binary = fp.read() 18 | 19 | cdicts = iid2captions[iid] 20 | captions = [c["phrase"] for c in cdicts] 21 | widths = [c["width"] for c in cdicts] 22 | heights = [c["height"] for c in cdicts] 23 | xs = [c["x"] for c in cdicts] 24 | ys = [c["y"] for c in cdicts] 25 | 26 | return [ 27 | binary, 28 | captions, 29 | widths, 30 | heights, 31 | xs, 32 | ys, 33 | str(iid), 34 | ] 35 | 36 | 37 | def make_arrow(root, dataset_root): 38 | with open(f"{root}/annotations/region_descriptions.json", "r") as fp: 39 | captions = json.load(fp) 40 | 41 | iid2captions = defaultdict(list) 42 | for cap in tqdm(captions): 43 | cap = cap["regions"] 44 | for c in cap: 45 | iid2captions[c["image_id"]].append(c) 46 | 47 | paths = list(glob(f"{root}/images/VG_100K/*.jpg")) + list( 48 | glob(f"{root}/images/VG_100K_2/*.jpg") 49 | ) 50 | random.shuffle(paths) 51 | caption_paths = [ 52 | path for path in paths if int(path.split("/")[-1][:-4]) in iid2captions 53 | ] 54 | 55 | if len(paths) == len(caption_paths): 56 | print("all images have caption annotations") 57 | else: 58 | print("not all images have caption annotations") 59 | print( 60 | len(paths), len(caption_paths), len(iid2captions), 61 | ) 62 | 63 | bs = [path2rest(path, iid2captions) for path in tqdm(caption_paths)] 64 | dataframe = pd.DataFrame( 65 | bs, columns=["image", "caption", "width", "height", "x", "y", "image_id"], 66 | ) 67 | table = pa.Table.from_pandas(dataframe) 68 | 69 | os.makedirs(dataset_root, exist_ok=True) 70 | with pa.OSFile(f"{dataset_root}/vg.arrow", "wb") as sink: 71 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 72 | writer.write_table(table) 73 | -------------------------------------------------------------------------------- /meter/utils/write_vqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | 13 | def get_score(occurences): 14 | if occurences == 0: 15 | return 0.0 16 | elif occurences == 1: 17 | return 0.3 18 | elif occurences == 2: 19 | return 0.6 20 | elif occurences == 3: 21 | return 0.9 22 | else: 23 | return 1.0 24 | 25 | 26 | def path2rest(path, split, annotations, label2ans): 27 | iid = int(path.split("/")[-1].split("_")[-1][:-4]) 28 | 29 | with open(path, "rb") as fp: 30 | binary = fp.read() 31 | 32 | _annot = annotations[split][iid] 33 | _annot = list(_annot.items()) 34 | qids, qas = [a[0] for a in _annot], [a[1] for a in _annot] 35 | questions = [qa[0] for qa in qas] 36 | answers = [qa[1] for qa in qas] if "test" not in split else list(list()) 37 | answer_labels = ( 38 | [a["labels"] for a in answers] if "test" not in split else list(list()) 39 | ) 40 | answer_scores = ( 41 | [a["scores"] for a in answers] if "test" not in split else list(list()) 42 | ) 43 | answers = ( 44 | [[label2ans[l] for l in al] for al in answer_labels] 45 | if "test" not in split 46 | else list(list()) 47 | ) 48 | 49 | return [binary, questions, answers, answer_labels, answer_scores, iid, qids, split] 50 | 51 | 52 | def make_arrow(root, dataset_root): 53 | with open(f"{root}/v2_OpenEnded_mscoco_train2014_questions.json", "r") as fp: 54 | questions_train2014 = json.load(fp)["questions"] 55 | with open(f"{root}/v2_OpenEnded_mscoco_val2014_questions.json", "r") as fp: 56 | questions_val2014 = json.load(fp)["questions"] 57 | with open(f"{root}/v2_OpenEnded_mscoco_test2015_questions.json", "r") as fp: 58 | questions_test2015 = json.load(fp)["questions"] 59 | with open(f"{root}/v2_OpenEnded_mscoco_test-dev2015_questions.json", "r") as fp: 60 | questions_test_dev2015 = json.load(fp)["questions"] 61 | 62 | with open(f"{root}/v2_mscoco_train2014_annotations.json", "r") as fp: 63 | annotations_train2014 = json.load(fp)["annotations"] 64 | with open(f"{root}/v2_mscoco_val2014_annotations.json", "r") as fp: 65 | annotations_val2014 = json.load(fp)["annotations"] 66 | 67 | annotations = dict() 68 | 69 | for split, questions in zip( 70 | ["train", "val", "test", "test-dev"], 71 | [ 72 | questions_train2014, 73 | questions_val2014, 74 | questions_test2015, 75 | questions_test_dev2015, 76 | ], 77 | ): 78 | _annot = defaultdict(dict) 79 | for q in tqdm(questions): 80 | _annot[q["image_id"]][q["question_id"]] = [q["question"]] 81 | 82 | annotations[split] = _annot 83 | 84 | all_major_answers = list() 85 | 86 | for split, annots in zip( 87 | ["train", "val"], [annotations_train2014, annotations_val2014], 88 | ): 89 | _annot = annotations[split] 90 | for q in tqdm(annots): 91 | all_major_answers.append(q["multiple_choice_answer"]) 92 | 93 | all_major_answers = [normalize_word(word) for word in tqdm(all_major_answers)] 94 | counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9} 95 | ans2label = {k: i for i, k in enumerate(counter.keys())} 96 | label2ans = list(counter.keys()) 97 | 98 | for split, annots in zip( 99 | ["train", "val"], [annotations_train2014, annotations_val2014], 100 | ): 101 | _annot = annotations[split] 102 | for q in tqdm(annots): 103 | answers = q["answers"] 104 | answer_count = {} 105 | for answer in answers: 106 | answer_ = answer["answer"] 107 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 108 | 109 | labels = [] 110 | scores = [] 111 | for answer in answer_count: 112 | if answer not in ans2label: 113 | continue 114 | labels.append(ans2label[answer]) 115 | score = get_score(answer_count[answer]) 116 | scores.append(score) 117 | 118 | _annot[q["image_id"]][q["question_id"]].append( 119 | {"labels": labels, "scores": scores,} 120 | ) 121 | 122 | for split in ["train", "val"]: 123 | filtered_annot = dict() 124 | for ik, iv in annotations[split].items(): 125 | new_q = dict() 126 | for qk, qv in iv.items(): 127 | if len(qv[1]["labels"]) != 0: 128 | new_q[qk] = qv 129 | if len(new_q) != 0: 130 | filtered_annot[ik] = new_q 131 | annotations[split] = filtered_annot 132 | 133 | for split in [ 134 | "train", 135 | "val", 136 | "test", 137 | "test-dev", 138 | ]: 139 | annot = annotations[split] 140 | split_name = { 141 | "train": "train2014", 142 | "val": "val2014", 143 | "test": "test2015", 144 | "test-dev": "test2015", 145 | }[split] 146 | paths = list(glob(f"{root}/{split_name}/*.jpg")) 147 | random.shuffle(paths) 148 | annot_paths = [ 149 | path 150 | for path in paths 151 | if int(path.split("/")[-1].split("_")[-1][:-4]) in annot 152 | ] 153 | 154 | if len(paths) == len(annot_paths): 155 | print("all images have caption annotations") 156 | else: 157 | print("not all images have caption annotations") 158 | print( 159 | len(paths), len(annot_paths), len(annot), 160 | ) 161 | 162 | bs = [ 163 | path2rest(path, split, annotations, label2ans) for path in tqdm(annot_paths) 164 | ] 165 | 166 | dataframe = pd.DataFrame( 167 | bs, 168 | columns=[ 169 | "image", 170 | "questions", 171 | "answers", 172 | "answer_labels", 173 | "answer_scores", 174 | "image_id", 175 | "question_id", 176 | "split", 177 | ], 178 | ) 179 | 180 | table = pa.Table.from_pandas(dataframe) 181 | 182 | os.makedirs(dataset_root, exist_ok=True) 183 | with pa.OSFile(f"{dataset_root}/vqav2_{split}.arrow", "wb") as sink: 184 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 185 | writer.write_table(table) 186 | 187 | table = pa.ipc.RecordBatchFileReader( 188 | pa.memory_map(f"{dataset_root}/vqav2_val.arrow", "r") 189 | ).read_all() 190 | 191 | pdtable = table.to_pandas() 192 | 193 | df1 = pdtable[:-1000] 194 | df2 = pdtable[-1000:] 195 | 196 | df1 = pa.Table.from_pandas(df1) 197 | df2 = pa.Table.from_pandas(df2) 198 | 199 | with pa.OSFile(f"{dataset_root}/vqav2_trainable_val.arrow", "wb") as sink: 200 | with pa.RecordBatchFileWriter(sink, df1.schema) as writer: 201 | writer.write_table(df1) 202 | 203 | with pa.OSFile(f"{dataset_root}/vqav2_rest_val.arrow", "wb") as sink: 204 | with pa.RecordBatchFileWriter(sink, df2.schema) as writer: 205 | writer.write_table(df2) 206 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | aiobotocore==2.5.4 3 | aiohttp==3.8.6 4 | aioitertools==0.11.0 5 | aiosignal==1.3.1 6 | annotated-types==0.6.0 7 | anyio==3.7.1 8 | arrow==1.3.0 9 | async-timeout==4.0.3 10 | attrs==23.1.0 11 | beautifulsoup4==4.12.2 12 | boto3==1.28.62 13 | botocore==1.31.62 14 | cachetools==5.3.1 15 | certifi==2023.7.22 16 | charset-normalizer==3.3.0 17 | click==8.1.7 18 | colorama==0.4.6 19 | contourpy==1.1.1 20 | croniter==1.3.15 21 | cycler==0.12.1 22 | deepdiff==6.6.0 23 | docopt==0.6.2 24 | einops==0.6.1 25 | exceptiongroup==1.1.3 26 | fastapi==0.103.2 27 | filelock==3.12.4 28 | fire==0.5.0 29 | fonttools==4.43.1 30 | frozenlist==1.4.0 31 | fsspec==2023.9.2 32 | future==0.18.3 33 | gitdb==4.0.10 34 | GitPython==3.1.37 35 | google-auth==2.23.3 36 | google-auth-oauthlib==1.0.0 37 | grpcio==1.59.0 38 | h11==0.14.0 39 | huggingface-hub==0.18.0 40 | idna==3.4 41 | importlib-metadata==6.8.0 42 | importlib-resources==6.1.0 43 | itsdangerous==2.1.2 44 | jmespath==1.0.1 45 | joblib==1.3.2 46 | jsonpickle==3.0.2 47 | kiwisolver==1.4.5 48 | lightning==1.8.0 49 | lightning-cloud==0.5.39 50 | lightning-lite==1.8.6 51 | lightning-utilities==0.3.0 52 | Markdown==3.5 53 | markdown-it-py==3.0.0 54 | MarkupSafe==2.1.3 55 | matplotlib==3.7.3 56 | mdurl==0.1.2 57 | multidict==6.0.4 58 | munch==2.5.0 59 | nltk==3.8.1 60 | numpy==1.22.0 61 | oauthlib==3.2.2 62 | opencv-python==4.9.0.80 63 | ordered-set==4.1.0 64 | packaging==23.2 65 | pandas==2.0.3 66 | Pillow==10.0.1 67 | protobuf==4.24.4 68 | psutil==5.9.5 69 | py-cpuinfo==9.0.0 70 | pyasn1==0.5.0 71 | pyasn1-modules==0.3.0 72 | pycocotools==2.0.7 73 | pydantic==2.4.2 74 | pydantic_core==2.10.1 75 | pyDeprecate==0.3.0 76 | Pygments==2.16.1 77 | PyJWT==2.8.0 78 | pyparsing==3.1.1 79 | python-dateutil==2.8.2 80 | python-multipart==0.0.6 81 | pytorch-lightning==1.3.2 82 | pytz==2023.3.post1 83 | PyYAML==5.4.1 84 | regex==2023.10.3 85 | requests==2.31.0 86 | requests-oauthlib==1.3.1 87 | rich==13.6.0 88 | rsa==4.9 89 | s3fs==2023.9.2 90 | s3transfer==0.7.0 91 | sacred==0.8.4 92 | safetensors==0.4.0 93 | seaborn==0.13.2 94 | six==1.16.0 95 | smmap==5.0.1 96 | sniffio==1.3.0 97 | soupsieve==2.5 98 | starlette==0.27.0 99 | starsessions==1.3.0 100 | tensorboard==2.14.0 101 | tensorboard-data-server==0.7.1 102 | termcolor==2.3.0 103 | timm==0.4.12 104 | tokenizers==0.13.3 105 | torch==1.11.0+cu113 106 | torchaudio==0.11.0+cu113 107 | torchmetrics==0.6.0 108 | torchvision==0.12.0+cu113 109 | tqdm==4.66.1 110 | traitlets==5.11.2 111 | transformers==4.24.0 112 | types-python-dateutil==2.8.19.14 113 | typing_extensions==4.8.0 114 | tzdata==2023.4 115 | urllib3==1.25.11 116 | uvicorn==0.23.2 117 | websocket-client==1.6.4 118 | Werkzeug==3.0.0 119 | wrapt==1.15.0 120 | yarl==1.9.2 121 | zipp==3.17.0 122 | # 123 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import pytorch_lightning as pl 3 | from meter.config import ex 4 | from meter.modules import METERTransformerSS 5 | from data import F30kDataModule, MscocoDataModule 6 | import torch 7 | import os 8 | #os.environ['CUDA_VISIBLE_DEVICES'] = '2' 9 | 10 | 11 | @ex.automain 12 | def main(_config): 13 | 14 | _config = copy.deepcopy(_config) 15 | pl.seed_everything(_config["seed"], workers=True) 16 | print(_config) 17 | 18 | if 'f30k' in _config['exp_name']: 19 | dm = F30kDataModule(_config) 20 | else: 21 | dm = MscocoDataModule(_config) 22 | 23 | model = METERTransformerSS(_config) 24 | 25 | if _config['test_only']: 26 | ckpt = torch.load(_config['checkpoint'], map_location="cuda:0") 27 | model.load_state_dict(ckpt['state_dict']) 28 | 29 | exp_name = f'{_config["exp_name"]}' 30 | 31 | os.makedirs(_config["log_dir"], exist_ok=True) 32 | 33 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 34 | save_top_k=5, 35 | dirpath = _config['save_path'], 36 | monitor="best_irtr", 37 | mode="max", 38 | save_last=True, 39 | ) 40 | logger = pl.loggers.TensorBoardLogger( 41 | _config["log_dir"], 42 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 43 | version=_config['experiment_name'] 44 | ) 45 | 46 | callbacks = [checkpoint_callback] 47 | 48 | num_gpus = ( 49 | _config["num_gpus"] 50 | if isinstance(_config["num_gpus"], int) 51 | else len(_config["num_gpus"]) 52 | ) 53 | 54 | trainer = pl.Trainer( 55 | gpus=[0], 56 | precision=_config["precision"], 57 | #accelerator="ddp",s 58 | # accelerator='ddp', 59 | # strategy='ddp', 60 | benchmark=True, 61 | deterministic=True, 62 | max_epochs=_config["max_epoch"], 63 | callbacks=callbacks, 64 | logger=logger, 65 | #replace_sampler_ddp=False, 66 | log_every_n_steps=10, 67 | flush_logs_every_n_steps=10, 68 | weights_summary="top", 69 | val_check_interval=_config["val_check_interval"], 70 | # gradient_clip_val=2.0 71 | ) 72 | # print("***********************{}".format(trainer.global_rank)s) 73 | # if trainer.global_rank == 0:s 74 | # print(_config) 75 | 76 | if not _config["test_only"]: 77 | trainer.fit(model, datamodule=dm) 78 | else: 79 | trainer.test(model, datamodule=dm) 80 | 81 | --------------------------------------------------------------------------------