├── .gitignore ├── .gitmodules ├── .idea ├── deployment.xml ├── encodings.xml ├── inspectionProfiles │ └── Project_Default.xml ├── lxmert.iml ├── misc.xml └── modules.xml ├── LICENSE ├── README.md ├── experience_in_pretraining.md ├── extract_gqa_image.py ├── find_example.py ├── get_test.py ├── get_testdev.py ├── gqa_debug.py ├── requirements.txt ├── run ├── README.md ├── fsb2.bash ├── gqa_finetuneft.bash ├── nlvr2_ft.bash └── vqa_finetuneft.bash └── src ├── lxrt ├── entry.py ├── file_utils.py ├── modeling.py ├── modeling_big.py ├── optimization.py └── tokenization.py ├── param.py ├── pretrain ├── lxmert_data.py ├── lxmert_pretrain.py ├── lxmert_pretrain2.py └── qa_answer_table.py ├── tasks ├── gqa.py ├── gqa_data.py ├── gqa_model.py ├── nlvr2.py ├── nlvr2_data.py ├── nlvr2_model.py ├── vqa.py ├── vqa_constant.py ├── vqa_data.py └── vqa_model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.caffemodel 2 | *.tsv 3 | /snap 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data/nlvr2/nlvr"] 2 | path = data/nlvr2/nlvr 3 | url = https://github.com/lil-lab/nlvr.git 4 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 148 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 14 | -------------------------------------------------------------------------------- /.idea/lxmert.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 10 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Hao Tan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Most Important Thing. 2 | ## Our code is developed based on: 3 | ## LXMERT: Learning Cross-Modality Encoder Representations from Transformers (https://github.com/airsplay/lxmert) 4 | ## If you think our work is useful, please also cite their work! 5 | 6 | ## Introduction 7 | PyTorch code for the CVPR 2021 paper ["Causal Attention for Vision-Language Tasks"](https://arxiv.org/pdf/2103.03493.pdf). 8 | PyTorch code for the CVPR 2021 paper ["Causal Attention for Vision-Language Tasks"](https://arxiv.org/pdf/2103.03493.pdf). Slides of our EMNLP 2019 talk are avialable [here](http://www.cs.unc.edu/~airsplay/EMNLP_2019_LXMERT_slides.pdf). 9 | For experiment settings, like the pytorch version and GPU setting, please refer to LXMERT (https://github.com/airsplay/lxmert) 10 | 11 | ## Results 36 RoI version 12 | 13 | | Split | [VQA](https://visualqa.org/) | [GQA](https://cs.stanford.edu/people/dorarad/gqa/) | [NLVR2](http://lil.nlp.cornell.edu/nlvr/) | 14 | |----------- |:----: |:---: |:------:| 15 | | Local Validation | 70.40% | 60.90% | 76.40% | 16 | | Test-Dev | 72.81% | 60.84% | 76.40% (Test-P) | 17 | | Test-Standard | 73.04% | 61.17% | 76.00% (Test-U) | 18 | 19 | ## Results 64 RoI version 20 | ## Extracting more RoI visual features from an image will largely improve the performances! 21 | 22 | | Split | [VQA](https://visualqa.org/) | [GQA](https://cs.stanford.edu/people/dorarad/gqa/) | [NLVR2](http://lil.nlp.cornell.edu/nlvr/) | 23 | |----------- |:----: |:---: |:------:| 24 | | Test-Dev | 73.54% | 61.87% | 77.27% (Test-P) | 25 | | Test-Standard | 73.63% | 62.07% | 77.23% (Test-U) | 26 | 27 | ## Pre-training 28 | ## Notice that this part is the same as LXMERT: https://github.com/airsplay/lxmert. We put them here for self-containing. 29 | 30 | 1. Download the aggregated LXMERT dataset from MS COCO, Visual Genome, VQA, and GQA (around 700MB in total). The joint answer labels are saved in `data/lxmert/all_ans.json`. 31 | ```bash 32 | mkdir -p data/lxmert 33 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/lxmert/mscoco_train.json -P data/lxmert/ 34 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/lxmert/mscoco_nominival.json -P data/lxmert/ 35 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/lxmert/vgnococo.json -P data/lxmert/ 36 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/lxmert/mscoco_minival.json -P data/lxmert/ 37 | ``` 38 | 39 | 2. Download the detection features from MS COCO images from LXMERT. 40 | ```bash 41 | mkdir -p data/mscoco_imgfeat 42 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/mscoco_imgfeat/train2014_obj36.zip -P data/mscoco_imgfeat 43 | unzip data/mscoco_imgfeat/train2014_obj36.zip -d data/mscoco_imgfeat && rm data/mscoco_imgfeat/train2014_obj36.zip 44 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/mscoco_imgfeat/val2014_obj36.zip -P data/mscoco_imgfeat 45 | unzip data/mscoco_imgfeat/val2014_obj36.zip -d data && rm data/mscoco_imgfeat/val2014_obj36.zip 46 | ``` 47 | 48 | 3. Download the detection features for Visual Genome images. 49 | ```bash 50 | mkdir -p data/vg_gqa_imgfeat 51 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/vg_gqa_imgfeat/vg_gqa_obj36.zip -P data/vg_gqa_imgfeat 52 | unzip data/vg_gqa_imgfeat/vg_gqa_obj36.zip -d data && rm data/vg_gqa_imgfeat/vg_gqa_obj36.zip 53 | ``` 54 | 55 | 4. Test on a small split of the MS COCO + Visual Genome datasets: 56 | ```bash 57 | bash run/lxmert_pretrain.bash 0,1,2,3 --multiGPU --tiny 58 | ``` 59 | 60 | 5. Run on the whole [MS COCO](http://cocodataset.org) and [Visual Genome](https://visualgenome.org/) related datasets (i.e., [VQA](https://visualqa.org/), [GQA](https://cs.stanford.edu/people/dorarad/gqa/index.html), [COCO caption](http://cocodataset.org/#captions-2015), [VG Caption](https://visualgenome.org/), [VG QA](https://github.com/yukezhu/visual7w-toolkit)). 61 | 62 | ## This part is ours: 63 | The pre-training code is: 64 | ```bash 65 | bash run/fsb2.bash 0,1,2,3 --multiGPU 66 | ``` 67 | After pre-training, the finetuning codes for VQA, GQA, and NLVE2 are: 68 | ```bash 69 | bash run/vqa_finetuneft.bash 0 0.00004 0.00004 70 | bash run/gqa_finetuneft.bash 0 0.000001 0.000001 71 | bash run/nlvr2_ft.bash 0 0.00003 0.00003 72 | ``` 73 | 74 | 75 | -------------------------------------------------------------------------------- /experience_in_pretraining.md: -------------------------------------------------------------------------------- 1 | # Experience in Pre-training 2 | Since I finish this project with quite limited computational resources, I would like to share some experiences. If you are also in a small group and plan to pre-train back-bone models for fun, hope it would help. 3 | 4 | ## Workflow 5 | 1. Design a model and its pre-training strategies. 6 | 2. Test whether the code is correct or not by over-fitting a super small split (5000 images, typically) of aggregated data. 7 | 3. Pre-train it on **all aggregated pre-training data** for around 3 to 4 epochs. (At least make sure that all the images are included!) 8 | 4. Test the pre-training performance on a **small split** of fine-tuning tasks. I used 5000 images by setting the `--fast` option. Note that the epochs should be increased to around 20/50 from 4. 9 | 5. If the accuracy (i.e., results) of the fine-tuning tasks keep growing, it indicates that the pre-training is effective! 10 | 6. Compare the **full fine-tuning-data** results when 3-4 epochs' pre-training finishes and select the best pre-training strategies. 11 | 7. Train on **full aggregated data** and have a good one-week sleep ;). 12 | 13 | 14 | ## Tips 15 | - **Do not** validate pre-training strategies (pre-training tasks, pre-training model) on a **small split** of the data. The behavior of pre-training on a small split is significantly different from the full pre-training dataset. 16 | - Do not over-tune the pre-training hyperparameters. Keep in mind that a good idea will overshadow all these cherry-pick hyper-parameters. Anyway, you would not have enough GPUs to do that. 17 | - Add a component at each time; Have a plan for it. 18 | - Pipeline everything. 19 | - You could rest but GPUs never get rest; GPUs are sometimes broken but you never give up. 20 | -------------------------------------------------------------------------------- /extract_gqa_image.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | 3 | # The root of bottom-up-attention repo. Do not need to change if using provided docker file. 4 | BUTD_ROOT = '/opt/butd/' 5 | 6 | import os, sys 7 | sys.path.insert(0, BUTD_ROOT + "/tools") 8 | os.environ['GLOG_minloglevel'] = '2' 9 | 10 | import _init_paths 11 | from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list 12 | from fast_rcnn.test import im_detect, _get_blobs 13 | from fast_rcnn.nms_wrapper import nms 14 | 15 | import caffe 16 | import argparse 17 | import pprint 18 | import base64 19 | import numpy as np 20 | import cv2 21 | import csv 22 | from tqdm import tqdm 23 | 24 | csv.field_size_limit(sys.maxsize) 25 | 26 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", 27 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] 28 | 29 | # Settings for the number of features per image. To re-create pretrained features with 36 features 30 | # per image, set both values to 36. 31 | MIN_BOXES = 100 32 | MAX_BOXES = 100 33 | 34 | 35 | def load_image_ids(img_root): 36 | pathXid = [] 37 | for name in os.listdir(img_root): 38 | idx = name.split(".")[0] 39 | pathXid.append( 40 | ( 41 | os.path.join(img_root, name), 42 | idx)) 43 | return pathXid 44 | 45 | def generate_tsv(prototxt, weights, image_ids, outfile): 46 | # First check if file exists, and if it is complete 47 | # never use set, it loses the order!!! F*** 48 | wanted_ids = set([image_id[1] for image_id in image_ids]) 49 | found_ids = set() 50 | if os.path.exists(outfile): 51 | with open(outfile, "r") as tsvfile: 52 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 53 | for item in reader: 54 | found_ids.add(item['img_id']) 55 | missing = wanted_ids - found_ids 56 | if len(missing) == 0: 57 | print('already completed {:d}'.format(len(image_ids))) 58 | else: 59 | print('missing {:d}/{:d}'.format(len(missing), len(image_ids))) 60 | if len(missing) > 0: 61 | caffe.set_mode_gpu() 62 | caffe.set_device(0) 63 | net = caffe.Net(prototxt, caffe.TEST, weights=weights) 64 | with open(outfile, 'ab') as tsvfile: 65 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 66 | for im_file, image_id in tqdm(image_ids): 67 | if image_id in missing: 68 | try: 69 | writer.writerow(get_detections_from_im(net, im_file, image_id)) 70 | except Exception as e: 71 | print(e) 72 | 73 | 74 | def get_detections_from_im(net, im_file, image_id, conf_thresh=0.2): 75 | """ 76 | :param net: 77 | :param im_file: full path to an image 78 | :param image_id: 79 | :param conf_thresh: 80 | :return: all information from detection and attr prediction 81 | """ 82 | im = cv2.imread(im_file) 83 | scores, boxes, attr_scores, rel_scores = im_detect(net, im) 84 | 85 | # Keep the original boxes, don't worry about the regresssion bbox outputs 86 | rois = net.blobs['rois'].data.copy() 87 | # unscale back to raw image space 88 | blobs, im_scales = _get_blobs(im, None) 89 | 90 | cls_boxes = rois[:, 1:5] / im_scales[0] 91 | cls_prob = net.blobs['cls_prob'].data 92 | attr_prob = net.blobs['attr_prob'].data 93 | pool5 = net.blobs['pool5_flat'].data 94 | 95 | # Keep only the best detections 96 | max_conf = np.zeros((rois.shape[0])) 97 | for cls_ind in range(1, cls_prob.shape[1]): 98 | cls_scores = scores[:, cls_ind] 99 | dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) 100 | keep = np.array(nms(dets, cfg.TEST.NMS)) 101 | max_conf[keep] = np.where(cls_scores[keep] > max_conf[keep], cls_scores[keep], max_conf[keep]) 102 | 103 | keep_boxes = np.where(max_conf >= conf_thresh)[0] 104 | if len(keep_boxes) < MIN_BOXES: 105 | keep_boxes = np.argsort(max_conf)[::-1][:MIN_BOXES] 106 | elif len(keep_boxes) > MAX_BOXES: 107 | keep_boxes = np.argsort(max_conf)[::-1][:MAX_BOXES] 108 | 109 | objects = np.argmax(cls_prob[keep_boxes][:, 1:], axis=1) 110 | objects_conf = np.max(cls_prob[keep_boxes][:, 1:], axis=1) 111 | attrs = np.argmax(attr_prob[keep_boxes][:, 1:], axis=1) 112 | attrs_conf = np.max(attr_prob[keep_boxes][:, 1:], axis=1) 113 | 114 | return { 115 | "img_id": image_id, 116 | "img_h": np.size(im, 0), 117 | "img_w": np.size(im, 1), 118 | "objects_id": base64.b64encode(objects), # int64 119 | "objects_conf": base64.b64encode(objects_conf), # float32 120 | "attrs_id": base64.b64encode(attrs), # int64 121 | "attrs_conf": base64.b64encode(attrs_conf), # float32 122 | "num_boxes": len(keep_boxes), 123 | "boxes": base64.b64encode(cls_boxes[keep_boxes]), # float32 124 | "features": base64.b64encode(pool5[keep_boxes]) # float32 125 | } 126 | 127 | 128 | def parse_args(): 129 | """ 130 | Parse input arguments 131 | """ 132 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') 133 | parser.add_argument('--gpu', dest='gpu_id', help='GPU id(s) to use', 134 | default='0', type=str) 135 | parser.add_argument('--def', dest='prototxt', 136 | help='prototxt file defining the network', 137 | default=None, type=str) 138 | parser.add_argument('--out', dest='outfile', 139 | help='output filepath', 140 | default=None, type=str) 141 | parser.add_argument('--cfg', dest='cfg_file', 142 | help='optional config file', default=None, type=str) 143 | parser.add_argument('--set', dest='set_cfgs', 144 | help='set config keys', default=None, 145 | nargs=argparse.REMAINDER) 146 | parser.add_argument('--imgroot', type=str, default='/workspace/images/') 147 | parser.add_argument('--split', type=str, default='valid') 148 | parser.add_argument('--caffemodel', type=str, default='pretrained/resnet101_faster_rcnn_final_iter_320000.caffemodel') 149 | 150 | args = parser.parse_args() 151 | return args 152 | 153 | 154 | if __name__ == '__main__': 155 | # Setup the configuration, normally do not need to touch these: 156 | args = parse_args() 157 | 158 | 159 | args.cfg_file = BUTD_ROOT + "experiments/cfgs/faster_rcnn_end2end_resnet.yml" # s = 500 160 | args.prototxt = BUTD_ROOT + "models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt" 161 | args.outfile = "%s_obj100.tsv" % "vg_gqa" 162 | 163 | print('Called with args:') 164 | print(args) 165 | 166 | if args.cfg_file is not None: 167 | cfg_from_file(args.cfg_file) 168 | 169 | print('Using config:') 170 | pprint.pprint(cfg) 171 | assert cfg.TEST.HAS_RPN 172 | 173 | # Load image ids, need modification for new datasets. 174 | image_ids = load_image_ids(args.imgroot) 175 | 176 | # Generate TSV files, noramlly do not need to modify 177 | generate_tsv(args.prototxt, args.caffemodel, image_ids, args.outfile) 178 | -------------------------------------------------------------------------------- /find_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | data = json.load(open('/data2/yangxu/lxmert/data/vqa/nominival.json')) 4 | N = len(data) 5 | data_new = [] 6 | do_txt = open("gender.txt", "w") 7 | play_txt = open("play.txt", "w") 8 | 9 | for i in range(N): 10 | if data[i]['sent'].find('gender')>=0: 11 | data_new.append(data[i]) 12 | print(data[i]['img_id']) 13 | print(data[i]['sent']) 14 | print(data[i]['label']) 15 | do_txt.write("{0}\n".format(data[i]['img_id'])) 16 | do_txt.write("{0}\n".format(data[i]['sent'])) 17 | do_txt.write("{0}\n".format(data[i]['label'])) 18 | # if data[i]['sent'].find('gender')>=0: 19 | # data_new.append(data[i]) 20 | # print(data[i]['img_id']) 21 | # print(data[i]['sent']) 22 | # print(data[i]['label']) 23 | # play_txt.write("{0}\n".format(data[i]['img_id'])) 24 | # play_txt.write("{0}\n".format(data[i]['sent'])) 25 | # play_txt.write("{0}\n".format(data[i]['label'])) 26 | 27 | do_txt.close() 28 | # play_txt.close() 29 | 30 | # data = json.load(open('/data2/yangxu/lxmert/data/vqa/train.json')) 31 | # N = len(data) 32 | # c0 = 0 33 | # w0 = 'surfboard' 34 | # c1 = 0 35 | # w1 = 'man' 36 | # c2 = 0 37 | # w2 = 'woman' 38 | # 39 | # for i in range(N): 40 | # if data[i]['sent'].find(w0)>=0: 41 | # c0 += 1 42 | # if data[i]['sent'].find(w1) >= 0: 43 | # c1 += 1 44 | # if data[i]['sent'].find(w2) >= 0: 45 | # c2 += 1 46 | # print(w0,c0) 47 | # print(w1,c1) 48 | # print(w2,c2) -------------------------------------------------------------------------------- /get_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | import csv 7 | import sys 8 | csv.field_size_limit(sys.maxsize) 9 | 10 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", 11 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] 12 | 13 | id_file = 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv' 14 | coco_file = 'data/mscoco_imgfeat/test2015_obj64.tsv' 15 | outfile = 'data/vg_gqa_imgfeat/vg_gqa_obj64.tsv' 16 | all_id = [] 17 | all_id_number = [] 18 | coco_data = [] 19 | 20 | with open(coco_file) as f: 21 | coco_reader = csv.DictReader(f, FIELDNAMES, delimiter="\t") 22 | for i, item in enumerate(coco_reader): 23 | if i % 1000 == 0: 24 | print(i) 25 | coco_data.append(item) 26 | all_id_number.append(int(item['img_id'][14:])) 27 | 28 | N = 0 29 | with open(outfile, 'a+') as tsvfile: 30 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 31 | with open(id_file) as f: 32 | id_reader = csv.DictReader(f, FIELDNAMES, delimiter="\t") 33 | for i, item in enumerate(id_reader): 34 | if i % 1000 == 0: 35 | print(i) 36 | nlvr2_id = item['img_id'] 37 | if nlvr2_id[0] == 'n': 38 | N+=1 39 | nlvr2_id_number=int(nlvr2_id[1:]) 40 | index_id = all_id_number.index(nlvr2_id_number) 41 | coco_datum=coco_data[index_id] 42 | coco_datum['img_id'] = item['img_id'] 43 | writer.writerow(coco_datum) 44 | print(N) -------------------------------------------------------------------------------- /get_testdev.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | import csv 7 | import sys 8 | csv.field_size_limit(sys.maxsize) 9 | 10 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", 11 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] 12 | 13 | id_file = 'data/vg_gqa_imgfeat/gqa_testdev_obj36.tsv' 14 | coco_file = 'data/mscoco_imgfeat/test2015_obj64.tsv' 15 | outfile = 'data/vg_gqa_imgfeat/gqa_testdev_obj64.tsv' 16 | all_id = [] 17 | all_id_number = [] 18 | coco_data = [] 19 | 20 | with open(coco_file) as f: 21 | coco_reader = csv.DictReader(f, FIELDNAMES, delimiter="\t") 22 | for i, item in enumerate(coco_reader): 23 | if i % 1000 == 0: 24 | print(i) 25 | coco_data.append(item) 26 | all_id_number.append(int(item['img_id'][14:])) 27 | 28 | with open(outfile, 'w') as tsvfile: 29 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 30 | with open(id_file) as f: 31 | id_reader = csv.DictReader(f, FIELDNAMES, delimiter="\t") 32 | for i, item in enumerate(id_reader): 33 | nlvr2_id = item['img_id'] 34 | nlvr2_id_number=int(nlvr2_id[1:]) 35 | index_id = all_id_number.index(nlvr2_id_number) 36 | coco_datum=coco_data[index_id] 37 | coco_datum['img_id'] = item['img_id'] 38 | writer.writerow(coco_datum) 39 | -------------------------------------------------------------------------------- /gqa_debug.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import os 5 | import collections 6 | 7 | import torch 8 | from tqdm import tqdm 9 | import torch.nn as nn 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | from src.param import args 13 | from src.pretrain.qa_answer_table import load_lxmert_qa 14 | from src.tasks.gqa_model import GQAModel 15 | from src.tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator 16 | 17 | 18 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') 19 | 20 | 21 | def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: 22 | dset = GQADataset(splits) 23 | tset = GQATorchDataset(dset) 24 | evaluator = GQAEvaluator(dset) 25 | data_loader = DataLoader( 26 | tset, batch_size=bs, 27 | shuffle=shuffle, num_workers=args.num_workers, 28 | drop_last=drop_last, pin_memory=True 29 | ) 30 | 31 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) 32 | 33 | 34 | class GQA: 35 | def __init__(self): 36 | self.train_tuple = get_tuple( 37 | args.train, bs=args.batch_size, shuffle=True, drop_last=True 38 | ) 39 | if args.valid != "": 40 | valid_bsize = 512 if args.multiGPU else 512 41 | self.valid_tuple = get_tuple( 42 | args.valid, bs=valid_bsize, 43 | shuffle=False, drop_last=False 44 | ) 45 | else: 46 | self.valid_tuple = None 47 | 48 | self.model = GQAModel(self.train_tuple.dataset.num_answers) 49 | 50 | # Load pre-trained weights 51 | if args.load_lxmert is not None: 52 | self.model.lxrt_encoder.load(args.load_lxmert) 53 | if args.load_lxmert_qa is not None: 54 | load_lxmert_qa(args.load_lxmert_qa, self.model, 55 | label2ans=self.train_tuple.dataset.label2ans) 56 | 57 | # GPU options 58 | self.model = self.model.cuda() 59 | if args.multiGPU: 60 | self.model.lxrt_encoder.multi_gpu() 61 | 62 | # Losses and optimizer 63 | self.bce_loss = nn.BCEWithLogitsLoss() 64 | self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) 65 | if 'bert' in args.optim: 66 | batch_per_epoch = len(self.train_tuple.loader) 67 | t_total = int(batch_per_epoch * args.epochs) 68 | print("Total Iters: %d" % t_total) 69 | from lxrt.optimization import BertAdam 70 | self.optim = BertAdam(list(self.model.parameters()), 71 | lr=args.lr, 72 | warmup=0.1, 73 | t_total=t_total) 74 | else: 75 | self.optim = args.optimizer(list(self.model.parameters()), args.lr) 76 | 77 | self.output = args.output 78 | os.makedirs(self.output, exist_ok=True) 79 | 80 | def train(self, train_tuple, eval_tuple): 81 | dset, loader, evaluator = train_tuple 82 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) 83 | 84 | best_valid = 0. 85 | for epoch in range(args.epochs): 86 | quesid2ans = {} 87 | for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): 88 | 89 | self.model.train() 90 | self.optim.zero_grad() 91 | 92 | feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda() 93 | logit = self.model(feats, boxes, sent) 94 | assert logit.dim() == target.dim() == 2 95 | if args.mce_loss: 96 | max_value, target = target.max(1) 97 | loss = self.mce_loss(logit, target) * logit.size(1) 98 | else: 99 | loss = self.bce_loss(logit, target) 100 | loss = loss * logit.size(1) 101 | 102 | loss.backward() 103 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.) 104 | self.optim.step() 105 | 106 | score, label = logit.max(1) 107 | for qid, l in zip(ques_id, label.cpu().numpy()): 108 | ans = dset.label2ans[l] 109 | quesid2ans[qid] = ans 110 | 111 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) 112 | 113 | if self.valid_tuple is not None: # Do Validation 114 | valid_score = self.evaluate(eval_tuple) 115 | if valid_score > best_valid: 116 | best_valid = valid_score 117 | self.save("BEST") 118 | 119 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ 120 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) 121 | 122 | print(log_str, end='') 123 | 124 | with open(self.output + "/log.log", 'a') as f: 125 | f.write(log_str) 126 | f.flush() 127 | 128 | self.save("LAST") 129 | 130 | def predict(self, eval_tuple: DataTuple, dump=None): 131 | self.model.eval() 132 | dset, loader, evaluator = eval_tuple 133 | quesid2ans = {} 134 | for i, datum_tuple in enumerate(loader): 135 | ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target 136 | with torch.no_grad(): 137 | feats, boxes = feats.cuda(), boxes.cuda() 138 | logit = self.model(feats, boxes, sent) 139 | score, label = logit.max(1) 140 | for qid, l in zip(ques_id, label.cpu().numpy()): 141 | ans = dset.label2ans[l] 142 | quesid2ans[qid] = ans 143 | if dump is not None: 144 | evaluator.dump_result(quesid2ans, dump) 145 | return quesid2ans 146 | 147 | def evaluate(self, eval_tuple: DataTuple, dump=None): 148 | dset, loader, evaluator = eval_tuple 149 | quesid2ans = self.predict(eval_tuple, dump) 150 | return evaluator.evaluate(quesid2ans) 151 | 152 | @staticmethod 153 | def oracle_score(data_tuple): 154 | dset, loader, evaluator = data_tuple 155 | quesid2ans = {} 156 | for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): 157 | _, label = target.max(1) 158 | for qid, l in zip(ques_id, label.cpu().numpy()): 159 | ans = dset.label2ans[l] 160 | quesid2ans[qid] = ans 161 | return evaluator.evaluate(quesid2ans) 162 | 163 | def save(self, name): 164 | torch.save(self.model.state_dict(), 165 | os.path.join(self.output, "%s.pth" % name)) 166 | 167 | def load(self, path): 168 | print("Load model from %s" % path) 169 | state_dict = torch.load("%s.pth" % path) 170 | for key in list(state_dict.keys()): 171 | if '.module' in key: 172 | state_dict[key.replace('.module', '')] = state_dict.pop(key) 173 | self.model.load_state_dict(state_dict, strict=False) 174 | 175 | 176 | if __name__ == "__main__": 177 | # Build Class 178 | gqa = GQA() 179 | 180 | # Load Model 181 | if args.load is not None: 182 | gqa.load(args.load) 183 | 184 | # Test or Train 185 | if args.test is not None: 186 | args.fast = args.tiny = False # Always loading all data in test 187 | if 'submit' in args.test: 188 | gqa.predict( 189 | get_tuple(args.test, bs=args.batch_size, 190 | shuffle=False, drop_last=False), 191 | dump=os.path.join(args.output, 'submit_predict.json') 192 | ) 193 | if 'testdev' in args.test: 194 | result = gqa.evaluate( 195 | get_tuple('testdev', bs=args.batch_size, 196 | shuffle=False, drop_last=False), 197 | dump=os.path.join(args.output, 'testdev_predict.json') 198 | ) 199 | print(result) 200 | else: 201 | # print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100)) 202 | print('Splits in Train data:', gqa.train_tuple.dataset.splits) 203 | if gqa.valid_tuple is not None: 204 | print('Splits in Valid data:', gqa.valid_tuple.dataset.splits) 205 | print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100)) 206 | else: 207 | print("DO NOT USE VALIDATION") 208 | gqa.train(gqa.train_tuple, gqa.valid_tuple) 209 | 210 | 211 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Copy from https://github.com/huggingface/pytorch-transformers 2 | 3 | # PyTorch 4 | torch>=1.0.0 5 | # Progress bars in model download and training scripts 6 | tqdm 7 | # Accessing files from S3 directly. 8 | boto3 9 | # Used for downloading models over HTTP 10 | requests 11 | 12 | -------------------------------------------------------------------------------- /run/README.md: -------------------------------------------------------------------------------- 1 | # Running Script Arguments 2 | 3 | ``` 4 | Data Splits: 5 | --train [str,str,...]: use the splits (separated by comma) in training. 6 | --valid [str,str,...]: use the splits (separated by comma) in validation. 7 | --test [str,str,...]: use the splits (separated by comma) in testing. 8 | Model Architecture: 9 | --llayers [int]: number of layers in language encoder. 10 | --xlayers [int]: number of layers in cross-modality encoder. 11 | --rlayers [int]: number of layers in object relationship encoder. 12 | Load Weights: 13 | --load [str='path/to/saved_model']: load fine-tuned model path/to/saved_model.pth. 14 | --loadLXMERT [str='path/to/saved_model']: load pre-trained model without answer heads from path/to/saved_model_LXRT.pth. 15 | --loadLXMERTQA [str='path/to/saved_model']: load pre-trained model with answer head path/to/saved_model_LXRT.pth. 16 | --fromScratch: If none of the above loading parameters are set, the default mode would 17 | load the pre-trained BERT weights. 18 | As we promised to EMNLP reviewers, the language encoder would be re-initialized with this one-line argument to test the performance without BERT weights. 19 | Training Hyper Parameters: 20 | --batchSize [int]: batch size. 21 | --optim [str]: optimizers. 22 | --lr [float]: peak learning rate. 23 | --epochs [int]: training epochs. 24 | Debugging: 25 | --tiny: Load 512 images for each data split. (Note: number of images might be changed due to dataset specification) 26 | --fast: Load 5000 images for each data split. (Note: number of images might be changed due to dataset specification) 27 | ``` 28 | 29 | # Pre-training-Specific Arguments 30 | ``` 31 | Pre-training Tasks: 32 | --taskMaskLM: use the masked language model task. 33 | --taskObjPredict: use the masked object prediction task. 34 | --taskMatched: use the cross-modality matched task. 35 | --taskQA: use the image QA task. 36 | Visual Pre-training Losses (Tasks): 37 | --visualLosses [str,str,...]: The sub-tasks in pre-training visual modality. Each one is from 'obj,attr,feat'. 38 | obj: detected-object-label classification. 39 | attr: detected-object-attribute classification. 40 | feat: RoI-feature regression. 41 | Mask Rate in Pre-training: 42 | --wordMaskRate [float]: The prob of masking a word. 43 | --objMaskRate [float]: The prob of masking an object. 44 | Initialization: 45 | --fromScratch: The default mode would load the pre-trained BERT weights into the model. 46 | As we promised to EMNLP reviewers, this option would re-initialize the language encoder. 47 | ``` 48 | 49 | 50 | -------------------------------------------------------------------------------- /run/fsb2.bash: -------------------------------------------------------------------------------- 1 | # The name of experiment 2 | name=fsb2 3 | 4 | # Create dirs and make backup 5 | output=snap/pretrain/$name 6 | mkdir -p $output/src 7 | cp -r src/* $output/src/ 8 | cp $0 $output/run.bash 9 | 10 | # Pre-training 11 | CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ 12 | python src/pretrain/lxmert_pretrain2.py \ 13 | --taskMaskLM --taskObjPredict --taskMatched --taskQA \ 14 | --visualLosses obj,attr,feat \ 15 | --wordMaskRate 0.15 --objMaskRate 0.15 \ 16 | --train mscoco_train,mscoco_nominival,vgnococo --valid mscoco_minival \ 17 | --llayers 9 --xlayers 5 --rlayers 5 \ 18 | --batchSize 196 --optim bert --lr 5e-5 --epochs 20 \ 19 | --MV_size 500 --ML_size 500 \ 20 | --tqdm --output $output --bert_type ft_same --multiGPU ${@:2} 21 | 22 | -------------------------------------------------------------------------------- /run/gqa_finetuneft.bash: -------------------------------------------------------------------------------- 1 | # The name of this experiment. 2 | name=$2 3 | 4 | # Save logs and models under snap/gqa; make backup. 5 | output=snap/gqa/fb4/Epoch20/$name 6 | mkdir -p $output/src 7 | cp -r src/* $output/src/ 8 | cp $0 $output/run.bash 9 | 10 | # See Readme.md for option details. 11 | CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ 12 | python src/tasks/gqa.py \ 13 | --train train,valid --valid testdev \ 14 | --llayers 9 --xlayers 5 --rlayers 5 \ 15 | --loadLXMERTQA snap/pretrain/fb4/Epoch20 \ 16 | --MV_size 800 --ML_size 800 \ 17 | --bert_type ft\ 18 | --batchSize 32 --optim bert --lr 5e-6 --epochs 4 \ 19 | --tqdm --output $output ${@:3} 20 | -------------------------------------------------------------------------------- /run/nlvr2_ft.bash: -------------------------------------------------------------------------------- 1 | # The name of this experiment. 2 | name=$2 3 | 4 | # Save logs and models under snap/nlvr2; Make backup. 5 | output=snap/nlvr2/nfsb2/Epoch20/$name 6 | mkdir -p $output/src 7 | cp -r src/* $output/src/ 8 | cp $0 $output/run.bash 9 | 10 | # See run/Readme.md for option details. 11 | CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ 12 | python src/tasks/nlvr2.py \ 13 | --train train --valid valid \ 14 | --llayers 9 --xlayers 5 --rlayers 5 \ 15 | --loadLXMERT snap/pretrain/nfsb2/Epoch20 \ 16 | --MV_size 500 --ML_size 500 \ 17 | --bert_type ft_same \ 18 | --batchSize 32 --optim bert --lr $3 --epochs 4 \ 19 | --tqdm --multiGPU --output $output ${@:4} 20 | 21 | -------------------------------------------------------------------------------- /run/vqa_finetuneft.bash: -------------------------------------------------------------------------------- 1 | # The name of this experiment. 2 | name=$2 3 | 4 | # Save logs and models under snap/vqa; make backup. 5 | output=snap/vqa/fb3/Epoch20/$name 6 | mkdir -p $output/src 7 | cp -r src/* $output/src/ 8 | cp $0 $output/run.bash 9 | 10 | # See Readme.md for option details. 11 | CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ 12 | python src/tasks/vqa.py \ 13 | --train train,nominival --valid minival \ 14 | --llayers 9 --xlayers 5 --rlayers 5 \ 15 | --loadLXMERTQA snap/pretrain/fb3/Epoch20 \ 16 | --batchSize 32 --optim bert \ 17 | --MV_size 800 --ML_size 800 \ 18 | --lr_schedule warmup_linear_yx --lr 5e-5 --lr_min 0 --epochs 4 \ 19 | --tqdm --output $output --bert_type ft ${@:3} 20 | 21 | # output=snap/vqa/ft20/$name 22 | #--loadLXMERTQA snap/pretrain/ftlxmert/Epoch20 \ 23 | -------------------------------------------------------------------------------- /src/lxrt/entry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 project LXRT. 3 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import os 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | from lxrt.tokenization import BertTokenizer 24 | from lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG 25 | 26 | 27 | class InputFeatures(object): 28 | """A single set of features of data.""" 29 | 30 | def __init__(self, input_ids, input_mask, segment_ids): 31 | self.input_ids = input_ids 32 | self.input_mask = input_mask 33 | self.segment_ids = segment_ids 34 | 35 | 36 | def convert_sents_to_features(sents, max_seq_length, tokenizer): 37 | """Loads a data file into a list of `InputBatch`s.""" 38 | 39 | features = [] 40 | for (i, sent) in enumerate(sents): 41 | tokens_a = tokenizer.tokenize(sent.strip()) 42 | 43 | # Account for [CLS] and [SEP] with "- 2" 44 | if len(tokens_a) > max_seq_length - 2: 45 | tokens_a = tokens_a[:(max_seq_length - 2)] 46 | 47 | # Keep segment id which allows loading BERT-weights. 48 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 49 | segment_ids = [0] * len(tokens) 50 | 51 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 52 | 53 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 54 | # tokens are attended to. 55 | input_mask = [1] * len(input_ids) 56 | 57 | # Zero-pad up to the sequence length. 58 | padding = [0] * (max_seq_length - len(input_ids)) 59 | input_ids += padding 60 | input_mask += padding 61 | segment_ids += padding 62 | 63 | assert len(input_ids) == max_seq_length 64 | assert len(input_mask) == max_seq_length 65 | assert len(segment_ids) == max_seq_length 66 | 67 | features.append( 68 | InputFeatures(input_ids=input_ids, 69 | input_mask=input_mask, 70 | segment_ids=segment_ids)) 71 | return features 72 | 73 | 74 | def set_visual_config(args): 75 | VISUAL_CONFIG.l_layers = args.llayers 76 | VISUAL_CONFIG.x_layers = args.xlayers 77 | VISUAL_CONFIG.r_layers = args.rlayers 78 | 79 | 80 | class LXRTEncoder(nn.Module): 81 | def __init__(self, args, max_seq_length, mode='x'): 82 | super().__init__() 83 | self.max_seq_length = max_seq_length 84 | set_visual_config(args) 85 | 86 | # Using the bert tokenizer 87 | self.tokenizer = BertTokenizer.from_pretrained( 88 | "bert-base-uncased", 89 | do_lower_case=True 90 | ) 91 | 92 | # Build LXRT Model 93 | self.model = VisualBertForLXRFeature.from_pretrained( 94 | "bert-base-uncased", 95 | args=args, 96 | mode=mode 97 | ) 98 | 99 | if args.from_scratch: 100 | print("initializing all the weights") 101 | self.model.apply(self.model.init_bert_weights) 102 | 103 | def multi_gpu(self): 104 | self.model = nn.DataParallel(self.model) 105 | 106 | @property 107 | def dim(self): 108 | return 768 109 | 110 | def forward(self, sents, feats, visual_attention_mask=None): 111 | train_features = convert_sents_to_features( 112 | sents, self.max_seq_length, self.tokenizer) 113 | 114 | input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda() 115 | input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda() 116 | segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda() 117 | 118 | output = self.model(input_ids, segment_ids, input_mask, 119 | visual_feats=feats, 120 | visual_attention_mask=visual_attention_mask) 121 | return output 122 | 123 | def save(self, path): 124 | torch.save(self.model.state_dict(), 125 | os.path.join("%s_LXRT.pth" % path)) 126 | 127 | def load(self, path): 128 | # Load state_dict from snapshot file 129 | print("Load LXMERT pre-trained model from %s" % path) 130 | # state_dict = torch.load("%s_LXRT.pth" % path) 131 | state_dict = torch.load("%s_LXRT.pth" % path)['state_dict'] 132 | new_state_dict = {} 133 | for key, value in state_dict.items(): 134 | if key.startswith("module."): 135 | new_state_dict[key[len("module."):]] = value 136 | else: 137 | new_state_dict[key] = value 138 | state_dict = new_state_dict 139 | 140 | # Print out the differences of pre-trained and model weights. 141 | load_keys = set(state_dict.keys()) 142 | model_keys = set(self.model.state_dict().keys()) 143 | print() 144 | print("Weights in loaded but not in model:") 145 | for key in sorted(load_keys.difference(model_keys)): 146 | print(key) 147 | print() 148 | print("Weights in model but not in loaded:") 149 | for key in sorted(model_keys.difference(load_keys)): 150 | print(key) 151 | print() 152 | 153 | # Load weights to model 154 | self.model.load_state_dict(state_dict, strict=False) 155 | tt=0 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/lxrt/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | import json 7 | import logging 8 | import os 9 | import shutil 10 | import tempfile 11 | from functools import wraps 12 | from hashlib import sha256 13 | import sys 14 | from io import open 15 | 16 | import boto3 17 | import requests 18 | from botocore.exceptions import ClientError 19 | from tqdm import tqdm 20 | 21 | try: 22 | from urllib.parse import urlparse 23 | except ImportError: 24 | from urlparse import urlparse 25 | 26 | try: 27 | from pathlib import Path 28 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 29 | Path.home() / '.pytorch_pretrained_bert')) 30 | print(PYTORCH_PRETRAINED_BERT_CACHE) 31 | except (AttributeError, ImportError): 32 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 33 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 34 | print(PYTORCH_PRETRAINED_BERT_CACHE) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | response = requests.head(url, allow_redirects=True) 192 | if response.status_code != 200: 193 | raise IOError("HEAD request failed for url {} with status code {}" 194 | .format(url, response.status_code)) 195 | etag = response.headers.get("ETag") 196 | 197 | filename = url_to_filename(url, etag) 198 | 199 | # get cache path to put the file 200 | cache_path = os.path.join(cache_dir, filename) 201 | 202 | if not os.path.exists(cache_path): 203 | # Download to temporary file, then copy to cache dir once finished. 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. 205 | with tempfile.NamedTemporaryFile() as temp_file: 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 207 | 208 | # GET file object 209 | if url.startswith("s3://"): 210 | s3_get(url, temp_file) 211 | else: 212 | http_get(url, temp_file) 213 | 214 | # we are copying the file before closing it, so flush to avoid truncation 215 | temp_file.flush() 216 | # shutil.copyfileobj() starts at the current position, so go to the start 217 | temp_file.seek(0) 218 | 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 220 | with open(cache_path, 'wb') as cache_file: 221 | shutil.copyfileobj(temp_file, cache_file) 222 | 223 | logger.info("creating metadata file for %s", cache_path) 224 | meta = {'url': url, 'etag': etag} 225 | meta_path = cache_path + '.json' 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 227 | json.dump(meta, meta_file) 228 | 229 | logger.info("removing temp file %s", temp_file.name) 230 | 231 | return cache_path 232 | 233 | 234 | def read_set_from_file(filename): 235 | ''' 236 | Extract a de-duped collection (set) of text from a file. 237 | Expected file format is one item per line. 238 | ''' 239 | collection = set() 240 | with open(filename, 'r', encoding='utf-8') as file_: 241 | for line in file_: 242 | collection.add(line.rstrip()) 243 | return collection 244 | 245 | 246 | def get_file_extension(path, dot=True, lower=True): 247 | ext = os.path.splitext(path)[1] 248 | ext = ext if dot else ext[1:] 249 | return ext.lower() if lower else ext 250 | -------------------------------------------------------------------------------- /src/lxrt/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 project LXRT 3 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch optimization for BERT model.""" 17 | 18 | import math 19 | import torch 20 | from torch.optim import Optimizer 21 | from torch.optim.optimizer import required 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002, opt = None): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002, opt = None): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002, opt = None): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | def warmup_linear_yx(x, warmup=0.002, opt = None): 46 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 47 | After `t_total`-th training step, learning rate is zero. """ 48 | if x < warmup: 49 | return x/warmup 50 | lr_max = 1 51 | lr_min = opt.lr_min/opt.lr 52 | b = (lr_max - lr_min*warmup)/(1-warmup) 53 | a = (lr_min - lr_max)/(1-warmup) 54 | return a*x+b 55 | 56 | def warmup_stair(x, warmup=0.002, opt = None): 57 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 58 | After `t_total`-th training step, learning rate is zero. """ 59 | if x < warmup: 60 | return x/warmup 61 | return max((x-1.)/(warmup-1.), 0) 62 | 63 | SCHEDULES = { 64 | 'warmup_cosine': warmup_cosine, 65 | 'warmup_constant': warmup_constant, 66 | 'warmup_linear': warmup_linear, 67 | 'warmup_linear_yx': warmup_linear_yx, 68 | 'warmup_stair': warmup_stair, 69 | } 70 | 71 | 72 | class BertAdam(Optimizer): 73 | """Implements BERT version of Adam algorithm with weight decay fix. 74 | Params: 75 | lr: learning rate 76 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 77 | t_total: total number of training steps for the learning 78 | rate schedule, -1 means constant learning rate. Default: -1 79 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 80 | b1: Adams b1. Default: 0.9 81 | b2: Adams b2. Default: 0.999 82 | e: Adams epsilon. Default: 1e-6 83 | weight_decay: Weight decay. Default: 0.01 84 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 85 | """ 86 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 87 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 88 | max_grad_norm=1.0, args=None): 89 | if lr is not required and lr < 0.0: 90 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 91 | if schedule not in SCHEDULES: 92 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 93 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 94 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 95 | if not 0.0 <= b1 < 1.0: 96 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 97 | if not 0.0 <= b2 < 1.0: 98 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 99 | if not e >= 0.0: 100 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 101 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 102 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 103 | max_grad_norm=max_grad_norm) 104 | self.args = args 105 | super(BertAdam, self).__init__(params, defaults) 106 | 107 | def get_lr(self): 108 | lr = [] 109 | for group in self.param_groups: 110 | for p in group['params']: 111 | state = self.state[p] 112 | if len(state) == 0: 113 | return [0] 114 | if group['t_total'] != -1: 115 | schedule_fct = SCHEDULES[group['schedule']] 116 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'], self.args) 117 | else: 118 | lr_scheduled = group['lr'] 119 | lr.append(lr_scheduled) 120 | return lr 121 | 122 | def step(self, closure=None): 123 | """Performs a single optimization step. 124 | 125 | Arguments: 126 | closure (callable, optional): A closure that reevaluates the model 127 | and returns the loss. 128 | """ 129 | loss = None 130 | if closure is not None: 131 | loss = closure() 132 | 133 | warned_for_t_total = False 134 | 135 | for group in self.param_groups: 136 | for p in group['params']: 137 | if p.grad is None: 138 | continue 139 | grad = p.grad.data 140 | if grad.is_sparse: 141 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 142 | 143 | state = self.state[p] 144 | 145 | # State initialization 146 | if len(state) == 0: 147 | state['step'] = 0 148 | # Exponential moving average of gradient values 149 | state['next_m'] = torch.zeros_like(p.data) 150 | # Exponential moving average of squared gradient values 151 | state['next_v'] = torch.zeros_like(p.data) 152 | 153 | next_m, next_v = state['next_m'], state['next_v'] 154 | beta1, beta2 = group['b1'], group['b2'] 155 | 156 | # LXRT: grad is clipped outside. 157 | # Add grad clipping 158 | # if group['max_grad_norm'] > 0: 159 | # clip_grad_norm_(p, group['max_grad_norm']) 160 | 161 | # Decay the first and second moment running average coefficient 162 | # In-place operations to update the averages at the same time 163 | next_m.mul_(beta1).add_(1 - beta1, grad) 164 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 165 | update = next_m / (next_v.sqrt() + group['e']) 166 | 167 | # Just adding the square of the weights to the loss function is *not* 168 | # the correct way of using L2 regularization/weight decay with Adam, 169 | # since that will interact with the m and v parameters in strange ways. 170 | # 171 | # Instead we want to decay the weights in a manner that doesn't interact 172 | # with the m/v parameters. This is equivalent to adding the square 173 | # of the weights to the loss with plain (non-momentum) SGD. 174 | if group['weight_decay'] > 0.0: 175 | update += group['weight_decay'] * p.data 176 | 177 | if group['t_total'] != -1: 178 | schedule_fct = SCHEDULES[group['schedule']] 179 | progress = state['step']/group['t_total'] 180 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'], self.args) 181 | # warning for exceeding t_total (only active with warmup_linear 182 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: 183 | logger.warning( 184 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " 185 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) 186 | warned_for_t_total = True 187 | # end warning 188 | else: 189 | lr_scheduled = group['lr'] 190 | 191 | update_with_lr = lr_scheduled * update 192 | p.data.add_(-update_with_lr) 193 | 194 | state['step'] += 1 195 | 196 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 197 | # No bias correction 198 | # bias_correction1 = 1 - beta1 ** state['step'] 199 | # bias_correction2 = 1 - beta2 ** state['step'] 200 | 201 | return loss 202 | -------------------------------------------------------------------------------- /src/lxrt/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | import collections 18 | import logging 19 | import os 20 | import unicodedata 21 | from io import open 22 | 23 | from .file_utils import cached_path 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 28 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 29 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 30 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 31 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 32 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 33 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 34 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 35 | } 36 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 37 | 'bert-base-uncased': 512, 38 | 'bert-large-uncased': 512, 39 | 'bert-base-cased': 512, 40 | 'bert-large-cased': 512, 41 | 'bert-base-multilingual-uncased': 512, 42 | 'bert-base-multilingual-cased': 512, 43 | 'bert-base-chinese': 512, 44 | } 45 | VOCAB_NAME = 'vocab.txt' 46 | 47 | 48 | def load_vocab(vocab_file): 49 | """Loads a vocabulary file into a dictionary.""" 50 | vocab = collections.OrderedDict() 51 | index = 0 52 | with open(vocab_file, "r", encoding="utf-8") as reader: 53 | while True: 54 | token = reader.readline() 55 | if not token: 56 | break 57 | token = token.strip() 58 | vocab[token] = index 59 | index += 1 60 | return vocab 61 | 62 | 63 | def whitespace_tokenize(text): 64 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 65 | text = text.strip() 66 | if not text: 67 | return [] 68 | tokens = text.split() 69 | return tokens 70 | 71 | 72 | class BertTokenizer(object): 73 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 74 | 75 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 76 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 77 | """Constructs a BertTokenizer. 78 | 79 | Args: 80 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 81 | do_lower_case: Whether to lower case the input 82 | Only has an effect when do_wordpiece_only=False 83 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 84 | max_len: An artificial maximum length to truncate tokenized sequences to; 85 | Effective maximum length is always the minimum of this 86 | value (if specified) and the underlying BERT model's 87 | sequence length. 88 | never_split: List of tokens which will never be split during tokenization. 89 | Only has an effect when do_wordpiece_only=False 90 | """ 91 | if not os.path.isfile(vocab_file): 92 | raise ValueError( 93 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 94 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 95 | self.vocab = load_vocab(vocab_file) 96 | self.ids_to_tokens = collections.OrderedDict( 97 | [(ids, tok) for tok, ids in self.vocab.items()]) 98 | self.do_basic_tokenize = do_basic_tokenize 99 | if do_basic_tokenize: 100 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 101 | never_split=never_split) 102 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 103 | self.max_len = max_len if max_len is not None else int(1e12) 104 | 105 | def tokenize(self, text): 106 | if self.do_basic_tokenize: 107 | split_tokens = [] 108 | for token in self.basic_tokenizer.tokenize(text): 109 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 110 | split_tokens.append(sub_token) 111 | else: 112 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 113 | return split_tokens 114 | 115 | def convert_tokens_to_ids(self, tokens): 116 | """Converts a sequence of tokens into ids using the vocab.""" 117 | ids = [] 118 | for token in tokens: 119 | ids.append(self.vocab[token]) 120 | if len(ids) > self.max_len: 121 | logger.warning( 122 | "Token indices sequence length is longer than the specified maximum " 123 | " sequence length for this BERT model ({} > {}). Running this" 124 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 125 | ) 126 | return ids 127 | 128 | def convert_ids_to_tokens(self, ids): 129 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 130 | tokens = [] 131 | for i in ids: 132 | tokens.append(self.ids_to_tokens[i]) 133 | return tokens 134 | 135 | @classmethod 136 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 137 | """ 138 | Instantiate a PreTrainedBertModel from a pre-trained model file. 139 | Download and cache the pre-trained model file if needed. 140 | """ 141 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 142 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 143 | else: 144 | vocab_file = pretrained_model_name_or_path 145 | if os.path.isdir(vocab_file): 146 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 147 | # redirect to the cache, if necessary 148 | try: 149 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 150 | except EnvironmentError: 151 | logger.error( 152 | "Model name '{}' was not found in model name list ({}). " 153 | "We assumed '{}' was a path or url but couldn't find any file " 154 | "associated to this path or url.".format( 155 | pretrained_model_name_or_path, 156 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 157 | vocab_file)) 158 | return None 159 | if resolved_vocab_file == vocab_file: 160 | logger.info("loading vocabulary file {}".format(vocab_file)) 161 | else: 162 | logger.info("loading vocabulary file {} from cache at {}".format( 163 | vocab_file, resolved_vocab_file)) 164 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 165 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 166 | # than the number of positional embeddings 167 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 168 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 169 | # Instantiate tokenizer. 170 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 171 | return tokenizer 172 | 173 | 174 | class BasicTokenizer(object): 175 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 176 | 177 | def __init__(self, 178 | do_lower_case=True, 179 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 180 | """Constructs a BasicTokenizer. 181 | 182 | Args: 183 | do_lower_case: Whether to lower case the input. 184 | """ 185 | self.do_lower_case = do_lower_case 186 | self.never_split = never_split 187 | 188 | def tokenize(self, text): 189 | """Tokenizes a piece of text.""" 190 | text = self._clean_text(text) 191 | # This was added on November 1st, 2018 for the multilingual and Chinese 192 | # models. This is also applied to the English models now, but it doesn't 193 | # matter since the English models were not trained on any Chinese data 194 | # and generally don't have any Chinese data in them (there are Chinese 195 | # characters in the vocabulary because Wikipedia does have some Chinese 196 | # words in the English Wikipedia.). 197 | text = self._tokenize_chinese_chars(text) 198 | orig_tokens = whitespace_tokenize(text) 199 | split_tokens = [] 200 | for token in orig_tokens: 201 | if self.do_lower_case and token not in self.never_split: 202 | token = token.lower() 203 | token = self._run_strip_accents(token) 204 | split_tokens.extend(self._run_split_on_punc(token)) 205 | 206 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 207 | return output_tokens 208 | 209 | def _run_strip_accents(self, text): 210 | """Strips accents from a piece of text.""" 211 | text = unicodedata.normalize("NFD", text) 212 | output = [] 213 | for char in text: 214 | cat = unicodedata.category(char) 215 | if cat == "Mn": 216 | continue 217 | output.append(char) 218 | return "".join(output) 219 | 220 | def _run_split_on_punc(self, text): 221 | """Splits punctuation on a piece of text.""" 222 | if text in self.never_split: 223 | return [text] 224 | chars = list(text) 225 | i = 0 226 | start_new_word = True 227 | output = [] 228 | while i < len(chars): 229 | char = chars[i] 230 | if _is_punctuation(char): 231 | output.append([char]) 232 | start_new_word = True 233 | else: 234 | if start_new_word: 235 | output.append([]) 236 | start_new_word = False 237 | output[-1].append(char) 238 | i += 1 239 | 240 | return ["".join(x) for x in output] 241 | 242 | def _tokenize_chinese_chars(self, text): 243 | """Adds whitespace around any CJK character.""" 244 | output = [] 245 | for char in text: 246 | cp = ord(char) 247 | if self._is_chinese_char(cp): 248 | output.append(" ") 249 | output.append(char) 250 | output.append(" ") 251 | else: 252 | output.append(char) 253 | return "".join(output) 254 | 255 | def _is_chinese_char(self, cp): 256 | """Checks whether CP is the codepoint of a CJK character.""" 257 | # This defines a "chinese character" as anything in the CJK Unicode block: 258 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 259 | # 260 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 261 | # despite its name. The modern Korean Hangul alphabet is a different block, 262 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 263 | # space-separated words, so they are not treated specially and handled 264 | # like the all of the other languages. 265 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 266 | (cp >= 0x3400 and cp <= 0x4DBF) or # 267 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 268 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 269 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 270 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 271 | (cp >= 0xF900 and cp <= 0xFAFF) or # 272 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 273 | return True 274 | 275 | return False 276 | 277 | def _clean_text(self, text): 278 | """Performs invalid character removal and whitespace cleanup on text.""" 279 | output = [] 280 | for char in text: 281 | cp = ord(char) 282 | if cp == 0 or cp == 0xfffd or _is_control(char): 283 | continue 284 | if _is_whitespace(char): 285 | output.append(" ") 286 | else: 287 | output.append(char) 288 | return "".join(output) 289 | 290 | 291 | class WordpieceTokenizer(object): 292 | """Runs WordPiece tokenization.""" 293 | 294 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 295 | self.vocab = vocab 296 | self.unk_token = unk_token 297 | self.max_input_chars_per_word = max_input_chars_per_word 298 | 299 | def tokenize(self, text): 300 | """Tokenizes a piece of text into its word pieces. 301 | 302 | This uses a greedy longest-match-first algorithm to perform tokenization 303 | using the given vocabulary. 304 | 305 | For example: 306 | input = "unaffable" 307 | output = ["un", "##aff", "##able"] 308 | 309 | Args: 310 | text: A single token or whitespace separated tokens. This should have 311 | already been passed through `BasicTokenizer`. 312 | 313 | Returns: 314 | A list of wordpiece tokens. 315 | """ 316 | 317 | output_tokens = [] 318 | for token in whitespace_tokenize(text): 319 | chars = list(token) 320 | if len(chars) > self.max_input_chars_per_word: 321 | output_tokens.append(self.unk_token) 322 | continue 323 | 324 | is_bad = False 325 | start = 0 326 | sub_tokens = [] 327 | while start < len(chars): 328 | end = len(chars) 329 | cur_substr = None 330 | while start < end: 331 | substr = "".join(chars[start:end]) 332 | if start > 0: 333 | substr = "##" + substr 334 | if substr in self.vocab: 335 | cur_substr = substr 336 | break 337 | end -= 1 338 | if cur_substr is None: 339 | is_bad = True 340 | break 341 | sub_tokens.append(cur_substr) 342 | start = end 343 | 344 | if is_bad: 345 | output_tokens.append(self.unk_token) 346 | else: 347 | output_tokens.extend(sub_tokens) 348 | return output_tokens 349 | 350 | 351 | def _is_whitespace(char): 352 | """Checks whether `chars` is a whitespace character.""" 353 | # \t, \n, and \r are technically contorl characters but we treat them 354 | # as whitespace since they are generally considered as such. 355 | if char == " " or char == "\t" or char == "\n" or char == "\r": 356 | return True 357 | cat = unicodedata.category(char) 358 | if cat == "Zs": 359 | return True 360 | return False 361 | 362 | 363 | def _is_control(char): 364 | """Checks whether `chars` is a control character.""" 365 | # These are technically control characters but we count them as whitespace 366 | # characters. 367 | if char == "\t" or char == "\n" or char == "\r": 368 | return False 369 | cat = unicodedata.category(char) 370 | if cat.startswith("C"): 371 | return True 372 | return False 373 | 374 | 375 | def _is_punctuation(char): 376 | """Checks whether `chars` is a punctuation character.""" 377 | cp = ord(char) 378 | # We treat all non-letter/number ASCII as punctuation. 379 | # Characters such as "^", "$", and "`" are not in the Unicode 380 | # Punctuation class but we treat them as punctuation anyways, for 381 | # consistency. 382 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 383 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 384 | return True 385 | cat = unicodedata.category(char) 386 | if cat.startswith("P"): 387 | return True 388 | return False 389 | -------------------------------------------------------------------------------- /src/param.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import argparse 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def get_optimizer(optim): 12 | # Bind the optimizer 13 | if optim == 'rms': 14 | print("Optimizer: Using RMSProp") 15 | optimizer = torch.optim.RMSprop 16 | elif optim == 'adam': 17 | print("Optimizer: Using Adam") 18 | optimizer = torch.optim.Adam 19 | elif optim == 'adamax': 20 | print("Optimizer: Using Adamax") 21 | optimizer = torch.optim.Adamax 22 | elif optim == 'sgd': 23 | print("Optimizer: sgd") 24 | optimizer = torch.optim.SGD 25 | elif 'bert' in optim: 26 | optimizer = 'bert' # The bert optimizer will be bind later. 27 | else: 28 | assert False, "Please add your optimizer %s in the list." % optim 29 | 30 | return optimizer 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser() 35 | 36 | # Data Splits 37 | parser.add_argument("--train", default='train') 38 | parser.add_argument("--valid", default='valid') 39 | parser.add_argument("--test", default=None) 40 | 41 | # Training Hyper-parameters 42 | parser.add_argument('--batchSize', dest='batch_size', type=int, default=256) 43 | parser.add_argument('--optim', default='bert') 44 | parser.add_argument('--lr', type=float, default=1e-4) 45 | # parser.add_argument('--lr', type=str, default="1e-4") 46 | parser.add_argument('--epochs', type=int, default=10) 47 | parser.add_argument('--dropout', type=float, default=0.1) 48 | parser.add_argument('--seed', type=int, default=9595, help='random seed') 49 | 50 | # Debugging 51 | parser.add_argument('--output', type=str, default='snap/test') 52 | parser.add_argument("--fast", action='store_const', default=False, const=True) 53 | parser.add_argument("--tiny", action='store_const', default=False, const=True) 54 | parser.add_argument("--tqdm", action='store_const', default=False, const=True) 55 | 56 | # Model Loading 57 | parser.add_argument('--load', type=str, default=None, 58 | help='Load the model (usually the fine-tuned model).') 59 | parser.add_argument('--loadLXMERT', dest='load_lxmert', type=str, default=None, 60 | help='Load the pre-trained LXMERT model.') 61 | parser.add_argument('--loadLXMERTQA', dest='load_lxmert_qa', type=str, default=None, 62 | help='Load the pre-trained LXMERT model with QA answer head.') 63 | parser.add_argument("--fromScratch", dest='from_scratch', action='store_const', default=False, const=True, 64 | help='If none of the --load, --loadLXMERT, --loadLXMERTQA is set, ' 65 | 'the model would be trained from scratch. If --fromScratch is' 66 | ' not specified, the model would load BERT-pre-trained weights by' 67 | ' default. ') 68 | 69 | # Optimization 70 | parser.add_argument("--mceLoss", dest='mce_loss', action='store_const', default=False, const=True) 71 | 72 | # LXRT Model Config 73 | # Note: LXRT = L, X, R (three encoders), Transformer 74 | parser.add_argument("--llayers", default=9, type=int, help='Number of Language layers') 75 | parser.add_argument("--xlayers", default=5, type=int, help='Number of CROSS-modality layers.') 76 | parser.add_argument("--rlayers", default=5, type=int, help='Number of object Relationship layers.') 77 | 78 | # LXMERT Pre-training Config 79 | parser.add_argument("--taskMatched", dest='task_matched', action='store_const', default=False, const=True) 80 | parser.add_argument("--taskMaskLM", dest='task_mask_lm', action='store_const', default=False, const=True) 81 | parser.add_argument("--taskObjPredict", dest='task_obj_predict', action='store_const', default=False, const=True) 82 | parser.add_argument("--taskQA", dest='task_qa', action='store_const', default=False, const=True) 83 | parser.add_argument("--visualLosses", dest='visual_losses', default='obj,attr,feat', type=str) 84 | parser.add_argument("--qaSets", dest='qa_sets', default=None, type=str) 85 | parser.add_argument("--wordMaskRate", dest='word_mask_rate', default=0.15, type=float) 86 | parser.add_argument("--objMaskRate", dest='obj_mask_rate', default=0.15, type=float) 87 | 88 | # Training configuration 89 | parser.add_argument("--fp16", action='store_const', default=False, const=True) 90 | parser.add_argument("--multiGPU", action='store_const', default=False, const=True) 91 | parser.add_argument("--numWorkers", dest='num_workers', default=0) 92 | 93 | # whether ft 94 | parser.add_argument("--bert_type", type=str, default='bert') 95 | parser.add_argument("--cross_type", type=str, default='g2g') 96 | parser.add_argument("--ft_large", type=int, default=0) 97 | parser.add_argument("--MV_size", type=int, default=500) 98 | parser.add_argument("--ML_size", type=int, default=500) 99 | 100 | #start from check point 101 | parser.add_argument("--start_from", type=int, default=0) 102 | 103 | #distributed running 104 | parser.add_argument('--local_rank', default=0, type=int, 105 | help='node rank for distributed training') 106 | 107 | #optimizer related 108 | parser.add_argument("--lr_schedule", type=str, default='warmup_linear') 109 | parser.add_argument("--lr_max", type=float, default=5e-5) 110 | parser.add_argument("--lr_min", type=float, default=6e-5) 111 | 112 | parser.add_argument("--pretraining_index", type=int, default=0) 113 | 114 | #init dict 115 | parser.add_argument("--kmeans", type=int, default=1) 116 | parser.add_argument("--mv_path", type=str, default='data2/lxmert/mv_path.npy') 117 | parser.add_argument("--ml_path", type=str, default='data2/lxmert/ml_path.npy') 118 | 119 | # Parse the arguments. 120 | args = parser.parse_args() 121 | 122 | # Bind optimizer class. 123 | args.optimizer = get_optimizer(args.optim) 124 | 125 | # Set seeds 126 | torch.manual_seed(args.seed) 127 | random.seed(args.seed) 128 | np.random.seed(args.seed) 129 | 130 | return args 131 | 132 | 133 | args = parse_args() 134 | -------------------------------------------------------------------------------- /src/pretrain/lxmert_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | from collections import defaultdict 5 | import json 6 | import random 7 | 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | 11 | from param import args 12 | from pretrain.qa_answer_table import AnswerTable 13 | from utils import load_obj_tsv 14 | 15 | TINY_IMG_NUM = 500 16 | FAST_IMG_NUM = 5000 17 | # 18 | # Split2ImgFeatPath = { 19 | # 'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv', 20 | # 'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv', 21 | # 'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv', 22 | # 'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv', 23 | # } 24 | 25 | Split2ImgFeatPath = { 26 | 'mscoco_train': 'data/mscoco_imgfeat/train2014_obj64.tsv', 27 | 'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj64.tsv', 28 | 'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj64.tsv', 29 | 'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj64.tsv', 30 | } 31 | 32 | 33 | class InputExample(object): 34 | """A single training/test example for the language model.""" 35 | def __init__(self, uid, sent, visual_feats=None, 36 | obj_labels=None, attr_labels=None, 37 | is_matched=None, label=None): 38 | self.uid = uid 39 | self.sent = sent 40 | self.visual_feats = visual_feats 41 | self.obj_labels = obj_labels 42 | self.attr_labels = attr_labels 43 | self.is_matched = is_matched # whether the visual and obj matched 44 | self.label = label 45 | 46 | 47 | class LXMERTDataset: 48 | def __init__(self, splits: str, qa_sets=None): 49 | """ 50 | :param splits: The data sources to be loaded 51 | :param qa_sets: if None, no action 52 | o.w., only takes the answers appearing in these dsets 53 | and remove all unlabeled data (MSCOCO captions) 54 | """ 55 | self.name = splits 56 | self.sources = splits.split(',') 57 | 58 | # Loading datasets to data 59 | self.data = [] 60 | for source in self.sources: 61 | self.data.extend(json.load(open("data/lxmert/%s.json" % source))) 62 | print("Load %d data from %s" % (len(self.data), self.name)) 63 | 64 | # Create answer table according to the qa_sets 65 | self.answer_table = AnswerTable(qa_sets) 66 | print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map()))) 67 | 68 | # Modify the answers 69 | for datum in self.data: 70 | labelf = datum['labelf'] 71 | for cat, labels in labelf.items(): 72 | for label in labels: 73 | for ans in list(label.keys()): 74 | new_ans = self.answer_table.convert_ans(ans) 75 | if self.answer_table.used(new_ans): 76 | if ans != new_ans: 77 | label[new_ans] = label.pop(ans) 78 | else: 79 | label.pop(ans) 80 | 81 | def __len__(self): 82 | return len(self.data) 83 | 84 | 85 | def make_uid(img_id, dset, sent_idx): 86 | return "%s_%s_%03d" % (img_id, dset, sent_idx), 87 | 88 | 89 | """ 90 | Example in obj tsv: 91 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", 92 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] 93 | """ 94 | class LXMERTTorchDataset(Dataset): 95 | def __init__(self, dataset: LXMERTDataset, topk=-1): 96 | super().__init__() 97 | self.raw_dataset = dataset 98 | self.task_matched = args.task_matched 99 | 100 | if args.tiny: 101 | topk = TINY_IMG_NUM 102 | elif args.fast: 103 | topk = FAST_IMG_NUM 104 | 105 | # Load the dataset 106 | img_data = [] 107 | for source in self.raw_dataset.sources: 108 | img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk, args.fp16)) 109 | 110 | self.imgid2img = {} 111 | for img_datum in img_data: 112 | self.imgid2img[img_datum['img_id']] = img_datum 113 | 114 | # Filter out the dataset 115 | used_data = [] 116 | for datum in self.raw_dataset.data: 117 | if datum['img_id'] in self.imgid2img: 118 | used_data.append(datum) 119 | 120 | # Flatten the dataset (into one sent + one image entries) 121 | self.data = [] 122 | for datum in used_data: 123 | sentf = datum['sentf'] 124 | for sents_cat, sents in sentf.items(): 125 | if sents_cat in datum['labelf']: 126 | labels = datum['labelf'][sents_cat] 127 | else: 128 | labels = None 129 | for sent_idx, sent in enumerate(sents): 130 | new_datum = { 131 | 'uid': make_uid(datum['img_id'], sents_cat, sent_idx), 132 | 'img_id': datum['img_id'], 133 | 'sent': sent 134 | } 135 | if labels is not None: 136 | new_datum['label'] = labels[sent_idx] 137 | self.data.append(new_datum) 138 | print("Use %d data in torch dataset" % (len(self.data))) 139 | 140 | def __len__(self): 141 | return len(self.data) 142 | 143 | def random_feat(self): 144 | """Get a random obj feat from the dataset.""" 145 | datum = self.data[random.randint(0, len(self.data)-1)] 146 | img_id = datum['img_id'] 147 | img_info = self.imgid2img[img_id] 148 | feat = img_info['features'][random.randint(0, 35)] 149 | return feat 150 | 151 | def __getitem__(self, item: int): 152 | datum = self.data[item] 153 | 154 | uid = datum['uid'] 155 | img_id = datum['img_id'] 156 | 157 | # Get image info 158 | img_info = self.imgid2img[img_id] 159 | obj_num = img_info['num_boxes'] 160 | feats = img_info['features'].copy() 161 | if not args.fp16 and feats.dtype != np.float32: # Save space in CPU memory with format float 16 (half-precision). 162 | feats = feats.astype(np.float32) 163 | boxes = img_info['boxes'].copy() 164 | obj_labels = img_info['objects_id'].copy() 165 | obj_confs = img_info['objects_conf'].copy() 166 | attr_labels = img_info['attrs_id'].copy() 167 | attr_confs = img_info['attrs_conf'].copy() 168 | assert obj_num == len(boxes) == len(feats) 169 | 170 | # Normalize the boxes (to 0 ~ 1) 171 | img_h, img_w = img_info['img_h'], img_info['img_w'] 172 | boxes = boxes.copy() 173 | boxes[:, (0, 2)] /= img_w 174 | boxes[:, (1, 3)] /= img_h 175 | np.testing.assert_array_less(boxes, 1+1e-5) 176 | np.testing.assert_array_less(-boxes, 0+1e-5) 177 | 178 | # If calculating the matched loss, replace the sentence with an sentence 179 | # corresponding to other image. 180 | is_matched = 1 181 | sent = datum['sent'] 182 | if self.task_matched: 183 | if random.random() < 0.5: 184 | is_matched = 0 185 | other_datum = self.data[random.randint(0, len(self.data)-1)] 186 | while other_datum['img_id'] == img_id: 187 | other_datum = self.data[random.randint(0, len(self.data)-1)] 188 | sent = other_datum['sent'] 189 | 190 | # Label, convert answer to id 191 | if 'label' in datum: 192 | label = datum['label'].copy() 193 | for ans in list(label.keys()): 194 | label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans) 195 | else: 196 | label = None 197 | 198 | # Create target 199 | example = InputExample( 200 | uid, sent, (feats, boxes), 201 | (obj_labels, obj_confs), (attr_labels, attr_confs), 202 | is_matched, label 203 | ) 204 | return example 205 | 206 | 207 | class LXMERTEvaluator: 208 | def __init__(self, dataset: LXMERTDataset): 209 | self.raw_dataset = dataset 210 | 211 | # Create QA Eval Data 212 | self.data = [] 213 | for datum in self.raw_dataset.data: 214 | sentf = datum['sentf'] 215 | for sents_cat, sents in sentf.items(): 216 | if sents_cat in datum['labelf']: # A labeled dataset 217 | labels = datum['labelf'][sents_cat] 218 | for sent_idx, sent in enumerate(sents): 219 | new_datum = { 220 | 'uid': make_uid(datum['img_id'], sents_cat, sent_idx), 221 | 'img_id': datum['img_id'], 222 | 'sent': sent, 223 | 'dset': sents_cat, 224 | 'label': labels[sent_idx] 225 | } 226 | self.data.append(new_datum) 227 | 228 | # uid2datum 229 | self.uid2datum = {} 230 | for datum in self.data: 231 | self.uid2datum[datum['uid']] = datum 232 | 233 | def evaluate(self, uid2ans: dict, pprint=False): 234 | score = 0. 235 | cnt = 0 236 | dset2score = defaultdict(lambda: 0.) 237 | dset2cnt = defaultdict(lambda: 0) 238 | for uid, ans in uid2ans.items(): 239 | if uid not in self.uid2datum: # Not a labeled data 240 | continue 241 | datum = self.uid2datum[uid] 242 | label = datum['label'] 243 | dset = datum['dset'] 244 | if ans in label: 245 | score += label[ans] 246 | dset2score[dset] += label[ans] 247 | cnt += 1 248 | dset2cnt[dset] += 1 249 | accu = score / cnt 250 | dset2accu = {} 251 | for dset in dset2cnt: 252 | dset2accu[dset] = dset2score[dset] / dset2cnt[dset] 253 | 254 | if pprint: 255 | accu_str = "Overall Accu %0.4f, " % (accu) 256 | sorted_keys = sorted(dset2accu.keys()) 257 | for key in sorted_keys: 258 | accu_str += "%s Accu %0.4f, " % (key, dset2accu[key]) 259 | print(accu_str) 260 | 261 | return accu, dset2accu 262 | 263 | def dump_result(self, uid2ans: dict, path): 264 | raise NotImplemented 265 | -------------------------------------------------------------------------------- /src/pretrain/lxmert_pretrain.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import collections 5 | import os 6 | import random 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | 14 | from param import args 15 | from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator 16 | from lxrt.entry import set_visual_config 17 | from lxrt.tokenization import BertTokenizer 18 | from lxrt.modeling import LXRTPretraining 19 | 20 | DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator') 21 | 22 | 23 | def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple: 24 | # Decide which QA datasets would be used in pre-training. 25 | # Options: vqa, gqa, visual7w 26 | # Note: visual7w is a part of vgqa, we take the name here. 27 | qa_sets = args.qa_sets 28 | if qa_sets is not None: 29 | qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(",")) 30 | 31 | # Build dataset, data loader, and evaluator. 32 | dset = LXMERTDataset(splits, qa_sets=qa_sets) 33 | tset = LXMERTTorchDataset(dset, topk) 34 | data_loader = DataLoader( 35 | tset, batch_size=bs, 36 | shuffle=shuffle, num_workers=args.num_workers, 37 | collate_fn=lambda x: x, 38 | drop_last=drop_last, pin_memory=True 39 | ) 40 | evaluator = LXMERTEvaluator(dset) 41 | print() 42 | 43 | return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator) 44 | 45 | 46 | train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True) 47 | valid_batch_size = 1024 if args.multiGPU else 512 48 | valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000) 49 | 50 | 51 | class InputFeatures(object): 52 | """A single set of features of data.""" 53 | 54 | def __init__(self, 55 | input_ids, input_mask, segment_ids, lm_label_ids, 56 | visual_feats, obj_labels, 57 | is_matched, ans): 58 | self.input_ids = input_ids 59 | self.input_mask = input_mask 60 | self.segment_ids = segment_ids 61 | self.lm_label_ids = lm_label_ids 62 | 63 | self.visual_feats = visual_feats 64 | self.obj_labels = obj_labels 65 | 66 | self.is_matched = is_matched 67 | 68 | self.ans = ans 69 | 70 | 71 | def random_word(tokens, tokenizer): 72 | """ 73 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper. 74 | :param tokens: list of str, tokenized sentence. 75 | :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here) 76 | :return: (list of str, list of int), masked tokens and related labels for LM prediction 77 | """ 78 | output_label = [] 79 | 80 | for i, token in enumerate(tokens): 81 | prob = random.random() 82 | # mask token with probability 83 | ratio = args.word_mask_rate 84 | if prob < ratio: 85 | prob /= ratio 86 | 87 | # 80% randomly change token to mask token 88 | if prob < 0.8: 89 | tokens[i] = "[MASK]" 90 | 91 | # 10% randomly change token to random token 92 | elif prob < 0.9: 93 | tokens[i] = random.choice(list(tokenizer.vocab.items()))[0] 94 | 95 | # -> rest 10% randomly keep current token 96 | 97 | # append current token to output (we will predict these later) 98 | try: 99 | output_label.append(tokenizer.vocab[token]) 100 | except KeyError: 101 | # For unknown words (should not occur with BPE vocab) 102 | output_label.append(tokenizer.vocab["[UNK]"]) 103 | else: 104 | # no masking token (will be ignored by loss function later) 105 | output_label.append(-1) 106 | 107 | return tokens, output_label 108 | 109 | 110 | def random_feat(feats): 111 | mask_feats = feats.copy() 112 | feat_mask = np.zeros(len(feats), dtype=np.float32) 113 | for i in range(len(feats)): 114 | prob = random.random() 115 | # mask token with probability 116 | if prob < args.obj_mask_rate: 117 | prob /= args.obj_mask_rate 118 | 119 | # 80% randomly change token to zero feat 120 | if prob < 0.8: 121 | mask_feats[i, :] = 0. 122 | 123 | # 10% randomly change token to random feat 124 | elif prob < 0.9: 125 | mask_feats[i, :] = train_tuple.torchdset.random_feat() 126 | # -> rest 10% randomly keep current feat 127 | 128 | # Need to predict this feat 129 | feat_mask[i] = 1. 130 | 131 | return mask_feats, feat_mask 132 | 133 | 134 | def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures: 135 | """ 136 | Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with 137 | IDs, LM labels, input_mask, CLS and SEP tokens etc. 138 | :param example: InputExample, containing sentence input as strings and is_next label 139 | :param max_seq_length: int, maximum length of sequence. 140 | :param tokenizer: Tokenizer 141 | :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training) 142 | """ 143 | tokens = tokenizer.tokenize(example.sent.strip()) 144 | 145 | # Account for [CLS] and [SEP] with "- 2" 146 | if len(tokens) > max_seq_length - 2: 147 | tokens = tokens[:(max_seq_length - 2)] 148 | 149 | # Ge random words 150 | masked_tokens, masked_label = random_word(tokens, tokenizer) 151 | 152 | # concatenate lm labels and account for CLS, SEP, SEP 153 | masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]'] 154 | input_ids = tokenizer.convert_tokens_to_ids(masked_tokens) 155 | 156 | # Mask & Segment Word 157 | lm_label_ids = ([-1] + masked_label + [-1]) 158 | input_mask = [1] * len(input_ids) 159 | segment_ids = [0] * len(input_ids) 160 | 161 | # Zero-pad up to the sequence length. 162 | while len(input_ids) < max_seq_length: 163 | input_ids.append(0) 164 | input_mask.append(0) 165 | segment_ids.append(0) 166 | lm_label_ids.append(-1) 167 | 168 | assert len(input_ids) == max_seq_length 169 | assert len(input_mask) == max_seq_length 170 | assert len(segment_ids) == max_seq_length 171 | assert len(lm_label_ids) == max_seq_length 172 | 173 | feat, boxes = example.visual_feats 174 | obj_labels, obj_confs = example.obj_labels 175 | attr_labels, attr_confs = example.attr_labels 176 | 177 | # Mask Image Features: 178 | masked_feat, feat_mask = random_feat(feat) 179 | 180 | # QA answer label 181 | if example.label is None or len(example.label) == 0 or example.is_matched != 1: 182 | # 1. No label 2. Label is pruned 3. unmatched visual + language pair 183 | ans = -1 184 | else: 185 | keys, values = zip(*example.label.items()) 186 | if len(keys) == 1: 187 | ans = keys[0] 188 | else: 189 | value_sum = sum(values) 190 | prob = [value / value_sum for value in values] 191 | choice = np.random.multinomial(1, prob).argmax() 192 | ans = keys[choice] 193 | 194 | features = InputFeatures( 195 | input_ids=input_ids, 196 | input_mask=input_mask, 197 | segment_ids=segment_ids, 198 | lm_label_ids=lm_label_ids, 199 | visual_feats=(masked_feat, boxes), 200 | obj_labels={ 201 | 'obj': (obj_labels, obj_confs), 202 | 'attr': (attr_labels, attr_confs), 203 | 'feat': (feat, feat_mask), 204 | }, 205 | is_matched=example.is_matched, 206 | ans=ans, 207 | ) 208 | return features 209 | 210 | 211 | LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA') 212 | 213 | 214 | class LXMERT: 215 | def __init__(self, max_seq_length): 216 | super().__init__() 217 | self.max_seq_length = max_seq_length 218 | 219 | self.tokenizer = BertTokenizer.from_pretrained( 220 | "bert-base-uncased", 221 | do_lower_case=True 222 | ) 223 | 224 | # Build model 225 | set_visual_config(args) 226 | self.model = LXRTPretraining.from_pretrained( 227 | "bert-base-uncased", 228 | task_mask_lm=args.task_mask_lm, 229 | task_obj_predict=args.task_obj_predict, 230 | task_matched=args.task_matched, 231 | task_qa=args.task_qa, 232 | visual_losses=args.visual_losses, 233 | num_answers=train_tuple.dataset.answer_table.num_answers, 234 | args = args 235 | ) 236 | 237 | # Weight initialization and loading 238 | if args.from_scratch: 239 | print("Train from Scratch: re-initialize all BERT weights.") 240 | self.model.apply(self.model.init_bert_weights) 241 | if args.load is not None: 242 | self.load(args.load) 243 | if args.load_lxmert is not None: 244 | # Load lxmert would not load the answer head. 245 | self.load_lxmert(args.load_lxmert) 246 | # GPU Options 247 | self.model = self.model.cuda() 248 | if args.multiGPU: 249 | #self.model = nn.DataParallel(self.model) 250 | self.multiGPU = 1 251 | else: 252 | self.multiGPU = 0 253 | # self.model = nn.DistributedDataParallel(self.model) 254 | self.output = args.output 255 | os.makedirs(self.output, exist_ok=True) 256 | 257 | def forward(self, examples): 258 | train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer) 259 | for example in examples] 260 | 261 | # language Inputs 262 | input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda() 263 | input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda() 264 | segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda() 265 | 266 | # Visual Inputs 267 | feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda() 268 | pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda() 269 | 270 | # Language Prediction 271 | lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda() 272 | 273 | # Visual Prediction 274 | obj_labels = {} 275 | for key in ('obj', 'attr', 'feat'): 276 | visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda() 277 | visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda() 278 | assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1) 279 | obj_labels[key] = (visn_labels, visn_mask) 280 | 281 | # Joint Prediction 282 | matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda() 283 | ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda() 284 | 285 | """ 286 | forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, 287 | visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None): 288 | """ 289 | loss, losses, ans_logit = self.model( 290 | input_ids, segment_ids, input_mask, lm_labels, 291 | feats, pos, obj_labels, matched_labels, ans 292 | ) 293 | if self.multiGPU: 294 | loss = loss.mean() 295 | losses = losses.mean(0) 296 | 297 | return loss, losses.detach().cpu(), ans_logit 298 | 299 | def train_batch(self, optim, batch): 300 | optim.zero_grad() 301 | loss, losses, ans_logit = self.forward(batch) 302 | # if args.multiGPU: 303 | # loss = loss.mean() 304 | # losses = losses.mean(0) 305 | if args.fp16: 306 | try: 307 | from apex import amp 308 | except ImportError: 309 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 310 | with amp.scale_loss(loss, optim) as scaled_loss: 311 | scaled_loss.backward() 312 | torch.nn.utils.clip_grad_norm_(amp.master_params(optim), 1.) 313 | else: 314 | loss.backward() 315 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.) 316 | 317 | optim.step() 318 | 319 | return loss.item(), losses.cpu().numpy(), ans_logit 320 | 321 | def valid_batch(self, batch): 322 | with torch.no_grad(): 323 | loss, losses, ans_logit = self.forward(batch) 324 | # if args.multiGPU: 325 | # loss = loss.mean() 326 | # losses = losses.mean(0) 327 | return loss.item(), losses.cpu().numpy(), ans_logit 328 | 329 | def train(self, train_tuple: DataTuple, eval_tuple: DataTuple): 330 | train_ld = train_tuple.loader 331 | 332 | # Optimizer 333 | from lxrt.optimization import BertAdam 334 | batch_per_epoch = len(train_ld) 335 | t_total = int(batch_per_epoch * args.epochs) 336 | warmup_ratio = 0.05 337 | warmup_iters = int(t_total * warmup_ratio) 338 | print("Batch per epoch: %d" % batch_per_epoch) 339 | print("Total Iters: %d" % t_total) 340 | print("Warm up Iters: %d" % warmup_iters) 341 | optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total) 342 | start_epoch = 0 343 | 344 | if args.fp16: 345 | try: 346 | from apex import amp 347 | except ImportError: 348 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 349 | self.model, optim = amp.initialize(self.model, optim, opt_level='O1') 350 | 351 | # GPU Options 352 | if args.multiGPU: 353 | self.model = nn.DataParallel(self.model) 354 | 355 | if args.start_from >0 and args.pretraining_index == 0: 356 | start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02')) 357 | print('start training from {0}'.format(start_path)) 358 | state = torch.load(start_path) 359 | self.model.load_state_dict(state['state_dict']) 360 | optim.load_state_dict(state['optimizer'],strict=False) 361 | start_epoch = args.start_from 362 | del state 363 | torch.cuda.empty_cache() 364 | elif args.start_from >0 and args.pretraining_index == 1: 365 | start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02')) 366 | print('start training from {0}'.format(start_path)) 367 | state = torch.load(start_path) 368 | self.model.load_state_dict(state['state_dict'],strict=False) 369 | del state 370 | torch.cuda.empty_cache() 371 | 372 | # Train 373 | best_eval_loss = 9595. 374 | for epoch in range(start_epoch, args.epochs): 375 | # Train 376 | self.model.train() 377 | total_loss = 0. 378 | total_losses = 0. 379 | uid2ans = {} 380 | for batch in tqdm(train_ld, total=len(train_ld)): 381 | loss, losses, logit = self.train_batch(optim, batch) 382 | total_loss += loss 383 | total_losses += losses 384 | 385 | if args.task_qa: 386 | score, label = logit.max(1) 387 | for datum, l in zip(batch, label.cpu().numpy()): 388 | uid = datum.uid 389 | ans = train_tuple.dataset.answer_table.id2ans(l) 390 | uid2ans[uid] = ans 391 | 392 | print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch)) 393 | log_str = "\nThe training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch) 394 | losses_str = "\nThe losses are " 395 | log_str += "\nThe losses are " 396 | for name, loss in zip(LOSSES_NAME, total_losses): 397 | losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch) 398 | log_str += "\n %s: %0.4f " % (name, loss / batch_per_epoch) 399 | print(losses_str) 400 | with open(self.output + "/log.log", 'a') as f: 401 | f.write(log_str) 402 | f.flush() 403 | if args.task_qa: 404 | train_tuple.evaluator.evaluate(uid2ans, pprint=True) 405 | 406 | # Eval 407 | avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1) 408 | 409 | state = { 410 | 'state_dict': self.model.state_dict(), 411 | 'optimizer': optim.state_dict(), 412 | } 413 | 414 | # Save 415 | if avg_eval_loss < best_eval_loss: 416 | best_eval_loss = avg_eval_loss 417 | self.save("BEST_EVAL_LOSS",state) 418 | if args.pretraining_index == 0: 419 | self.save("Epoch%02d" % (epoch+1), state) 420 | elif args.pretraining_index == 1: 421 | self.save("Epoch%02d" % (epoch + 1 + args.start_from), state) 422 | 423 | def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1): 424 | self.model.eval() 425 | eval_ld = eval_tuple.loader 426 | total_loss = 0. 427 | total_losses = 0. 428 | uid2ans = {} 429 | for i, batch in enumerate(eval_ld): 430 | loss, losses, logit = self.valid_batch(batch) 431 | total_loss += loss 432 | total_losses += losses 433 | if args.task_qa: 434 | score, label = logit.max(1) 435 | for datum, l in zip(batch, label.cpu().numpy()): 436 | uid = datum.uid 437 | ans = train_tuple.dataset.answer_table.id2ans(l) 438 | uid2ans[uid] = ans 439 | if i == iters: 440 | break 441 | 442 | print("The valid loss is %0.4f" % (total_loss / len(eval_ld))) 443 | log_str = "\nThe valid loss is %0.4f" % (total_loss / len(eval_ld)) 444 | losses_str = "\nThe losses are " 445 | log_str += "\nThe losses are " 446 | for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)): 447 | losses_str += "%s: %0.4f " % (name, loss) 448 | log_str += "\n%s: %0.4f " % (name, loss) 449 | print(losses_str) 450 | with open(self.output + "/log.log", 'a') as f: 451 | f.write(log_str) 452 | f.flush() 453 | 454 | if args.task_qa: 455 | eval_tuple.evaluator.evaluate(uid2ans, pprint=True) 456 | 457 | return total_loss / len(eval_ld) 458 | 459 | def save(self, name, state): 460 | torch.save(state, 461 | os.path.join(args.output, "%s_LXRT.pth" % name)) 462 | 463 | 464 | def load(self, path): 465 | print("Load BERT extractor from %s" % path) 466 | state_dict = torch.load("%s_LXRT.pth" % path) 467 | self.model.load_state_dict(state_dict) 468 | 469 | def load_lxmert(self, path): 470 | print("Load LXMERT model from %s" % path) 471 | state_dict = torch.load("%s_LXRT.pth" % path) 472 | 473 | # Do not load any answer head 474 | for key in list(state_dict.keys()): 475 | if 'answer' in key: 476 | state_dict.pop(key) 477 | 478 | # Change Multi GPU to single GPU 479 | new_state_dict = {} 480 | for key, value in state_dict.items(): 481 | if key.startswith("module."): 482 | new_state_dict[key[len("module."):]] = value 483 | state_dict = new_state_dict 484 | 485 | load_keys = set(state_dict.keys()) 486 | model_keys = set(self.model.state_dict().keys()) 487 | print() 488 | print("Keys in loaded but not in model:") 489 | for key in sorted(load_keys.difference(model_keys)): 490 | print(key) 491 | print() 492 | print("Keys in model but not in loaded:") 493 | for key in sorted(model_keys.difference(load_keys)): 494 | print(key) 495 | print() 496 | 497 | self.model.load_state_dict(state_dict, strict=False) 498 | 499 | 500 | if __name__ == "__main__": 501 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3" 502 | 503 | lxmert = LXMERT(max_seq_length=20) 504 | 505 | lxmert.train(train_tuple, valid_tuple) 506 | -------------------------------------------------------------------------------- /src/pretrain/lxmert_pretrain2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import collections 5 | import os 6 | import random 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | 14 | from param import args 15 | from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator 16 | from lxrt.entry import set_visual_config 17 | from lxrt.tokenization import BertTokenizer 18 | from lxrt.modeling import LXRTPretraining 19 | 20 | DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator') 21 | 22 | 23 | def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple: 24 | # Decide which QA datasets would be used in pre-training. 25 | # Options: vqa, gqa, visual7w 26 | # Note: visual7w is a part of vgqa, we take the name here. 27 | qa_sets = args.qa_sets 28 | if qa_sets is not None: 29 | qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(",")) 30 | 31 | # Build dataset, data loader, and evaluator. 32 | dset = LXMERTDataset(splits, qa_sets=qa_sets) 33 | tset = LXMERTTorchDataset(dset, topk) 34 | data_loader = DataLoader( 35 | tset, batch_size=bs, 36 | shuffle=shuffle, num_workers=args.num_workers, 37 | collate_fn=lambda x: x, 38 | drop_last=drop_last, pin_memory=True 39 | ) 40 | evaluator = LXMERTEvaluator(dset) 41 | print() 42 | 43 | return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator) 44 | 45 | 46 | train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True) 47 | valid_batch_size = 1024 if args.multiGPU else 512 48 | valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000) 49 | 50 | 51 | class InputFeatures(object): 52 | """A single set of features of data.""" 53 | 54 | def __init__(self, 55 | input_ids, input_mask, segment_ids, lm_label_ids, 56 | visual_feats, visual_masks, obj_labels, 57 | is_matched, ans, random_index): 58 | self.input_ids = input_ids 59 | self.input_mask = input_mask 60 | self.segment_ids = segment_ids 61 | self.lm_label_ids = lm_label_ids 62 | 63 | self.visual_feats = visual_feats 64 | self.visual_masks = visual_masks 65 | self.obj_labels = obj_labels 66 | 67 | self.is_matched = is_matched 68 | 69 | self.ans = ans 70 | self.random_index = random_index 71 | 72 | 73 | def random_word(tokens, tokenizer, random_index): 74 | """ 75 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper. 76 | :param tokens: list of str, tokenized sentence. 77 | :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here) 78 | :param random_index: if 0, mask sentence, if 1, do not mask 79 | :return: (list of str, list of int), masked tokens and related labels for LM prediction 80 | """ 81 | output_label = [] 82 | 83 | for i, token in enumerate(tokens): 84 | prob = random.random() 85 | # mask token with probability 86 | ratio = args.word_mask_rate 87 | if prob < ratio and random_index == 0: 88 | prob /= ratio 89 | 90 | # 80% randomly change token to mask token 91 | if prob < 0.8: 92 | tokens[i] = "[MASK]" 93 | 94 | # 10% randomly change token to random token 95 | elif prob < 0.9: 96 | tokens[i] = random.choice(list(tokenizer.vocab.items()))[0] 97 | 98 | # -> rest 10% randomly keep current token 99 | 100 | # append current token to output (we will predict these later) 101 | try: 102 | output_label.append(tokenizer.vocab[token]) 103 | except KeyError: 104 | # For unknown words (should not occur with BPE vocab) 105 | output_label.append(tokenizer.vocab["[UNK]"]) 106 | else: 107 | # no masking token (will be ignored by loss function later) 108 | output_label.append(-1) 109 | 110 | return tokens, output_label 111 | 112 | 113 | def random_feat(feats, random_index): 114 | """ 115 | :param random_index: if 1, mask image, if 0, do not mask 116 | """ 117 | mask_feats = feats.copy() 118 | feat_mask = np.zeros(len(feats), dtype=np.float32) 119 | for i in range(len(feats)): 120 | prob = random.random() 121 | # mask token with probability 122 | if prob < args.obj_mask_rate and random_index == 1: 123 | prob /= args.obj_mask_rate 124 | 125 | # 80% randomly change token to zero feat 126 | if prob < 0.8: 127 | mask_feats[i, :] = 0. 128 | 129 | # 10% randomly change token to random feat 130 | elif prob < 0.9: 131 | mask_feats[i, :] = train_tuple.torchdset.random_feat() 132 | # -> rest 10% randomly keep current feat 133 | 134 | # Need to predict this feat 135 | feat_mask[i] = 1. 136 | 137 | return mask_feats, feat_mask 138 | 139 | 140 | def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures: 141 | """ 142 | Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with 143 | IDs, LM labels, input_mask, CLS and SEP tokens etc. 144 | :param example: InputExample, containing sentence input as strings and is_next label 145 | :param max_seq_length: int, maximum length of sequence. 146 | :param tokenizer: Tokenizer 147 | :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training) 148 | """ 149 | tokens = tokenizer.tokenize(example.sent.strip()) 150 | 151 | # Account for [CLS] and [SEP] with "- 2" 152 | if len(tokens) > max_seq_length - 2: 153 | tokens = tokens[:(max_seq_length - 2)] 154 | 155 | #if random_index = 0, mask sentence, random_index = 1, mask image 156 | random_index = np.random.randint(2) 157 | 158 | # Ge random words 159 | masked_tokens, masked_label = random_word(tokens, tokenizer, random_index) 160 | 161 | # concatenate lm labels and account for CLS, SEP, SEP 162 | masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]'] 163 | input_ids = tokenizer.convert_tokens_to_ids(masked_tokens) 164 | 165 | # Mask & Segment Word 166 | lm_label_ids = ([-1] + masked_label + [-1]) 167 | input_mask = [1] * len(input_ids) 168 | segment_ids = [0] * len(input_ids) 169 | 170 | # Zero-pad up to the sequence length. 171 | while len(input_ids) < max_seq_length: 172 | input_ids.append(0) 173 | input_mask.append(0) 174 | segment_ids.append(0) 175 | lm_label_ids.append(-1) 176 | 177 | assert len(input_ids) == max_seq_length 178 | assert len(input_mask) == max_seq_length 179 | assert len(segment_ids) == max_seq_length 180 | assert len(lm_label_ids) == max_seq_length 181 | 182 | feat, boxes = example.visual_feats 183 | obj_labels, obj_confs = example.obj_labels 184 | attr_labels, attr_confs = example.attr_labels 185 | 186 | # Mask Image Features: 187 | masked_feat, feat_mask = random_feat(feat, random_index) 188 | obj_confs = feat_mask 189 | attr_confs = feat_mask 190 | visn_masks = 1-feat_mask 191 | 192 | # QA answer label 193 | if example.label is None or len(example.label) == 0 or example.is_matched != 1: 194 | # 1. No label 2. Label is pruned 3. unmatched visual + language pair 195 | ans = -1 196 | else: 197 | keys, values = zip(*example.label.items()) 198 | if len(keys) == 1: 199 | ans = keys[0] 200 | else: 201 | value_sum = sum(values) 202 | prob = [value / value_sum for value in values] 203 | choice = np.random.multinomial(1, prob).argmax() 204 | ans = keys[choice] 205 | 206 | features = InputFeatures( 207 | input_ids=input_ids, 208 | input_mask=input_mask, 209 | segment_ids=segment_ids, 210 | lm_label_ids=lm_label_ids, 211 | visual_feats=(masked_feat, boxes), 212 | visual_masks = visn_masks, 213 | obj_labels={ 214 | 'obj': (obj_labels, obj_confs), 215 | 'attr': (attr_labels, attr_confs), 216 | 'feat': (feat, feat_mask), 217 | }, 218 | is_matched=example.is_matched, 219 | ans=ans, 220 | random_index=random_index 221 | ) 222 | return features 223 | 224 | 225 | LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA') 226 | 227 | 228 | class LXMERT: 229 | def __init__(self, max_seq_length): 230 | super().__init__() 231 | self.max_seq_length = max_seq_length 232 | 233 | self.tokenizer = BertTokenizer.from_pretrained( 234 | "bert-base-uncased", 235 | do_lower_case=True 236 | ) 237 | 238 | # Build model 239 | set_visual_config(args) 240 | self.model = LXRTPretraining.from_pretrained( 241 | "bert-base-uncased", 242 | task_mask_lm=args.task_mask_lm, 243 | task_obj_predict=args.task_obj_predict, 244 | task_matched=args.task_matched, 245 | task_qa=args.task_qa, 246 | visual_losses=args.visual_losses, 247 | num_answers=train_tuple.dataset.answer_table.num_answers, 248 | args = args 249 | ) 250 | 251 | # Weight initialization and loading 252 | if args.from_scratch: 253 | print("Train from Scratch: re-initialize all BERT weights.") 254 | self.model.apply(self.model.init_bert_weights) 255 | if args.load is not None: 256 | self.load(args.load) 257 | if args.load_lxmert is not None: 258 | # Load lxmert would not load the answer head. 259 | self.load_lxmert(args.load_lxmert) 260 | # GPU Options 261 | self.model = self.model.cuda() 262 | if args.multiGPU: 263 | #self.model = nn.DataParallel(self.model) 264 | self.multiGPU = 1 265 | else: 266 | self.multiGPU = 0 267 | # self.model = nn.DistributedDataParallel(self.model) 268 | self.output = args.output 269 | os.makedirs(self.output, exist_ok=True) 270 | 271 | def forward(self, examples): 272 | train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer) 273 | for example in examples] 274 | 275 | # language Inputs 276 | input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda() 277 | input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda() 278 | segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda() 279 | 280 | # Visual Inputs 281 | feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda() 282 | pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda() 283 | visual_mask = torch.tensor([f.visual_masks for f in train_features], dtype=torch.long).cuda() 284 | 285 | # Language Prediction 286 | lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda() 287 | 288 | 289 | # Visual Prediction 290 | obj_labels = {} 291 | for key in ('obj', 'attr', 'feat'): 292 | visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda() 293 | visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda() 294 | assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1) 295 | obj_labels[key] = (visn_labels, visn_mask) 296 | 297 | # Joint Prediction 298 | matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda() 299 | ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda() 300 | 301 | """ 302 | forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, 303 | visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None): 304 | """ 305 | loss, losses, ans_logit = self.model( 306 | input_ids, segment_ids, input_mask, lm_labels, 307 | feats, pos, obj_labels, matched_labels, ans, visual_mask 308 | ) 309 | if self.multiGPU: 310 | loss = loss.mean() 311 | losses = losses.mean(0) 312 | 313 | return loss, losses.detach().cpu(), ans_logit 314 | 315 | def train_batch(self, optim, batch): 316 | optim.zero_grad() 317 | loss, losses, ans_logit = self.forward(batch) 318 | # if args.multiGPU: 319 | # loss = loss.mean() 320 | # losses = losses.mean(0) 321 | if args.fp16: 322 | try: 323 | from apex import amp 324 | except ImportError: 325 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 326 | with amp.scale_loss(loss, optim) as scaled_loss: 327 | scaled_loss.backward() 328 | torch.nn.utils.clip_grad_norm_(amp.master_params(optim), 1.) 329 | else: 330 | loss.backward() 331 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.) 332 | 333 | optim.step() 334 | 335 | return loss.item(), losses.cpu().numpy(), ans_logit 336 | 337 | def valid_batch(self, batch): 338 | with torch.no_grad(): 339 | loss, losses, ans_logit = self.forward(batch) 340 | # if args.multiGPU: 341 | # loss = loss.mean() 342 | # losses = losses.mean(0) 343 | return loss.item(), losses.cpu().numpy(), ans_logit 344 | 345 | def train(self, train_tuple: DataTuple, eval_tuple: DataTuple): 346 | train_ld = train_tuple.loader 347 | 348 | # Optimizer 349 | from lxrt.optimization import BertAdam 350 | batch_per_epoch = len(train_ld) 351 | t_total = int(batch_per_epoch * args.epochs) 352 | warmup_ratio = 0.05 353 | warmup_iters = int(t_total * warmup_ratio) 354 | print("Batch per epoch: %d" % batch_per_epoch) 355 | print("Total Iters: %d" % t_total) 356 | print("Warm up Iters: %d" % warmup_iters) 357 | optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total) 358 | start_epoch = 0 359 | 360 | if args.fp16: 361 | try: 362 | from apex import amp 363 | except ImportError: 364 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 365 | self.model, optim = amp.initialize(self.model, optim, opt_level='O1') 366 | 367 | 368 | # GPU Options 369 | if args.multiGPU: 370 | self.model = nn.DataParallel(self.model) 371 | 372 | if args.start_from >0 and args.pretraining_index == 0: 373 | start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02')) 374 | print('start training from {0}'.format(start_path)) 375 | state = torch.load(start_path) 376 | self.model.load_state_dict(state['state_dict'],strict=False) 377 | optim.load_state_dict(state['optimizer']) 378 | start_epoch = args.start_from 379 | del state 380 | torch.cuda.empty_cache() 381 | elif args.start_from >0 and args.pretraining_index == 1: 382 | start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02')) 383 | print('start training from {0}'.format(start_path)) 384 | state = torch.load(start_path) 385 | self.model.load_state_dict(state['state_dict'],strict=False) 386 | del state 387 | torch.cuda.empty_cache() 388 | 389 | # if args.start_from >0: 390 | # start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02')) 391 | # print('start training from {0}'.format(start_path)) 392 | # state = torch.load(start_path) 393 | # self.model.load_state_dict(state['state_dict']) 394 | # optim.load_state_dict(state['optimizer']) 395 | # start_epoch = args.start_from 396 | # del state 397 | # torch.cuda.empty_cache() 398 | 399 | 400 | # Train 401 | best_eval_loss = 9595. 402 | for epoch in range(start_epoch, args.epochs): 403 | # Train 404 | self.model.train() 405 | total_loss = 0. 406 | total_losses = 0. 407 | uid2ans = {} 408 | for batch in tqdm(train_ld, total=len(train_ld)): 409 | loss, losses, logit = self.train_batch(optim, batch) 410 | total_loss += loss 411 | total_losses += losses 412 | 413 | if args.task_qa: 414 | score, label = logit.max(1) 415 | for datum, l in zip(batch, label.cpu().numpy()): 416 | uid = datum.uid 417 | ans = train_tuple.dataset.answer_table.id2ans(l) 418 | uid2ans[uid] = ans 419 | 420 | print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch)) 421 | log_str = "\nThe training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch) 422 | losses_str = "\nThe losses are " 423 | log_str += "\nThe losses are " 424 | for name, loss in zip(LOSSES_NAME, total_losses): 425 | losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch) 426 | log_str += "\n %s: %0.4f " % (name, loss / batch_per_epoch) 427 | print(losses_str) 428 | with open(self.output + "/log.log", 'a') as f: 429 | f.write(log_str) 430 | f.flush() 431 | if args.task_qa: 432 | train_tuple.evaluator.evaluate(uid2ans, pprint=True) 433 | 434 | # Eval 435 | avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1) 436 | 437 | state = { 438 | 'state_dict': self.model.state_dict(), 439 | 'optimizer': optim.state_dict(), 440 | } 441 | 442 | # Save 443 | if avg_eval_loss < best_eval_loss: 444 | best_eval_loss = avg_eval_loss 445 | self.save("BEST_EVAL_LOSS",state) 446 | 447 | self.save("Epoch%02d" % (epoch+1), state) 448 | 449 | def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1): 450 | self.model.eval() 451 | eval_ld = eval_tuple.loader 452 | total_loss = 0. 453 | total_losses = 0. 454 | uid2ans = {} 455 | for i, batch in enumerate(eval_ld): 456 | loss, losses, logit = self.valid_batch(batch) 457 | total_loss += loss 458 | total_losses += losses 459 | if args.task_qa: 460 | score, label = logit.max(1) 461 | for datum, l in zip(batch, label.cpu().numpy()): 462 | uid = datum.uid 463 | ans = train_tuple.dataset.answer_table.id2ans(l) 464 | uid2ans[uid] = ans 465 | if i == iters: 466 | break 467 | 468 | print("The valid loss is %0.4f" % (total_loss / len(eval_ld))) 469 | log_str = "\nThe valid loss is %0.4f" % (total_loss / len(eval_ld)) 470 | losses_str = "\nThe losses are " 471 | log_str += "\nThe losses are " 472 | for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)): 473 | losses_str += "%s: %0.4f " % (name, loss) 474 | log_str += "\n%s: %0.4f " % (name, loss) 475 | print(losses_str) 476 | with open(self.output + "/log.log", 'a') as f: 477 | f.write(log_str) 478 | f.flush() 479 | 480 | if args.task_qa: 481 | eval_tuple.evaluator.evaluate(uid2ans, pprint=True) 482 | 483 | return total_loss / len(eval_ld) 484 | 485 | def save(self, name, state): 486 | torch.save(state, 487 | os.path.join(args.output, "%s_LXRT.pth" % name)) 488 | 489 | 490 | def load(self, path): 491 | print("Load BERT extractor from %s" % path) 492 | state_dict = torch.load("%s_LXRT.pth" % path) 493 | self.model.load_state_dict(state_dict) 494 | 495 | def load_lxmert(self, path): 496 | print("Load LXMERT model from %s" % path) 497 | state_dict = torch.load("%s_LXRT.pth" % path) 498 | 499 | # Do not load any answer head 500 | for key in list(state_dict.keys()): 501 | if 'answer' in key: 502 | state_dict.pop(key) 503 | 504 | # Change Multi GPU to single GPU 505 | new_state_dict = {} 506 | for key, value in state_dict.items(): 507 | if key.startswith("module."): 508 | new_state_dict[key[len("module."):]] = value 509 | state_dict = new_state_dict 510 | 511 | load_keys = set(state_dict.keys()) 512 | model_keys = set(self.model.state_dict().keys()) 513 | print() 514 | print("Keys in loaded but not in model:") 515 | for key in sorted(load_keys.difference(model_keys)): 516 | print(key) 517 | print() 518 | print("Keys in model but not in loaded:") 519 | for key in sorted(model_keys.difference(load_keys)): 520 | print(key) 521 | print() 522 | 523 | self.model.load_state_dict(state_dict, strict=False) 524 | 525 | 526 | if __name__ == "__main__": 527 | # os.environ['CUDA_VISIBLE_DEVICES'] = "1" 528 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3" 529 | 530 | lxmert = LXMERT(max_seq_length=20) 531 | 532 | lxmert.train(train_tuple, valid_tuple) 533 | -------------------------------------------------------------------------------- /src/pretrain/qa_answer_table.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import json 5 | import torch 6 | 7 | 8 | class AnswerTable: 9 | ANS_CONVERT = { 10 | "a man": "man", 11 | "the man": "man", 12 | "a woman": "woman", 13 | "the woman": "woman", 14 | 'one': '1', 15 | 'two': '2', 16 | 'three': '3', 17 | 'four': '4', 18 | 'five': '5', 19 | 'six': '6', 20 | 'seven': '7', 21 | 'eight': '8', 22 | 'nine': '9', 23 | 'ten': '10', 24 | 'grey': 'gray', 25 | } 26 | 27 | def __init__(self, dsets=None): 28 | self.all_ans = json.load(open("data/lxmert/all_ans.json")) 29 | if dsets is not None: 30 | dsets = set(dsets) 31 | # If the answer is used in the dsets 32 | self.anss = [ans['ans'] for ans in self.all_ans if 33 | len(set(ans['dsets']) & dsets) > 0] 34 | else: 35 | self.anss = [ans['ans'] for ans in self.all_ans] 36 | self.ans_set = set(self.anss) 37 | 38 | self._id2ans_map = self.anss 39 | self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)} 40 | 41 | assert len(self._id2ans_map) == len(self._ans2id_map) 42 | for ans_id, ans in enumerate(self._id2ans_map): 43 | assert self._ans2id_map[ans] == ans_id 44 | 45 | def convert_ans(self, ans): 46 | if len(ans) == 0: 47 | return "" 48 | ans = ans.lower() 49 | if ans[-1] == '.': 50 | ans = ans[:-1].strip() 51 | if ans.startswith("a "): 52 | ans = ans[2:].strip() 53 | if ans.startswith("an "): 54 | ans = ans[3:].strip() 55 | if ans.startswith("the "): 56 | ans = ans[4:].strip() 57 | if ans in self.ANS_CONVERT: 58 | ans = self.ANS_CONVERT[ans] 59 | return ans 60 | 61 | def ans2id(self, ans): 62 | return self._ans2id_map[ans] 63 | 64 | def id2ans(self, ans_id): 65 | return self._id2ans_map[ans_id] 66 | 67 | def ans2id_map(self): 68 | return self._ans2id_map.copy() 69 | 70 | def id2ans_map(self): 71 | return self._id2ans_map.copy() 72 | 73 | def used(self, ans): 74 | return ans in self.ans_set 75 | 76 | def all_answers(self): 77 | return self.anss.copy() 78 | 79 | @property 80 | def num_answers(self): 81 | return len(self.anss) 82 | 83 | 84 | def load_lxmert_qa(path, model, label2ans): 85 | """ 86 | Load model weights from LXMERT pre-training. 87 | The answers in the fine-tuned QA task (indicated by label2ans) 88 | would also be properly initialized with LXMERT pre-trained 89 | QA heads. 90 | 91 | :param path: Path to LXMERT snapshot. 92 | :param model: LXRT model instance. 93 | :param label2ans: The label2ans dict of fine-tuned QA datasets, like 94 | {0: 'cat', 1: 'dog', ...} 95 | :return: 96 | """ 97 | print("Load QA pre-trained LXMERT from %s " % path) 98 | # loaded_state_dict = torch.load("%s_LXRT.pth" % path) 99 | loaded_state_dict = torch.load("%s_LXRT.pth" % path)['state_dict'] 100 | 101 | model_state_dict = model.state_dict() 102 | 103 | # Handle Multi-GPU pre-training --> Single GPU fine-tuning 104 | for key in list(loaded_state_dict.keys()): 105 | loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key) 106 | 107 | # Isolate bert model 108 | bert_state_dict = {} 109 | for key, value in loaded_state_dict.items(): 110 | if key.startswith('bert.'): 111 | bert_state_dict[key] = value 112 | 113 | # Isolate answer head 114 | answer_state_dict = {} 115 | for key, value in loaded_state_dict.items(): 116 | if key.startswith("answer_head."): 117 | answer_state_dict[key.replace('answer_head.', '')] = value 118 | 119 | # Do surgery on answer state dict 120 | ans_weight = answer_state_dict['logit_fc.3.weight'] 121 | ans_bias = answer_state_dict['logit_fc.3.bias'] 122 | import copy 123 | new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight']) 124 | new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias']) 125 | answer_table = AnswerTable() 126 | loaded = 0 127 | unload = 0 128 | if type(label2ans) is list: 129 | label2ans = {label: ans for label, ans in enumerate(label2ans)} 130 | for label, ans in label2ans.items(): 131 | new_ans = answer_table.convert_ans(ans) 132 | if answer_table.used(new_ans): 133 | ans_id_9500 = answer_table.ans2id(new_ans) 134 | new_answer_weight[label] = ans_weight[ans_id_9500] 135 | new_answer_bias[label] = ans_bias[ans_id_9500] 136 | loaded += 1 137 | else: 138 | new_answer_weight[label] = 0. 139 | new_answer_bias[label] = 0. 140 | unload += 1 141 | print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload)) 142 | print() 143 | answer_state_dict['logit_fc.3.weight'] = new_answer_weight 144 | answer_state_dict['logit_fc.3.bias'] = new_answer_bias 145 | 146 | # Load Bert Weights 147 | bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys()) 148 | bert_loaded_keys = set(bert_state_dict.keys()) 149 | # assert len(bert_model_keys - bert_loaded_keys) == 0 150 | model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False) 151 | 152 | # Load Answer Logic FC Weights 153 | model_keys = set(model.state_dict().keys()) 154 | ans_loaded_keys = set(answer_state_dict.keys()) 155 | assert len(ans_loaded_keys - model_keys) == 0 156 | 157 | model.load_state_dict(answer_state_dict, strict=False) 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /src/tasks/gqa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import os 5 | import collections 6 | 7 | import torch 8 | from tqdm import tqdm 9 | import torch.nn as nn 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | from param import args 13 | from pretrain.qa_answer_table import load_lxmert_qa 14 | from tasks.gqa_model import GQAModel 15 | from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator 16 | 17 | 18 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') 19 | 20 | 21 | def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: 22 | dset = GQADataset(splits) 23 | tset = GQATorchDataset(dset) 24 | evaluator = GQAEvaluator(dset) 25 | data_loader = DataLoader( 26 | tset, batch_size=bs, 27 | shuffle=shuffle, num_workers=args.num_workers, 28 | drop_last=drop_last, pin_memory=True 29 | ) 30 | 31 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) 32 | 33 | 34 | class GQA: 35 | def __init__(self): 36 | self.train_tuple = get_tuple( 37 | args.train, bs=args.batch_size, shuffle=True, drop_last=True 38 | ) 39 | if args.valid != "": 40 | valid_bsize = 512 if args.multiGPU else 512 41 | self.valid_tuple = get_tuple( 42 | args.valid, bs=valid_bsize, 43 | shuffle=False, drop_last=False 44 | ) 45 | else: 46 | self.valid_tuple = None 47 | 48 | self.model = GQAModel(self.train_tuple.dataset.num_answers) 49 | 50 | # Load pre-trained weights 51 | if args.load_lxmert is not None: 52 | self.model.lxrt_encoder.load(args.load_lxmert) 53 | if args.load_lxmert_qa is not None: 54 | load_lxmert_qa(args.load_lxmert_qa, self.model, 55 | label2ans=self.train_tuple.dataset.label2ans) 56 | 57 | # GPU options 58 | self.model = self.model.cuda() 59 | if args.multiGPU: 60 | self.model.lxrt_encoder.multi_gpu() 61 | 62 | # Losses and optimizer 63 | self.bce_loss = nn.BCEWithLogitsLoss() 64 | self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) 65 | if 'bert' in args.optim: 66 | batch_per_epoch = len(self.train_tuple.loader) 67 | t_total = int(batch_per_epoch * args.epochs) 68 | print("Total Iters: %d" % t_total) 69 | from lxrt.optimization import BertAdam 70 | self.optim = BertAdam(list(self.model.parameters()), 71 | lr=args.lr, 72 | warmup=0.1, 73 | t_total=t_total) 74 | else: 75 | self.optim = args.optimizer(list(self.model.parameters()), args.lr) 76 | 77 | self.output = args.output 78 | os.makedirs(self.output, exist_ok=True) 79 | 80 | def train(self, train_tuple, eval_tuple): 81 | dset, loader, evaluator = train_tuple 82 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) 83 | 84 | best_valid = 0. 85 | for epoch in range(args.epochs): 86 | quesid2ans = {} 87 | for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): 88 | 89 | self.model.train() 90 | self.optim.zero_grad() 91 | 92 | feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda() 93 | logit = self.model(feats, boxes, sent) 94 | assert logit.dim() == target.dim() == 2 95 | if args.mce_loss: 96 | max_value, target = target.max(1) 97 | loss = self.mce_loss(logit, target) * logit.size(1) 98 | else: 99 | loss = self.bce_loss(logit, target) 100 | loss = loss * logit.size(1) 101 | 102 | loss.backward() 103 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.) 104 | self.optim.step() 105 | 106 | score, label = logit.max(1) 107 | for qid, l in zip(ques_id, label.cpu().numpy()): 108 | ans = dset.label2ans[l] 109 | quesid2ans[qid] = ans 110 | 111 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) 112 | 113 | if self.valid_tuple is not None: # Do Validation 114 | valid_score = self.evaluate(eval_tuple) 115 | if valid_score > best_valid: 116 | best_valid = valid_score 117 | self.save("BEST") 118 | 119 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ 120 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) 121 | 122 | print(log_str, end='') 123 | 124 | with open(self.output + "/log.log", 'a') as f: 125 | f.write(log_str) 126 | f.flush() 127 | 128 | self.save("LAST") 129 | 130 | def predict(self, eval_tuple: DataTuple, dump=None): 131 | self.model.eval() 132 | dset, loader, evaluator = eval_tuple 133 | quesid2ans = {} 134 | for i, datum_tuple in enumerate(loader): 135 | if i%100 == 0: 136 | print(i) 137 | ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target 138 | with torch.no_grad(): 139 | feats, boxes = feats.cuda(), boxes.cuda() 140 | logit = self.model(feats, boxes, sent) 141 | score, label = logit.max(1) 142 | for qid, l in zip(ques_id, label.cpu().numpy()): 143 | ans = dset.label2ans[l] 144 | quesid2ans[qid] = ans 145 | if dump is not None: 146 | evaluator.dump_result(quesid2ans, dump) 147 | return quesid2ans 148 | 149 | def evaluate(self, eval_tuple: DataTuple, dump=None): 150 | dset, loader, evaluator = eval_tuple 151 | quesid2ans = self.predict(eval_tuple, dump) 152 | return evaluator.evaluate(quesid2ans) 153 | 154 | @staticmethod 155 | def oracle_score(data_tuple): 156 | dset, loader, evaluator = data_tuple 157 | quesid2ans = {} 158 | for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): 159 | _, label = target.max(1) 160 | for qid, l in zip(ques_id, label.cpu().numpy()): 161 | ans = dset.label2ans[l] 162 | quesid2ans[qid] = ans 163 | return evaluator.evaluate(quesid2ans) 164 | 165 | def save(self, name): 166 | torch.save(self.model.state_dict(), 167 | os.path.join(self.output, "%s.pth" % name)) 168 | 169 | def load(self, path): 170 | print("Load model from %s" % path) 171 | state_dict = torch.load("%s.pth" % path) 172 | for key in list(state_dict.keys()): 173 | if '.module' in key: 174 | state_dict[key.replace('.module', '')] = state_dict.pop(key) 175 | self.model.load_state_dict(state_dict, strict=False) 176 | 177 | 178 | if __name__ == "__main__": 179 | # Build Class 180 | gqa = GQA() 181 | 182 | # Load Model 183 | if args.load is not None: 184 | gqa.load(args.load) 185 | 186 | # Test or Train 187 | if args.test is not None: 188 | args.fast = args.tiny = False # Always loading all data in test 189 | if 'submit' in args.test: 190 | gqa.predict( 191 | get_tuple(args.test, bs=args.batch_size, 192 | shuffle=False, drop_last=False), 193 | dump=os.path.join(args.output, 'submit_predict.json') 194 | ) 195 | if 'testdev' in args.test: 196 | result = gqa.evaluate( 197 | get_tuple('testdev', bs=args.batch_size, 198 | shuffle=False, drop_last=False), 199 | dump=os.path.join(args.output, 'testdev_predict.json') 200 | ) 201 | print(result) 202 | else: 203 | # print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100)) 204 | print('Splits in Train data:', gqa.train_tuple.dataset.splits) 205 | if gqa.valid_tuple is not None: 206 | print('Splits in Valid data:', gqa.valid_tuple.dataset.splits) 207 | print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100)) 208 | else: 209 | print("DO NOT USE VALIDATION") 210 | gqa.train(gqa.train_tuple, gqa.valid_tuple) 211 | 212 | 213 | -------------------------------------------------------------------------------- /src/tasks/gqa_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import json 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | from param import args 11 | from utils import load_obj_tsv 12 | 13 | # Load part of the dataset for fast checking. 14 | # Notice that here is the number of images instead of the number of data, 15 | # which means all related data to the images would be used. 16 | TINY_IMG_NUM = 512 17 | FAST_IMG_NUM = 5000 18 | 19 | 20 | class GQADataset: 21 | """ 22 | A GQA data example in json file: 23 | { 24 | "img_id": "2375429", 25 | "label": { 26 | "pipe": 1.0 27 | }, 28 | "question_id": "07333408", 29 | "sent": "What is on the white wall?" 30 | } 31 | """ 32 | def __init__(self, splits: str): 33 | self.name = splits 34 | self.splits = splits.split(',') 35 | 36 | # Loading datasets to data 37 | self.data = [] 38 | for split in self.splits: 39 | self.data.extend(json.load(open("data/gqa/%s.json" % split))) 40 | print("Load %d data from split(s) %s." % (len(self.data), self.name)) 41 | 42 | # List to dict (for evaluation and others) 43 | self.id2datum = { 44 | datum['question_id']: datum 45 | for datum in self.data 46 | } 47 | 48 | # Answers 49 | self.ans2label = json.load(open("data/gqa/trainval_ans2label.json")) 50 | self.label2ans = json.load(open("data/gqa/trainval_label2ans.json")) 51 | assert len(self.ans2label) == len(self.label2ans) 52 | for ans, label in self.ans2label.items(): 53 | assert self.label2ans[label] == ans 54 | 55 | @property 56 | def num_answers(self): 57 | return len(self.ans2label) 58 | 59 | def __len__(self): 60 | return len(self.data) 61 | 62 | 63 | class GQABufferLoader(): 64 | def __init__(self): 65 | self.key2data = {} 66 | 67 | def load_data(self, name, number): 68 | if name == 'testdev': 69 | path = "data/vg_gqa_imgfeat/gqa_testdev_obj64.tsv" 70 | else: 71 | path = "data/vg_gqa_imgfeat/vg_gqa_obj64.tsv" 72 | key = "%s_%d" % (path, number) 73 | if key not in self.key2data: 74 | self.key2data[key] = load_obj_tsv( 75 | path, 76 | topk=number 77 | ) 78 | return self.key2data[key] 79 | 80 | 81 | gqa_buffer_loader = GQABufferLoader() 82 | 83 | 84 | """ 85 | Example in obj tsv: 86 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", 87 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] 88 | """ 89 | class GQATorchDataset(Dataset): 90 | def __init__(self, dataset: GQADataset): 91 | super().__init__() 92 | self.raw_dataset = dataset 93 | 94 | if args.tiny: 95 | topk = TINY_IMG_NUM 96 | elif args.fast: 97 | topk = FAST_IMG_NUM 98 | else: 99 | topk = -1 100 | 101 | # Loading detection features to img_data 102 | # Since images in train and valid both come from Visual Genome, 103 | # buffer the image loading to save memory. 104 | img_data = [] 105 | if 'testdev' in dataset.splits or 'testdev_all' in dataset.splits: # Always loading all the data in testdev 106 | img_data.extend(gqa_buffer_loader.load_data('testdev', -1)) 107 | else: 108 | img_data.extend(gqa_buffer_loader.load_data('train', topk)) 109 | self.imgid2img = {} 110 | for img_datum in img_data: 111 | self.imgid2img[img_datum['img_id']] = img_datum 112 | 113 | # Only kept the data with loaded image features 114 | self.data = [] 115 | for datum in self.raw_dataset.data: 116 | if datum['img_id'] in self.imgid2img: 117 | self.data.append(datum) 118 | print("Use %d data in torch dataset" % (len(self.data))) 119 | print() 120 | 121 | def __len__(self): 122 | return len(self.data) 123 | 124 | def __getitem__(self, item: int): 125 | datum = self.data[item] 126 | 127 | img_id = datum['img_id'] 128 | ques_id = datum['question_id'] 129 | ques = datum['sent'] 130 | 131 | # Get image info 132 | img_info = self.imgid2img[img_id] 133 | obj_num = img_info['num_boxes'] 134 | boxes = img_info['boxes'].copy() 135 | feats = img_info['features'].copy() 136 | assert len(boxes) == len(feats) == obj_num 137 | 138 | # Normalize the boxes (to 0 ~ 1) 139 | img_h, img_w = img_info['img_h'], img_info['img_w'] 140 | boxes = boxes.copy() 141 | boxes[:, (0, 2)] /= img_w 142 | boxes[:, (1, 3)] /= img_h 143 | np.testing.assert_array_less(boxes, 1+1e-5) 144 | np.testing.assert_array_less(-boxes, 0+1e-5) 145 | 146 | # Create target 147 | if 'label' in datum: 148 | label = datum['label'] 149 | target = torch.zeros(self.raw_dataset.num_answers) 150 | for ans, score in label.items(): 151 | if ans in self.raw_dataset.ans2label: 152 | target[self.raw_dataset.ans2label[ans]] = score 153 | return ques_id, feats, boxes, ques, target 154 | else: 155 | return ques_id, feats, boxes, ques 156 | 157 | 158 | class GQAEvaluator: 159 | def __init__(self, dataset: GQADataset): 160 | self.dataset = dataset 161 | 162 | def evaluate(self, quesid2ans: dict): 163 | score = 0. 164 | for quesid, ans in quesid2ans.items(): 165 | datum = self.dataset.id2datum[quesid] 166 | label = datum['label'] 167 | if ans in label: 168 | score += label[ans] 169 | return score / len(quesid2ans) 170 | 171 | def dump_result(self, quesid2ans: dict, path): 172 | """ 173 | Dump the result to a GQA-challenge submittable json file. 174 | GQA json file submission requirement: 175 | results = [result] 176 | result = { 177 | "questionId": str, # Note: it's a actually an int number but the server requires an str. 178 | "prediction": str 179 | } 180 | 181 | :param quesid2ans: A dict mapping question id to its predicted answer. 182 | :param path: The file path to save the json file. 183 | :return: 184 | """ 185 | with open(path, 'w') as f: 186 | result = [] 187 | for ques_id, ans in quesid2ans.items(): 188 | result.append({ 189 | 'questionId': ques_id, 190 | 'prediction': ans 191 | }) 192 | json.dump(result, f, indent=4, sort_keys=True) 193 | 194 | 195 | -------------------------------------------------------------------------------- /src/tasks/gqa_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import torch.nn as nn 5 | 6 | from param import args 7 | from lxrt.entry import LXRTEncoder 8 | from lxrt.modeling import BertLayerNorm, GeLU 9 | 10 | # Max length including and 11 | MAX_GQA_LENGTH = 20 12 | 13 | 14 | class GQAModel(nn.Module): 15 | def __init__(self, num_answers): 16 | super().__init__() 17 | self.lxrt_encoder = LXRTEncoder( 18 | args, 19 | max_seq_length=MAX_GQA_LENGTH 20 | ) 21 | hid_dim = self.lxrt_encoder.dim 22 | self.logit_fc = nn.Sequential( 23 | nn.Linear(hid_dim, hid_dim * 2), 24 | GeLU(), 25 | BertLayerNorm(hid_dim * 2, eps=1e-12), 26 | nn.Linear(hid_dim * 2, num_answers) 27 | ) 28 | self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) 29 | 30 | def forward(self, feat, pos, sent): 31 | """ 32 | b -- batch_size, o -- object_number, f -- visual_feature_size 33 | 34 | :param feat: (b, o, f) 35 | :param pos: (b, o, 4) 36 | :param sent: (b,) Type -- list of string 37 | :param leng: (b,) Type -- int numpy array 38 | :return: (b, num_answer) The logit of each answers. 39 | """ 40 | x = self.lxrt_encoder(sent, (feat, pos)) 41 | logit = self.logit_fc(x) 42 | 43 | return logit 44 | 45 | 46 | -------------------------------------------------------------------------------- /src/tasks/nlvr2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import os 5 | import collections 6 | 7 | from tqdm import tqdm 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | from param import args 13 | from tasks.nlvr2_model import NLVR2Model 14 | from tasks.nlvr2_data import NLVR2Dataset, NLVR2TorchDataset, NLVR2Evaluator 15 | 16 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') 17 | 18 | 19 | def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: 20 | dset = NLVR2Dataset(splits) 21 | tset = NLVR2TorchDataset(dset) 22 | evaluator = NLVR2Evaluator(dset) 23 | data_loader = DataLoader( 24 | tset, batch_size=bs, 25 | shuffle=shuffle, num_workers=args.num_workers, 26 | drop_last=drop_last, pin_memory=True 27 | ) 28 | 29 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) 30 | 31 | 32 | class NLVR2: 33 | def __init__(self): 34 | self.train_tuple = get_tuple( 35 | args.train, bs=args.batch_size, shuffle=True, drop_last=True 36 | ) 37 | if args.valid != "": 38 | valid_bsize = 256 if args.multiGPU else 256 39 | self.valid_tuple = get_tuple( 40 | args.valid, bs=valid_bsize, 41 | shuffle=False, drop_last=False 42 | ) 43 | else: 44 | self.valid_tuple = None 45 | 46 | self.model = NLVR2Model() 47 | 48 | # Load pre-trained weights 49 | if args.load_lxmert is not None: 50 | self.model.lxrt_encoder.load(args.load_lxmert) 51 | 52 | # GPU options 53 | if args.multiGPU: 54 | self.model.lxrt_encoder.multi_gpu() 55 | self.model = self.model.cuda() 56 | 57 | # Losses and optimizer 58 | self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) 59 | if 'bert' in args.optim: 60 | batch_per_epoch = len(self.train_tuple.loader) 61 | t_total = int(batch_per_epoch * args.epochs) 62 | print("Total Iters: %d" % t_total) 63 | from lxrt.optimization import BertAdam 64 | self.optim = BertAdam(list(self.model.parameters()), 65 | lr=args.lr, 66 | warmup=0.1, 67 | t_total=t_total) 68 | else: 69 | self.optim = args.optimizer(list(self.model.parameters()), args.lr) 70 | 71 | self.output = args.output 72 | os.makedirs(self.output, exist_ok=True) 73 | 74 | def train(self, train_tuple, eval_tuple): 75 | dset, loader, evaluator = train_tuple 76 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) 77 | 78 | best_valid = 0. 79 | for epoch in range(args.epochs): 80 | quesid2ans = {} 81 | for i, (ques_id, feats, boxes, sent, label) in iter_wrapper(enumerate(loader)): 82 | self.model.train() 83 | 84 | self.optim.zero_grad() 85 | feats, boxes, label = feats.cuda(), boxes.cuda(), label.cuda() 86 | logit = self.model(feats, boxes, sent) 87 | 88 | loss = self.mce_loss(logit, label) 89 | 90 | loss.backward() 91 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.) 92 | self.optim.step() 93 | 94 | score, predict = logit.max(1) 95 | for qid, l in zip(ques_id, predict.cpu().numpy()): 96 | quesid2ans[qid] = l 97 | 98 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) 99 | 100 | if self.valid_tuple is not None: # Do Validation 101 | valid_score = self.evaluate(eval_tuple) 102 | if valid_score > best_valid: 103 | best_valid = valid_score 104 | self.save("BEST") 105 | 106 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ 107 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) 108 | 109 | print(log_str, end='') 110 | 111 | with open(self.output + "/log.log", 'a') as f: 112 | f.write(log_str) 113 | f.flush() 114 | 115 | self.save("LAST") 116 | 117 | def predict(self, eval_tuple: DataTuple, dump=None): 118 | self.model.eval() 119 | dset, loader, evaluator = eval_tuple 120 | quesid2ans = {} 121 | for i, datum_tuple in enumerate(loader): 122 | ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target 123 | with torch.no_grad(): 124 | feats, boxes = feats.cuda(), boxes.cuda() 125 | logit = self.model(feats, boxes, sent) 126 | score, predict = logit.max(1) 127 | for qid, l in zip(ques_id, predict.cpu().numpy()): 128 | quesid2ans[qid] = l 129 | if dump is not None: 130 | evaluator.dump_result(quesid2ans, dump) 131 | return quesid2ans 132 | 133 | def evaluate(self, eval_tuple: DataTuple, dump=None): 134 | dset, loader, evaluator = eval_tuple 135 | quesid2ans = self.predict(eval_tuple, dump) 136 | return evaluator.evaluate(quesid2ans) 137 | 138 | def save(self, name): 139 | torch.save(self.model.state_dict(), 140 | os.path.join(self.output, "%s.pth" % name)) 141 | 142 | # def load(self, path): 143 | # print("Load model from %s" % path) 144 | # state_dict = torch.load("%s.pth" % path) 145 | # self.model.load_state_dict(state_dict) 146 | 147 | def load(self, path): 148 | print("Load model from %s" % path) 149 | state_dict = torch.load("%s.pth" % path) 150 | for key in list(state_dict.keys()): 151 | if '.module' in key: 152 | state_dict[key.replace('.module', '')] = state_dict.pop(key) 153 | self.model.load_state_dict(state_dict, strict=False) 154 | 155 | 156 | if __name__ == "__main__": 157 | # Build Class 158 | nlvr2 = NLVR2() 159 | 160 | # Load Model 161 | if args.load is not None: 162 | nlvr2.load(args.load) 163 | 164 | # Test or Train 165 | if args.test is not None: 166 | args.fast = args.tiny = False # Always loading all data in test 167 | if 'hidden' in args.test: 168 | nlvr2.predict( 169 | get_tuple(args.test, bs=args.batch_size, 170 | shuffle=False, drop_last=False), 171 | dump=os.path.join(args.output, 'hidden_predict.csv') 172 | ) 173 | elif 'test' in args.test or 'valid' in args.test: 174 | result = nlvr2.evaluate( 175 | get_tuple(args.test, bs=args.batch_size, 176 | shuffle=False, drop_last=False), 177 | dump=os.path.join(args.output, '%s_predict.csv' % args.test) 178 | ) 179 | print(result) 180 | else: 181 | assert False, "No such test option for %s" % args.test 182 | else: 183 | print('Splits in Train data:', nlvr2.train_tuple.dataset.splits) 184 | if nlvr2.valid_tuple is not None: 185 | print('Splits in Valid data:', nlvr2.valid_tuple.dataset.splits) 186 | else: 187 | print("DO NOT USE VALIDATION") 188 | nlvr2.train(nlvr2.train_tuple, nlvr2.valid_tuple) 189 | 190 | 191 | -------------------------------------------------------------------------------- /src/tasks/nlvr2_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import json 5 | 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | 9 | from param import args 10 | from utils import load_obj_tsv 11 | 12 | # Load part of the dataset for fast checking. 13 | # Notice that here is the number of images instead of the number of data, 14 | # which means all related data to the images would be used. 15 | TINY_IMG_NUM = 512 16 | FAST_IMG_NUM = 5000 17 | 18 | 19 | class NLVR2Dataset: 20 | """ 21 | An NLVR2 data example in json file: 22 | { 23 | "identifier": "train-10171-0-0", 24 | "img0": "train-10171-0-img0", 25 | "img1": "train-10171-0-img1", 26 | "label": 0, 27 | "sent": "An image shows one leather pencil case, displayed open with writing implements tucked inside. 28 | ", 29 | "uid": "nlvr2_train_0" 30 | } 31 | """ 32 | def __init__(self, splits: str): 33 | self.name = splits 34 | self.splits = splits.split(',') 35 | 36 | # Loading datasets to data 37 | self.data = [] 38 | for split in self.splits: 39 | self.data.extend(json.load(open("data/nlvr2/%s.json" % split))) 40 | print("Load %d data from split(s) %s." % (len(self.data), self.name)) 41 | 42 | # List to dict (for evaluation and others) 43 | self.id2datum = { 44 | datum['uid']: datum 45 | for datum in self.data 46 | } 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | 52 | """ 53 | An example in obj36 tsv: 54 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", 55 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] 56 | FIELDNAMES would be keys in the dict returned by load_obj_tsv. 57 | """ 58 | class NLVR2TorchDataset(Dataset): 59 | def __init__(self, dataset: NLVR2Dataset): 60 | super().__init__() 61 | self.raw_dataset = dataset 62 | 63 | if args.tiny: 64 | topk = TINY_IMG_NUM 65 | elif args.fast: 66 | topk = FAST_IMG_NUM 67 | else: 68 | topk = -1 69 | 70 | # Loading detection features to img_data 71 | img_data = [] 72 | if 'train' in dataset.splits: 73 | img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/train_obj64.tsv', topk=topk)) 74 | if 'valid' in dataset.splits: 75 | img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/valid_obj64.tsv', topk=topk)) 76 | if 'test' in dataset.name: 77 | img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/test_obj64.tsv', topk=topk)) 78 | self.imgid2img = {} 79 | for img_datum in img_data: 80 | self.imgid2img[img_datum['img_id']] = img_datum 81 | 82 | # Filter out the dataset 83 | self.data = [] 84 | for datum in self.raw_dataset.data: 85 | if datum['img0'] in self.imgid2img and datum['img1'] in self.imgid2img: 86 | self.data.append(datum) 87 | print("Use %d data in torch dataset" % (len(self.data))) 88 | print() 89 | 90 | def __len__(self): 91 | return len(self.data) 92 | 93 | def __getitem__(self, item: int): 94 | datum = self.data[item] 95 | 96 | ques_id = datum['uid'] 97 | ques = datum['sent'] 98 | 99 | # Get image info 100 | boxes2 = [] 101 | feats2 = [] 102 | for key in ['img0', 'img1']: 103 | img_id = datum[key] 104 | img_info = self.imgid2img[img_id] 105 | boxes = img_info['boxes'].copy() 106 | feats = img_info['features'].copy() 107 | assert len(boxes) == len(feats) 108 | 109 | # Normalize the boxes (to 0 ~ 1) 110 | img_h, img_w = img_info['img_h'], img_info['img_w'] 111 | boxes[..., (0, 2)] /= img_w 112 | boxes[..., (1, 3)] /= img_h 113 | np.testing.assert_array_less(boxes, 1+1e-5) 114 | np.testing.assert_array_less(-boxes, 0+1e-5) 115 | 116 | boxes2.append(boxes) 117 | feats2.append(feats) 118 | feats = np.stack(feats2) 119 | boxes = np.stack(boxes2) 120 | 121 | # Create target 122 | if 'label' in datum: 123 | label = datum['label'] 124 | return ques_id, feats, boxes, ques, label 125 | else: 126 | return ques_id, feats, boxes, ques 127 | 128 | 129 | class NLVR2Evaluator: 130 | def __init__(self, dataset: NLVR2Dataset): 131 | self.dataset = dataset 132 | 133 | def evaluate(self, quesid2ans: dict): 134 | score = 0. 135 | for quesid, ans in quesid2ans.items(): 136 | datum = self.dataset.id2datum[quesid] 137 | label = datum['label'] 138 | if ans == label: 139 | score += 1 140 | return score / len(quesid2ans) 141 | 142 | def dump_result(self, quesid2ans: dict, path): 143 | """ 144 | Dump result to a CSV file, which is compatible with NLVR2 evaluation system. 145 | NLVR2 CSV file requirement: 146 | Each line contains: identifier, answer 147 | 148 | :param quesid2ans: nlvr2 uid to ans (either "True" or "False") 149 | :param path: The desired path of saved file. 150 | :return: 151 | """ 152 | with open(path, 'w') as f: 153 | for uid, ans in quesid2ans.items(): 154 | idt = self.dataset.id2datum[uid]["identifier"] 155 | ans = 'True' if ans == 1 else 'False' 156 | f.write("%s,%s\n" % (idt, ans)) 157 | 158 | -------------------------------------------------------------------------------- /src/tasks/nlvr2_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import torch.nn as nn 5 | from lxrt.modeling import GeLU, BertLayerNorm 6 | from lxrt.entry import LXRTEncoder 7 | from param import args 8 | 9 | 10 | class NLVR2Model(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | self.lxrt_encoder = LXRTEncoder( 14 | args, 15 | max_seq_length=20 16 | ) 17 | self.hid_dim = hid_dim = self.lxrt_encoder.dim 18 | self.logit_fc = nn.Sequential( 19 | nn.Linear(hid_dim * 2, hid_dim * 2), 20 | GeLU(), 21 | BertLayerNorm(hid_dim * 2, eps=1e-12), 22 | nn.Linear(hid_dim * 2, 2) 23 | ) 24 | self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) 25 | 26 | def forward(self, feat, pos, sent): 27 | """ 28 | :param feat: b, 2, o, f 29 | :param pos: b, 2, o, 4 30 | :param sent: b, (string) 31 | :param leng: b, (numpy, int) 32 | :return: 33 | """ 34 | # Pairing images and sentences: 35 | # The input of NLVR2 is two images and one sentence. In batch level, they are saved as 36 | # [ [img0_0, img0_1], [img1_0, img1_1], ...] and [sent0, sent1, ...] 37 | # Here, we flat them to 38 | # feat/pos = [ img0_0, img0_1, img1_0, img1_1, ...] 39 | # sent = [ sent0, sent0, sent1, sent1, ...] 40 | sent = sum(zip(sent, sent), ()) 41 | batch_size, img_num, obj_num, feat_size = feat.size() 42 | assert img_num == 2 and obj_num == 64 and feat_size == 2048 43 | feat = feat.view(batch_size * 2, obj_num, feat_size) 44 | pos = pos.view(batch_size * 2, obj_num, 4) 45 | 46 | # Extract feature --> Concat 47 | x = self.lxrt_encoder(sent, (feat, pos)) 48 | x = x.view(-1, self.hid_dim*2) 49 | 50 | # Compute logit of answers 51 | logit = self.logit_fc(x) 52 | 53 | return logit 54 | 55 | 56 | -------------------------------------------------------------------------------- /src/tasks/vqa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import os 5 | import collections 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data.dataloader import DataLoader 10 | from tqdm import tqdm 11 | 12 | from param import args 13 | from pretrain.qa_answer_table import load_lxmert_qa 14 | from tasks.vqa_model import VQAModel 15 | from tasks.vqa_data import VQADataset, VQATorchDataset, VQAEvaluator 16 | 17 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') 18 | 19 | 20 | def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: 21 | dset = VQADataset(splits) 22 | tset = VQATorchDataset(dset) 23 | evaluator = VQAEvaluator(dset) 24 | data_loader = DataLoader( 25 | tset, batch_size=bs, 26 | shuffle=shuffle, num_workers=args.num_workers, 27 | drop_last=drop_last, pin_memory=True 28 | ) 29 | 30 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) 31 | 32 | 33 | class VQA: 34 | def __init__(self): 35 | # Datasets 36 | self.train_tuple = get_data_tuple( 37 | args.train, bs=args.batch_size, shuffle=True, drop_last=True 38 | ) 39 | if args.bert_type == 'ft': 40 | bs_infer = 256 41 | else: 42 | bs_infer = 1024 43 | if args.valid != "": 44 | self.valid_tuple = get_data_tuple( 45 | args.valid, bs=bs_infer, 46 | shuffle=False, drop_last=False 47 | ) 48 | else: 49 | self.valid_tuple = None 50 | 51 | print("args.lr is {0}".format(args.lr)) 52 | 53 | # Model 54 | self.model = VQAModel(self.train_tuple.dataset.num_answers) 55 | 56 | # Load pre-trained weights 57 | if args.load_lxmert is not None: 58 | self.model.lxrt_encoder.load(args.load_lxmert) 59 | if args.load_lxmert_qa is not None: 60 | load_lxmert_qa(args.load_lxmert_qa, self.model, 61 | label2ans=self.train_tuple.dataset.label2ans) 62 | 63 | # GPU options 64 | self.model = self.model.cuda() 65 | if args.multiGPU: 66 | self.model.lxrt_encoder.multi_gpu() 67 | 68 | # Loss and Optimizer 69 | self.bce_loss = nn.BCEWithLogitsLoss() 70 | if 'bert' in args.optim: 71 | # if type(args.lr) == type("sdfg"): 72 | # args.lr = float(args.lr) 73 | 74 | batch_per_epoch = len(self.train_tuple.loader) 75 | t_total = int(batch_per_epoch * args.epochs) 76 | print("BertAdam Total Iters: %d" % t_total) 77 | from lxrt.optimization import BertAdam 78 | self.optim = BertAdam(list(self.model.parameters()), 79 | lr=args.lr, 80 | warmup=0.1, 81 | t_total=t_total, 82 | schedule=args.lr_schedule, 83 | args=args) 84 | 85 | else: 86 | self.optim = args.optimizer(self.model.parameters(), args.lr) 87 | 88 | # Output Directory 89 | self.output = args.output 90 | os.makedirs(self.output, exist_ok=True) 91 | 92 | def train(self, train_tuple, eval_tuple): 93 | dset, loader, evaluator = train_tuple 94 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) 95 | 96 | best_valid = 0. 97 | for epoch in range(args.epochs): 98 | quesid2ans = {} 99 | for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): 100 | 101 | self.model.train() 102 | # pytorch_total_params = sum(p.numel() for p in self.model.parameters()) 103 | # print('total parameters are:',pytorch_total_params) 104 | self.optim.zero_grad() 105 | 106 | feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda() 107 | logit = self.model(feats, boxes, sent) 108 | assert logit.dim() == target.dim() == 2 109 | loss = self.bce_loss(logit, target) 110 | loss = loss * logit.size(1) 111 | 112 | loss.backward() 113 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.) 114 | self.optim.step() 115 | 116 | score, label = logit.max(1) 117 | for qid, l in zip(ques_id, label.cpu().numpy()): 118 | ans = dset.label2ans[l] 119 | quesid2ans[qid.item()] = ans 120 | 121 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) 122 | 123 | if self.valid_tuple is not None: # Do Validation 124 | valid_score = self.evaluate(eval_tuple) 125 | if valid_score > best_valid: 126 | best_valid = valid_score 127 | self.save("BEST") 128 | 129 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ 130 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) 131 | 132 | print(log_str, end='') 133 | 134 | with open(self.output + "/log.log", 'a') as f: 135 | f.write(log_str) 136 | f.flush() 137 | 138 | self.save("LAST") 139 | 140 | def predict(self, eval_tuple: DataTuple, dump=None): 141 | """ 142 | Predict the answers to questions in a data split. 143 | 144 | :param eval_tuple: The data tuple to be evaluated. 145 | :param dump: The path of saved file to dump results. 146 | :return: A dict of question_id to answer. 147 | """ 148 | self.model.eval() 149 | dset, loader, evaluator = eval_tuple 150 | quesid2ans = {} 151 | for i, datum_tuple in enumerate(loader): 152 | ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth 153 | with torch.no_grad(): 154 | feats, boxes = feats.cuda(), boxes.cuda() 155 | logit = self.model(feats, boxes, sent) 156 | score, label = logit.max(1) 157 | for qid, l in zip(ques_id, label.cpu().numpy()): 158 | ans = dset.label2ans[l] 159 | quesid2ans[qid.item()] = ans 160 | if dump is not None: 161 | evaluator.dump_result(quesid2ans, dump) 162 | return quesid2ans 163 | 164 | def evaluate(self, eval_tuple: DataTuple, dump=None): 165 | """Evaluate all data in data_tuple.""" 166 | quesid2ans = self.predict(eval_tuple, dump) 167 | return eval_tuple.evaluator.evaluate(quesid2ans) 168 | 169 | @staticmethod 170 | def oracle_score(data_tuple): 171 | dset, loader, evaluator = data_tuple 172 | quesid2ans = {} 173 | for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): 174 | _, label = target.max(1) 175 | for qid, l in zip(ques_id, label.cpu().numpy()): 176 | ans = dset.label2ans[l] 177 | quesid2ans[qid.item()] = ans 178 | return evaluator.evaluate(quesid2ans) 179 | 180 | def save(self, name): 181 | torch.save(self.model.state_dict(), 182 | os.path.join(self.output, "%s.pth" % name)) 183 | 184 | # def load(self, path): 185 | # print("Load model from %s" % path) 186 | # state_dict = torch.load("%s.pth" % path) 187 | # self.model.load_state_dict(state_dict) 188 | def load(self, path): 189 | print("Load model from %s" % path) 190 | state_dict = torch.load("%s.pth" % path) 191 | for key in list(state_dict.keys()): 192 | if '.module' in key: 193 | state_dict[key.replace('.module', '')] = state_dict.pop(key) 194 | self.model.load_state_dict(state_dict, strict=False) 195 | 196 | 197 | if __name__ == "__main__": 198 | # Build Class 199 | vqa = VQA() 200 | 201 | # Load VQA model weights 202 | # Note: It is different from loading LXMERT pre-trained weights. 203 | if args.load is not None: 204 | vqa.load(args.load) 205 | 206 | # Test or Train 207 | if args.test is not None: 208 | args.fast = args.tiny = False # Always loading all data in test 209 | if args.bert_type == 'ft': 210 | bs_infer = 128 211 | else: 212 | bs_infer = 128 213 | if 'test' in args.test: 214 | vqa.predict( 215 | get_data_tuple(args.test, bs=bs_infer, 216 | shuffle=False, drop_last=False), 217 | dump=os.path.join(args.output, 'test_predict.json') 218 | ) 219 | elif 'val' in args.test: 220 | # Since part of valididation data are used in pre-training/fine-tuning, 221 | # only validate on the minival set. 222 | result = vqa.evaluate( 223 | get_data_tuple('minival', bs=bs_infer, 224 | shuffle=False, drop_last=False), 225 | dump=os.path.join(args.output, 'minival_predict.json') 226 | ) 227 | print(result) 228 | else: 229 | assert False, "No such test option for %s" % args.test 230 | else: 231 | print('Splits in Train data:', vqa.train_tuple.dataset.splits) 232 | if vqa.valid_tuple is not None: 233 | print('Splits in Valid data:', vqa.valid_tuple.dataset.splits) 234 | print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100)) 235 | else: 236 | print("DO NOT USE VALIDATION") 237 | vqa.train(vqa.train_tuple, vqa.valid_tuple) 238 | 239 | 240 | -------------------------------------------------------------------------------- /src/tasks/vqa_constant.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import os 5 | import collections 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data.dataloader import DataLoader 10 | from tqdm import tqdm 11 | 12 | from param import args 13 | from pretrain.qa_answer_table import load_lxmert_qa 14 | from tasks.vqa_model import VQAModel 15 | from tasks.vqa_data import VQADataset, VQATorchDataset, VQAEvaluator 16 | 17 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') 18 | 19 | 20 | def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: 21 | dset = VQADataset(splits) 22 | tset = VQATorchDataset(dset) 23 | evaluator = VQAEvaluator(dset) 24 | data_loader = DataLoader( 25 | tset, batch_size=bs, 26 | shuffle=shuffle, num_workers=args.num_workers, 27 | drop_last=drop_last, pin_memory=True 28 | ) 29 | 30 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) 31 | 32 | 33 | class VQA: 34 | def __init__(self): 35 | # Datasets 36 | self.train_tuple = get_data_tuple( 37 | args.train, bs=args.batch_size, shuffle=True, drop_last=True 38 | ) 39 | if args.bert_type == 'ft': 40 | bs_infer = 256 41 | else: 42 | bs_infer = 1024 43 | if args.valid != "": 44 | self.valid_tuple = get_data_tuple( 45 | args.valid, bs=bs_infer, 46 | shuffle=False, drop_last=False 47 | ) 48 | else: 49 | self.valid_tuple = None 50 | 51 | print("args.lr is {0}".format(args.lr)) 52 | 53 | # Model 54 | self.model = VQAModel(self.train_tuple.dataset.num_answers) 55 | 56 | # Load pre-trained weights 57 | if args.load_lxmert is not None: 58 | self.model.lxrt_encoder.load(args.load_lxmert) 59 | if args.load_lxmert_qa is not None: 60 | load_lxmert_qa(args.load_lxmert_qa, self.model, 61 | label2ans=self.train_tuple.dataset.label2ans) 62 | 63 | # GPU options 64 | self.model = self.model.cuda() 65 | if args.multiGPU: 66 | self.model.lxrt_encoder.multi_gpu() 67 | 68 | # Loss and Optimizer 69 | self.bce_loss = nn.BCEWithLogitsLoss() 70 | if 'bert' in args.optim: 71 | # if type(args.lr) == type("sdfg"): 72 | # args.lr = float(args.lr) 73 | 74 | batch_per_epoch = len(self.train_tuple.loader) 75 | t_total = int(batch_per_epoch * args.epochs) 76 | print("BertAdam Total Iters: %d" % t_total) 77 | from lxrt.optimization import BertAdam 78 | # self.optim = BertAdam(list(self.model.parameters()), 79 | # lr=args.lr, 80 | # warmup=0.1, 81 | # t_total=t_total, 82 | # schedule=args.lr_schedule, 83 | # args=args) 84 | # suggested by yang xiaofeng 85 | self.optim = BertAdam(list(self.model.parameters()), 86 | lr=args.lr, 87 | warmup=0.1, 88 | schedule = 'warmup_constant', 89 | t_total=t_total) 90 | else: 91 | self.optim = args.optimizer(self.model.parameters(), args.lr) 92 | 93 | # Output Directory 94 | self.output = args.output 95 | os.makedirs(self.output, exist_ok=True) 96 | 97 | def train(self, train_tuple, eval_tuple): 98 | dset, loader, evaluator = train_tuple 99 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) 100 | 101 | best_valid = 0. 102 | for epoch in range(args.epochs): 103 | 104 | #suggested by yangxiaofeng 105 | if (epoch == 4): 106 | for g in self.optim.param_groups: 107 | g['lr'] = g['lr']/10 108 | if (epoch == 6): 109 | for g in self.optim.param_groups: 110 | g['lr'] = g['lr']/10 111 | 112 | quesid2ans = {} 113 | for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): 114 | 115 | self.model.train() 116 | self.optim.zero_grad() 117 | 118 | feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda() 119 | logit = self.model(feats, boxes, sent) 120 | assert logit.dim() == target.dim() == 2 121 | loss = self.bce_loss(logit, target) 122 | loss = loss * logit.size(1) 123 | 124 | loss.backward() 125 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.) 126 | self.optim.step() 127 | 128 | score, label = logit.max(1) 129 | for qid, l in zip(ques_id, label.cpu().numpy()): 130 | ans = dset.label2ans[l] 131 | quesid2ans[qid.item()] = ans 132 | 133 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) 134 | 135 | if self.valid_tuple is not None: # Do Validation 136 | valid_score = self.evaluate(eval_tuple) 137 | if valid_score > best_valid: 138 | best_valid = valid_score 139 | self.save("BEST") 140 | 141 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ 142 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) 143 | 144 | print(log_str, end='') 145 | 146 | with open(self.output + "/log.log", 'a') as f: 147 | f.write(log_str) 148 | f.flush() 149 | 150 | self.save("LAST") 151 | 152 | def predict(self, eval_tuple: DataTuple, dump=None): 153 | """ 154 | Predict the answers to questions in a data split. 155 | 156 | :param eval_tuple: The data tuple to be evaluated. 157 | :param dump: The path of saved file to dump results. 158 | :return: A dict of question_id to answer. 159 | """ 160 | self.model.eval() 161 | dset, loader, evaluator = eval_tuple 162 | quesid2ans = {} 163 | for i, datum_tuple in enumerate(loader): 164 | ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth 165 | with torch.no_grad(): 166 | feats, boxes = feats.cuda(), boxes.cuda() 167 | logit = self.model(feats, boxes, sent) 168 | score, label = logit.max(1) 169 | for qid, l in zip(ques_id, label.cpu().numpy()): 170 | ans = dset.label2ans[l] 171 | quesid2ans[qid.item()] = ans 172 | if dump is not None: 173 | evaluator.dump_result(quesid2ans, dump) 174 | return quesid2ans 175 | 176 | def evaluate(self, eval_tuple: DataTuple, dump=None): 177 | """Evaluate all data in data_tuple.""" 178 | quesid2ans = self.predict(eval_tuple, dump) 179 | return eval_tuple.evaluator.evaluate(quesid2ans) 180 | 181 | @staticmethod 182 | def oracle_score(data_tuple): 183 | dset, loader, evaluator = data_tuple 184 | quesid2ans = {} 185 | for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): 186 | _, label = target.max(1) 187 | for qid, l in zip(ques_id, label.cpu().numpy()): 188 | ans = dset.label2ans[l] 189 | quesid2ans[qid.item()] = ans 190 | return evaluator.evaluate(quesid2ans) 191 | 192 | def save(self, name): 193 | torch.save(self.model.state_dict(), 194 | os.path.join(self.output, "%s.pth" % name)) 195 | 196 | def load(self, path): 197 | print("Load model from %s" % path) 198 | state_dict = torch.load("%s.pth" % path) 199 | self.model.load_state_dict(state_dict) 200 | 201 | 202 | if __name__ == "__main__": 203 | # Build Class 204 | vqa = VQA() 205 | 206 | # Load VQA model weights 207 | # Note: It is different from loading LXMERT pre-trained weights. 208 | if args.load is not None: 209 | vqa.load(args.load) 210 | 211 | # Test or Train 212 | if args.test is not None: 213 | args.fast = args.tiny = False # Always loading all data in test 214 | if args.bert_type == 'ft': 215 | bs_infer = 256 216 | else: 217 | bs_infer = 950 218 | if 'test' in args.test: 219 | vqa.predict( 220 | get_data_tuple(args.test, bs=950, 221 | shuffle=False, drop_last=False), 222 | dump=os.path.join(args.output, 'test_predict.json') 223 | ) 224 | elif 'val' in args.test: 225 | # Since part of valididation data are used in pre-training/fine-tuning, 226 | # only validate on the minival set. 227 | result = vqa.evaluate( 228 | get_data_tuple('minival', bs=950, 229 | shuffle=False, drop_last=False), 230 | dump=os.path.join(args.output, 'minival_predict.json') 231 | ) 232 | print(result) 233 | else: 234 | assert False, "No such test option for %s" % args.test 235 | else: 236 | print('Splits in Train data:', vqa.train_tuple.dataset.splits) 237 | if vqa.valid_tuple is not None: 238 | print('Splits in Valid data:', vqa.valid_tuple.dataset.splits) 239 | print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100)) 240 | else: 241 | print("DO NOT USE VALIDATION") 242 | vqa.train(vqa.train_tuple, vqa.valid_tuple) 243 | 244 | 245 | -------------------------------------------------------------------------------- /src/tasks/vqa_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import json 5 | import os 6 | import pickle 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | from param import args 13 | from utils import load_obj_tsv 14 | 15 | # Load part of the dataset for fast checking. 16 | # Notice that here is the number of images instead of the number of data, 17 | # which means all related data to the images would be used. 18 | TINY_IMG_NUM = 512 19 | FAST_IMG_NUM = 5000 20 | 21 | # The path to data and image features. 22 | VQA_DATA_ROOT = 'data/vqa/' 23 | MSCOCO_IMGFEAT_ROOT = 'data/mscoco_imgfeat/' 24 | SPLIT2NAME = { 25 | 'train': 'train2014', 26 | 'valid': 'val2014', 27 | 'minival': 'val2014', 28 | 'nominival': 'val2014', 29 | 'test': 'test2015', 30 | } 31 | 32 | 33 | class VQADataset: 34 | """ 35 | A VQA data example in json file: 36 | { 37 | "answer_type": "other", 38 | "img_id": "COCO_train2014_000000458752", 39 | "label": { 40 | "net": 1 41 | }, 42 | "question_id": 458752000, 43 | "question_type": "what is this", 44 | "sent": "What is this photo taken looking through?" 45 | } 46 | """ 47 | def __init__(self, splits: str): 48 | self.name = splits 49 | self.splits = splits.split(',') 50 | 51 | # Loading datasets 52 | self.data = [] 53 | for split in self.splits: 54 | self.data.extend(json.load(open("data/vqa/%s.json" % split))) 55 | print("Load %d data from split(s) %s." % (len(self.data), self.name)) 56 | 57 | # Convert list to dict (for evaluation) 58 | self.id2datum = { 59 | datum['question_id']: datum 60 | for datum in self.data 61 | } 62 | 63 | # Answers 64 | self.ans2label = json.load(open("data/vqa/trainval_ans2label.json")) 65 | self.label2ans = json.load(open("data/vqa/trainval_label2ans.json")) 66 | assert len(self.ans2label) == len(self.label2ans) 67 | 68 | @property 69 | def num_answers(self): 70 | return len(self.ans2label) 71 | 72 | def __len__(self): 73 | return len(self.data) 74 | 75 | 76 | """ 77 | An example in obj36 tsv: 78 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", 79 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] 80 | FIELDNAMES would be keys in the dict returned by load_obj_tsv. 81 | """ 82 | class VQATorchDataset(Dataset): 83 | def __init__(self, dataset: VQADataset): 84 | super().__init__() 85 | self.raw_dataset = dataset 86 | 87 | if args.tiny: 88 | topk = TINY_IMG_NUM 89 | elif args.fast: 90 | topk = FAST_IMG_NUM 91 | else: 92 | topk = None 93 | 94 | # Loading detection features to img_data 95 | img_data = [] 96 | for split in dataset.splits: 97 | # Minival is 5K images in MS COCO, which is used in evaluating VQA/LXMERT-pre-training. 98 | # It is saved as the top 5K features in val2014_***.tsv 99 | load_topk = 5000 if (split == 'minival' and topk is None) else topk 100 | # img_data.extend(load_obj_tsv( 101 | # os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])), 102 | # topk=load_topk)) 103 | img_data.extend(load_obj_tsv( 104 | os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj64.tsv' % (SPLIT2NAME[split])), 105 | topk=load_topk)) 106 | print('load from {0}_obj64.tsv'.format(SPLIT2NAME[split])) 107 | 108 | # Convert img list to dict 109 | self.imgid2img = {} 110 | for img_datum in img_data: 111 | self.imgid2img[img_datum['img_id']] = img_datum 112 | 113 | # Only kept the data with loaded image features 114 | self.data = [] 115 | for datum in self.raw_dataset.data: 116 | if datum['img_id'] in self.imgid2img: 117 | self.data.append(datum) 118 | print("Use %d data in torch dataset" % (len(self.data))) 119 | print() 120 | 121 | def __len__(self): 122 | return len(self.data) 123 | 124 | def __getitem__(self, item: int): 125 | datum = self.data[item] 126 | 127 | img_id = datum['img_id'] 128 | ques_id = datum['question_id'] 129 | ques = datum['sent'] 130 | 131 | # Get image info 132 | img_info = self.imgid2img[img_id] 133 | obj_num = img_info['num_boxes'] 134 | feats = img_info['features'].copy() 135 | boxes = img_info['boxes'].copy() 136 | assert obj_num == len(boxes) == len(feats) 137 | 138 | # Normalize the boxes (to 0 ~ 1) 139 | img_h, img_w = img_info['img_h'], img_info['img_w'] 140 | boxes = boxes.copy() 141 | boxes[:, (0, 2)] /= img_w 142 | boxes[:, (1, 3)] /= img_h 143 | np.testing.assert_array_less(boxes, 1+1e-5) 144 | np.testing.assert_array_less(-boxes, 0+1e-5) 145 | 146 | # Provide label (target) 147 | if 'label' in datum: 148 | label = datum['label'] 149 | target = torch.zeros(self.raw_dataset.num_answers) 150 | for ans, score in label.items(): 151 | target[self.raw_dataset.ans2label[ans]] = score 152 | return ques_id, feats, boxes, ques, target 153 | else: 154 | return ques_id, feats, boxes, ques 155 | 156 | 157 | class VQAEvaluator: 158 | def __init__(self, dataset: VQADataset): 159 | self.dataset = dataset 160 | 161 | def evaluate(self, quesid2ans: dict): 162 | score = 0. 163 | for quesid, ans in quesid2ans.items(): 164 | datum = self.dataset.id2datum[quesid] 165 | label = datum['label'] 166 | if ans in label: 167 | score += label[ans] 168 | return score / len(quesid2ans) 169 | 170 | def dump_result(self, quesid2ans: dict, path): 171 | """ 172 | Dump results to a json file, which could be submitted to the VQA online evaluation. 173 | VQA json file submission requirement: 174 | results = [result] 175 | result = { 176 | "question_id": int, 177 | "answer": str 178 | } 179 | 180 | :param quesid2ans: dict of quesid --> ans 181 | :param path: The desired path of saved file. 182 | """ 183 | with open(path, 'w') as f: 184 | result = [] 185 | for ques_id, ans in quesid2ans.items(): 186 | result.append({ 187 | 'question_id': ques_id, 188 | 'answer': ans 189 | }) 190 | json.dump(result, f, indent=4, sort_keys=True) 191 | 192 | 193 | -------------------------------------------------------------------------------- /src/tasks/vqa_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 project LXRT. 3 | 4 | import torch.nn as nn 5 | 6 | from param import args 7 | from lxrt.entry import LXRTEncoder 8 | from lxrt.modeling import BertLayerNorm, GeLU 9 | 10 | # from src.param import args 11 | # from src.lxrt.entry import LXRTEncoder 12 | # from src.lxrt.modeling import BertLayerNorm, GeLU 13 | 14 | # Max length including and 15 | MAX_VQA_LENGTH = 20 16 | 17 | 18 | class VQAModel(nn.Module): 19 | def __init__(self, num_answers): 20 | super().__init__() 21 | 22 | # Build LXRT encoder 23 | self.lxrt_encoder = LXRTEncoder( 24 | args, 25 | max_seq_length=MAX_VQA_LENGTH 26 | ) 27 | hid_dim = self.lxrt_encoder.dim 28 | 29 | # VQA Answer heads 30 | self.logit_fc = nn.Sequential( 31 | nn.Linear(hid_dim, hid_dim * 2), 32 | GeLU(), 33 | BertLayerNorm(hid_dim * 2, eps=1e-12), 34 | nn.Linear(hid_dim * 2, num_answers) 35 | ) 36 | self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) 37 | 38 | def forward(self, feat, pos, sent): 39 | """ 40 | b -- batch_size, o -- object_number, f -- visual_feature_size 41 | 42 | :param feat: (b, o, f) 43 | :param pos: (b, o, 4) 44 | :param sent: (b,) Type -- list of string 45 | :param leng: (b,) Type -- int numpy array 46 | :return: (b, num_answer) The logit of each answers. 47 | """ 48 | x = self.lxrt_encoder(sent, (feat, pos)) 49 | logit = self.logit_fc(x) 50 | 51 | return logit 52 | 53 | 54 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2019 Project LXRT 3 | 4 | import sys 5 | import csv 6 | import base64 7 | import time 8 | 9 | import numpy as np 10 | 11 | csv.field_size_limit(sys.maxsize) 12 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", 13 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] 14 | 15 | 16 | def load_obj_tsv(fname, topk=None, fp16=False): 17 | """Load object features from tsv file. 18 | 19 | :param fname: The path to the tsv file. 20 | :param topk: Only load features for top K images (lines) in the tsv file. 21 | Will load all the features if topk is either -1 or None. 22 | :return: A list of image object features where each feature is a dict. 23 | See FILENAMES above for the keys in the feature dict. 24 | """ 25 | data = [] 26 | start_time = time.time() 27 | print("Start to load Faster-RCNN detected objects from %s" % fname) 28 | with open(fname) as f: 29 | reader = csv.DictReader(f, FIELDNAMES, delimiter="\t") 30 | for i, item in enumerate(reader): 31 | if i%1000 == 0: 32 | print(i) 33 | 34 | for key in ['img_h', 'img_w', 'num_boxes']: 35 | item[key] = int(item[key]) 36 | 37 | boxes = item['num_boxes'] 38 | decode_config = [ 39 | ('objects_id', (boxes, ), np.int64), 40 | ('objects_conf', (boxes, ), np.float32), 41 | ('attrs_id', (boxes, ), np.int64), 42 | ('attrs_conf', (boxes, ), np.float32), 43 | ('boxes', (boxes, 4), np.float32), 44 | ('features', (boxes, -1), np.float32), 45 | ] 46 | for key, shape, dtype in decode_config: 47 | item[key] = np.frombuffer(base64.b64decode(item[key]), dtype=dtype) 48 | 49 | if fp16 and item[key].dtype == np.float32: 50 | item[key] = item[key].astype(np.float16) # Save features as half-precision in memory. 51 | item[key] = item[key].reshape(shape) 52 | item[key].setflags(write=False) 53 | 54 | data.append(item) 55 | if topk is not None and len(data) == topk: 56 | break 57 | elapsed_time = time.time() - start_time 58 | print("Loaded %d images in file %s in %d seconds." % (len(data), fname, elapsed_time)) 59 | return data 60 | 61 | --------------------------------------------------------------------------------