├── .gitignore
├── .gitmodules
├── .idea
├── deployment.xml
├── encodings.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── lxmert.iml
├── misc.xml
└── modules.xml
├── LICENSE
├── README.md
├── experience_in_pretraining.md
├── extract_gqa_image.py
├── find_example.py
├── get_test.py
├── get_testdev.py
├── gqa_debug.py
├── requirements.txt
├── run
├── README.md
├── fsb2.bash
├── gqa_finetuneft.bash
├── nlvr2_ft.bash
└── vqa_finetuneft.bash
└── src
├── lxrt
├── entry.py
├── file_utils.py
├── modeling.py
├── modeling_big.py
├── optimization.py
└── tokenization.py
├── param.py
├── pretrain
├── lxmert_data.py
├── lxmert_pretrain.py
├── lxmert_pretrain2.py
└── qa_answer_table.py
├── tasks
├── gqa.py
├── gqa_data.py
├── gqa_model.py
├── nlvr2.py
├── nlvr2_data.py
├── nlvr2_model.py
├── vqa.py
├── vqa_constant.py
├── vqa_data.py
└── vqa_model.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.caffemodel
2 | *.tsv
3 | /snap
4 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "data/nlvr2/nlvr"]
2 | path = data/nlvr2/nlvr
3 | url = https://github.com/lil-lab/nlvr.git
4 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/lxmert.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Hao Tan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # The Most Important Thing.
2 | ## Our code is developed based on:
3 | ## LXMERT: Learning Cross-Modality Encoder Representations from Transformers (https://github.com/airsplay/lxmert)
4 | ## If you think our work is useful, please also cite their work!
5 |
6 | ## Introduction
7 | PyTorch code for the CVPR 2021 paper ["Causal Attention for Vision-Language Tasks"](https://arxiv.org/pdf/2103.03493.pdf).
8 | PyTorch code for the CVPR 2021 paper ["Causal Attention for Vision-Language Tasks"](https://arxiv.org/pdf/2103.03493.pdf). Slides of our EMNLP 2019 talk are avialable [here](http://www.cs.unc.edu/~airsplay/EMNLP_2019_LXMERT_slides.pdf).
9 | For experiment settings, like the pytorch version and GPU setting, please refer to LXMERT (https://github.com/airsplay/lxmert)
10 |
11 | ## Results 36 RoI version
12 |
13 | | Split | [VQA](https://visualqa.org/) | [GQA](https://cs.stanford.edu/people/dorarad/gqa/) | [NLVR2](http://lil.nlp.cornell.edu/nlvr/) |
14 | |----------- |:----: |:---: |:------:|
15 | | Local Validation | 70.40% | 60.90% | 76.40% |
16 | | Test-Dev | 72.81% | 60.84% | 76.40% (Test-P) |
17 | | Test-Standard | 73.04% | 61.17% | 76.00% (Test-U) |
18 |
19 | ## Results 64 RoI version
20 | ## Extracting more RoI visual features from an image will largely improve the performances!
21 |
22 | | Split | [VQA](https://visualqa.org/) | [GQA](https://cs.stanford.edu/people/dorarad/gqa/) | [NLVR2](http://lil.nlp.cornell.edu/nlvr/) |
23 | |----------- |:----: |:---: |:------:|
24 | | Test-Dev | 73.54% | 61.87% | 77.27% (Test-P) |
25 | | Test-Standard | 73.63% | 62.07% | 77.23% (Test-U) |
26 |
27 | ## Pre-training
28 | ## Notice that this part is the same as LXMERT: https://github.com/airsplay/lxmert. We put them here for self-containing.
29 |
30 | 1. Download the aggregated LXMERT dataset from MS COCO, Visual Genome, VQA, and GQA (around 700MB in total). The joint answer labels are saved in `data/lxmert/all_ans.json`.
31 | ```bash
32 | mkdir -p data/lxmert
33 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/lxmert/mscoco_train.json -P data/lxmert/
34 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/lxmert/mscoco_nominival.json -P data/lxmert/
35 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/lxmert/vgnococo.json -P data/lxmert/
36 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/lxmert/mscoco_minival.json -P data/lxmert/
37 | ```
38 |
39 | 2. Download the detection features from MS COCO images from LXMERT.
40 | ```bash
41 | mkdir -p data/mscoco_imgfeat
42 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/mscoco_imgfeat/train2014_obj36.zip -P data/mscoco_imgfeat
43 | unzip data/mscoco_imgfeat/train2014_obj36.zip -d data/mscoco_imgfeat && rm data/mscoco_imgfeat/train2014_obj36.zip
44 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/mscoco_imgfeat/val2014_obj36.zip -P data/mscoco_imgfeat
45 | unzip data/mscoco_imgfeat/val2014_obj36.zip -d data && rm data/mscoco_imgfeat/val2014_obj36.zip
46 | ```
47 |
48 | 3. Download the detection features for Visual Genome images.
49 | ```bash
50 | mkdir -p data/vg_gqa_imgfeat
51 | wget --no-check-certificate https://nlp1.cs.unc.edu/data/lxmert_data/vg_gqa_imgfeat/vg_gqa_obj36.zip -P data/vg_gqa_imgfeat
52 | unzip data/vg_gqa_imgfeat/vg_gqa_obj36.zip -d data && rm data/vg_gqa_imgfeat/vg_gqa_obj36.zip
53 | ```
54 |
55 | 4. Test on a small split of the MS COCO + Visual Genome datasets:
56 | ```bash
57 | bash run/lxmert_pretrain.bash 0,1,2,3 --multiGPU --tiny
58 | ```
59 |
60 | 5. Run on the whole [MS COCO](http://cocodataset.org) and [Visual Genome](https://visualgenome.org/) related datasets (i.e., [VQA](https://visualqa.org/), [GQA](https://cs.stanford.edu/people/dorarad/gqa/index.html), [COCO caption](http://cocodataset.org/#captions-2015), [VG Caption](https://visualgenome.org/), [VG QA](https://github.com/yukezhu/visual7w-toolkit)).
61 |
62 | ## This part is ours:
63 | The pre-training code is:
64 | ```bash
65 | bash run/fsb2.bash 0,1,2,3 --multiGPU
66 | ```
67 | After pre-training, the finetuning codes for VQA, GQA, and NLVE2 are:
68 | ```bash
69 | bash run/vqa_finetuneft.bash 0 0.00004 0.00004
70 | bash run/gqa_finetuneft.bash 0 0.000001 0.000001
71 | bash run/nlvr2_ft.bash 0 0.00003 0.00003
72 | ```
73 |
74 |
75 |
--------------------------------------------------------------------------------
/experience_in_pretraining.md:
--------------------------------------------------------------------------------
1 | # Experience in Pre-training
2 | Since I finish this project with quite limited computational resources, I would like to share some experiences. If you are also in a small group and plan to pre-train back-bone models for fun, hope it would help.
3 |
4 | ## Workflow
5 | 1. Design a model and its pre-training strategies.
6 | 2. Test whether the code is correct or not by over-fitting a super small split (5000 images, typically) of aggregated data.
7 | 3. Pre-train it on **all aggregated pre-training data** for around 3 to 4 epochs. (At least make sure that all the images are included!)
8 | 4. Test the pre-training performance on a **small split** of fine-tuning tasks. I used 5000 images by setting the `--fast` option. Note that the epochs should be increased to around 20/50 from 4.
9 | 5. If the accuracy (i.e., results) of the fine-tuning tasks keep growing, it indicates that the pre-training is effective!
10 | 6. Compare the **full fine-tuning-data** results when 3-4 epochs' pre-training finishes and select the best pre-training strategies.
11 | 7. Train on **full aggregated data** and have a good one-week sleep ;).
12 |
13 |
14 | ## Tips
15 | - **Do not** validate pre-training strategies (pre-training tasks, pre-training model) on a **small split** of the data. The behavior of pre-training on a small split is significantly different from the full pre-training dataset.
16 | - Do not over-tune the pre-training hyperparameters. Keep in mind that a good idea will overshadow all these cherry-pick hyper-parameters. Anyway, you would not have enough GPUs to do that.
17 | - Add a component at each time; Have a plan for it.
18 | - Pipeline everything.
19 | - You could rest but GPUs never get rest; GPUs are sometimes broken but you never give up.
20 |
--------------------------------------------------------------------------------
/extract_gqa_image.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/env python
2 |
3 | # The root of bottom-up-attention repo. Do not need to change if using provided docker file.
4 | BUTD_ROOT = '/opt/butd/'
5 |
6 | import os, sys
7 | sys.path.insert(0, BUTD_ROOT + "/tools")
8 | os.environ['GLOG_minloglevel'] = '2'
9 |
10 | import _init_paths
11 | from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list
12 | from fast_rcnn.test import im_detect, _get_blobs
13 | from fast_rcnn.nms_wrapper import nms
14 |
15 | import caffe
16 | import argparse
17 | import pprint
18 | import base64
19 | import numpy as np
20 | import cv2
21 | import csv
22 | from tqdm import tqdm
23 |
24 | csv.field_size_limit(sys.maxsize)
25 |
26 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
27 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
28 |
29 | # Settings for the number of features per image. To re-create pretrained features with 36 features
30 | # per image, set both values to 36.
31 | MIN_BOXES = 100
32 | MAX_BOXES = 100
33 |
34 |
35 | def load_image_ids(img_root):
36 | pathXid = []
37 | for name in os.listdir(img_root):
38 | idx = name.split(".")[0]
39 | pathXid.append(
40 | (
41 | os.path.join(img_root, name),
42 | idx))
43 | return pathXid
44 |
45 | def generate_tsv(prototxt, weights, image_ids, outfile):
46 | # First check if file exists, and if it is complete
47 | # never use set, it loses the order!!! F***
48 | wanted_ids = set([image_id[1] for image_id in image_ids])
49 | found_ids = set()
50 | if os.path.exists(outfile):
51 | with open(outfile, "r") as tsvfile:
52 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
53 | for item in reader:
54 | found_ids.add(item['img_id'])
55 | missing = wanted_ids - found_ids
56 | if len(missing) == 0:
57 | print('already completed {:d}'.format(len(image_ids)))
58 | else:
59 | print('missing {:d}/{:d}'.format(len(missing), len(image_ids)))
60 | if len(missing) > 0:
61 | caffe.set_mode_gpu()
62 | caffe.set_device(0)
63 | net = caffe.Net(prototxt, caffe.TEST, weights=weights)
64 | with open(outfile, 'ab') as tsvfile:
65 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
66 | for im_file, image_id in tqdm(image_ids):
67 | if image_id in missing:
68 | try:
69 | writer.writerow(get_detections_from_im(net, im_file, image_id))
70 | except Exception as e:
71 | print(e)
72 |
73 |
74 | def get_detections_from_im(net, im_file, image_id, conf_thresh=0.2):
75 | """
76 | :param net:
77 | :param im_file: full path to an image
78 | :param image_id:
79 | :param conf_thresh:
80 | :return: all information from detection and attr prediction
81 | """
82 | im = cv2.imread(im_file)
83 | scores, boxes, attr_scores, rel_scores = im_detect(net, im)
84 |
85 | # Keep the original boxes, don't worry about the regresssion bbox outputs
86 | rois = net.blobs['rois'].data.copy()
87 | # unscale back to raw image space
88 | blobs, im_scales = _get_blobs(im, None)
89 |
90 | cls_boxes = rois[:, 1:5] / im_scales[0]
91 | cls_prob = net.blobs['cls_prob'].data
92 | attr_prob = net.blobs['attr_prob'].data
93 | pool5 = net.blobs['pool5_flat'].data
94 |
95 | # Keep only the best detections
96 | max_conf = np.zeros((rois.shape[0]))
97 | for cls_ind in range(1, cls_prob.shape[1]):
98 | cls_scores = scores[:, cls_ind]
99 | dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
100 | keep = np.array(nms(dets, cfg.TEST.NMS))
101 | max_conf[keep] = np.where(cls_scores[keep] > max_conf[keep], cls_scores[keep], max_conf[keep])
102 |
103 | keep_boxes = np.where(max_conf >= conf_thresh)[0]
104 | if len(keep_boxes) < MIN_BOXES:
105 | keep_boxes = np.argsort(max_conf)[::-1][:MIN_BOXES]
106 | elif len(keep_boxes) > MAX_BOXES:
107 | keep_boxes = np.argsort(max_conf)[::-1][:MAX_BOXES]
108 |
109 | objects = np.argmax(cls_prob[keep_boxes][:, 1:], axis=1)
110 | objects_conf = np.max(cls_prob[keep_boxes][:, 1:], axis=1)
111 | attrs = np.argmax(attr_prob[keep_boxes][:, 1:], axis=1)
112 | attrs_conf = np.max(attr_prob[keep_boxes][:, 1:], axis=1)
113 |
114 | return {
115 | "img_id": image_id,
116 | "img_h": np.size(im, 0),
117 | "img_w": np.size(im, 1),
118 | "objects_id": base64.b64encode(objects), # int64
119 | "objects_conf": base64.b64encode(objects_conf), # float32
120 | "attrs_id": base64.b64encode(attrs), # int64
121 | "attrs_conf": base64.b64encode(attrs_conf), # float32
122 | "num_boxes": len(keep_boxes),
123 | "boxes": base64.b64encode(cls_boxes[keep_boxes]), # float32
124 | "features": base64.b64encode(pool5[keep_boxes]) # float32
125 | }
126 |
127 |
128 | def parse_args():
129 | """
130 | Parse input arguments
131 | """
132 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network')
133 | parser.add_argument('--gpu', dest='gpu_id', help='GPU id(s) to use',
134 | default='0', type=str)
135 | parser.add_argument('--def', dest='prototxt',
136 | help='prototxt file defining the network',
137 | default=None, type=str)
138 | parser.add_argument('--out', dest='outfile',
139 | help='output filepath',
140 | default=None, type=str)
141 | parser.add_argument('--cfg', dest='cfg_file',
142 | help='optional config file', default=None, type=str)
143 | parser.add_argument('--set', dest='set_cfgs',
144 | help='set config keys', default=None,
145 | nargs=argparse.REMAINDER)
146 | parser.add_argument('--imgroot', type=str, default='/workspace/images/')
147 | parser.add_argument('--split', type=str, default='valid')
148 | parser.add_argument('--caffemodel', type=str, default='pretrained/resnet101_faster_rcnn_final_iter_320000.caffemodel')
149 |
150 | args = parser.parse_args()
151 | return args
152 |
153 |
154 | if __name__ == '__main__':
155 | # Setup the configuration, normally do not need to touch these:
156 | args = parse_args()
157 |
158 |
159 | args.cfg_file = BUTD_ROOT + "experiments/cfgs/faster_rcnn_end2end_resnet.yml" # s = 500
160 | args.prototxt = BUTD_ROOT + "models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt"
161 | args.outfile = "%s_obj100.tsv" % "vg_gqa"
162 |
163 | print('Called with args:')
164 | print(args)
165 |
166 | if args.cfg_file is not None:
167 | cfg_from_file(args.cfg_file)
168 |
169 | print('Using config:')
170 | pprint.pprint(cfg)
171 | assert cfg.TEST.HAS_RPN
172 |
173 | # Load image ids, need modification for new datasets.
174 | image_ids = load_image_ids(args.imgroot)
175 |
176 | # Generate TSV files, noramlly do not need to modify
177 | generate_tsv(args.prototxt, args.caffemodel, image_ids, args.outfile)
178 |
--------------------------------------------------------------------------------
/find_example.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | data = json.load(open('/data2/yangxu/lxmert/data/vqa/nominival.json'))
4 | N = len(data)
5 | data_new = []
6 | do_txt = open("gender.txt", "w")
7 | play_txt = open("play.txt", "w")
8 |
9 | for i in range(N):
10 | if data[i]['sent'].find('gender')>=0:
11 | data_new.append(data[i])
12 | print(data[i]['img_id'])
13 | print(data[i]['sent'])
14 | print(data[i]['label'])
15 | do_txt.write("{0}\n".format(data[i]['img_id']))
16 | do_txt.write("{0}\n".format(data[i]['sent']))
17 | do_txt.write("{0}\n".format(data[i]['label']))
18 | # if data[i]['sent'].find('gender')>=0:
19 | # data_new.append(data[i])
20 | # print(data[i]['img_id'])
21 | # print(data[i]['sent'])
22 | # print(data[i]['label'])
23 | # play_txt.write("{0}\n".format(data[i]['img_id']))
24 | # play_txt.write("{0}\n".format(data[i]['sent']))
25 | # play_txt.write("{0}\n".format(data[i]['label']))
26 |
27 | do_txt.close()
28 | # play_txt.close()
29 |
30 | # data = json.load(open('/data2/yangxu/lxmert/data/vqa/train.json'))
31 | # N = len(data)
32 | # c0 = 0
33 | # w0 = 'surfboard'
34 | # c1 = 0
35 | # w1 = 'man'
36 | # c2 = 0
37 | # w2 = 'woman'
38 | #
39 | # for i in range(N):
40 | # if data[i]['sent'].find(w0)>=0:
41 | # c0 += 1
42 | # if data[i]['sent'].find(w1) >= 0:
43 | # c1 += 1
44 | # if data[i]['sent'].find(w2) >= 0:
45 | # c2 += 1
46 | # print(w0,c0)
47 | # print(w1,c1)
48 | # print(w2,c2)
--------------------------------------------------------------------------------
/get_test.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 | import csv
7 | import sys
8 | csv.field_size_limit(sys.maxsize)
9 |
10 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
11 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
12 |
13 | id_file = 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv'
14 | coco_file = 'data/mscoco_imgfeat/test2015_obj64.tsv'
15 | outfile = 'data/vg_gqa_imgfeat/vg_gqa_obj64.tsv'
16 | all_id = []
17 | all_id_number = []
18 | coco_data = []
19 |
20 | with open(coco_file) as f:
21 | coco_reader = csv.DictReader(f, FIELDNAMES, delimiter="\t")
22 | for i, item in enumerate(coco_reader):
23 | if i % 1000 == 0:
24 | print(i)
25 | coco_data.append(item)
26 | all_id_number.append(int(item['img_id'][14:]))
27 |
28 | N = 0
29 | with open(outfile, 'a+') as tsvfile:
30 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
31 | with open(id_file) as f:
32 | id_reader = csv.DictReader(f, FIELDNAMES, delimiter="\t")
33 | for i, item in enumerate(id_reader):
34 | if i % 1000 == 0:
35 | print(i)
36 | nlvr2_id = item['img_id']
37 | if nlvr2_id[0] == 'n':
38 | N+=1
39 | nlvr2_id_number=int(nlvr2_id[1:])
40 | index_id = all_id_number.index(nlvr2_id_number)
41 | coco_datum=coco_data[index_id]
42 | coco_datum['img_id'] = item['img_id']
43 | writer.writerow(coco_datum)
44 | print(N)
--------------------------------------------------------------------------------
/get_testdev.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 | import csv
7 | import sys
8 | csv.field_size_limit(sys.maxsize)
9 |
10 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
11 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
12 |
13 | id_file = 'data/vg_gqa_imgfeat/gqa_testdev_obj36.tsv'
14 | coco_file = 'data/mscoco_imgfeat/test2015_obj64.tsv'
15 | outfile = 'data/vg_gqa_imgfeat/gqa_testdev_obj64.tsv'
16 | all_id = []
17 | all_id_number = []
18 | coco_data = []
19 |
20 | with open(coco_file) as f:
21 | coco_reader = csv.DictReader(f, FIELDNAMES, delimiter="\t")
22 | for i, item in enumerate(coco_reader):
23 | if i % 1000 == 0:
24 | print(i)
25 | coco_data.append(item)
26 | all_id_number.append(int(item['img_id'][14:]))
27 |
28 | with open(outfile, 'w') as tsvfile:
29 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
30 | with open(id_file) as f:
31 | id_reader = csv.DictReader(f, FIELDNAMES, delimiter="\t")
32 | for i, item in enumerate(id_reader):
33 | nlvr2_id = item['img_id']
34 | nlvr2_id_number=int(nlvr2_id[1:])
35 | index_id = all_id_number.index(nlvr2_id_number)
36 | coco_datum=coco_data[index_id]
37 | coco_datum['img_id'] = item['img_id']
38 | writer.writerow(coco_datum)
39 |
--------------------------------------------------------------------------------
/gqa_debug.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import os
5 | import collections
6 |
7 | import torch
8 | from tqdm import tqdm
9 | import torch.nn as nn
10 | from torch.utils.data.dataloader import DataLoader
11 |
12 | from src.param import args
13 | from src.pretrain.qa_answer_table import load_lxmert_qa
14 | from src.tasks.gqa_model import GQAModel
15 | from src.tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator
16 |
17 |
18 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
19 |
20 |
21 | def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
22 | dset = GQADataset(splits)
23 | tset = GQATorchDataset(dset)
24 | evaluator = GQAEvaluator(dset)
25 | data_loader = DataLoader(
26 | tset, batch_size=bs,
27 | shuffle=shuffle, num_workers=args.num_workers,
28 | drop_last=drop_last, pin_memory=True
29 | )
30 |
31 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
32 |
33 |
34 | class GQA:
35 | def __init__(self):
36 | self.train_tuple = get_tuple(
37 | args.train, bs=args.batch_size, shuffle=True, drop_last=True
38 | )
39 | if args.valid != "":
40 | valid_bsize = 512 if args.multiGPU else 512
41 | self.valid_tuple = get_tuple(
42 | args.valid, bs=valid_bsize,
43 | shuffle=False, drop_last=False
44 | )
45 | else:
46 | self.valid_tuple = None
47 |
48 | self.model = GQAModel(self.train_tuple.dataset.num_answers)
49 |
50 | # Load pre-trained weights
51 | if args.load_lxmert is not None:
52 | self.model.lxrt_encoder.load(args.load_lxmert)
53 | if args.load_lxmert_qa is not None:
54 | load_lxmert_qa(args.load_lxmert_qa, self.model,
55 | label2ans=self.train_tuple.dataset.label2ans)
56 |
57 | # GPU options
58 | self.model = self.model.cuda()
59 | if args.multiGPU:
60 | self.model.lxrt_encoder.multi_gpu()
61 |
62 | # Losses and optimizer
63 | self.bce_loss = nn.BCEWithLogitsLoss()
64 | self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
65 | if 'bert' in args.optim:
66 | batch_per_epoch = len(self.train_tuple.loader)
67 | t_total = int(batch_per_epoch * args.epochs)
68 | print("Total Iters: %d" % t_total)
69 | from lxrt.optimization import BertAdam
70 | self.optim = BertAdam(list(self.model.parameters()),
71 | lr=args.lr,
72 | warmup=0.1,
73 | t_total=t_total)
74 | else:
75 | self.optim = args.optimizer(list(self.model.parameters()), args.lr)
76 |
77 | self.output = args.output
78 | os.makedirs(self.output, exist_ok=True)
79 |
80 | def train(self, train_tuple, eval_tuple):
81 | dset, loader, evaluator = train_tuple
82 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
83 |
84 | best_valid = 0.
85 | for epoch in range(args.epochs):
86 | quesid2ans = {}
87 | for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
88 |
89 | self.model.train()
90 | self.optim.zero_grad()
91 |
92 | feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
93 | logit = self.model(feats, boxes, sent)
94 | assert logit.dim() == target.dim() == 2
95 | if args.mce_loss:
96 | max_value, target = target.max(1)
97 | loss = self.mce_loss(logit, target) * logit.size(1)
98 | else:
99 | loss = self.bce_loss(logit, target)
100 | loss = loss * logit.size(1)
101 |
102 | loss.backward()
103 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
104 | self.optim.step()
105 |
106 | score, label = logit.max(1)
107 | for qid, l in zip(ques_id, label.cpu().numpy()):
108 | ans = dset.label2ans[l]
109 | quesid2ans[qid] = ans
110 |
111 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
112 |
113 | if self.valid_tuple is not None: # Do Validation
114 | valid_score = self.evaluate(eval_tuple)
115 | if valid_score > best_valid:
116 | best_valid = valid_score
117 | self.save("BEST")
118 |
119 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
120 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
121 |
122 | print(log_str, end='')
123 |
124 | with open(self.output + "/log.log", 'a') as f:
125 | f.write(log_str)
126 | f.flush()
127 |
128 | self.save("LAST")
129 |
130 | def predict(self, eval_tuple: DataTuple, dump=None):
131 | self.model.eval()
132 | dset, loader, evaluator = eval_tuple
133 | quesid2ans = {}
134 | for i, datum_tuple in enumerate(loader):
135 | ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
136 | with torch.no_grad():
137 | feats, boxes = feats.cuda(), boxes.cuda()
138 | logit = self.model(feats, boxes, sent)
139 | score, label = logit.max(1)
140 | for qid, l in zip(ques_id, label.cpu().numpy()):
141 | ans = dset.label2ans[l]
142 | quesid2ans[qid] = ans
143 | if dump is not None:
144 | evaluator.dump_result(quesid2ans, dump)
145 | return quesid2ans
146 |
147 | def evaluate(self, eval_tuple: DataTuple, dump=None):
148 | dset, loader, evaluator = eval_tuple
149 | quesid2ans = self.predict(eval_tuple, dump)
150 | return evaluator.evaluate(quesid2ans)
151 |
152 | @staticmethod
153 | def oracle_score(data_tuple):
154 | dset, loader, evaluator = data_tuple
155 | quesid2ans = {}
156 | for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
157 | _, label = target.max(1)
158 | for qid, l in zip(ques_id, label.cpu().numpy()):
159 | ans = dset.label2ans[l]
160 | quesid2ans[qid] = ans
161 | return evaluator.evaluate(quesid2ans)
162 |
163 | def save(self, name):
164 | torch.save(self.model.state_dict(),
165 | os.path.join(self.output, "%s.pth" % name))
166 |
167 | def load(self, path):
168 | print("Load model from %s" % path)
169 | state_dict = torch.load("%s.pth" % path)
170 | for key in list(state_dict.keys()):
171 | if '.module' in key:
172 | state_dict[key.replace('.module', '')] = state_dict.pop(key)
173 | self.model.load_state_dict(state_dict, strict=False)
174 |
175 |
176 | if __name__ == "__main__":
177 | # Build Class
178 | gqa = GQA()
179 |
180 | # Load Model
181 | if args.load is not None:
182 | gqa.load(args.load)
183 |
184 | # Test or Train
185 | if args.test is not None:
186 | args.fast = args.tiny = False # Always loading all data in test
187 | if 'submit' in args.test:
188 | gqa.predict(
189 | get_tuple(args.test, bs=args.batch_size,
190 | shuffle=False, drop_last=False),
191 | dump=os.path.join(args.output, 'submit_predict.json')
192 | )
193 | if 'testdev' in args.test:
194 | result = gqa.evaluate(
195 | get_tuple('testdev', bs=args.batch_size,
196 | shuffle=False, drop_last=False),
197 | dump=os.path.join(args.output, 'testdev_predict.json')
198 | )
199 | print(result)
200 | else:
201 | # print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100))
202 | print('Splits in Train data:', gqa.train_tuple.dataset.splits)
203 | if gqa.valid_tuple is not None:
204 | print('Splits in Valid data:', gqa.valid_tuple.dataset.splits)
205 | print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100))
206 | else:
207 | print("DO NOT USE VALIDATION")
208 | gqa.train(gqa.train_tuple, gqa.valid_tuple)
209 |
210 |
211 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Copy from https://github.com/huggingface/pytorch-transformers
2 |
3 | # PyTorch
4 | torch>=1.0.0
5 | # Progress bars in model download and training scripts
6 | tqdm
7 | # Accessing files from S3 directly.
8 | boto3
9 | # Used for downloading models over HTTP
10 | requests
11 |
12 |
--------------------------------------------------------------------------------
/run/README.md:
--------------------------------------------------------------------------------
1 | # Running Script Arguments
2 |
3 | ```
4 | Data Splits:
5 | --train [str,str,...]: use the splits (separated by comma) in training.
6 | --valid [str,str,...]: use the splits (separated by comma) in validation.
7 | --test [str,str,...]: use the splits (separated by comma) in testing.
8 | Model Architecture:
9 | --llayers [int]: number of layers in language encoder.
10 | --xlayers [int]: number of layers in cross-modality encoder.
11 | --rlayers [int]: number of layers in object relationship encoder.
12 | Load Weights:
13 | --load [str='path/to/saved_model']: load fine-tuned model path/to/saved_model.pth.
14 | --loadLXMERT [str='path/to/saved_model']: load pre-trained model without answer heads from path/to/saved_model_LXRT.pth.
15 | --loadLXMERTQA [str='path/to/saved_model']: load pre-trained model with answer head path/to/saved_model_LXRT.pth.
16 | --fromScratch: If none of the above loading parameters are set, the default mode would
17 | load the pre-trained BERT weights.
18 | As we promised to EMNLP reviewers, the language encoder would be re-initialized with this one-line argument to test the performance without BERT weights.
19 | Training Hyper Parameters:
20 | --batchSize [int]: batch size.
21 | --optim [str]: optimizers.
22 | --lr [float]: peak learning rate.
23 | --epochs [int]: training epochs.
24 | Debugging:
25 | --tiny: Load 512 images for each data split. (Note: number of images might be changed due to dataset specification)
26 | --fast: Load 5000 images for each data split. (Note: number of images might be changed due to dataset specification)
27 | ```
28 |
29 | # Pre-training-Specific Arguments
30 | ```
31 | Pre-training Tasks:
32 | --taskMaskLM: use the masked language model task.
33 | --taskObjPredict: use the masked object prediction task.
34 | --taskMatched: use the cross-modality matched task.
35 | --taskQA: use the image QA task.
36 | Visual Pre-training Losses (Tasks):
37 | --visualLosses [str,str,...]: The sub-tasks in pre-training visual modality. Each one is from 'obj,attr,feat'.
38 | obj: detected-object-label classification.
39 | attr: detected-object-attribute classification.
40 | feat: RoI-feature regression.
41 | Mask Rate in Pre-training:
42 | --wordMaskRate [float]: The prob of masking a word.
43 | --objMaskRate [float]: The prob of masking an object.
44 | Initialization:
45 | --fromScratch: The default mode would load the pre-trained BERT weights into the model.
46 | As we promised to EMNLP reviewers, this option would re-initialize the language encoder.
47 | ```
48 |
49 |
50 |
--------------------------------------------------------------------------------
/run/fsb2.bash:
--------------------------------------------------------------------------------
1 | # The name of experiment
2 | name=fsb2
3 |
4 | # Create dirs and make backup
5 | output=snap/pretrain/$name
6 | mkdir -p $output/src
7 | cp -r src/* $output/src/
8 | cp $0 $output/run.bash
9 |
10 | # Pre-training
11 | CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12 | python src/pretrain/lxmert_pretrain2.py \
13 | --taskMaskLM --taskObjPredict --taskMatched --taskQA \
14 | --visualLosses obj,attr,feat \
15 | --wordMaskRate 0.15 --objMaskRate 0.15 \
16 | --train mscoco_train,mscoco_nominival,vgnococo --valid mscoco_minival \
17 | --llayers 9 --xlayers 5 --rlayers 5 \
18 | --batchSize 196 --optim bert --lr 5e-5 --epochs 20 \
19 | --MV_size 500 --ML_size 500 \
20 | --tqdm --output $output --bert_type ft_same --multiGPU ${@:2}
21 |
22 |
--------------------------------------------------------------------------------
/run/gqa_finetuneft.bash:
--------------------------------------------------------------------------------
1 | # The name of this experiment.
2 | name=$2
3 |
4 | # Save logs and models under snap/gqa; make backup.
5 | output=snap/gqa/fb4/Epoch20/$name
6 | mkdir -p $output/src
7 | cp -r src/* $output/src/
8 | cp $0 $output/run.bash
9 |
10 | # See Readme.md for option details.
11 | CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12 | python src/tasks/gqa.py \
13 | --train train,valid --valid testdev \
14 | --llayers 9 --xlayers 5 --rlayers 5 \
15 | --loadLXMERTQA snap/pretrain/fb4/Epoch20 \
16 | --MV_size 800 --ML_size 800 \
17 | --bert_type ft\
18 | --batchSize 32 --optim bert --lr 5e-6 --epochs 4 \
19 | --tqdm --output $output ${@:3}
20 |
--------------------------------------------------------------------------------
/run/nlvr2_ft.bash:
--------------------------------------------------------------------------------
1 | # The name of this experiment.
2 | name=$2
3 |
4 | # Save logs and models under snap/nlvr2; Make backup.
5 | output=snap/nlvr2/nfsb2/Epoch20/$name
6 | mkdir -p $output/src
7 | cp -r src/* $output/src/
8 | cp $0 $output/run.bash
9 |
10 | # See run/Readme.md for option details.
11 | CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12 | python src/tasks/nlvr2.py \
13 | --train train --valid valid \
14 | --llayers 9 --xlayers 5 --rlayers 5 \
15 | --loadLXMERT snap/pretrain/nfsb2/Epoch20 \
16 | --MV_size 500 --ML_size 500 \
17 | --bert_type ft_same \
18 | --batchSize 32 --optim bert --lr $3 --epochs 4 \
19 | --tqdm --multiGPU --output $output ${@:4}
20 |
21 |
--------------------------------------------------------------------------------
/run/vqa_finetuneft.bash:
--------------------------------------------------------------------------------
1 | # The name of this experiment.
2 | name=$2
3 |
4 | # Save logs and models under snap/vqa; make backup.
5 | output=snap/vqa/fb3/Epoch20/$name
6 | mkdir -p $output/src
7 | cp -r src/* $output/src/
8 | cp $0 $output/run.bash
9 |
10 | # See Readme.md for option details.
11 | CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12 | python src/tasks/vqa.py \
13 | --train train,nominival --valid minival \
14 | --llayers 9 --xlayers 5 --rlayers 5 \
15 | --loadLXMERTQA snap/pretrain/fb3/Epoch20 \
16 | --batchSize 32 --optim bert \
17 | --MV_size 800 --ML_size 800 \
18 | --lr_schedule warmup_linear_yx --lr 5e-5 --lr_min 0 --epochs 4 \
19 | --tqdm --output $output --bert_type ft ${@:3}
20 |
21 | # output=snap/vqa/ft20/$name
22 | #--loadLXMERTQA snap/pretrain/ftlxmert/Epoch20 \
23 |
--------------------------------------------------------------------------------
/src/lxrt/entry.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 project LXRT.
3 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import os
19 |
20 | import torch
21 | import torch.nn as nn
22 |
23 | from lxrt.tokenization import BertTokenizer
24 | from lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG
25 |
26 |
27 | class InputFeatures(object):
28 | """A single set of features of data."""
29 |
30 | def __init__(self, input_ids, input_mask, segment_ids):
31 | self.input_ids = input_ids
32 | self.input_mask = input_mask
33 | self.segment_ids = segment_ids
34 |
35 |
36 | def convert_sents_to_features(sents, max_seq_length, tokenizer):
37 | """Loads a data file into a list of `InputBatch`s."""
38 |
39 | features = []
40 | for (i, sent) in enumerate(sents):
41 | tokens_a = tokenizer.tokenize(sent.strip())
42 |
43 | # Account for [CLS] and [SEP] with "- 2"
44 | if len(tokens_a) > max_seq_length - 2:
45 | tokens_a = tokens_a[:(max_seq_length - 2)]
46 |
47 | # Keep segment id which allows loading BERT-weights.
48 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
49 | segment_ids = [0] * len(tokens)
50 |
51 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
52 |
53 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
54 | # tokens are attended to.
55 | input_mask = [1] * len(input_ids)
56 |
57 | # Zero-pad up to the sequence length.
58 | padding = [0] * (max_seq_length - len(input_ids))
59 | input_ids += padding
60 | input_mask += padding
61 | segment_ids += padding
62 |
63 | assert len(input_ids) == max_seq_length
64 | assert len(input_mask) == max_seq_length
65 | assert len(segment_ids) == max_seq_length
66 |
67 | features.append(
68 | InputFeatures(input_ids=input_ids,
69 | input_mask=input_mask,
70 | segment_ids=segment_ids))
71 | return features
72 |
73 |
74 | def set_visual_config(args):
75 | VISUAL_CONFIG.l_layers = args.llayers
76 | VISUAL_CONFIG.x_layers = args.xlayers
77 | VISUAL_CONFIG.r_layers = args.rlayers
78 |
79 |
80 | class LXRTEncoder(nn.Module):
81 | def __init__(self, args, max_seq_length, mode='x'):
82 | super().__init__()
83 | self.max_seq_length = max_seq_length
84 | set_visual_config(args)
85 |
86 | # Using the bert tokenizer
87 | self.tokenizer = BertTokenizer.from_pretrained(
88 | "bert-base-uncased",
89 | do_lower_case=True
90 | )
91 |
92 | # Build LXRT Model
93 | self.model = VisualBertForLXRFeature.from_pretrained(
94 | "bert-base-uncased",
95 | args=args,
96 | mode=mode
97 | )
98 |
99 | if args.from_scratch:
100 | print("initializing all the weights")
101 | self.model.apply(self.model.init_bert_weights)
102 |
103 | def multi_gpu(self):
104 | self.model = nn.DataParallel(self.model)
105 |
106 | @property
107 | def dim(self):
108 | return 768
109 |
110 | def forward(self, sents, feats, visual_attention_mask=None):
111 | train_features = convert_sents_to_features(
112 | sents, self.max_seq_length, self.tokenizer)
113 |
114 | input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
115 | input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
116 | segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
117 |
118 | output = self.model(input_ids, segment_ids, input_mask,
119 | visual_feats=feats,
120 | visual_attention_mask=visual_attention_mask)
121 | return output
122 |
123 | def save(self, path):
124 | torch.save(self.model.state_dict(),
125 | os.path.join("%s_LXRT.pth" % path))
126 |
127 | def load(self, path):
128 | # Load state_dict from snapshot file
129 | print("Load LXMERT pre-trained model from %s" % path)
130 | # state_dict = torch.load("%s_LXRT.pth" % path)
131 | state_dict = torch.load("%s_LXRT.pth" % path)['state_dict']
132 | new_state_dict = {}
133 | for key, value in state_dict.items():
134 | if key.startswith("module."):
135 | new_state_dict[key[len("module."):]] = value
136 | else:
137 | new_state_dict[key] = value
138 | state_dict = new_state_dict
139 |
140 | # Print out the differences of pre-trained and model weights.
141 | load_keys = set(state_dict.keys())
142 | model_keys = set(self.model.state_dict().keys())
143 | print()
144 | print("Weights in loaded but not in model:")
145 | for key in sorted(load_keys.difference(model_keys)):
146 | print(key)
147 | print()
148 | print("Weights in model but not in loaded:")
149 | for key in sorted(model_keys.difference(load_keys)):
150 | print(key)
151 | print()
152 |
153 | # Load weights to model
154 | self.model.load_state_dict(state_dict, strict=False)
155 | tt=0
156 |
157 |
158 |
159 |
160 |
--------------------------------------------------------------------------------
/src/lxrt/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | import json
7 | import logging
8 | import os
9 | import shutil
10 | import tempfile
11 | from functools import wraps
12 | from hashlib import sha256
13 | import sys
14 | from io import open
15 |
16 | import boto3
17 | import requests
18 | from botocore.exceptions import ClientError
19 | from tqdm import tqdm
20 |
21 | try:
22 | from urllib.parse import urlparse
23 | except ImportError:
24 | from urlparse import urlparse
25 |
26 | try:
27 | from pathlib import Path
28 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
29 | Path.home() / '.pytorch_pretrained_bert'))
30 | print(PYTORCH_PRETRAINED_BERT_CACHE)
31 | except (AttributeError, ImportError):
32 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
33 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
34 | print(PYTORCH_PRETRAINED_BERT_CACHE)
35 |
36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37 |
38 |
39 | def url_to_filename(url, etag=None):
40 | """
41 | Convert `url` into a hashed filename in a repeatable way.
42 | If `etag` is specified, append its hash to the url's, delimited
43 | by a period.
44 | """
45 | url_bytes = url.encode('utf-8')
46 | url_hash = sha256(url_bytes)
47 | filename = url_hash.hexdigest()
48 |
49 | if etag:
50 | etag_bytes = etag.encode('utf-8')
51 | etag_hash = sha256(etag_bytes)
52 | filename += '.' + etag_hash.hexdigest()
53 |
54 | return filename
55 |
56 |
57 | def filename_to_url(filename, cache_dir=None):
58 | """
59 | Return the url and etag (which may be ``None``) stored for `filename`.
60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61 | """
62 | if cache_dir is None:
63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65 | cache_dir = str(cache_dir)
66 |
67 | cache_path = os.path.join(cache_dir, filename)
68 | if not os.path.exists(cache_path):
69 | raise EnvironmentError("file {} not found".format(cache_path))
70 |
71 | meta_path = cache_path + '.json'
72 | if not os.path.exists(meta_path):
73 | raise EnvironmentError("file {} not found".format(meta_path))
74 |
75 | with open(meta_path, encoding="utf-8") as meta_file:
76 | metadata = json.load(meta_file)
77 | url = metadata['url']
78 | etag = metadata['etag']
79 |
80 | return url, etag
81 |
82 |
83 | def cached_path(url_or_filename, cache_dir=None):
84 | """
85 | Given something that might be a URL (or might be a local path),
86 | determine which. If it's a URL, download the file and cache it, and
87 | return the path to the cached file. If it's already a local path,
88 | make sure the file exists and then return the path.
89 | """
90 | if cache_dir is None:
91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93 | url_or_filename = str(url_or_filename)
94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95 | cache_dir = str(cache_dir)
96 |
97 | parsed = urlparse(url_or_filename)
98 |
99 | if parsed.scheme in ('http', 'https', 's3'):
100 | # URL, so get it from the cache (downloading if necessary)
101 | return get_from_cache(url_or_filename, cache_dir)
102 | elif os.path.exists(url_or_filename):
103 | # File, and it exists.
104 | return url_or_filename
105 | elif parsed.scheme == '':
106 | # File, but it doesn't exist.
107 | raise EnvironmentError("file {} not found".format(url_or_filename))
108 | else:
109 | # Something unknown
110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111 |
112 |
113 | def split_s3_path(url):
114 | """Split a full s3 path into the bucket name and path."""
115 | parsed = urlparse(url)
116 | if not parsed.netloc or not parsed.path:
117 | raise ValueError("bad s3 path {}".format(url))
118 | bucket_name = parsed.netloc
119 | s3_path = parsed.path
120 | # Remove '/' at beginning of path.
121 | if s3_path.startswith("/"):
122 | s3_path = s3_path[1:]
123 | return bucket_name, s3_path
124 |
125 |
126 | def s3_request(func):
127 | """
128 | Wrapper function for s3 requests in order to create more helpful error
129 | messages.
130 | """
131 |
132 | @wraps(func)
133 | def wrapper(url, *args, **kwargs):
134 | try:
135 | return func(url, *args, **kwargs)
136 | except ClientError as exc:
137 | if int(exc.response["Error"]["Code"]) == 404:
138 | raise EnvironmentError("file {} not found".format(url))
139 | else:
140 | raise
141 |
142 | return wrapper
143 |
144 |
145 | @s3_request
146 | def s3_etag(url):
147 | """Check ETag on S3 object."""
148 | s3_resource = boto3.resource("s3")
149 | bucket_name, s3_path = split_s3_path(url)
150 | s3_object = s3_resource.Object(bucket_name, s3_path)
151 | return s3_object.e_tag
152 |
153 |
154 | @s3_request
155 | def s3_get(url, temp_file):
156 | """Pull a file directly from S3."""
157 | s3_resource = boto3.resource("s3")
158 | bucket_name, s3_path = split_s3_path(url)
159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160 |
161 |
162 | def http_get(url, temp_file):
163 | req = requests.get(url, stream=True)
164 | content_length = req.headers.get('Content-Length')
165 | total = int(content_length) if content_length is not None else None
166 | progress = tqdm(unit="B", total=total)
167 | for chunk in req.iter_content(chunk_size=1024):
168 | if chunk: # filter out keep-alive new chunks
169 | progress.update(len(chunk))
170 | temp_file.write(chunk)
171 | progress.close()
172 |
173 |
174 | def get_from_cache(url, cache_dir=None):
175 | """
176 | Given a URL, look for the corresponding dataset in the local cache.
177 | If it's not there, download it. Then return the path to the cached file.
178 | """
179 | if cache_dir is None:
180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182 | cache_dir = str(cache_dir)
183 |
184 | if not os.path.exists(cache_dir):
185 | os.makedirs(cache_dir)
186 |
187 | # Get eTag to add to filename, if it exists.
188 | if url.startswith("s3://"):
189 | etag = s3_etag(url)
190 | else:
191 | response = requests.head(url, allow_redirects=True)
192 | if response.status_code != 200:
193 | raise IOError("HEAD request failed for url {} with status code {}"
194 | .format(url, response.status_code))
195 | etag = response.headers.get("ETag")
196 |
197 | filename = url_to_filename(url, etag)
198 |
199 | # get cache path to put the file
200 | cache_path = os.path.join(cache_dir, filename)
201 |
202 | if not os.path.exists(cache_path):
203 | # Download to temporary file, then copy to cache dir once finished.
204 | # Otherwise you get corrupt cache entries if the download gets interrupted.
205 | with tempfile.NamedTemporaryFile() as temp_file:
206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
207 |
208 | # GET file object
209 | if url.startswith("s3://"):
210 | s3_get(url, temp_file)
211 | else:
212 | http_get(url, temp_file)
213 |
214 | # we are copying the file before closing it, so flush to avoid truncation
215 | temp_file.flush()
216 | # shutil.copyfileobj() starts at the current position, so go to the start
217 | temp_file.seek(0)
218 |
219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
220 | with open(cache_path, 'wb') as cache_file:
221 | shutil.copyfileobj(temp_file, cache_file)
222 |
223 | logger.info("creating metadata file for %s", cache_path)
224 | meta = {'url': url, 'etag': etag}
225 | meta_path = cache_path + '.json'
226 | with open(meta_path, 'w', encoding="utf-8") as meta_file:
227 | json.dump(meta, meta_file)
228 |
229 | logger.info("removing temp file %s", temp_file.name)
230 |
231 | return cache_path
232 |
233 |
234 | def read_set_from_file(filename):
235 | '''
236 | Extract a de-duped collection (set) of text from a file.
237 | Expected file format is one item per line.
238 | '''
239 | collection = set()
240 | with open(filename, 'r', encoding='utf-8') as file_:
241 | for line in file_:
242 | collection.add(line.rstrip())
243 | return collection
244 |
245 |
246 | def get_file_extension(path, dot=True, lower=True):
247 | ext = os.path.splitext(path)[1]
248 | ext = ext if dot else ext[1:]
249 | return ext.lower() if lower else ext
250 |
--------------------------------------------------------------------------------
/src/lxrt/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 project LXRT
3 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch optimization for BERT model."""
17 |
18 | import math
19 | import torch
20 | from torch.optim import Optimizer
21 | from torch.optim.optimizer import required
22 | import logging
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | def warmup_cosine(x, warmup=0.002, opt = None):
27 | if x < warmup:
28 | return x/warmup
29 | return 0.5 * (1.0 + torch.cos(math.pi * x))
30 |
31 | def warmup_constant(x, warmup=0.002, opt = None):
32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
33 | Learning rate is 1. afterwards. """
34 | if x < warmup:
35 | return x/warmup
36 | return 1.0
37 |
38 | def warmup_linear(x, warmup=0.002, opt = None):
39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
40 | After `t_total`-th training step, learning rate is zero. """
41 | if x < warmup:
42 | return x/warmup
43 | return max((x-1.)/(warmup-1.), 0)
44 |
45 | def warmup_linear_yx(x, warmup=0.002, opt = None):
46 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
47 | After `t_total`-th training step, learning rate is zero. """
48 | if x < warmup:
49 | return x/warmup
50 | lr_max = 1
51 | lr_min = opt.lr_min/opt.lr
52 | b = (lr_max - lr_min*warmup)/(1-warmup)
53 | a = (lr_min - lr_max)/(1-warmup)
54 | return a*x+b
55 |
56 | def warmup_stair(x, warmup=0.002, opt = None):
57 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
58 | After `t_total`-th training step, learning rate is zero. """
59 | if x < warmup:
60 | return x/warmup
61 | return max((x-1.)/(warmup-1.), 0)
62 |
63 | SCHEDULES = {
64 | 'warmup_cosine': warmup_cosine,
65 | 'warmup_constant': warmup_constant,
66 | 'warmup_linear': warmup_linear,
67 | 'warmup_linear_yx': warmup_linear_yx,
68 | 'warmup_stair': warmup_stair,
69 | }
70 |
71 |
72 | class BertAdam(Optimizer):
73 | """Implements BERT version of Adam algorithm with weight decay fix.
74 | Params:
75 | lr: learning rate
76 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
77 | t_total: total number of training steps for the learning
78 | rate schedule, -1 means constant learning rate. Default: -1
79 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
80 | b1: Adams b1. Default: 0.9
81 | b2: Adams b2. Default: 0.999
82 | e: Adams epsilon. Default: 1e-6
83 | weight_decay: Weight decay. Default: 0.01
84 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
85 | """
86 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
87 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
88 | max_grad_norm=1.0, args=None):
89 | if lr is not required and lr < 0.0:
90 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
91 | if schedule not in SCHEDULES:
92 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
93 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
94 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
95 | if not 0.0 <= b1 < 1.0:
96 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
97 | if not 0.0 <= b2 < 1.0:
98 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
99 | if not e >= 0.0:
100 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
101 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
102 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
103 | max_grad_norm=max_grad_norm)
104 | self.args = args
105 | super(BertAdam, self).__init__(params, defaults)
106 |
107 | def get_lr(self):
108 | lr = []
109 | for group in self.param_groups:
110 | for p in group['params']:
111 | state = self.state[p]
112 | if len(state) == 0:
113 | return [0]
114 | if group['t_total'] != -1:
115 | schedule_fct = SCHEDULES[group['schedule']]
116 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'], self.args)
117 | else:
118 | lr_scheduled = group['lr']
119 | lr.append(lr_scheduled)
120 | return lr
121 |
122 | def step(self, closure=None):
123 | """Performs a single optimization step.
124 |
125 | Arguments:
126 | closure (callable, optional): A closure that reevaluates the model
127 | and returns the loss.
128 | """
129 | loss = None
130 | if closure is not None:
131 | loss = closure()
132 |
133 | warned_for_t_total = False
134 |
135 | for group in self.param_groups:
136 | for p in group['params']:
137 | if p.grad is None:
138 | continue
139 | grad = p.grad.data
140 | if grad.is_sparse:
141 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
142 |
143 | state = self.state[p]
144 |
145 | # State initialization
146 | if len(state) == 0:
147 | state['step'] = 0
148 | # Exponential moving average of gradient values
149 | state['next_m'] = torch.zeros_like(p.data)
150 | # Exponential moving average of squared gradient values
151 | state['next_v'] = torch.zeros_like(p.data)
152 |
153 | next_m, next_v = state['next_m'], state['next_v']
154 | beta1, beta2 = group['b1'], group['b2']
155 |
156 | # LXRT: grad is clipped outside.
157 | # Add grad clipping
158 | # if group['max_grad_norm'] > 0:
159 | # clip_grad_norm_(p, group['max_grad_norm'])
160 |
161 | # Decay the first and second moment running average coefficient
162 | # In-place operations to update the averages at the same time
163 | next_m.mul_(beta1).add_(1 - beta1, grad)
164 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
165 | update = next_m / (next_v.sqrt() + group['e'])
166 |
167 | # Just adding the square of the weights to the loss function is *not*
168 | # the correct way of using L2 regularization/weight decay with Adam,
169 | # since that will interact with the m and v parameters in strange ways.
170 | #
171 | # Instead we want to decay the weights in a manner that doesn't interact
172 | # with the m/v parameters. This is equivalent to adding the square
173 | # of the weights to the loss with plain (non-momentum) SGD.
174 | if group['weight_decay'] > 0.0:
175 | update += group['weight_decay'] * p.data
176 |
177 | if group['t_total'] != -1:
178 | schedule_fct = SCHEDULES[group['schedule']]
179 | progress = state['step']/group['t_total']
180 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'], self.args)
181 | # warning for exceeding t_total (only active with warmup_linear
182 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
183 | logger.warning(
184 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
185 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
186 | warned_for_t_total = True
187 | # end warning
188 | else:
189 | lr_scheduled = group['lr']
190 |
191 | update_with_lr = lr_scheduled * update
192 | p.data.add_(-update_with_lr)
193 |
194 | state['step'] += 1
195 |
196 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
197 | # No bias correction
198 | # bias_correction1 = 1 - beta1 ** state['step']
199 | # bias_correction2 = 1 - beta2 ** state['step']
200 |
201 | return loss
202 |
--------------------------------------------------------------------------------
/src/lxrt/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | import collections
18 | import logging
19 | import os
20 | import unicodedata
21 | from io import open
22 |
23 | from .file_utils import cached_path
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
28 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
29 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
30 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
31 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
32 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
33 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
34 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
35 | }
36 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
37 | 'bert-base-uncased': 512,
38 | 'bert-large-uncased': 512,
39 | 'bert-base-cased': 512,
40 | 'bert-large-cased': 512,
41 | 'bert-base-multilingual-uncased': 512,
42 | 'bert-base-multilingual-cased': 512,
43 | 'bert-base-chinese': 512,
44 | }
45 | VOCAB_NAME = 'vocab.txt'
46 |
47 |
48 | def load_vocab(vocab_file):
49 | """Loads a vocabulary file into a dictionary."""
50 | vocab = collections.OrderedDict()
51 | index = 0
52 | with open(vocab_file, "r", encoding="utf-8") as reader:
53 | while True:
54 | token = reader.readline()
55 | if not token:
56 | break
57 | token = token.strip()
58 | vocab[token] = index
59 | index += 1
60 | return vocab
61 |
62 |
63 | def whitespace_tokenize(text):
64 | """Runs basic whitespace cleaning and splitting on a piece of text."""
65 | text = text.strip()
66 | if not text:
67 | return []
68 | tokens = text.split()
69 | return tokens
70 |
71 |
72 | class BertTokenizer(object):
73 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
74 |
75 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
76 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
77 | """Constructs a BertTokenizer.
78 |
79 | Args:
80 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
81 | do_lower_case: Whether to lower case the input
82 | Only has an effect when do_wordpiece_only=False
83 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
84 | max_len: An artificial maximum length to truncate tokenized sequences to;
85 | Effective maximum length is always the minimum of this
86 | value (if specified) and the underlying BERT model's
87 | sequence length.
88 | never_split: List of tokens which will never be split during tokenization.
89 | Only has an effect when do_wordpiece_only=False
90 | """
91 | if not os.path.isfile(vocab_file):
92 | raise ValueError(
93 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
94 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
95 | self.vocab = load_vocab(vocab_file)
96 | self.ids_to_tokens = collections.OrderedDict(
97 | [(ids, tok) for tok, ids in self.vocab.items()])
98 | self.do_basic_tokenize = do_basic_tokenize
99 | if do_basic_tokenize:
100 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
101 | never_split=never_split)
102 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
103 | self.max_len = max_len if max_len is not None else int(1e12)
104 |
105 | def tokenize(self, text):
106 | if self.do_basic_tokenize:
107 | split_tokens = []
108 | for token in self.basic_tokenizer.tokenize(text):
109 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
110 | split_tokens.append(sub_token)
111 | else:
112 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
113 | return split_tokens
114 |
115 | def convert_tokens_to_ids(self, tokens):
116 | """Converts a sequence of tokens into ids using the vocab."""
117 | ids = []
118 | for token in tokens:
119 | ids.append(self.vocab[token])
120 | if len(ids) > self.max_len:
121 | logger.warning(
122 | "Token indices sequence length is longer than the specified maximum "
123 | " sequence length for this BERT model ({} > {}). Running this"
124 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
125 | )
126 | return ids
127 |
128 | def convert_ids_to_tokens(self, ids):
129 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
130 | tokens = []
131 | for i in ids:
132 | tokens.append(self.ids_to_tokens[i])
133 | return tokens
134 |
135 | @classmethod
136 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
137 | """
138 | Instantiate a PreTrainedBertModel from a pre-trained model file.
139 | Download and cache the pre-trained model file if needed.
140 | """
141 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
142 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
143 | else:
144 | vocab_file = pretrained_model_name_or_path
145 | if os.path.isdir(vocab_file):
146 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
147 | # redirect to the cache, if necessary
148 | try:
149 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
150 | except EnvironmentError:
151 | logger.error(
152 | "Model name '{}' was not found in model name list ({}). "
153 | "We assumed '{}' was a path or url but couldn't find any file "
154 | "associated to this path or url.".format(
155 | pretrained_model_name_or_path,
156 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
157 | vocab_file))
158 | return None
159 | if resolved_vocab_file == vocab_file:
160 | logger.info("loading vocabulary file {}".format(vocab_file))
161 | else:
162 | logger.info("loading vocabulary file {} from cache at {}".format(
163 | vocab_file, resolved_vocab_file))
164 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
165 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
166 | # than the number of positional embeddings
167 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
168 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
169 | # Instantiate tokenizer.
170 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
171 | return tokenizer
172 |
173 |
174 | class BasicTokenizer(object):
175 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
176 |
177 | def __init__(self,
178 | do_lower_case=True,
179 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
180 | """Constructs a BasicTokenizer.
181 |
182 | Args:
183 | do_lower_case: Whether to lower case the input.
184 | """
185 | self.do_lower_case = do_lower_case
186 | self.never_split = never_split
187 |
188 | def tokenize(self, text):
189 | """Tokenizes a piece of text."""
190 | text = self._clean_text(text)
191 | # This was added on November 1st, 2018 for the multilingual and Chinese
192 | # models. This is also applied to the English models now, but it doesn't
193 | # matter since the English models were not trained on any Chinese data
194 | # and generally don't have any Chinese data in them (there are Chinese
195 | # characters in the vocabulary because Wikipedia does have some Chinese
196 | # words in the English Wikipedia.).
197 | text = self._tokenize_chinese_chars(text)
198 | orig_tokens = whitespace_tokenize(text)
199 | split_tokens = []
200 | for token in orig_tokens:
201 | if self.do_lower_case and token not in self.never_split:
202 | token = token.lower()
203 | token = self._run_strip_accents(token)
204 | split_tokens.extend(self._run_split_on_punc(token))
205 |
206 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
207 | return output_tokens
208 |
209 | def _run_strip_accents(self, text):
210 | """Strips accents from a piece of text."""
211 | text = unicodedata.normalize("NFD", text)
212 | output = []
213 | for char in text:
214 | cat = unicodedata.category(char)
215 | if cat == "Mn":
216 | continue
217 | output.append(char)
218 | return "".join(output)
219 |
220 | def _run_split_on_punc(self, text):
221 | """Splits punctuation on a piece of text."""
222 | if text in self.never_split:
223 | return [text]
224 | chars = list(text)
225 | i = 0
226 | start_new_word = True
227 | output = []
228 | while i < len(chars):
229 | char = chars[i]
230 | if _is_punctuation(char):
231 | output.append([char])
232 | start_new_word = True
233 | else:
234 | if start_new_word:
235 | output.append([])
236 | start_new_word = False
237 | output[-1].append(char)
238 | i += 1
239 |
240 | return ["".join(x) for x in output]
241 |
242 | def _tokenize_chinese_chars(self, text):
243 | """Adds whitespace around any CJK character."""
244 | output = []
245 | for char in text:
246 | cp = ord(char)
247 | if self._is_chinese_char(cp):
248 | output.append(" ")
249 | output.append(char)
250 | output.append(" ")
251 | else:
252 | output.append(char)
253 | return "".join(output)
254 |
255 | def _is_chinese_char(self, cp):
256 | """Checks whether CP is the codepoint of a CJK character."""
257 | # This defines a "chinese character" as anything in the CJK Unicode block:
258 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
259 | #
260 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
261 | # despite its name. The modern Korean Hangul alphabet is a different block,
262 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
263 | # space-separated words, so they are not treated specially and handled
264 | # like the all of the other languages.
265 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
266 | (cp >= 0x3400 and cp <= 0x4DBF) or #
267 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
268 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
269 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
270 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
271 | (cp >= 0xF900 and cp <= 0xFAFF) or #
272 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
273 | return True
274 |
275 | return False
276 |
277 | def _clean_text(self, text):
278 | """Performs invalid character removal and whitespace cleanup on text."""
279 | output = []
280 | for char in text:
281 | cp = ord(char)
282 | if cp == 0 or cp == 0xfffd or _is_control(char):
283 | continue
284 | if _is_whitespace(char):
285 | output.append(" ")
286 | else:
287 | output.append(char)
288 | return "".join(output)
289 |
290 |
291 | class WordpieceTokenizer(object):
292 | """Runs WordPiece tokenization."""
293 |
294 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
295 | self.vocab = vocab
296 | self.unk_token = unk_token
297 | self.max_input_chars_per_word = max_input_chars_per_word
298 |
299 | def tokenize(self, text):
300 | """Tokenizes a piece of text into its word pieces.
301 |
302 | This uses a greedy longest-match-first algorithm to perform tokenization
303 | using the given vocabulary.
304 |
305 | For example:
306 | input = "unaffable"
307 | output = ["un", "##aff", "##able"]
308 |
309 | Args:
310 | text: A single token or whitespace separated tokens. This should have
311 | already been passed through `BasicTokenizer`.
312 |
313 | Returns:
314 | A list of wordpiece tokens.
315 | """
316 |
317 | output_tokens = []
318 | for token in whitespace_tokenize(text):
319 | chars = list(token)
320 | if len(chars) > self.max_input_chars_per_word:
321 | output_tokens.append(self.unk_token)
322 | continue
323 |
324 | is_bad = False
325 | start = 0
326 | sub_tokens = []
327 | while start < len(chars):
328 | end = len(chars)
329 | cur_substr = None
330 | while start < end:
331 | substr = "".join(chars[start:end])
332 | if start > 0:
333 | substr = "##" + substr
334 | if substr in self.vocab:
335 | cur_substr = substr
336 | break
337 | end -= 1
338 | if cur_substr is None:
339 | is_bad = True
340 | break
341 | sub_tokens.append(cur_substr)
342 | start = end
343 |
344 | if is_bad:
345 | output_tokens.append(self.unk_token)
346 | else:
347 | output_tokens.extend(sub_tokens)
348 | return output_tokens
349 |
350 |
351 | def _is_whitespace(char):
352 | """Checks whether `chars` is a whitespace character."""
353 | # \t, \n, and \r are technically contorl characters but we treat them
354 | # as whitespace since they are generally considered as such.
355 | if char == " " or char == "\t" or char == "\n" or char == "\r":
356 | return True
357 | cat = unicodedata.category(char)
358 | if cat == "Zs":
359 | return True
360 | return False
361 |
362 |
363 | def _is_control(char):
364 | """Checks whether `chars` is a control character."""
365 | # These are technically control characters but we count them as whitespace
366 | # characters.
367 | if char == "\t" or char == "\n" or char == "\r":
368 | return False
369 | cat = unicodedata.category(char)
370 | if cat.startswith("C"):
371 | return True
372 | return False
373 |
374 |
375 | def _is_punctuation(char):
376 | """Checks whether `chars` is a punctuation character."""
377 | cp = ord(char)
378 | # We treat all non-letter/number ASCII as punctuation.
379 | # Characters such as "^", "$", and "`" are not in the Unicode
380 | # Punctuation class but we treat them as punctuation anyways, for
381 | # consistency.
382 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
383 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
384 | return True
385 | cat = unicodedata.category(char)
386 | if cat.startswith("P"):
387 | return True
388 | return False
389 |
--------------------------------------------------------------------------------
/src/param.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import argparse
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def get_optimizer(optim):
12 | # Bind the optimizer
13 | if optim == 'rms':
14 | print("Optimizer: Using RMSProp")
15 | optimizer = torch.optim.RMSprop
16 | elif optim == 'adam':
17 | print("Optimizer: Using Adam")
18 | optimizer = torch.optim.Adam
19 | elif optim == 'adamax':
20 | print("Optimizer: Using Adamax")
21 | optimizer = torch.optim.Adamax
22 | elif optim == 'sgd':
23 | print("Optimizer: sgd")
24 | optimizer = torch.optim.SGD
25 | elif 'bert' in optim:
26 | optimizer = 'bert' # The bert optimizer will be bind later.
27 | else:
28 | assert False, "Please add your optimizer %s in the list." % optim
29 |
30 | return optimizer
31 |
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser()
35 |
36 | # Data Splits
37 | parser.add_argument("--train", default='train')
38 | parser.add_argument("--valid", default='valid')
39 | parser.add_argument("--test", default=None)
40 |
41 | # Training Hyper-parameters
42 | parser.add_argument('--batchSize', dest='batch_size', type=int, default=256)
43 | parser.add_argument('--optim', default='bert')
44 | parser.add_argument('--lr', type=float, default=1e-4)
45 | # parser.add_argument('--lr', type=str, default="1e-4")
46 | parser.add_argument('--epochs', type=int, default=10)
47 | parser.add_argument('--dropout', type=float, default=0.1)
48 | parser.add_argument('--seed', type=int, default=9595, help='random seed')
49 |
50 | # Debugging
51 | parser.add_argument('--output', type=str, default='snap/test')
52 | parser.add_argument("--fast", action='store_const', default=False, const=True)
53 | parser.add_argument("--tiny", action='store_const', default=False, const=True)
54 | parser.add_argument("--tqdm", action='store_const', default=False, const=True)
55 |
56 | # Model Loading
57 | parser.add_argument('--load', type=str, default=None,
58 | help='Load the model (usually the fine-tuned model).')
59 | parser.add_argument('--loadLXMERT', dest='load_lxmert', type=str, default=None,
60 | help='Load the pre-trained LXMERT model.')
61 | parser.add_argument('--loadLXMERTQA', dest='load_lxmert_qa', type=str, default=None,
62 | help='Load the pre-trained LXMERT model with QA answer head.')
63 | parser.add_argument("--fromScratch", dest='from_scratch', action='store_const', default=False, const=True,
64 | help='If none of the --load, --loadLXMERT, --loadLXMERTQA is set, '
65 | 'the model would be trained from scratch. If --fromScratch is'
66 | ' not specified, the model would load BERT-pre-trained weights by'
67 | ' default. ')
68 |
69 | # Optimization
70 | parser.add_argument("--mceLoss", dest='mce_loss', action='store_const', default=False, const=True)
71 |
72 | # LXRT Model Config
73 | # Note: LXRT = L, X, R (three encoders), Transformer
74 | parser.add_argument("--llayers", default=9, type=int, help='Number of Language layers')
75 | parser.add_argument("--xlayers", default=5, type=int, help='Number of CROSS-modality layers.')
76 | parser.add_argument("--rlayers", default=5, type=int, help='Number of object Relationship layers.')
77 |
78 | # LXMERT Pre-training Config
79 | parser.add_argument("--taskMatched", dest='task_matched', action='store_const', default=False, const=True)
80 | parser.add_argument("--taskMaskLM", dest='task_mask_lm', action='store_const', default=False, const=True)
81 | parser.add_argument("--taskObjPredict", dest='task_obj_predict', action='store_const', default=False, const=True)
82 | parser.add_argument("--taskQA", dest='task_qa', action='store_const', default=False, const=True)
83 | parser.add_argument("--visualLosses", dest='visual_losses', default='obj,attr,feat', type=str)
84 | parser.add_argument("--qaSets", dest='qa_sets', default=None, type=str)
85 | parser.add_argument("--wordMaskRate", dest='word_mask_rate', default=0.15, type=float)
86 | parser.add_argument("--objMaskRate", dest='obj_mask_rate', default=0.15, type=float)
87 |
88 | # Training configuration
89 | parser.add_argument("--fp16", action='store_const', default=False, const=True)
90 | parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
91 | parser.add_argument("--numWorkers", dest='num_workers', default=0)
92 |
93 | # whether ft
94 | parser.add_argument("--bert_type", type=str, default='bert')
95 | parser.add_argument("--cross_type", type=str, default='g2g')
96 | parser.add_argument("--ft_large", type=int, default=0)
97 | parser.add_argument("--MV_size", type=int, default=500)
98 | parser.add_argument("--ML_size", type=int, default=500)
99 |
100 | #start from check point
101 | parser.add_argument("--start_from", type=int, default=0)
102 |
103 | #distributed running
104 | parser.add_argument('--local_rank', default=0, type=int,
105 | help='node rank for distributed training')
106 |
107 | #optimizer related
108 | parser.add_argument("--lr_schedule", type=str, default='warmup_linear')
109 | parser.add_argument("--lr_max", type=float, default=5e-5)
110 | parser.add_argument("--lr_min", type=float, default=6e-5)
111 |
112 | parser.add_argument("--pretraining_index", type=int, default=0)
113 |
114 | #init dict
115 | parser.add_argument("--kmeans", type=int, default=1)
116 | parser.add_argument("--mv_path", type=str, default='data2/lxmert/mv_path.npy')
117 | parser.add_argument("--ml_path", type=str, default='data2/lxmert/ml_path.npy')
118 |
119 | # Parse the arguments.
120 | args = parser.parse_args()
121 |
122 | # Bind optimizer class.
123 | args.optimizer = get_optimizer(args.optim)
124 |
125 | # Set seeds
126 | torch.manual_seed(args.seed)
127 | random.seed(args.seed)
128 | np.random.seed(args.seed)
129 |
130 | return args
131 |
132 |
133 | args = parse_args()
134 |
--------------------------------------------------------------------------------
/src/pretrain/lxmert_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | from collections import defaultdict
5 | import json
6 | import random
7 |
8 | import numpy as np
9 | from torch.utils.data import Dataset
10 |
11 | from param import args
12 | from pretrain.qa_answer_table import AnswerTable
13 | from utils import load_obj_tsv
14 |
15 | TINY_IMG_NUM = 500
16 | FAST_IMG_NUM = 5000
17 | #
18 | # Split2ImgFeatPath = {
19 | # 'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv',
20 | # 'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
21 | # 'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
22 | # 'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv',
23 | # }
24 |
25 | Split2ImgFeatPath = {
26 | 'mscoco_train': 'data/mscoco_imgfeat/train2014_obj64.tsv',
27 | 'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj64.tsv',
28 | 'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj64.tsv',
29 | 'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj64.tsv',
30 | }
31 |
32 |
33 | class InputExample(object):
34 | """A single training/test example for the language model."""
35 | def __init__(self, uid, sent, visual_feats=None,
36 | obj_labels=None, attr_labels=None,
37 | is_matched=None, label=None):
38 | self.uid = uid
39 | self.sent = sent
40 | self.visual_feats = visual_feats
41 | self.obj_labels = obj_labels
42 | self.attr_labels = attr_labels
43 | self.is_matched = is_matched # whether the visual and obj matched
44 | self.label = label
45 |
46 |
47 | class LXMERTDataset:
48 | def __init__(self, splits: str, qa_sets=None):
49 | """
50 | :param splits: The data sources to be loaded
51 | :param qa_sets: if None, no action
52 | o.w., only takes the answers appearing in these dsets
53 | and remove all unlabeled data (MSCOCO captions)
54 | """
55 | self.name = splits
56 | self.sources = splits.split(',')
57 |
58 | # Loading datasets to data
59 | self.data = []
60 | for source in self.sources:
61 | self.data.extend(json.load(open("data/lxmert/%s.json" % source)))
62 | print("Load %d data from %s" % (len(self.data), self.name))
63 |
64 | # Create answer table according to the qa_sets
65 | self.answer_table = AnswerTable(qa_sets)
66 | print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map())))
67 |
68 | # Modify the answers
69 | for datum in self.data:
70 | labelf = datum['labelf']
71 | for cat, labels in labelf.items():
72 | for label in labels:
73 | for ans in list(label.keys()):
74 | new_ans = self.answer_table.convert_ans(ans)
75 | if self.answer_table.used(new_ans):
76 | if ans != new_ans:
77 | label[new_ans] = label.pop(ans)
78 | else:
79 | label.pop(ans)
80 |
81 | def __len__(self):
82 | return len(self.data)
83 |
84 |
85 | def make_uid(img_id, dset, sent_idx):
86 | return "%s_%s_%03d" % (img_id, dset, sent_idx),
87 |
88 |
89 | """
90 | Example in obj tsv:
91 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
92 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
93 | """
94 | class LXMERTTorchDataset(Dataset):
95 | def __init__(self, dataset: LXMERTDataset, topk=-1):
96 | super().__init__()
97 | self.raw_dataset = dataset
98 | self.task_matched = args.task_matched
99 |
100 | if args.tiny:
101 | topk = TINY_IMG_NUM
102 | elif args.fast:
103 | topk = FAST_IMG_NUM
104 |
105 | # Load the dataset
106 | img_data = []
107 | for source in self.raw_dataset.sources:
108 | img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk, args.fp16))
109 |
110 | self.imgid2img = {}
111 | for img_datum in img_data:
112 | self.imgid2img[img_datum['img_id']] = img_datum
113 |
114 | # Filter out the dataset
115 | used_data = []
116 | for datum in self.raw_dataset.data:
117 | if datum['img_id'] in self.imgid2img:
118 | used_data.append(datum)
119 |
120 | # Flatten the dataset (into one sent + one image entries)
121 | self.data = []
122 | for datum in used_data:
123 | sentf = datum['sentf']
124 | for sents_cat, sents in sentf.items():
125 | if sents_cat in datum['labelf']:
126 | labels = datum['labelf'][sents_cat]
127 | else:
128 | labels = None
129 | for sent_idx, sent in enumerate(sents):
130 | new_datum = {
131 | 'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
132 | 'img_id': datum['img_id'],
133 | 'sent': sent
134 | }
135 | if labels is not None:
136 | new_datum['label'] = labels[sent_idx]
137 | self.data.append(new_datum)
138 | print("Use %d data in torch dataset" % (len(self.data)))
139 |
140 | def __len__(self):
141 | return len(self.data)
142 |
143 | def random_feat(self):
144 | """Get a random obj feat from the dataset."""
145 | datum = self.data[random.randint(0, len(self.data)-1)]
146 | img_id = datum['img_id']
147 | img_info = self.imgid2img[img_id]
148 | feat = img_info['features'][random.randint(0, 35)]
149 | return feat
150 |
151 | def __getitem__(self, item: int):
152 | datum = self.data[item]
153 |
154 | uid = datum['uid']
155 | img_id = datum['img_id']
156 |
157 | # Get image info
158 | img_info = self.imgid2img[img_id]
159 | obj_num = img_info['num_boxes']
160 | feats = img_info['features'].copy()
161 | if not args.fp16 and feats.dtype != np.float32: # Save space in CPU memory with format float 16 (half-precision).
162 | feats = feats.astype(np.float32)
163 | boxes = img_info['boxes'].copy()
164 | obj_labels = img_info['objects_id'].copy()
165 | obj_confs = img_info['objects_conf'].copy()
166 | attr_labels = img_info['attrs_id'].copy()
167 | attr_confs = img_info['attrs_conf'].copy()
168 | assert obj_num == len(boxes) == len(feats)
169 |
170 | # Normalize the boxes (to 0 ~ 1)
171 | img_h, img_w = img_info['img_h'], img_info['img_w']
172 | boxes = boxes.copy()
173 | boxes[:, (0, 2)] /= img_w
174 | boxes[:, (1, 3)] /= img_h
175 | np.testing.assert_array_less(boxes, 1+1e-5)
176 | np.testing.assert_array_less(-boxes, 0+1e-5)
177 |
178 | # If calculating the matched loss, replace the sentence with an sentence
179 | # corresponding to other image.
180 | is_matched = 1
181 | sent = datum['sent']
182 | if self.task_matched:
183 | if random.random() < 0.5:
184 | is_matched = 0
185 | other_datum = self.data[random.randint(0, len(self.data)-1)]
186 | while other_datum['img_id'] == img_id:
187 | other_datum = self.data[random.randint(0, len(self.data)-1)]
188 | sent = other_datum['sent']
189 |
190 | # Label, convert answer to id
191 | if 'label' in datum:
192 | label = datum['label'].copy()
193 | for ans in list(label.keys()):
194 | label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans)
195 | else:
196 | label = None
197 |
198 | # Create target
199 | example = InputExample(
200 | uid, sent, (feats, boxes),
201 | (obj_labels, obj_confs), (attr_labels, attr_confs),
202 | is_matched, label
203 | )
204 | return example
205 |
206 |
207 | class LXMERTEvaluator:
208 | def __init__(self, dataset: LXMERTDataset):
209 | self.raw_dataset = dataset
210 |
211 | # Create QA Eval Data
212 | self.data = []
213 | for datum in self.raw_dataset.data:
214 | sentf = datum['sentf']
215 | for sents_cat, sents in sentf.items():
216 | if sents_cat in datum['labelf']: # A labeled dataset
217 | labels = datum['labelf'][sents_cat]
218 | for sent_idx, sent in enumerate(sents):
219 | new_datum = {
220 | 'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
221 | 'img_id': datum['img_id'],
222 | 'sent': sent,
223 | 'dset': sents_cat,
224 | 'label': labels[sent_idx]
225 | }
226 | self.data.append(new_datum)
227 |
228 | # uid2datum
229 | self.uid2datum = {}
230 | for datum in self.data:
231 | self.uid2datum[datum['uid']] = datum
232 |
233 | def evaluate(self, uid2ans: dict, pprint=False):
234 | score = 0.
235 | cnt = 0
236 | dset2score = defaultdict(lambda: 0.)
237 | dset2cnt = defaultdict(lambda: 0)
238 | for uid, ans in uid2ans.items():
239 | if uid not in self.uid2datum: # Not a labeled data
240 | continue
241 | datum = self.uid2datum[uid]
242 | label = datum['label']
243 | dset = datum['dset']
244 | if ans in label:
245 | score += label[ans]
246 | dset2score[dset] += label[ans]
247 | cnt += 1
248 | dset2cnt[dset] += 1
249 | accu = score / cnt
250 | dset2accu = {}
251 | for dset in dset2cnt:
252 | dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
253 |
254 | if pprint:
255 | accu_str = "Overall Accu %0.4f, " % (accu)
256 | sorted_keys = sorted(dset2accu.keys())
257 | for key in sorted_keys:
258 | accu_str += "%s Accu %0.4f, " % (key, dset2accu[key])
259 | print(accu_str)
260 |
261 | return accu, dset2accu
262 |
263 | def dump_result(self, uid2ans: dict, path):
264 | raise NotImplemented
265 |
--------------------------------------------------------------------------------
/src/pretrain/lxmert_pretrain.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import collections
5 | import os
6 | import random
7 |
8 | from tqdm import tqdm
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils.data import DataLoader
13 |
14 | from param import args
15 | from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator
16 | from lxrt.entry import set_visual_config
17 | from lxrt.tokenization import BertTokenizer
18 | from lxrt.modeling import LXRTPretraining
19 |
20 | DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator')
21 |
22 |
23 | def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple:
24 | # Decide which QA datasets would be used in pre-training.
25 | # Options: vqa, gqa, visual7w
26 | # Note: visual7w is a part of vgqa, we take the name here.
27 | qa_sets = args.qa_sets
28 | if qa_sets is not None:
29 | qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))
30 |
31 | # Build dataset, data loader, and evaluator.
32 | dset = LXMERTDataset(splits, qa_sets=qa_sets)
33 | tset = LXMERTTorchDataset(dset, topk)
34 | data_loader = DataLoader(
35 | tset, batch_size=bs,
36 | shuffle=shuffle, num_workers=args.num_workers,
37 | collate_fn=lambda x: x,
38 | drop_last=drop_last, pin_memory=True
39 | )
40 | evaluator = LXMERTEvaluator(dset)
41 | print()
42 |
43 | return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator)
44 |
45 |
46 | train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True)
47 | valid_batch_size = 1024 if args.multiGPU else 512
48 | valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000)
49 |
50 |
51 | class InputFeatures(object):
52 | """A single set of features of data."""
53 |
54 | def __init__(self,
55 | input_ids, input_mask, segment_ids, lm_label_ids,
56 | visual_feats, obj_labels,
57 | is_matched, ans):
58 | self.input_ids = input_ids
59 | self.input_mask = input_mask
60 | self.segment_ids = segment_ids
61 | self.lm_label_ids = lm_label_ids
62 |
63 | self.visual_feats = visual_feats
64 | self.obj_labels = obj_labels
65 |
66 | self.is_matched = is_matched
67 |
68 | self.ans = ans
69 |
70 |
71 | def random_word(tokens, tokenizer):
72 | """
73 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
74 | :param tokens: list of str, tokenized sentence.
75 | :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
76 | :return: (list of str, list of int), masked tokens and related labels for LM prediction
77 | """
78 | output_label = []
79 |
80 | for i, token in enumerate(tokens):
81 | prob = random.random()
82 | # mask token with probability
83 | ratio = args.word_mask_rate
84 | if prob < ratio:
85 | prob /= ratio
86 |
87 | # 80% randomly change token to mask token
88 | if prob < 0.8:
89 | tokens[i] = "[MASK]"
90 |
91 | # 10% randomly change token to random token
92 | elif prob < 0.9:
93 | tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
94 |
95 | # -> rest 10% randomly keep current token
96 |
97 | # append current token to output (we will predict these later)
98 | try:
99 | output_label.append(tokenizer.vocab[token])
100 | except KeyError:
101 | # For unknown words (should not occur with BPE vocab)
102 | output_label.append(tokenizer.vocab["[UNK]"])
103 | else:
104 | # no masking token (will be ignored by loss function later)
105 | output_label.append(-1)
106 |
107 | return tokens, output_label
108 |
109 |
110 | def random_feat(feats):
111 | mask_feats = feats.copy()
112 | feat_mask = np.zeros(len(feats), dtype=np.float32)
113 | for i in range(len(feats)):
114 | prob = random.random()
115 | # mask token with probability
116 | if prob < args.obj_mask_rate:
117 | prob /= args.obj_mask_rate
118 |
119 | # 80% randomly change token to zero feat
120 | if prob < 0.8:
121 | mask_feats[i, :] = 0.
122 |
123 | # 10% randomly change token to random feat
124 | elif prob < 0.9:
125 | mask_feats[i, :] = train_tuple.torchdset.random_feat()
126 | # -> rest 10% randomly keep current feat
127 |
128 | # Need to predict this feat
129 | feat_mask[i] = 1.
130 |
131 | return mask_feats, feat_mask
132 |
133 |
134 | def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures:
135 | """
136 | Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
137 | IDs, LM labels, input_mask, CLS and SEP tokens etc.
138 | :param example: InputExample, containing sentence input as strings and is_next label
139 | :param max_seq_length: int, maximum length of sequence.
140 | :param tokenizer: Tokenizer
141 | :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
142 | """
143 | tokens = tokenizer.tokenize(example.sent.strip())
144 |
145 | # Account for [CLS] and [SEP] with "- 2"
146 | if len(tokens) > max_seq_length - 2:
147 | tokens = tokens[:(max_seq_length - 2)]
148 |
149 | # Ge random words
150 | masked_tokens, masked_label = random_word(tokens, tokenizer)
151 |
152 | # concatenate lm labels and account for CLS, SEP, SEP
153 | masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]']
154 | input_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
155 |
156 | # Mask & Segment Word
157 | lm_label_ids = ([-1] + masked_label + [-1])
158 | input_mask = [1] * len(input_ids)
159 | segment_ids = [0] * len(input_ids)
160 |
161 | # Zero-pad up to the sequence length.
162 | while len(input_ids) < max_seq_length:
163 | input_ids.append(0)
164 | input_mask.append(0)
165 | segment_ids.append(0)
166 | lm_label_ids.append(-1)
167 |
168 | assert len(input_ids) == max_seq_length
169 | assert len(input_mask) == max_seq_length
170 | assert len(segment_ids) == max_seq_length
171 | assert len(lm_label_ids) == max_seq_length
172 |
173 | feat, boxes = example.visual_feats
174 | obj_labels, obj_confs = example.obj_labels
175 | attr_labels, attr_confs = example.attr_labels
176 |
177 | # Mask Image Features:
178 | masked_feat, feat_mask = random_feat(feat)
179 |
180 | # QA answer label
181 | if example.label is None or len(example.label) == 0 or example.is_matched != 1:
182 | # 1. No label 2. Label is pruned 3. unmatched visual + language pair
183 | ans = -1
184 | else:
185 | keys, values = zip(*example.label.items())
186 | if len(keys) == 1:
187 | ans = keys[0]
188 | else:
189 | value_sum = sum(values)
190 | prob = [value / value_sum for value in values]
191 | choice = np.random.multinomial(1, prob).argmax()
192 | ans = keys[choice]
193 |
194 | features = InputFeatures(
195 | input_ids=input_ids,
196 | input_mask=input_mask,
197 | segment_ids=segment_ids,
198 | lm_label_ids=lm_label_ids,
199 | visual_feats=(masked_feat, boxes),
200 | obj_labels={
201 | 'obj': (obj_labels, obj_confs),
202 | 'attr': (attr_labels, attr_confs),
203 | 'feat': (feat, feat_mask),
204 | },
205 | is_matched=example.is_matched,
206 | ans=ans,
207 | )
208 | return features
209 |
210 |
211 | LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA')
212 |
213 |
214 | class LXMERT:
215 | def __init__(self, max_seq_length):
216 | super().__init__()
217 | self.max_seq_length = max_seq_length
218 |
219 | self.tokenizer = BertTokenizer.from_pretrained(
220 | "bert-base-uncased",
221 | do_lower_case=True
222 | )
223 |
224 | # Build model
225 | set_visual_config(args)
226 | self.model = LXRTPretraining.from_pretrained(
227 | "bert-base-uncased",
228 | task_mask_lm=args.task_mask_lm,
229 | task_obj_predict=args.task_obj_predict,
230 | task_matched=args.task_matched,
231 | task_qa=args.task_qa,
232 | visual_losses=args.visual_losses,
233 | num_answers=train_tuple.dataset.answer_table.num_answers,
234 | args = args
235 | )
236 |
237 | # Weight initialization and loading
238 | if args.from_scratch:
239 | print("Train from Scratch: re-initialize all BERT weights.")
240 | self.model.apply(self.model.init_bert_weights)
241 | if args.load is not None:
242 | self.load(args.load)
243 | if args.load_lxmert is not None:
244 | # Load lxmert would not load the answer head.
245 | self.load_lxmert(args.load_lxmert)
246 | # GPU Options
247 | self.model = self.model.cuda()
248 | if args.multiGPU:
249 | #self.model = nn.DataParallel(self.model)
250 | self.multiGPU = 1
251 | else:
252 | self.multiGPU = 0
253 | # self.model = nn.DistributedDataParallel(self.model)
254 | self.output = args.output
255 | os.makedirs(self.output, exist_ok=True)
256 |
257 | def forward(self, examples):
258 | train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer)
259 | for example in examples]
260 |
261 | # language Inputs
262 | input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
263 | input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
264 | segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
265 |
266 | # Visual Inputs
267 | feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda()
268 | pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda()
269 |
270 | # Language Prediction
271 | lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda()
272 |
273 | # Visual Prediction
274 | obj_labels = {}
275 | for key in ('obj', 'attr', 'feat'):
276 | visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda()
277 | visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda()
278 | assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1)
279 | obj_labels[key] = (visn_labels, visn_mask)
280 |
281 | # Joint Prediction
282 | matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda()
283 | ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda()
284 |
285 | """
286 | forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
287 | visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None):
288 | """
289 | loss, losses, ans_logit = self.model(
290 | input_ids, segment_ids, input_mask, lm_labels,
291 | feats, pos, obj_labels, matched_labels, ans
292 | )
293 | if self.multiGPU:
294 | loss = loss.mean()
295 | losses = losses.mean(0)
296 |
297 | return loss, losses.detach().cpu(), ans_logit
298 |
299 | def train_batch(self, optim, batch):
300 | optim.zero_grad()
301 | loss, losses, ans_logit = self.forward(batch)
302 | # if args.multiGPU:
303 | # loss = loss.mean()
304 | # losses = losses.mean(0)
305 | if args.fp16:
306 | try:
307 | from apex import amp
308 | except ImportError:
309 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
310 | with amp.scale_loss(loss, optim) as scaled_loss:
311 | scaled_loss.backward()
312 | torch.nn.utils.clip_grad_norm_(amp.master_params(optim), 1.)
313 | else:
314 | loss.backward()
315 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
316 |
317 | optim.step()
318 |
319 | return loss.item(), losses.cpu().numpy(), ans_logit
320 |
321 | def valid_batch(self, batch):
322 | with torch.no_grad():
323 | loss, losses, ans_logit = self.forward(batch)
324 | # if args.multiGPU:
325 | # loss = loss.mean()
326 | # losses = losses.mean(0)
327 | return loss.item(), losses.cpu().numpy(), ans_logit
328 |
329 | def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
330 | train_ld = train_tuple.loader
331 |
332 | # Optimizer
333 | from lxrt.optimization import BertAdam
334 | batch_per_epoch = len(train_ld)
335 | t_total = int(batch_per_epoch * args.epochs)
336 | warmup_ratio = 0.05
337 | warmup_iters = int(t_total * warmup_ratio)
338 | print("Batch per epoch: %d" % batch_per_epoch)
339 | print("Total Iters: %d" % t_total)
340 | print("Warm up Iters: %d" % warmup_iters)
341 | optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total)
342 | start_epoch = 0
343 |
344 | if args.fp16:
345 | try:
346 | from apex import amp
347 | except ImportError:
348 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
349 | self.model, optim = amp.initialize(self.model, optim, opt_level='O1')
350 |
351 | # GPU Options
352 | if args.multiGPU:
353 | self.model = nn.DataParallel(self.model)
354 |
355 | if args.start_from >0 and args.pretraining_index == 0:
356 | start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02'))
357 | print('start training from {0}'.format(start_path))
358 | state = torch.load(start_path)
359 | self.model.load_state_dict(state['state_dict'])
360 | optim.load_state_dict(state['optimizer'],strict=False)
361 | start_epoch = args.start_from
362 | del state
363 | torch.cuda.empty_cache()
364 | elif args.start_from >0 and args.pretraining_index == 1:
365 | start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02'))
366 | print('start training from {0}'.format(start_path))
367 | state = torch.load(start_path)
368 | self.model.load_state_dict(state['state_dict'],strict=False)
369 | del state
370 | torch.cuda.empty_cache()
371 |
372 | # Train
373 | best_eval_loss = 9595.
374 | for epoch in range(start_epoch, args.epochs):
375 | # Train
376 | self.model.train()
377 | total_loss = 0.
378 | total_losses = 0.
379 | uid2ans = {}
380 | for batch in tqdm(train_ld, total=len(train_ld)):
381 | loss, losses, logit = self.train_batch(optim, batch)
382 | total_loss += loss
383 | total_losses += losses
384 |
385 | if args.task_qa:
386 | score, label = logit.max(1)
387 | for datum, l in zip(batch, label.cpu().numpy()):
388 | uid = datum.uid
389 | ans = train_tuple.dataset.answer_table.id2ans(l)
390 | uid2ans[uid] = ans
391 |
392 | print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch))
393 | log_str = "\nThe training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch)
394 | losses_str = "\nThe losses are "
395 | log_str += "\nThe losses are "
396 | for name, loss in zip(LOSSES_NAME, total_losses):
397 | losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch)
398 | log_str += "\n %s: %0.4f " % (name, loss / batch_per_epoch)
399 | print(losses_str)
400 | with open(self.output + "/log.log", 'a') as f:
401 | f.write(log_str)
402 | f.flush()
403 | if args.task_qa:
404 | train_tuple.evaluator.evaluate(uid2ans, pprint=True)
405 |
406 | # Eval
407 | avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)
408 |
409 | state = {
410 | 'state_dict': self.model.state_dict(),
411 | 'optimizer': optim.state_dict(),
412 | }
413 |
414 | # Save
415 | if avg_eval_loss < best_eval_loss:
416 | best_eval_loss = avg_eval_loss
417 | self.save("BEST_EVAL_LOSS",state)
418 | if args.pretraining_index == 0:
419 | self.save("Epoch%02d" % (epoch+1), state)
420 | elif args.pretraining_index == 1:
421 | self.save("Epoch%02d" % (epoch + 1 + args.start_from), state)
422 |
423 | def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1):
424 | self.model.eval()
425 | eval_ld = eval_tuple.loader
426 | total_loss = 0.
427 | total_losses = 0.
428 | uid2ans = {}
429 | for i, batch in enumerate(eval_ld):
430 | loss, losses, logit = self.valid_batch(batch)
431 | total_loss += loss
432 | total_losses += losses
433 | if args.task_qa:
434 | score, label = logit.max(1)
435 | for datum, l in zip(batch, label.cpu().numpy()):
436 | uid = datum.uid
437 | ans = train_tuple.dataset.answer_table.id2ans(l)
438 | uid2ans[uid] = ans
439 | if i == iters:
440 | break
441 |
442 | print("The valid loss is %0.4f" % (total_loss / len(eval_ld)))
443 | log_str = "\nThe valid loss is %0.4f" % (total_loss / len(eval_ld))
444 | losses_str = "\nThe losses are "
445 | log_str += "\nThe losses are "
446 | for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)):
447 | losses_str += "%s: %0.4f " % (name, loss)
448 | log_str += "\n%s: %0.4f " % (name, loss)
449 | print(losses_str)
450 | with open(self.output + "/log.log", 'a') as f:
451 | f.write(log_str)
452 | f.flush()
453 |
454 | if args.task_qa:
455 | eval_tuple.evaluator.evaluate(uid2ans, pprint=True)
456 |
457 | return total_loss / len(eval_ld)
458 |
459 | def save(self, name, state):
460 | torch.save(state,
461 | os.path.join(args.output, "%s_LXRT.pth" % name))
462 |
463 |
464 | def load(self, path):
465 | print("Load BERT extractor from %s" % path)
466 | state_dict = torch.load("%s_LXRT.pth" % path)
467 | self.model.load_state_dict(state_dict)
468 |
469 | def load_lxmert(self, path):
470 | print("Load LXMERT model from %s" % path)
471 | state_dict = torch.load("%s_LXRT.pth" % path)
472 |
473 | # Do not load any answer head
474 | for key in list(state_dict.keys()):
475 | if 'answer' in key:
476 | state_dict.pop(key)
477 |
478 | # Change Multi GPU to single GPU
479 | new_state_dict = {}
480 | for key, value in state_dict.items():
481 | if key.startswith("module."):
482 | new_state_dict[key[len("module."):]] = value
483 | state_dict = new_state_dict
484 |
485 | load_keys = set(state_dict.keys())
486 | model_keys = set(self.model.state_dict().keys())
487 | print()
488 | print("Keys in loaded but not in model:")
489 | for key in sorted(load_keys.difference(model_keys)):
490 | print(key)
491 | print()
492 | print("Keys in model but not in loaded:")
493 | for key in sorted(model_keys.difference(load_keys)):
494 | print(key)
495 | print()
496 |
497 | self.model.load_state_dict(state_dict, strict=False)
498 |
499 |
500 | if __name__ == "__main__":
501 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3"
502 |
503 | lxmert = LXMERT(max_seq_length=20)
504 |
505 | lxmert.train(train_tuple, valid_tuple)
506 |
--------------------------------------------------------------------------------
/src/pretrain/lxmert_pretrain2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import collections
5 | import os
6 | import random
7 |
8 | from tqdm import tqdm
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils.data import DataLoader
13 |
14 | from param import args
15 | from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator
16 | from lxrt.entry import set_visual_config
17 | from lxrt.tokenization import BertTokenizer
18 | from lxrt.modeling import LXRTPretraining
19 |
20 | DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator')
21 |
22 |
23 | def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple:
24 | # Decide which QA datasets would be used in pre-training.
25 | # Options: vqa, gqa, visual7w
26 | # Note: visual7w is a part of vgqa, we take the name here.
27 | qa_sets = args.qa_sets
28 | if qa_sets is not None:
29 | qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))
30 |
31 | # Build dataset, data loader, and evaluator.
32 | dset = LXMERTDataset(splits, qa_sets=qa_sets)
33 | tset = LXMERTTorchDataset(dset, topk)
34 | data_loader = DataLoader(
35 | tset, batch_size=bs,
36 | shuffle=shuffle, num_workers=args.num_workers,
37 | collate_fn=lambda x: x,
38 | drop_last=drop_last, pin_memory=True
39 | )
40 | evaluator = LXMERTEvaluator(dset)
41 | print()
42 |
43 | return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator)
44 |
45 |
46 | train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True)
47 | valid_batch_size = 1024 if args.multiGPU else 512
48 | valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000)
49 |
50 |
51 | class InputFeatures(object):
52 | """A single set of features of data."""
53 |
54 | def __init__(self,
55 | input_ids, input_mask, segment_ids, lm_label_ids,
56 | visual_feats, visual_masks, obj_labels,
57 | is_matched, ans, random_index):
58 | self.input_ids = input_ids
59 | self.input_mask = input_mask
60 | self.segment_ids = segment_ids
61 | self.lm_label_ids = lm_label_ids
62 |
63 | self.visual_feats = visual_feats
64 | self.visual_masks = visual_masks
65 | self.obj_labels = obj_labels
66 |
67 | self.is_matched = is_matched
68 |
69 | self.ans = ans
70 | self.random_index = random_index
71 |
72 |
73 | def random_word(tokens, tokenizer, random_index):
74 | """
75 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
76 | :param tokens: list of str, tokenized sentence.
77 | :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
78 | :param random_index: if 0, mask sentence, if 1, do not mask
79 | :return: (list of str, list of int), masked tokens and related labels for LM prediction
80 | """
81 | output_label = []
82 |
83 | for i, token in enumerate(tokens):
84 | prob = random.random()
85 | # mask token with probability
86 | ratio = args.word_mask_rate
87 | if prob < ratio and random_index == 0:
88 | prob /= ratio
89 |
90 | # 80% randomly change token to mask token
91 | if prob < 0.8:
92 | tokens[i] = "[MASK]"
93 |
94 | # 10% randomly change token to random token
95 | elif prob < 0.9:
96 | tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
97 |
98 | # -> rest 10% randomly keep current token
99 |
100 | # append current token to output (we will predict these later)
101 | try:
102 | output_label.append(tokenizer.vocab[token])
103 | except KeyError:
104 | # For unknown words (should not occur with BPE vocab)
105 | output_label.append(tokenizer.vocab["[UNK]"])
106 | else:
107 | # no masking token (will be ignored by loss function later)
108 | output_label.append(-1)
109 |
110 | return tokens, output_label
111 |
112 |
113 | def random_feat(feats, random_index):
114 | """
115 | :param random_index: if 1, mask image, if 0, do not mask
116 | """
117 | mask_feats = feats.copy()
118 | feat_mask = np.zeros(len(feats), dtype=np.float32)
119 | for i in range(len(feats)):
120 | prob = random.random()
121 | # mask token with probability
122 | if prob < args.obj_mask_rate and random_index == 1:
123 | prob /= args.obj_mask_rate
124 |
125 | # 80% randomly change token to zero feat
126 | if prob < 0.8:
127 | mask_feats[i, :] = 0.
128 |
129 | # 10% randomly change token to random feat
130 | elif prob < 0.9:
131 | mask_feats[i, :] = train_tuple.torchdset.random_feat()
132 | # -> rest 10% randomly keep current feat
133 |
134 | # Need to predict this feat
135 | feat_mask[i] = 1.
136 |
137 | return mask_feats, feat_mask
138 |
139 |
140 | def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures:
141 | """
142 | Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
143 | IDs, LM labels, input_mask, CLS and SEP tokens etc.
144 | :param example: InputExample, containing sentence input as strings and is_next label
145 | :param max_seq_length: int, maximum length of sequence.
146 | :param tokenizer: Tokenizer
147 | :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
148 | """
149 | tokens = tokenizer.tokenize(example.sent.strip())
150 |
151 | # Account for [CLS] and [SEP] with "- 2"
152 | if len(tokens) > max_seq_length - 2:
153 | tokens = tokens[:(max_seq_length - 2)]
154 |
155 | #if random_index = 0, mask sentence, random_index = 1, mask image
156 | random_index = np.random.randint(2)
157 |
158 | # Ge random words
159 | masked_tokens, masked_label = random_word(tokens, tokenizer, random_index)
160 |
161 | # concatenate lm labels and account for CLS, SEP, SEP
162 | masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]']
163 | input_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
164 |
165 | # Mask & Segment Word
166 | lm_label_ids = ([-1] + masked_label + [-1])
167 | input_mask = [1] * len(input_ids)
168 | segment_ids = [0] * len(input_ids)
169 |
170 | # Zero-pad up to the sequence length.
171 | while len(input_ids) < max_seq_length:
172 | input_ids.append(0)
173 | input_mask.append(0)
174 | segment_ids.append(0)
175 | lm_label_ids.append(-1)
176 |
177 | assert len(input_ids) == max_seq_length
178 | assert len(input_mask) == max_seq_length
179 | assert len(segment_ids) == max_seq_length
180 | assert len(lm_label_ids) == max_seq_length
181 |
182 | feat, boxes = example.visual_feats
183 | obj_labels, obj_confs = example.obj_labels
184 | attr_labels, attr_confs = example.attr_labels
185 |
186 | # Mask Image Features:
187 | masked_feat, feat_mask = random_feat(feat, random_index)
188 | obj_confs = feat_mask
189 | attr_confs = feat_mask
190 | visn_masks = 1-feat_mask
191 |
192 | # QA answer label
193 | if example.label is None or len(example.label) == 0 or example.is_matched != 1:
194 | # 1. No label 2. Label is pruned 3. unmatched visual + language pair
195 | ans = -1
196 | else:
197 | keys, values = zip(*example.label.items())
198 | if len(keys) == 1:
199 | ans = keys[0]
200 | else:
201 | value_sum = sum(values)
202 | prob = [value / value_sum for value in values]
203 | choice = np.random.multinomial(1, prob).argmax()
204 | ans = keys[choice]
205 |
206 | features = InputFeatures(
207 | input_ids=input_ids,
208 | input_mask=input_mask,
209 | segment_ids=segment_ids,
210 | lm_label_ids=lm_label_ids,
211 | visual_feats=(masked_feat, boxes),
212 | visual_masks = visn_masks,
213 | obj_labels={
214 | 'obj': (obj_labels, obj_confs),
215 | 'attr': (attr_labels, attr_confs),
216 | 'feat': (feat, feat_mask),
217 | },
218 | is_matched=example.is_matched,
219 | ans=ans,
220 | random_index=random_index
221 | )
222 | return features
223 |
224 |
225 | LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA')
226 |
227 |
228 | class LXMERT:
229 | def __init__(self, max_seq_length):
230 | super().__init__()
231 | self.max_seq_length = max_seq_length
232 |
233 | self.tokenizer = BertTokenizer.from_pretrained(
234 | "bert-base-uncased",
235 | do_lower_case=True
236 | )
237 |
238 | # Build model
239 | set_visual_config(args)
240 | self.model = LXRTPretraining.from_pretrained(
241 | "bert-base-uncased",
242 | task_mask_lm=args.task_mask_lm,
243 | task_obj_predict=args.task_obj_predict,
244 | task_matched=args.task_matched,
245 | task_qa=args.task_qa,
246 | visual_losses=args.visual_losses,
247 | num_answers=train_tuple.dataset.answer_table.num_answers,
248 | args = args
249 | )
250 |
251 | # Weight initialization and loading
252 | if args.from_scratch:
253 | print("Train from Scratch: re-initialize all BERT weights.")
254 | self.model.apply(self.model.init_bert_weights)
255 | if args.load is not None:
256 | self.load(args.load)
257 | if args.load_lxmert is not None:
258 | # Load lxmert would not load the answer head.
259 | self.load_lxmert(args.load_lxmert)
260 | # GPU Options
261 | self.model = self.model.cuda()
262 | if args.multiGPU:
263 | #self.model = nn.DataParallel(self.model)
264 | self.multiGPU = 1
265 | else:
266 | self.multiGPU = 0
267 | # self.model = nn.DistributedDataParallel(self.model)
268 | self.output = args.output
269 | os.makedirs(self.output, exist_ok=True)
270 |
271 | def forward(self, examples):
272 | train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer)
273 | for example in examples]
274 |
275 | # language Inputs
276 | input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
277 | input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
278 | segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
279 |
280 | # Visual Inputs
281 | feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda()
282 | pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda()
283 | visual_mask = torch.tensor([f.visual_masks for f in train_features], dtype=torch.long).cuda()
284 |
285 | # Language Prediction
286 | lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda()
287 |
288 |
289 | # Visual Prediction
290 | obj_labels = {}
291 | for key in ('obj', 'attr', 'feat'):
292 | visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda()
293 | visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda()
294 | assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1)
295 | obj_labels[key] = (visn_labels, visn_mask)
296 |
297 | # Joint Prediction
298 | matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda()
299 | ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda()
300 |
301 | """
302 | forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
303 | visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None):
304 | """
305 | loss, losses, ans_logit = self.model(
306 | input_ids, segment_ids, input_mask, lm_labels,
307 | feats, pos, obj_labels, matched_labels, ans, visual_mask
308 | )
309 | if self.multiGPU:
310 | loss = loss.mean()
311 | losses = losses.mean(0)
312 |
313 | return loss, losses.detach().cpu(), ans_logit
314 |
315 | def train_batch(self, optim, batch):
316 | optim.zero_grad()
317 | loss, losses, ans_logit = self.forward(batch)
318 | # if args.multiGPU:
319 | # loss = loss.mean()
320 | # losses = losses.mean(0)
321 | if args.fp16:
322 | try:
323 | from apex import amp
324 | except ImportError:
325 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
326 | with amp.scale_loss(loss, optim) as scaled_loss:
327 | scaled_loss.backward()
328 | torch.nn.utils.clip_grad_norm_(amp.master_params(optim), 1.)
329 | else:
330 | loss.backward()
331 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
332 |
333 | optim.step()
334 |
335 | return loss.item(), losses.cpu().numpy(), ans_logit
336 |
337 | def valid_batch(self, batch):
338 | with torch.no_grad():
339 | loss, losses, ans_logit = self.forward(batch)
340 | # if args.multiGPU:
341 | # loss = loss.mean()
342 | # losses = losses.mean(0)
343 | return loss.item(), losses.cpu().numpy(), ans_logit
344 |
345 | def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
346 | train_ld = train_tuple.loader
347 |
348 | # Optimizer
349 | from lxrt.optimization import BertAdam
350 | batch_per_epoch = len(train_ld)
351 | t_total = int(batch_per_epoch * args.epochs)
352 | warmup_ratio = 0.05
353 | warmup_iters = int(t_total * warmup_ratio)
354 | print("Batch per epoch: %d" % batch_per_epoch)
355 | print("Total Iters: %d" % t_total)
356 | print("Warm up Iters: %d" % warmup_iters)
357 | optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total)
358 | start_epoch = 0
359 |
360 | if args.fp16:
361 | try:
362 | from apex import amp
363 | except ImportError:
364 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
365 | self.model, optim = amp.initialize(self.model, optim, opt_level='O1')
366 |
367 |
368 | # GPU Options
369 | if args.multiGPU:
370 | self.model = nn.DataParallel(self.model)
371 |
372 | if args.start_from >0 and args.pretraining_index == 0:
373 | start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02'))
374 | print('start training from {0}'.format(start_path))
375 | state = torch.load(start_path)
376 | self.model.load_state_dict(state['state_dict'],strict=False)
377 | optim.load_state_dict(state['optimizer'])
378 | start_epoch = args.start_from
379 | del state
380 | torch.cuda.empty_cache()
381 | elif args.start_from >0 and args.pretraining_index == 1:
382 | start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02'))
383 | print('start training from {0}'.format(start_path))
384 | state = torch.load(start_path)
385 | self.model.load_state_dict(state['state_dict'],strict=False)
386 | del state
387 | torch.cuda.empty_cache()
388 |
389 | # if args.start_from >0:
390 | # start_path = os.path.join(args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02'))
391 | # print('start training from {0}'.format(start_path))
392 | # state = torch.load(start_path)
393 | # self.model.load_state_dict(state['state_dict'])
394 | # optim.load_state_dict(state['optimizer'])
395 | # start_epoch = args.start_from
396 | # del state
397 | # torch.cuda.empty_cache()
398 |
399 |
400 | # Train
401 | best_eval_loss = 9595.
402 | for epoch in range(start_epoch, args.epochs):
403 | # Train
404 | self.model.train()
405 | total_loss = 0.
406 | total_losses = 0.
407 | uid2ans = {}
408 | for batch in tqdm(train_ld, total=len(train_ld)):
409 | loss, losses, logit = self.train_batch(optim, batch)
410 | total_loss += loss
411 | total_losses += losses
412 |
413 | if args.task_qa:
414 | score, label = logit.max(1)
415 | for datum, l in zip(batch, label.cpu().numpy()):
416 | uid = datum.uid
417 | ans = train_tuple.dataset.answer_table.id2ans(l)
418 | uid2ans[uid] = ans
419 |
420 | print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch))
421 | log_str = "\nThe training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch)
422 | losses_str = "\nThe losses are "
423 | log_str += "\nThe losses are "
424 | for name, loss in zip(LOSSES_NAME, total_losses):
425 | losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch)
426 | log_str += "\n %s: %0.4f " % (name, loss / batch_per_epoch)
427 | print(losses_str)
428 | with open(self.output + "/log.log", 'a') as f:
429 | f.write(log_str)
430 | f.flush()
431 | if args.task_qa:
432 | train_tuple.evaluator.evaluate(uid2ans, pprint=True)
433 |
434 | # Eval
435 | avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)
436 |
437 | state = {
438 | 'state_dict': self.model.state_dict(),
439 | 'optimizer': optim.state_dict(),
440 | }
441 |
442 | # Save
443 | if avg_eval_loss < best_eval_loss:
444 | best_eval_loss = avg_eval_loss
445 | self.save("BEST_EVAL_LOSS",state)
446 |
447 | self.save("Epoch%02d" % (epoch+1), state)
448 |
449 | def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1):
450 | self.model.eval()
451 | eval_ld = eval_tuple.loader
452 | total_loss = 0.
453 | total_losses = 0.
454 | uid2ans = {}
455 | for i, batch in enumerate(eval_ld):
456 | loss, losses, logit = self.valid_batch(batch)
457 | total_loss += loss
458 | total_losses += losses
459 | if args.task_qa:
460 | score, label = logit.max(1)
461 | for datum, l in zip(batch, label.cpu().numpy()):
462 | uid = datum.uid
463 | ans = train_tuple.dataset.answer_table.id2ans(l)
464 | uid2ans[uid] = ans
465 | if i == iters:
466 | break
467 |
468 | print("The valid loss is %0.4f" % (total_loss / len(eval_ld)))
469 | log_str = "\nThe valid loss is %0.4f" % (total_loss / len(eval_ld))
470 | losses_str = "\nThe losses are "
471 | log_str += "\nThe losses are "
472 | for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)):
473 | losses_str += "%s: %0.4f " % (name, loss)
474 | log_str += "\n%s: %0.4f " % (name, loss)
475 | print(losses_str)
476 | with open(self.output + "/log.log", 'a') as f:
477 | f.write(log_str)
478 | f.flush()
479 |
480 | if args.task_qa:
481 | eval_tuple.evaluator.evaluate(uid2ans, pprint=True)
482 |
483 | return total_loss / len(eval_ld)
484 |
485 | def save(self, name, state):
486 | torch.save(state,
487 | os.path.join(args.output, "%s_LXRT.pth" % name))
488 |
489 |
490 | def load(self, path):
491 | print("Load BERT extractor from %s" % path)
492 | state_dict = torch.load("%s_LXRT.pth" % path)
493 | self.model.load_state_dict(state_dict)
494 |
495 | def load_lxmert(self, path):
496 | print("Load LXMERT model from %s" % path)
497 | state_dict = torch.load("%s_LXRT.pth" % path)
498 |
499 | # Do not load any answer head
500 | for key in list(state_dict.keys()):
501 | if 'answer' in key:
502 | state_dict.pop(key)
503 |
504 | # Change Multi GPU to single GPU
505 | new_state_dict = {}
506 | for key, value in state_dict.items():
507 | if key.startswith("module."):
508 | new_state_dict[key[len("module."):]] = value
509 | state_dict = new_state_dict
510 |
511 | load_keys = set(state_dict.keys())
512 | model_keys = set(self.model.state_dict().keys())
513 | print()
514 | print("Keys in loaded but not in model:")
515 | for key in sorted(load_keys.difference(model_keys)):
516 | print(key)
517 | print()
518 | print("Keys in model but not in loaded:")
519 | for key in sorted(model_keys.difference(load_keys)):
520 | print(key)
521 | print()
522 |
523 | self.model.load_state_dict(state_dict, strict=False)
524 |
525 |
526 | if __name__ == "__main__":
527 | # os.environ['CUDA_VISIBLE_DEVICES'] = "1"
528 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3"
529 |
530 | lxmert = LXMERT(max_seq_length=20)
531 |
532 | lxmert.train(train_tuple, valid_tuple)
533 |
--------------------------------------------------------------------------------
/src/pretrain/qa_answer_table.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import json
5 | import torch
6 |
7 |
8 | class AnswerTable:
9 | ANS_CONVERT = {
10 | "a man": "man",
11 | "the man": "man",
12 | "a woman": "woman",
13 | "the woman": "woman",
14 | 'one': '1',
15 | 'two': '2',
16 | 'three': '3',
17 | 'four': '4',
18 | 'five': '5',
19 | 'six': '6',
20 | 'seven': '7',
21 | 'eight': '8',
22 | 'nine': '9',
23 | 'ten': '10',
24 | 'grey': 'gray',
25 | }
26 |
27 | def __init__(self, dsets=None):
28 | self.all_ans = json.load(open("data/lxmert/all_ans.json"))
29 | if dsets is not None:
30 | dsets = set(dsets)
31 | # If the answer is used in the dsets
32 | self.anss = [ans['ans'] for ans in self.all_ans if
33 | len(set(ans['dsets']) & dsets) > 0]
34 | else:
35 | self.anss = [ans['ans'] for ans in self.all_ans]
36 | self.ans_set = set(self.anss)
37 |
38 | self._id2ans_map = self.anss
39 | self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)}
40 |
41 | assert len(self._id2ans_map) == len(self._ans2id_map)
42 | for ans_id, ans in enumerate(self._id2ans_map):
43 | assert self._ans2id_map[ans] == ans_id
44 |
45 | def convert_ans(self, ans):
46 | if len(ans) == 0:
47 | return ""
48 | ans = ans.lower()
49 | if ans[-1] == '.':
50 | ans = ans[:-1].strip()
51 | if ans.startswith("a "):
52 | ans = ans[2:].strip()
53 | if ans.startswith("an "):
54 | ans = ans[3:].strip()
55 | if ans.startswith("the "):
56 | ans = ans[4:].strip()
57 | if ans in self.ANS_CONVERT:
58 | ans = self.ANS_CONVERT[ans]
59 | return ans
60 |
61 | def ans2id(self, ans):
62 | return self._ans2id_map[ans]
63 |
64 | def id2ans(self, ans_id):
65 | return self._id2ans_map[ans_id]
66 |
67 | def ans2id_map(self):
68 | return self._ans2id_map.copy()
69 |
70 | def id2ans_map(self):
71 | return self._id2ans_map.copy()
72 |
73 | def used(self, ans):
74 | return ans in self.ans_set
75 |
76 | def all_answers(self):
77 | return self.anss.copy()
78 |
79 | @property
80 | def num_answers(self):
81 | return len(self.anss)
82 |
83 |
84 | def load_lxmert_qa(path, model, label2ans):
85 | """
86 | Load model weights from LXMERT pre-training.
87 | The answers in the fine-tuned QA task (indicated by label2ans)
88 | would also be properly initialized with LXMERT pre-trained
89 | QA heads.
90 |
91 | :param path: Path to LXMERT snapshot.
92 | :param model: LXRT model instance.
93 | :param label2ans: The label2ans dict of fine-tuned QA datasets, like
94 | {0: 'cat', 1: 'dog', ...}
95 | :return:
96 | """
97 | print("Load QA pre-trained LXMERT from %s " % path)
98 | # loaded_state_dict = torch.load("%s_LXRT.pth" % path)
99 | loaded_state_dict = torch.load("%s_LXRT.pth" % path)['state_dict']
100 |
101 | model_state_dict = model.state_dict()
102 |
103 | # Handle Multi-GPU pre-training --> Single GPU fine-tuning
104 | for key in list(loaded_state_dict.keys()):
105 | loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key)
106 |
107 | # Isolate bert model
108 | bert_state_dict = {}
109 | for key, value in loaded_state_dict.items():
110 | if key.startswith('bert.'):
111 | bert_state_dict[key] = value
112 |
113 | # Isolate answer head
114 | answer_state_dict = {}
115 | for key, value in loaded_state_dict.items():
116 | if key.startswith("answer_head."):
117 | answer_state_dict[key.replace('answer_head.', '')] = value
118 |
119 | # Do surgery on answer state dict
120 | ans_weight = answer_state_dict['logit_fc.3.weight']
121 | ans_bias = answer_state_dict['logit_fc.3.bias']
122 | import copy
123 | new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight'])
124 | new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias'])
125 | answer_table = AnswerTable()
126 | loaded = 0
127 | unload = 0
128 | if type(label2ans) is list:
129 | label2ans = {label: ans for label, ans in enumerate(label2ans)}
130 | for label, ans in label2ans.items():
131 | new_ans = answer_table.convert_ans(ans)
132 | if answer_table.used(new_ans):
133 | ans_id_9500 = answer_table.ans2id(new_ans)
134 | new_answer_weight[label] = ans_weight[ans_id_9500]
135 | new_answer_bias[label] = ans_bias[ans_id_9500]
136 | loaded += 1
137 | else:
138 | new_answer_weight[label] = 0.
139 | new_answer_bias[label] = 0.
140 | unload += 1
141 | print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload))
142 | print()
143 | answer_state_dict['logit_fc.3.weight'] = new_answer_weight
144 | answer_state_dict['logit_fc.3.bias'] = new_answer_bias
145 |
146 | # Load Bert Weights
147 | bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys())
148 | bert_loaded_keys = set(bert_state_dict.keys())
149 | # assert len(bert_model_keys - bert_loaded_keys) == 0
150 | model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False)
151 |
152 | # Load Answer Logic FC Weights
153 | model_keys = set(model.state_dict().keys())
154 | ans_loaded_keys = set(answer_state_dict.keys())
155 | assert len(ans_loaded_keys - model_keys) == 0
156 |
157 | model.load_state_dict(answer_state_dict, strict=False)
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/src/tasks/gqa.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import os
5 | import collections
6 |
7 | import torch
8 | from tqdm import tqdm
9 | import torch.nn as nn
10 | from torch.utils.data.dataloader import DataLoader
11 |
12 | from param import args
13 | from pretrain.qa_answer_table import load_lxmert_qa
14 | from tasks.gqa_model import GQAModel
15 | from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator
16 |
17 |
18 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
19 |
20 |
21 | def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
22 | dset = GQADataset(splits)
23 | tset = GQATorchDataset(dset)
24 | evaluator = GQAEvaluator(dset)
25 | data_loader = DataLoader(
26 | tset, batch_size=bs,
27 | shuffle=shuffle, num_workers=args.num_workers,
28 | drop_last=drop_last, pin_memory=True
29 | )
30 |
31 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
32 |
33 |
34 | class GQA:
35 | def __init__(self):
36 | self.train_tuple = get_tuple(
37 | args.train, bs=args.batch_size, shuffle=True, drop_last=True
38 | )
39 | if args.valid != "":
40 | valid_bsize = 512 if args.multiGPU else 512
41 | self.valid_tuple = get_tuple(
42 | args.valid, bs=valid_bsize,
43 | shuffle=False, drop_last=False
44 | )
45 | else:
46 | self.valid_tuple = None
47 |
48 | self.model = GQAModel(self.train_tuple.dataset.num_answers)
49 |
50 | # Load pre-trained weights
51 | if args.load_lxmert is not None:
52 | self.model.lxrt_encoder.load(args.load_lxmert)
53 | if args.load_lxmert_qa is not None:
54 | load_lxmert_qa(args.load_lxmert_qa, self.model,
55 | label2ans=self.train_tuple.dataset.label2ans)
56 |
57 | # GPU options
58 | self.model = self.model.cuda()
59 | if args.multiGPU:
60 | self.model.lxrt_encoder.multi_gpu()
61 |
62 | # Losses and optimizer
63 | self.bce_loss = nn.BCEWithLogitsLoss()
64 | self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
65 | if 'bert' in args.optim:
66 | batch_per_epoch = len(self.train_tuple.loader)
67 | t_total = int(batch_per_epoch * args.epochs)
68 | print("Total Iters: %d" % t_total)
69 | from lxrt.optimization import BertAdam
70 | self.optim = BertAdam(list(self.model.parameters()),
71 | lr=args.lr,
72 | warmup=0.1,
73 | t_total=t_total)
74 | else:
75 | self.optim = args.optimizer(list(self.model.parameters()), args.lr)
76 |
77 | self.output = args.output
78 | os.makedirs(self.output, exist_ok=True)
79 |
80 | def train(self, train_tuple, eval_tuple):
81 | dset, loader, evaluator = train_tuple
82 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
83 |
84 | best_valid = 0.
85 | for epoch in range(args.epochs):
86 | quesid2ans = {}
87 | for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
88 |
89 | self.model.train()
90 | self.optim.zero_grad()
91 |
92 | feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
93 | logit = self.model(feats, boxes, sent)
94 | assert logit.dim() == target.dim() == 2
95 | if args.mce_loss:
96 | max_value, target = target.max(1)
97 | loss = self.mce_loss(logit, target) * logit.size(1)
98 | else:
99 | loss = self.bce_loss(logit, target)
100 | loss = loss * logit.size(1)
101 |
102 | loss.backward()
103 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
104 | self.optim.step()
105 |
106 | score, label = logit.max(1)
107 | for qid, l in zip(ques_id, label.cpu().numpy()):
108 | ans = dset.label2ans[l]
109 | quesid2ans[qid] = ans
110 |
111 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
112 |
113 | if self.valid_tuple is not None: # Do Validation
114 | valid_score = self.evaluate(eval_tuple)
115 | if valid_score > best_valid:
116 | best_valid = valid_score
117 | self.save("BEST")
118 |
119 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
120 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
121 |
122 | print(log_str, end='')
123 |
124 | with open(self.output + "/log.log", 'a') as f:
125 | f.write(log_str)
126 | f.flush()
127 |
128 | self.save("LAST")
129 |
130 | def predict(self, eval_tuple: DataTuple, dump=None):
131 | self.model.eval()
132 | dset, loader, evaluator = eval_tuple
133 | quesid2ans = {}
134 | for i, datum_tuple in enumerate(loader):
135 | if i%100 == 0:
136 | print(i)
137 | ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
138 | with torch.no_grad():
139 | feats, boxes = feats.cuda(), boxes.cuda()
140 | logit = self.model(feats, boxes, sent)
141 | score, label = logit.max(1)
142 | for qid, l in zip(ques_id, label.cpu().numpy()):
143 | ans = dset.label2ans[l]
144 | quesid2ans[qid] = ans
145 | if dump is not None:
146 | evaluator.dump_result(quesid2ans, dump)
147 | return quesid2ans
148 |
149 | def evaluate(self, eval_tuple: DataTuple, dump=None):
150 | dset, loader, evaluator = eval_tuple
151 | quesid2ans = self.predict(eval_tuple, dump)
152 | return evaluator.evaluate(quesid2ans)
153 |
154 | @staticmethod
155 | def oracle_score(data_tuple):
156 | dset, loader, evaluator = data_tuple
157 | quesid2ans = {}
158 | for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
159 | _, label = target.max(1)
160 | for qid, l in zip(ques_id, label.cpu().numpy()):
161 | ans = dset.label2ans[l]
162 | quesid2ans[qid] = ans
163 | return evaluator.evaluate(quesid2ans)
164 |
165 | def save(self, name):
166 | torch.save(self.model.state_dict(),
167 | os.path.join(self.output, "%s.pth" % name))
168 |
169 | def load(self, path):
170 | print("Load model from %s" % path)
171 | state_dict = torch.load("%s.pth" % path)
172 | for key in list(state_dict.keys()):
173 | if '.module' in key:
174 | state_dict[key.replace('.module', '')] = state_dict.pop(key)
175 | self.model.load_state_dict(state_dict, strict=False)
176 |
177 |
178 | if __name__ == "__main__":
179 | # Build Class
180 | gqa = GQA()
181 |
182 | # Load Model
183 | if args.load is not None:
184 | gqa.load(args.load)
185 |
186 | # Test or Train
187 | if args.test is not None:
188 | args.fast = args.tiny = False # Always loading all data in test
189 | if 'submit' in args.test:
190 | gqa.predict(
191 | get_tuple(args.test, bs=args.batch_size,
192 | shuffle=False, drop_last=False),
193 | dump=os.path.join(args.output, 'submit_predict.json')
194 | )
195 | if 'testdev' in args.test:
196 | result = gqa.evaluate(
197 | get_tuple('testdev', bs=args.batch_size,
198 | shuffle=False, drop_last=False),
199 | dump=os.path.join(args.output, 'testdev_predict.json')
200 | )
201 | print(result)
202 | else:
203 | # print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100))
204 | print('Splits in Train data:', gqa.train_tuple.dataset.splits)
205 | if gqa.valid_tuple is not None:
206 | print('Splits in Valid data:', gqa.valid_tuple.dataset.splits)
207 | print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100))
208 | else:
209 | print("DO NOT USE VALIDATION")
210 | gqa.train(gqa.train_tuple, gqa.valid_tuple)
211 |
212 |
213 |
--------------------------------------------------------------------------------
/src/tasks/gqa_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import json
5 |
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import Dataset
9 |
10 | from param import args
11 | from utils import load_obj_tsv
12 |
13 | # Load part of the dataset for fast checking.
14 | # Notice that here is the number of images instead of the number of data,
15 | # which means all related data to the images would be used.
16 | TINY_IMG_NUM = 512
17 | FAST_IMG_NUM = 5000
18 |
19 |
20 | class GQADataset:
21 | """
22 | A GQA data example in json file:
23 | {
24 | "img_id": "2375429",
25 | "label": {
26 | "pipe": 1.0
27 | },
28 | "question_id": "07333408",
29 | "sent": "What is on the white wall?"
30 | }
31 | """
32 | def __init__(self, splits: str):
33 | self.name = splits
34 | self.splits = splits.split(',')
35 |
36 | # Loading datasets to data
37 | self.data = []
38 | for split in self.splits:
39 | self.data.extend(json.load(open("data/gqa/%s.json" % split)))
40 | print("Load %d data from split(s) %s." % (len(self.data), self.name))
41 |
42 | # List to dict (for evaluation and others)
43 | self.id2datum = {
44 | datum['question_id']: datum
45 | for datum in self.data
46 | }
47 |
48 | # Answers
49 | self.ans2label = json.load(open("data/gqa/trainval_ans2label.json"))
50 | self.label2ans = json.load(open("data/gqa/trainval_label2ans.json"))
51 | assert len(self.ans2label) == len(self.label2ans)
52 | for ans, label in self.ans2label.items():
53 | assert self.label2ans[label] == ans
54 |
55 | @property
56 | def num_answers(self):
57 | return len(self.ans2label)
58 |
59 | def __len__(self):
60 | return len(self.data)
61 |
62 |
63 | class GQABufferLoader():
64 | def __init__(self):
65 | self.key2data = {}
66 |
67 | def load_data(self, name, number):
68 | if name == 'testdev':
69 | path = "data/vg_gqa_imgfeat/gqa_testdev_obj64.tsv"
70 | else:
71 | path = "data/vg_gqa_imgfeat/vg_gqa_obj64.tsv"
72 | key = "%s_%d" % (path, number)
73 | if key not in self.key2data:
74 | self.key2data[key] = load_obj_tsv(
75 | path,
76 | topk=number
77 | )
78 | return self.key2data[key]
79 |
80 |
81 | gqa_buffer_loader = GQABufferLoader()
82 |
83 |
84 | """
85 | Example in obj tsv:
86 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
87 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
88 | """
89 | class GQATorchDataset(Dataset):
90 | def __init__(self, dataset: GQADataset):
91 | super().__init__()
92 | self.raw_dataset = dataset
93 |
94 | if args.tiny:
95 | topk = TINY_IMG_NUM
96 | elif args.fast:
97 | topk = FAST_IMG_NUM
98 | else:
99 | topk = -1
100 |
101 | # Loading detection features to img_data
102 | # Since images in train and valid both come from Visual Genome,
103 | # buffer the image loading to save memory.
104 | img_data = []
105 | if 'testdev' in dataset.splits or 'testdev_all' in dataset.splits: # Always loading all the data in testdev
106 | img_data.extend(gqa_buffer_loader.load_data('testdev', -1))
107 | else:
108 | img_data.extend(gqa_buffer_loader.load_data('train', topk))
109 | self.imgid2img = {}
110 | for img_datum in img_data:
111 | self.imgid2img[img_datum['img_id']] = img_datum
112 |
113 | # Only kept the data with loaded image features
114 | self.data = []
115 | for datum in self.raw_dataset.data:
116 | if datum['img_id'] in self.imgid2img:
117 | self.data.append(datum)
118 | print("Use %d data in torch dataset" % (len(self.data)))
119 | print()
120 |
121 | def __len__(self):
122 | return len(self.data)
123 |
124 | def __getitem__(self, item: int):
125 | datum = self.data[item]
126 |
127 | img_id = datum['img_id']
128 | ques_id = datum['question_id']
129 | ques = datum['sent']
130 |
131 | # Get image info
132 | img_info = self.imgid2img[img_id]
133 | obj_num = img_info['num_boxes']
134 | boxes = img_info['boxes'].copy()
135 | feats = img_info['features'].copy()
136 | assert len(boxes) == len(feats) == obj_num
137 |
138 | # Normalize the boxes (to 0 ~ 1)
139 | img_h, img_w = img_info['img_h'], img_info['img_w']
140 | boxes = boxes.copy()
141 | boxes[:, (0, 2)] /= img_w
142 | boxes[:, (1, 3)] /= img_h
143 | np.testing.assert_array_less(boxes, 1+1e-5)
144 | np.testing.assert_array_less(-boxes, 0+1e-5)
145 |
146 | # Create target
147 | if 'label' in datum:
148 | label = datum['label']
149 | target = torch.zeros(self.raw_dataset.num_answers)
150 | for ans, score in label.items():
151 | if ans in self.raw_dataset.ans2label:
152 | target[self.raw_dataset.ans2label[ans]] = score
153 | return ques_id, feats, boxes, ques, target
154 | else:
155 | return ques_id, feats, boxes, ques
156 |
157 |
158 | class GQAEvaluator:
159 | def __init__(self, dataset: GQADataset):
160 | self.dataset = dataset
161 |
162 | def evaluate(self, quesid2ans: dict):
163 | score = 0.
164 | for quesid, ans in quesid2ans.items():
165 | datum = self.dataset.id2datum[quesid]
166 | label = datum['label']
167 | if ans in label:
168 | score += label[ans]
169 | return score / len(quesid2ans)
170 |
171 | def dump_result(self, quesid2ans: dict, path):
172 | """
173 | Dump the result to a GQA-challenge submittable json file.
174 | GQA json file submission requirement:
175 | results = [result]
176 | result = {
177 | "questionId": str, # Note: it's a actually an int number but the server requires an str.
178 | "prediction": str
179 | }
180 |
181 | :param quesid2ans: A dict mapping question id to its predicted answer.
182 | :param path: The file path to save the json file.
183 | :return:
184 | """
185 | with open(path, 'w') as f:
186 | result = []
187 | for ques_id, ans in quesid2ans.items():
188 | result.append({
189 | 'questionId': ques_id,
190 | 'prediction': ans
191 | })
192 | json.dump(result, f, indent=4, sort_keys=True)
193 |
194 |
195 |
--------------------------------------------------------------------------------
/src/tasks/gqa_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import torch.nn as nn
5 |
6 | from param import args
7 | from lxrt.entry import LXRTEncoder
8 | from lxrt.modeling import BertLayerNorm, GeLU
9 |
10 | # Max length including and
11 | MAX_GQA_LENGTH = 20
12 |
13 |
14 | class GQAModel(nn.Module):
15 | def __init__(self, num_answers):
16 | super().__init__()
17 | self.lxrt_encoder = LXRTEncoder(
18 | args,
19 | max_seq_length=MAX_GQA_LENGTH
20 | )
21 | hid_dim = self.lxrt_encoder.dim
22 | self.logit_fc = nn.Sequential(
23 | nn.Linear(hid_dim, hid_dim * 2),
24 | GeLU(),
25 | BertLayerNorm(hid_dim * 2, eps=1e-12),
26 | nn.Linear(hid_dim * 2, num_answers)
27 | )
28 | self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
29 |
30 | def forward(self, feat, pos, sent):
31 | """
32 | b -- batch_size, o -- object_number, f -- visual_feature_size
33 |
34 | :param feat: (b, o, f)
35 | :param pos: (b, o, 4)
36 | :param sent: (b,) Type -- list of string
37 | :param leng: (b,) Type -- int numpy array
38 | :return: (b, num_answer) The logit of each answers.
39 | """
40 | x = self.lxrt_encoder(sent, (feat, pos))
41 | logit = self.logit_fc(x)
42 |
43 | return logit
44 |
45 |
46 |
--------------------------------------------------------------------------------
/src/tasks/nlvr2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import os
5 | import collections
6 |
7 | from tqdm import tqdm
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data.dataloader import DataLoader
11 |
12 | from param import args
13 | from tasks.nlvr2_model import NLVR2Model
14 | from tasks.nlvr2_data import NLVR2Dataset, NLVR2TorchDataset, NLVR2Evaluator
15 |
16 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
17 |
18 |
19 | def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
20 | dset = NLVR2Dataset(splits)
21 | tset = NLVR2TorchDataset(dset)
22 | evaluator = NLVR2Evaluator(dset)
23 | data_loader = DataLoader(
24 | tset, batch_size=bs,
25 | shuffle=shuffle, num_workers=args.num_workers,
26 | drop_last=drop_last, pin_memory=True
27 | )
28 |
29 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
30 |
31 |
32 | class NLVR2:
33 | def __init__(self):
34 | self.train_tuple = get_tuple(
35 | args.train, bs=args.batch_size, shuffle=True, drop_last=True
36 | )
37 | if args.valid != "":
38 | valid_bsize = 256 if args.multiGPU else 256
39 | self.valid_tuple = get_tuple(
40 | args.valid, bs=valid_bsize,
41 | shuffle=False, drop_last=False
42 | )
43 | else:
44 | self.valid_tuple = None
45 |
46 | self.model = NLVR2Model()
47 |
48 | # Load pre-trained weights
49 | if args.load_lxmert is not None:
50 | self.model.lxrt_encoder.load(args.load_lxmert)
51 |
52 | # GPU options
53 | if args.multiGPU:
54 | self.model.lxrt_encoder.multi_gpu()
55 | self.model = self.model.cuda()
56 |
57 | # Losses and optimizer
58 | self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
59 | if 'bert' in args.optim:
60 | batch_per_epoch = len(self.train_tuple.loader)
61 | t_total = int(batch_per_epoch * args.epochs)
62 | print("Total Iters: %d" % t_total)
63 | from lxrt.optimization import BertAdam
64 | self.optim = BertAdam(list(self.model.parameters()),
65 | lr=args.lr,
66 | warmup=0.1,
67 | t_total=t_total)
68 | else:
69 | self.optim = args.optimizer(list(self.model.parameters()), args.lr)
70 |
71 | self.output = args.output
72 | os.makedirs(self.output, exist_ok=True)
73 |
74 | def train(self, train_tuple, eval_tuple):
75 | dset, loader, evaluator = train_tuple
76 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
77 |
78 | best_valid = 0.
79 | for epoch in range(args.epochs):
80 | quesid2ans = {}
81 | for i, (ques_id, feats, boxes, sent, label) in iter_wrapper(enumerate(loader)):
82 | self.model.train()
83 |
84 | self.optim.zero_grad()
85 | feats, boxes, label = feats.cuda(), boxes.cuda(), label.cuda()
86 | logit = self.model(feats, boxes, sent)
87 |
88 | loss = self.mce_loss(logit, label)
89 |
90 | loss.backward()
91 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
92 | self.optim.step()
93 |
94 | score, predict = logit.max(1)
95 | for qid, l in zip(ques_id, predict.cpu().numpy()):
96 | quesid2ans[qid] = l
97 |
98 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
99 |
100 | if self.valid_tuple is not None: # Do Validation
101 | valid_score = self.evaluate(eval_tuple)
102 | if valid_score > best_valid:
103 | best_valid = valid_score
104 | self.save("BEST")
105 |
106 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
107 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
108 |
109 | print(log_str, end='')
110 |
111 | with open(self.output + "/log.log", 'a') as f:
112 | f.write(log_str)
113 | f.flush()
114 |
115 | self.save("LAST")
116 |
117 | def predict(self, eval_tuple: DataTuple, dump=None):
118 | self.model.eval()
119 | dset, loader, evaluator = eval_tuple
120 | quesid2ans = {}
121 | for i, datum_tuple in enumerate(loader):
122 | ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
123 | with torch.no_grad():
124 | feats, boxes = feats.cuda(), boxes.cuda()
125 | logit = self.model(feats, boxes, sent)
126 | score, predict = logit.max(1)
127 | for qid, l in zip(ques_id, predict.cpu().numpy()):
128 | quesid2ans[qid] = l
129 | if dump is not None:
130 | evaluator.dump_result(quesid2ans, dump)
131 | return quesid2ans
132 |
133 | def evaluate(self, eval_tuple: DataTuple, dump=None):
134 | dset, loader, evaluator = eval_tuple
135 | quesid2ans = self.predict(eval_tuple, dump)
136 | return evaluator.evaluate(quesid2ans)
137 |
138 | def save(self, name):
139 | torch.save(self.model.state_dict(),
140 | os.path.join(self.output, "%s.pth" % name))
141 |
142 | # def load(self, path):
143 | # print("Load model from %s" % path)
144 | # state_dict = torch.load("%s.pth" % path)
145 | # self.model.load_state_dict(state_dict)
146 |
147 | def load(self, path):
148 | print("Load model from %s" % path)
149 | state_dict = torch.load("%s.pth" % path)
150 | for key in list(state_dict.keys()):
151 | if '.module' in key:
152 | state_dict[key.replace('.module', '')] = state_dict.pop(key)
153 | self.model.load_state_dict(state_dict, strict=False)
154 |
155 |
156 | if __name__ == "__main__":
157 | # Build Class
158 | nlvr2 = NLVR2()
159 |
160 | # Load Model
161 | if args.load is not None:
162 | nlvr2.load(args.load)
163 |
164 | # Test or Train
165 | if args.test is not None:
166 | args.fast = args.tiny = False # Always loading all data in test
167 | if 'hidden' in args.test:
168 | nlvr2.predict(
169 | get_tuple(args.test, bs=args.batch_size,
170 | shuffle=False, drop_last=False),
171 | dump=os.path.join(args.output, 'hidden_predict.csv')
172 | )
173 | elif 'test' in args.test or 'valid' in args.test:
174 | result = nlvr2.evaluate(
175 | get_tuple(args.test, bs=args.batch_size,
176 | shuffle=False, drop_last=False),
177 | dump=os.path.join(args.output, '%s_predict.csv' % args.test)
178 | )
179 | print(result)
180 | else:
181 | assert False, "No such test option for %s" % args.test
182 | else:
183 | print('Splits in Train data:', nlvr2.train_tuple.dataset.splits)
184 | if nlvr2.valid_tuple is not None:
185 | print('Splits in Valid data:', nlvr2.valid_tuple.dataset.splits)
186 | else:
187 | print("DO NOT USE VALIDATION")
188 | nlvr2.train(nlvr2.train_tuple, nlvr2.valid_tuple)
189 |
190 |
191 |
--------------------------------------------------------------------------------
/src/tasks/nlvr2_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import json
5 |
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 |
9 | from param import args
10 | from utils import load_obj_tsv
11 |
12 | # Load part of the dataset for fast checking.
13 | # Notice that here is the number of images instead of the number of data,
14 | # which means all related data to the images would be used.
15 | TINY_IMG_NUM = 512
16 | FAST_IMG_NUM = 5000
17 |
18 |
19 | class NLVR2Dataset:
20 | """
21 | An NLVR2 data example in json file:
22 | {
23 | "identifier": "train-10171-0-0",
24 | "img0": "train-10171-0-img0",
25 | "img1": "train-10171-0-img1",
26 | "label": 0,
27 | "sent": "An image shows one leather pencil case, displayed open with writing implements tucked inside.
28 | ",
29 | "uid": "nlvr2_train_0"
30 | }
31 | """
32 | def __init__(self, splits: str):
33 | self.name = splits
34 | self.splits = splits.split(',')
35 |
36 | # Loading datasets to data
37 | self.data = []
38 | for split in self.splits:
39 | self.data.extend(json.load(open("data/nlvr2/%s.json" % split)))
40 | print("Load %d data from split(s) %s." % (len(self.data), self.name))
41 |
42 | # List to dict (for evaluation and others)
43 | self.id2datum = {
44 | datum['uid']: datum
45 | for datum in self.data
46 | }
47 |
48 | def __len__(self):
49 | return len(self.data)
50 |
51 |
52 | """
53 | An example in obj36 tsv:
54 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
55 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
56 | FIELDNAMES would be keys in the dict returned by load_obj_tsv.
57 | """
58 | class NLVR2TorchDataset(Dataset):
59 | def __init__(self, dataset: NLVR2Dataset):
60 | super().__init__()
61 | self.raw_dataset = dataset
62 |
63 | if args.tiny:
64 | topk = TINY_IMG_NUM
65 | elif args.fast:
66 | topk = FAST_IMG_NUM
67 | else:
68 | topk = -1
69 |
70 | # Loading detection features to img_data
71 | img_data = []
72 | if 'train' in dataset.splits:
73 | img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/train_obj64.tsv', topk=topk))
74 | if 'valid' in dataset.splits:
75 | img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/valid_obj64.tsv', topk=topk))
76 | if 'test' in dataset.name:
77 | img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/test_obj64.tsv', topk=topk))
78 | self.imgid2img = {}
79 | for img_datum in img_data:
80 | self.imgid2img[img_datum['img_id']] = img_datum
81 |
82 | # Filter out the dataset
83 | self.data = []
84 | for datum in self.raw_dataset.data:
85 | if datum['img0'] in self.imgid2img and datum['img1'] in self.imgid2img:
86 | self.data.append(datum)
87 | print("Use %d data in torch dataset" % (len(self.data)))
88 | print()
89 |
90 | def __len__(self):
91 | return len(self.data)
92 |
93 | def __getitem__(self, item: int):
94 | datum = self.data[item]
95 |
96 | ques_id = datum['uid']
97 | ques = datum['sent']
98 |
99 | # Get image info
100 | boxes2 = []
101 | feats2 = []
102 | for key in ['img0', 'img1']:
103 | img_id = datum[key]
104 | img_info = self.imgid2img[img_id]
105 | boxes = img_info['boxes'].copy()
106 | feats = img_info['features'].copy()
107 | assert len(boxes) == len(feats)
108 |
109 | # Normalize the boxes (to 0 ~ 1)
110 | img_h, img_w = img_info['img_h'], img_info['img_w']
111 | boxes[..., (0, 2)] /= img_w
112 | boxes[..., (1, 3)] /= img_h
113 | np.testing.assert_array_less(boxes, 1+1e-5)
114 | np.testing.assert_array_less(-boxes, 0+1e-5)
115 |
116 | boxes2.append(boxes)
117 | feats2.append(feats)
118 | feats = np.stack(feats2)
119 | boxes = np.stack(boxes2)
120 |
121 | # Create target
122 | if 'label' in datum:
123 | label = datum['label']
124 | return ques_id, feats, boxes, ques, label
125 | else:
126 | return ques_id, feats, boxes, ques
127 |
128 |
129 | class NLVR2Evaluator:
130 | def __init__(self, dataset: NLVR2Dataset):
131 | self.dataset = dataset
132 |
133 | def evaluate(self, quesid2ans: dict):
134 | score = 0.
135 | for quesid, ans in quesid2ans.items():
136 | datum = self.dataset.id2datum[quesid]
137 | label = datum['label']
138 | if ans == label:
139 | score += 1
140 | return score / len(quesid2ans)
141 |
142 | def dump_result(self, quesid2ans: dict, path):
143 | """
144 | Dump result to a CSV file, which is compatible with NLVR2 evaluation system.
145 | NLVR2 CSV file requirement:
146 | Each line contains: identifier, answer
147 |
148 | :param quesid2ans: nlvr2 uid to ans (either "True" or "False")
149 | :param path: The desired path of saved file.
150 | :return:
151 | """
152 | with open(path, 'w') as f:
153 | for uid, ans in quesid2ans.items():
154 | idt = self.dataset.id2datum[uid]["identifier"]
155 | ans = 'True' if ans == 1 else 'False'
156 | f.write("%s,%s\n" % (idt, ans))
157 |
158 |
--------------------------------------------------------------------------------
/src/tasks/nlvr2_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import torch.nn as nn
5 | from lxrt.modeling import GeLU, BertLayerNorm
6 | from lxrt.entry import LXRTEncoder
7 | from param import args
8 |
9 |
10 | class NLVR2Model(nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 | self.lxrt_encoder = LXRTEncoder(
14 | args,
15 | max_seq_length=20
16 | )
17 | self.hid_dim = hid_dim = self.lxrt_encoder.dim
18 | self.logit_fc = nn.Sequential(
19 | nn.Linear(hid_dim * 2, hid_dim * 2),
20 | GeLU(),
21 | BertLayerNorm(hid_dim * 2, eps=1e-12),
22 | nn.Linear(hid_dim * 2, 2)
23 | )
24 | self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
25 |
26 | def forward(self, feat, pos, sent):
27 | """
28 | :param feat: b, 2, o, f
29 | :param pos: b, 2, o, 4
30 | :param sent: b, (string)
31 | :param leng: b, (numpy, int)
32 | :return:
33 | """
34 | # Pairing images and sentences:
35 | # The input of NLVR2 is two images and one sentence. In batch level, they are saved as
36 | # [ [img0_0, img0_1], [img1_0, img1_1], ...] and [sent0, sent1, ...]
37 | # Here, we flat them to
38 | # feat/pos = [ img0_0, img0_1, img1_0, img1_1, ...]
39 | # sent = [ sent0, sent0, sent1, sent1, ...]
40 | sent = sum(zip(sent, sent), ())
41 | batch_size, img_num, obj_num, feat_size = feat.size()
42 | assert img_num == 2 and obj_num == 64 and feat_size == 2048
43 | feat = feat.view(batch_size * 2, obj_num, feat_size)
44 | pos = pos.view(batch_size * 2, obj_num, 4)
45 |
46 | # Extract feature --> Concat
47 | x = self.lxrt_encoder(sent, (feat, pos))
48 | x = x.view(-1, self.hid_dim*2)
49 |
50 | # Compute logit of answers
51 | logit = self.logit_fc(x)
52 |
53 | return logit
54 |
55 |
56 |
--------------------------------------------------------------------------------
/src/tasks/vqa.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import os
5 | import collections
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data.dataloader import DataLoader
10 | from tqdm import tqdm
11 |
12 | from param import args
13 | from pretrain.qa_answer_table import load_lxmert_qa
14 | from tasks.vqa_model import VQAModel
15 | from tasks.vqa_data import VQADataset, VQATorchDataset, VQAEvaluator
16 |
17 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
18 |
19 |
20 | def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
21 | dset = VQADataset(splits)
22 | tset = VQATorchDataset(dset)
23 | evaluator = VQAEvaluator(dset)
24 | data_loader = DataLoader(
25 | tset, batch_size=bs,
26 | shuffle=shuffle, num_workers=args.num_workers,
27 | drop_last=drop_last, pin_memory=True
28 | )
29 |
30 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
31 |
32 |
33 | class VQA:
34 | def __init__(self):
35 | # Datasets
36 | self.train_tuple = get_data_tuple(
37 | args.train, bs=args.batch_size, shuffle=True, drop_last=True
38 | )
39 | if args.bert_type == 'ft':
40 | bs_infer = 256
41 | else:
42 | bs_infer = 1024
43 | if args.valid != "":
44 | self.valid_tuple = get_data_tuple(
45 | args.valid, bs=bs_infer,
46 | shuffle=False, drop_last=False
47 | )
48 | else:
49 | self.valid_tuple = None
50 |
51 | print("args.lr is {0}".format(args.lr))
52 |
53 | # Model
54 | self.model = VQAModel(self.train_tuple.dataset.num_answers)
55 |
56 | # Load pre-trained weights
57 | if args.load_lxmert is not None:
58 | self.model.lxrt_encoder.load(args.load_lxmert)
59 | if args.load_lxmert_qa is not None:
60 | load_lxmert_qa(args.load_lxmert_qa, self.model,
61 | label2ans=self.train_tuple.dataset.label2ans)
62 |
63 | # GPU options
64 | self.model = self.model.cuda()
65 | if args.multiGPU:
66 | self.model.lxrt_encoder.multi_gpu()
67 |
68 | # Loss and Optimizer
69 | self.bce_loss = nn.BCEWithLogitsLoss()
70 | if 'bert' in args.optim:
71 | # if type(args.lr) == type("sdfg"):
72 | # args.lr = float(args.lr)
73 |
74 | batch_per_epoch = len(self.train_tuple.loader)
75 | t_total = int(batch_per_epoch * args.epochs)
76 | print("BertAdam Total Iters: %d" % t_total)
77 | from lxrt.optimization import BertAdam
78 | self.optim = BertAdam(list(self.model.parameters()),
79 | lr=args.lr,
80 | warmup=0.1,
81 | t_total=t_total,
82 | schedule=args.lr_schedule,
83 | args=args)
84 |
85 | else:
86 | self.optim = args.optimizer(self.model.parameters(), args.lr)
87 |
88 | # Output Directory
89 | self.output = args.output
90 | os.makedirs(self.output, exist_ok=True)
91 |
92 | def train(self, train_tuple, eval_tuple):
93 | dset, loader, evaluator = train_tuple
94 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
95 |
96 | best_valid = 0.
97 | for epoch in range(args.epochs):
98 | quesid2ans = {}
99 | for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
100 |
101 | self.model.train()
102 | # pytorch_total_params = sum(p.numel() for p in self.model.parameters())
103 | # print('total parameters are:',pytorch_total_params)
104 | self.optim.zero_grad()
105 |
106 | feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
107 | logit = self.model(feats, boxes, sent)
108 | assert logit.dim() == target.dim() == 2
109 | loss = self.bce_loss(logit, target)
110 | loss = loss * logit.size(1)
111 |
112 | loss.backward()
113 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
114 | self.optim.step()
115 |
116 | score, label = logit.max(1)
117 | for qid, l in zip(ques_id, label.cpu().numpy()):
118 | ans = dset.label2ans[l]
119 | quesid2ans[qid.item()] = ans
120 |
121 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
122 |
123 | if self.valid_tuple is not None: # Do Validation
124 | valid_score = self.evaluate(eval_tuple)
125 | if valid_score > best_valid:
126 | best_valid = valid_score
127 | self.save("BEST")
128 |
129 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
130 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
131 |
132 | print(log_str, end='')
133 |
134 | with open(self.output + "/log.log", 'a') as f:
135 | f.write(log_str)
136 | f.flush()
137 |
138 | self.save("LAST")
139 |
140 | def predict(self, eval_tuple: DataTuple, dump=None):
141 | """
142 | Predict the answers to questions in a data split.
143 |
144 | :param eval_tuple: The data tuple to be evaluated.
145 | :param dump: The path of saved file to dump results.
146 | :return: A dict of question_id to answer.
147 | """
148 | self.model.eval()
149 | dset, loader, evaluator = eval_tuple
150 | quesid2ans = {}
151 | for i, datum_tuple in enumerate(loader):
152 | ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth
153 | with torch.no_grad():
154 | feats, boxes = feats.cuda(), boxes.cuda()
155 | logit = self.model(feats, boxes, sent)
156 | score, label = logit.max(1)
157 | for qid, l in zip(ques_id, label.cpu().numpy()):
158 | ans = dset.label2ans[l]
159 | quesid2ans[qid.item()] = ans
160 | if dump is not None:
161 | evaluator.dump_result(quesid2ans, dump)
162 | return quesid2ans
163 |
164 | def evaluate(self, eval_tuple: DataTuple, dump=None):
165 | """Evaluate all data in data_tuple."""
166 | quesid2ans = self.predict(eval_tuple, dump)
167 | return eval_tuple.evaluator.evaluate(quesid2ans)
168 |
169 | @staticmethod
170 | def oracle_score(data_tuple):
171 | dset, loader, evaluator = data_tuple
172 | quesid2ans = {}
173 | for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
174 | _, label = target.max(1)
175 | for qid, l in zip(ques_id, label.cpu().numpy()):
176 | ans = dset.label2ans[l]
177 | quesid2ans[qid.item()] = ans
178 | return evaluator.evaluate(quesid2ans)
179 |
180 | def save(self, name):
181 | torch.save(self.model.state_dict(),
182 | os.path.join(self.output, "%s.pth" % name))
183 |
184 | # def load(self, path):
185 | # print("Load model from %s" % path)
186 | # state_dict = torch.load("%s.pth" % path)
187 | # self.model.load_state_dict(state_dict)
188 | def load(self, path):
189 | print("Load model from %s" % path)
190 | state_dict = torch.load("%s.pth" % path)
191 | for key in list(state_dict.keys()):
192 | if '.module' in key:
193 | state_dict[key.replace('.module', '')] = state_dict.pop(key)
194 | self.model.load_state_dict(state_dict, strict=False)
195 |
196 |
197 | if __name__ == "__main__":
198 | # Build Class
199 | vqa = VQA()
200 |
201 | # Load VQA model weights
202 | # Note: It is different from loading LXMERT pre-trained weights.
203 | if args.load is not None:
204 | vqa.load(args.load)
205 |
206 | # Test or Train
207 | if args.test is not None:
208 | args.fast = args.tiny = False # Always loading all data in test
209 | if args.bert_type == 'ft':
210 | bs_infer = 128
211 | else:
212 | bs_infer = 128
213 | if 'test' in args.test:
214 | vqa.predict(
215 | get_data_tuple(args.test, bs=bs_infer,
216 | shuffle=False, drop_last=False),
217 | dump=os.path.join(args.output, 'test_predict.json')
218 | )
219 | elif 'val' in args.test:
220 | # Since part of valididation data are used in pre-training/fine-tuning,
221 | # only validate on the minival set.
222 | result = vqa.evaluate(
223 | get_data_tuple('minival', bs=bs_infer,
224 | shuffle=False, drop_last=False),
225 | dump=os.path.join(args.output, 'minival_predict.json')
226 | )
227 | print(result)
228 | else:
229 | assert False, "No such test option for %s" % args.test
230 | else:
231 | print('Splits in Train data:', vqa.train_tuple.dataset.splits)
232 | if vqa.valid_tuple is not None:
233 | print('Splits in Valid data:', vqa.valid_tuple.dataset.splits)
234 | print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100))
235 | else:
236 | print("DO NOT USE VALIDATION")
237 | vqa.train(vqa.train_tuple, vqa.valid_tuple)
238 |
239 |
240 |
--------------------------------------------------------------------------------
/src/tasks/vqa_constant.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import os
5 | import collections
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data.dataloader import DataLoader
10 | from tqdm import tqdm
11 |
12 | from param import args
13 | from pretrain.qa_answer_table import load_lxmert_qa
14 | from tasks.vqa_model import VQAModel
15 | from tasks.vqa_data import VQADataset, VQATorchDataset, VQAEvaluator
16 |
17 | DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
18 |
19 |
20 | def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
21 | dset = VQADataset(splits)
22 | tset = VQATorchDataset(dset)
23 | evaluator = VQAEvaluator(dset)
24 | data_loader = DataLoader(
25 | tset, batch_size=bs,
26 | shuffle=shuffle, num_workers=args.num_workers,
27 | drop_last=drop_last, pin_memory=True
28 | )
29 |
30 | return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
31 |
32 |
33 | class VQA:
34 | def __init__(self):
35 | # Datasets
36 | self.train_tuple = get_data_tuple(
37 | args.train, bs=args.batch_size, shuffle=True, drop_last=True
38 | )
39 | if args.bert_type == 'ft':
40 | bs_infer = 256
41 | else:
42 | bs_infer = 1024
43 | if args.valid != "":
44 | self.valid_tuple = get_data_tuple(
45 | args.valid, bs=bs_infer,
46 | shuffle=False, drop_last=False
47 | )
48 | else:
49 | self.valid_tuple = None
50 |
51 | print("args.lr is {0}".format(args.lr))
52 |
53 | # Model
54 | self.model = VQAModel(self.train_tuple.dataset.num_answers)
55 |
56 | # Load pre-trained weights
57 | if args.load_lxmert is not None:
58 | self.model.lxrt_encoder.load(args.load_lxmert)
59 | if args.load_lxmert_qa is not None:
60 | load_lxmert_qa(args.load_lxmert_qa, self.model,
61 | label2ans=self.train_tuple.dataset.label2ans)
62 |
63 | # GPU options
64 | self.model = self.model.cuda()
65 | if args.multiGPU:
66 | self.model.lxrt_encoder.multi_gpu()
67 |
68 | # Loss and Optimizer
69 | self.bce_loss = nn.BCEWithLogitsLoss()
70 | if 'bert' in args.optim:
71 | # if type(args.lr) == type("sdfg"):
72 | # args.lr = float(args.lr)
73 |
74 | batch_per_epoch = len(self.train_tuple.loader)
75 | t_total = int(batch_per_epoch * args.epochs)
76 | print("BertAdam Total Iters: %d" % t_total)
77 | from lxrt.optimization import BertAdam
78 | # self.optim = BertAdam(list(self.model.parameters()),
79 | # lr=args.lr,
80 | # warmup=0.1,
81 | # t_total=t_total,
82 | # schedule=args.lr_schedule,
83 | # args=args)
84 | # suggested by yang xiaofeng
85 | self.optim = BertAdam(list(self.model.parameters()),
86 | lr=args.lr,
87 | warmup=0.1,
88 | schedule = 'warmup_constant',
89 | t_total=t_total)
90 | else:
91 | self.optim = args.optimizer(self.model.parameters(), args.lr)
92 |
93 | # Output Directory
94 | self.output = args.output
95 | os.makedirs(self.output, exist_ok=True)
96 |
97 | def train(self, train_tuple, eval_tuple):
98 | dset, loader, evaluator = train_tuple
99 | iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
100 |
101 | best_valid = 0.
102 | for epoch in range(args.epochs):
103 |
104 | #suggested by yangxiaofeng
105 | if (epoch == 4):
106 | for g in self.optim.param_groups:
107 | g['lr'] = g['lr']/10
108 | if (epoch == 6):
109 | for g in self.optim.param_groups:
110 | g['lr'] = g['lr']/10
111 |
112 | quesid2ans = {}
113 | for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
114 |
115 | self.model.train()
116 | self.optim.zero_grad()
117 |
118 | feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
119 | logit = self.model(feats, boxes, sent)
120 | assert logit.dim() == target.dim() == 2
121 | loss = self.bce_loss(logit, target)
122 | loss = loss * logit.size(1)
123 |
124 | loss.backward()
125 | nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
126 | self.optim.step()
127 |
128 | score, label = logit.max(1)
129 | for qid, l in zip(ques_id, label.cpu().numpy()):
130 | ans = dset.label2ans[l]
131 | quesid2ans[qid.item()] = ans
132 |
133 | log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
134 |
135 | if self.valid_tuple is not None: # Do Validation
136 | valid_score = self.evaluate(eval_tuple)
137 | if valid_score > best_valid:
138 | best_valid = valid_score
139 | self.save("BEST")
140 |
141 | log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
142 | "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
143 |
144 | print(log_str, end='')
145 |
146 | with open(self.output + "/log.log", 'a') as f:
147 | f.write(log_str)
148 | f.flush()
149 |
150 | self.save("LAST")
151 |
152 | def predict(self, eval_tuple: DataTuple, dump=None):
153 | """
154 | Predict the answers to questions in a data split.
155 |
156 | :param eval_tuple: The data tuple to be evaluated.
157 | :param dump: The path of saved file to dump results.
158 | :return: A dict of question_id to answer.
159 | """
160 | self.model.eval()
161 | dset, loader, evaluator = eval_tuple
162 | quesid2ans = {}
163 | for i, datum_tuple in enumerate(loader):
164 | ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth
165 | with torch.no_grad():
166 | feats, boxes = feats.cuda(), boxes.cuda()
167 | logit = self.model(feats, boxes, sent)
168 | score, label = logit.max(1)
169 | for qid, l in zip(ques_id, label.cpu().numpy()):
170 | ans = dset.label2ans[l]
171 | quesid2ans[qid.item()] = ans
172 | if dump is not None:
173 | evaluator.dump_result(quesid2ans, dump)
174 | return quesid2ans
175 |
176 | def evaluate(self, eval_tuple: DataTuple, dump=None):
177 | """Evaluate all data in data_tuple."""
178 | quesid2ans = self.predict(eval_tuple, dump)
179 | return eval_tuple.evaluator.evaluate(quesid2ans)
180 |
181 | @staticmethod
182 | def oracle_score(data_tuple):
183 | dset, loader, evaluator = data_tuple
184 | quesid2ans = {}
185 | for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
186 | _, label = target.max(1)
187 | for qid, l in zip(ques_id, label.cpu().numpy()):
188 | ans = dset.label2ans[l]
189 | quesid2ans[qid.item()] = ans
190 | return evaluator.evaluate(quesid2ans)
191 |
192 | def save(self, name):
193 | torch.save(self.model.state_dict(),
194 | os.path.join(self.output, "%s.pth" % name))
195 |
196 | def load(self, path):
197 | print("Load model from %s" % path)
198 | state_dict = torch.load("%s.pth" % path)
199 | self.model.load_state_dict(state_dict)
200 |
201 |
202 | if __name__ == "__main__":
203 | # Build Class
204 | vqa = VQA()
205 |
206 | # Load VQA model weights
207 | # Note: It is different from loading LXMERT pre-trained weights.
208 | if args.load is not None:
209 | vqa.load(args.load)
210 |
211 | # Test or Train
212 | if args.test is not None:
213 | args.fast = args.tiny = False # Always loading all data in test
214 | if args.bert_type == 'ft':
215 | bs_infer = 256
216 | else:
217 | bs_infer = 950
218 | if 'test' in args.test:
219 | vqa.predict(
220 | get_data_tuple(args.test, bs=950,
221 | shuffle=False, drop_last=False),
222 | dump=os.path.join(args.output, 'test_predict.json')
223 | )
224 | elif 'val' in args.test:
225 | # Since part of valididation data are used in pre-training/fine-tuning,
226 | # only validate on the minival set.
227 | result = vqa.evaluate(
228 | get_data_tuple('minival', bs=950,
229 | shuffle=False, drop_last=False),
230 | dump=os.path.join(args.output, 'minival_predict.json')
231 | )
232 | print(result)
233 | else:
234 | assert False, "No such test option for %s" % args.test
235 | else:
236 | print('Splits in Train data:', vqa.train_tuple.dataset.splits)
237 | if vqa.valid_tuple is not None:
238 | print('Splits in Valid data:', vqa.valid_tuple.dataset.splits)
239 | print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100))
240 | else:
241 | print("DO NOT USE VALIDATION")
242 | vqa.train(vqa.train_tuple, vqa.valid_tuple)
243 |
244 |
245 |
--------------------------------------------------------------------------------
/src/tasks/vqa_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import json
5 | import os
6 | import pickle
7 |
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | from param import args
13 | from utils import load_obj_tsv
14 |
15 | # Load part of the dataset for fast checking.
16 | # Notice that here is the number of images instead of the number of data,
17 | # which means all related data to the images would be used.
18 | TINY_IMG_NUM = 512
19 | FAST_IMG_NUM = 5000
20 |
21 | # The path to data and image features.
22 | VQA_DATA_ROOT = 'data/vqa/'
23 | MSCOCO_IMGFEAT_ROOT = 'data/mscoco_imgfeat/'
24 | SPLIT2NAME = {
25 | 'train': 'train2014',
26 | 'valid': 'val2014',
27 | 'minival': 'val2014',
28 | 'nominival': 'val2014',
29 | 'test': 'test2015',
30 | }
31 |
32 |
33 | class VQADataset:
34 | """
35 | A VQA data example in json file:
36 | {
37 | "answer_type": "other",
38 | "img_id": "COCO_train2014_000000458752",
39 | "label": {
40 | "net": 1
41 | },
42 | "question_id": 458752000,
43 | "question_type": "what is this",
44 | "sent": "What is this photo taken looking through?"
45 | }
46 | """
47 | def __init__(self, splits: str):
48 | self.name = splits
49 | self.splits = splits.split(',')
50 |
51 | # Loading datasets
52 | self.data = []
53 | for split in self.splits:
54 | self.data.extend(json.load(open("data/vqa/%s.json" % split)))
55 | print("Load %d data from split(s) %s." % (len(self.data), self.name))
56 |
57 | # Convert list to dict (for evaluation)
58 | self.id2datum = {
59 | datum['question_id']: datum
60 | for datum in self.data
61 | }
62 |
63 | # Answers
64 | self.ans2label = json.load(open("data/vqa/trainval_ans2label.json"))
65 | self.label2ans = json.load(open("data/vqa/trainval_label2ans.json"))
66 | assert len(self.ans2label) == len(self.label2ans)
67 |
68 | @property
69 | def num_answers(self):
70 | return len(self.ans2label)
71 |
72 | def __len__(self):
73 | return len(self.data)
74 |
75 |
76 | """
77 | An example in obj36 tsv:
78 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
79 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
80 | FIELDNAMES would be keys in the dict returned by load_obj_tsv.
81 | """
82 | class VQATorchDataset(Dataset):
83 | def __init__(self, dataset: VQADataset):
84 | super().__init__()
85 | self.raw_dataset = dataset
86 |
87 | if args.tiny:
88 | topk = TINY_IMG_NUM
89 | elif args.fast:
90 | topk = FAST_IMG_NUM
91 | else:
92 | topk = None
93 |
94 | # Loading detection features to img_data
95 | img_data = []
96 | for split in dataset.splits:
97 | # Minival is 5K images in MS COCO, which is used in evaluating VQA/LXMERT-pre-training.
98 | # It is saved as the top 5K features in val2014_***.tsv
99 | load_topk = 5000 if (split == 'minival' and topk is None) else topk
100 | # img_data.extend(load_obj_tsv(
101 | # os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])),
102 | # topk=load_topk))
103 | img_data.extend(load_obj_tsv(
104 | os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj64.tsv' % (SPLIT2NAME[split])),
105 | topk=load_topk))
106 | print('load from {0}_obj64.tsv'.format(SPLIT2NAME[split]))
107 |
108 | # Convert img list to dict
109 | self.imgid2img = {}
110 | for img_datum in img_data:
111 | self.imgid2img[img_datum['img_id']] = img_datum
112 |
113 | # Only kept the data with loaded image features
114 | self.data = []
115 | for datum in self.raw_dataset.data:
116 | if datum['img_id'] in self.imgid2img:
117 | self.data.append(datum)
118 | print("Use %d data in torch dataset" % (len(self.data)))
119 | print()
120 |
121 | def __len__(self):
122 | return len(self.data)
123 |
124 | def __getitem__(self, item: int):
125 | datum = self.data[item]
126 |
127 | img_id = datum['img_id']
128 | ques_id = datum['question_id']
129 | ques = datum['sent']
130 |
131 | # Get image info
132 | img_info = self.imgid2img[img_id]
133 | obj_num = img_info['num_boxes']
134 | feats = img_info['features'].copy()
135 | boxes = img_info['boxes'].copy()
136 | assert obj_num == len(boxes) == len(feats)
137 |
138 | # Normalize the boxes (to 0 ~ 1)
139 | img_h, img_w = img_info['img_h'], img_info['img_w']
140 | boxes = boxes.copy()
141 | boxes[:, (0, 2)] /= img_w
142 | boxes[:, (1, 3)] /= img_h
143 | np.testing.assert_array_less(boxes, 1+1e-5)
144 | np.testing.assert_array_less(-boxes, 0+1e-5)
145 |
146 | # Provide label (target)
147 | if 'label' in datum:
148 | label = datum['label']
149 | target = torch.zeros(self.raw_dataset.num_answers)
150 | for ans, score in label.items():
151 | target[self.raw_dataset.ans2label[ans]] = score
152 | return ques_id, feats, boxes, ques, target
153 | else:
154 | return ques_id, feats, boxes, ques
155 |
156 |
157 | class VQAEvaluator:
158 | def __init__(self, dataset: VQADataset):
159 | self.dataset = dataset
160 |
161 | def evaluate(self, quesid2ans: dict):
162 | score = 0.
163 | for quesid, ans in quesid2ans.items():
164 | datum = self.dataset.id2datum[quesid]
165 | label = datum['label']
166 | if ans in label:
167 | score += label[ans]
168 | return score / len(quesid2ans)
169 |
170 | def dump_result(self, quesid2ans: dict, path):
171 | """
172 | Dump results to a json file, which could be submitted to the VQA online evaluation.
173 | VQA json file submission requirement:
174 | results = [result]
175 | result = {
176 | "question_id": int,
177 | "answer": str
178 | }
179 |
180 | :param quesid2ans: dict of quesid --> ans
181 | :param path: The desired path of saved file.
182 | """
183 | with open(path, 'w') as f:
184 | result = []
185 | for ques_id, ans in quesid2ans.items():
186 | result.append({
187 | 'question_id': ques_id,
188 | 'answer': ans
189 | })
190 | json.dump(result, f, indent=4, sort_keys=True)
191 |
192 |
193 |
--------------------------------------------------------------------------------
/src/tasks/vqa_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 project LXRT.
3 |
4 | import torch.nn as nn
5 |
6 | from param import args
7 | from lxrt.entry import LXRTEncoder
8 | from lxrt.modeling import BertLayerNorm, GeLU
9 |
10 | # from src.param import args
11 | # from src.lxrt.entry import LXRTEncoder
12 | # from src.lxrt.modeling import BertLayerNorm, GeLU
13 |
14 | # Max length including and
15 | MAX_VQA_LENGTH = 20
16 |
17 |
18 | class VQAModel(nn.Module):
19 | def __init__(self, num_answers):
20 | super().__init__()
21 |
22 | # Build LXRT encoder
23 | self.lxrt_encoder = LXRTEncoder(
24 | args,
25 | max_seq_length=MAX_VQA_LENGTH
26 | )
27 | hid_dim = self.lxrt_encoder.dim
28 |
29 | # VQA Answer heads
30 | self.logit_fc = nn.Sequential(
31 | nn.Linear(hid_dim, hid_dim * 2),
32 | GeLU(),
33 | BertLayerNorm(hid_dim * 2, eps=1e-12),
34 | nn.Linear(hid_dim * 2, num_answers)
35 | )
36 | self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
37 |
38 | def forward(self, feat, pos, sent):
39 | """
40 | b -- batch_size, o -- object_number, f -- visual_feature_size
41 |
42 | :param feat: (b, o, f)
43 | :param pos: (b, o, 4)
44 | :param sent: (b,) Type -- list of string
45 | :param leng: (b,) Type -- int numpy array
46 | :return: (b, num_answer) The logit of each answers.
47 | """
48 | x = self.lxrt_encoder(sent, (feat, pos))
49 | logit = self.logit_fc(x)
50 |
51 | return logit
52 |
53 |
54 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyleft 2019 Project LXRT
3 |
4 | import sys
5 | import csv
6 | import base64
7 | import time
8 |
9 | import numpy as np
10 |
11 | csv.field_size_limit(sys.maxsize)
12 | FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
13 | "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
14 |
15 |
16 | def load_obj_tsv(fname, topk=None, fp16=False):
17 | """Load object features from tsv file.
18 |
19 | :param fname: The path to the tsv file.
20 | :param topk: Only load features for top K images (lines) in the tsv file.
21 | Will load all the features if topk is either -1 or None.
22 | :return: A list of image object features where each feature is a dict.
23 | See FILENAMES above for the keys in the feature dict.
24 | """
25 | data = []
26 | start_time = time.time()
27 | print("Start to load Faster-RCNN detected objects from %s" % fname)
28 | with open(fname) as f:
29 | reader = csv.DictReader(f, FIELDNAMES, delimiter="\t")
30 | for i, item in enumerate(reader):
31 | if i%1000 == 0:
32 | print(i)
33 |
34 | for key in ['img_h', 'img_w', 'num_boxes']:
35 | item[key] = int(item[key])
36 |
37 | boxes = item['num_boxes']
38 | decode_config = [
39 | ('objects_id', (boxes, ), np.int64),
40 | ('objects_conf', (boxes, ), np.float32),
41 | ('attrs_id', (boxes, ), np.int64),
42 | ('attrs_conf', (boxes, ), np.float32),
43 | ('boxes', (boxes, 4), np.float32),
44 | ('features', (boxes, -1), np.float32),
45 | ]
46 | for key, shape, dtype in decode_config:
47 | item[key] = np.frombuffer(base64.b64decode(item[key]), dtype=dtype)
48 |
49 | if fp16 and item[key].dtype == np.float32:
50 | item[key] = item[key].astype(np.float16) # Save features as half-precision in memory.
51 | item[key] = item[key].reshape(shape)
52 | item[key].setflags(write=False)
53 |
54 | data.append(item)
55 | if topk is not None and len(data) == topk:
56 | break
57 | elapsed_time = time.time() - start_time
58 | print("Loaded %d images in file %s in %d seconds." % (len(data), fname, elapsed_time))
59 | return data
60 |
61 |
--------------------------------------------------------------------------------