├── requirements.txt ├── download_backbones.py ├── Geoformer ├── scripts │ ├── evaluate.sh │ ├── pretrain.sh │ └── train.sh └── src │ ├── ManualProgram │ ├── operators.py │ └── eval_equ.py │ ├── preprocess.py │ ├── geo_model.py │ ├── data_utils.py │ ├── utils.py │ ├── tokenization.py │ ├── param.py │ ├── trainer_base.py │ ├── dist_utils.py │ ├── pretrain_data.py │ ├── geo_data.py │ ├── pretrain.py │ ├── geo.py │ └── modeling_t5.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | transformers==4.2.1 4 | -------------------------------------------------------------------------------- /download_backbones.py: -------------------------------------------------------------------------------- 1 | from transformers import T5ForConditionalGeneration, T5Tokenizer 2 | 3 | if __name__ == '__main__': 4 | 5 | print('Downloading checkpoints if not cached') 6 | print('T5-base') 7 | model = T5ForConditionalGeneration.from_pretrained('t5-base') 8 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 9 | print('Done!') 10 | 11 | -------------------------------------------------------------------------------- /Geoformer/scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | # inference 2 | output=snap/test 3 | 4 | PYTHONPATH=$PYTHONPATH:./src \ 5 | python -m torch.distributed.launch \ 6 | --nproc_per_node=$1 --master_port 2098 \ 7 | src/geo.py \ 8 | --distributed --multiGPU \ 9 | --test_only \ 10 | --train calculation_train \ 11 | --valid calculation_val \ 12 | --test calculation_test,proving_test \ 13 | --optim adamw \ 14 | --warmup_ratio 0.1 \ 15 | --lr 1e-3 \ 16 | --epochs 100 \ 17 | --num_workers 4 \ 18 | --backbone 't5-base' \ 19 | --output $output ${@:2} \ 20 | --load $output/BEST \ 21 | --num_beams 10 \ 22 | --batch_size 10 \ 23 | --max_text_length 200 \ 24 | --gen_max_length 40 \ -------------------------------------------------------------------------------- /Geoformer/scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | # The name of experiment 2 | output=snap/pretrain_test 3 | 4 | PYTHONPATH=$PYTHONPATH:./src \ 5 | python -m torch.distributed.launch \ 6 | --nproc_per_node=$1 --master_port 20417 \ 7 | src/pretrain.py \ 8 | --distributed --multiGPU \ 9 | --train calculation_train \ 10 | --valid calculation_val \ 11 | --test calculation_test \ 12 | --optim adamw \ 13 | --warmup_ratio 0.1 \ 14 | --lr 5e-4 \ 15 | --epochs 20 \ 16 | --batch_size 5 \ 17 | --wordMaskRate 0.3 \ 18 | --backbone 't5-base' \ 19 | --output $output ${@:2} \ 20 | --num_beams 1 \ 21 | --max_text_length 200 \ 22 | --gen_max_length 200 \ 23 | --num_workers 8 \ 24 | 25 | -------------------------------------------------------------------------------- /Geoformer/scripts/train.sh: -------------------------------------------------------------------------------- 1 | # The name of experiment 2 | output=snap/test 3 | 4 | PYTHONPATH=$PYTHONPATH:./src \ 5 | python -m torch.distributed.launch \ 6 | --nproc_per_node=$1 --master_port 2132 \ 7 | src/geo.py \ 8 | --distributed --multiGPU \ 9 | --train calculation_train,proving_train \ 10 | --valid calculation_val,proving_val \ 11 | --test calculation_test,proving_test \ 12 | --optim adamw \ 13 | --warmup_ratio 0.1 \ 14 | --clip_grad_norm 5 \ 15 | --lr 2e-4 \ 16 | --batch_size 10 \ 17 | --epochs 100 \ 18 | --num_workers 8 \ 19 | --backbone 't5-base' \ 20 | --output $output ${@:2} \ 21 | --num_beams 1 \ 22 | --max_text_length 200 \ 23 | --gen_max_length 40 \ 24 | --load snap/pretrained -------------------------------------------------------------------------------- /Geoformer/src/ManualProgram/operators.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def g_equal(n1): # 0 5 | return n1 6 | 7 | 8 | def g_double(n1): # 1 9 | return n1*2 10 | 11 | 12 | def g_half(n1): # 2 13 | return n1/2 14 | 15 | 16 | def g_add(n1, n2): # 3 17 | return n1 + n2 18 | 19 | 20 | def g_minus(n1, n2): # 4 21 | return math.fabs(n1 - n2) 22 | 23 | 24 | def g_sin(n1): # 5 25 | if n1 % 15 == 0 and 0 <= n1 <= 180: 26 | return math.sin(n1/180*math.pi) 27 | return False 28 | 29 | 30 | def g_cos(n1): # 6 31 | if n1 % 15 == 0 and 0 <= n1 <= 180: 32 | return math.cos(n1/180*math.pi) 33 | return False 34 | 35 | 36 | def g_tan(n1): # 7 37 | if n1 % 15 == 0 and 5 <= n1 <= 85: 38 | return math.tan(n1/180*math.pi) 39 | return False 40 | 41 | 42 | def g_asin(n1): # 8 43 | if -1 < n1 < 1: 44 | n1 = math.asin(n1) 45 | n1 = math.degrees(n1) 46 | return n1 47 | return False 48 | 49 | 50 | def g_acos(n1): # 9 51 | if -1 < n1 < 1: 52 | n1 = math.acos(n1) 53 | n1 = math.degrees(n1) 54 | return n1 55 | return False 56 | 57 | 58 | def gougu_add(n1, n2): # 13 59 | return math.sqrt(n1*n1+n2*n2) 60 | 61 | 62 | def gougu_minus(n1, n2): # 14 63 | if n1 != n2: 64 | return math.sqrt(math.fabs(n1*n1-n2*n2)) 65 | return False 66 | 67 | 68 | def g_bili(n1, n2, n3): # 16 69 | if n1 > 0 and n2 > 0 and n3 > 0: 70 | return n1/n2*n3 71 | else: 72 | return False 73 | 74 | 75 | def g_mul(n1, n2): # 17 76 | return n1*n2 77 | 78 | 79 | def g_divide(n1, n2): # 18 80 | if n1 > 0 and n2 > 0: 81 | return n1/n2 82 | return False 83 | 84 | 85 | def cal_circle_area(n1): # 19 86 | return n1*n1*math.pi 87 | 88 | 89 | def cal_circle_perimeter(n1): # 20 90 | return 2*math.pi*n1 91 | 92 | 93 | def cal_cone(n1, n2): # 21 94 | return n1*n2*math.pi 95 | 96 | -------------------------------------------------------------------------------- /Geoformer/src/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | 5 | def corrupt_bart(input_text, mask_ratio=0.30, prefix="denoise text:"): 6 | """BART-style Masked Language Modeling with corrupted span prediction 7 | Args: 8 | text 9 | 10 | Returns: 11 | source_text (masked_text) 12 | target_text 13 | 14 | Ex) (in vocab ids) 15 | input 16 | In this tutorial, we’ll explore how to preprocess your data using Transformers. The main tool for this is what we call a tokenizer. 17 | 18 | masked_text 19 | denoise text: In we’ll explore how to preprocess your data Transformers. main for this is what we a tokenizer. 20 | target_text 21 | same is input text 22 | """ 23 | 24 | tokens = input_text.split() 25 | 26 | n_tokens = len(tokens) 27 | 28 | n_mask = int(max(mask_ratio * n_tokens, 1)) 29 | mask_indices = torch.randperm(n_tokens)[:n_mask].sort().values 30 | 31 | assert len(mask_indices) > 0, input_text 32 | 33 | mask_indices = mask_indices.tolist() 34 | span = [mask_indices[0], mask_indices[0]+1] 35 | spans = [] 36 | 37 | for i, mask_index in enumerate(mask_indices): 38 | # if current mask is not the last one & the next mask is right after current mask 39 | if i < len(mask_indices) - 1 and mask_indices[i+1] == mask_index + 1: 40 | contiguous = True 41 | else: 42 | contiguous = False 43 | 44 | if contiguous: 45 | span[1] += 1 46 | 47 | else: 48 | # non contiguous -> output current span 49 | spans.append(span) 50 | # if current mask is not the last one -> create next span 51 | if i < len(mask_indices) - 1: 52 | span = [mask_indices[i+1], mask_indices[i+1]+1] 53 | 54 | masked_tokens = deepcopy(tokens) 55 | 56 | cum_span_length = 0 57 | for i, span in enumerate(spans): 58 | start, end = span 59 | 60 | masked_tokens[start-cum_span_length + 61 | i: end-cum_span_length+i] = [''] 62 | 63 | cum_span_length += (end - start) 64 | 65 | masked_text = " ".join(masked_tokens) 66 | 67 | if prefix is None: 68 | source_text = masked_text 69 | else: 70 | source_text = f"{prefix} {masked_text}" 71 | 72 | target_text = input_text 73 | 74 | return source_text, target_text 75 | 76 | -------------------------------------------------------------------------------- /Geoformer/src/geo_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from modeling_t5 import VLT5 4 | 5 | 6 | def build_model(): 7 | cnn = getattr(torchvision.models, 'resnet101')(pretrained=True) 8 | layers = [cnn.conv1, 9 | cnn.bn1, 10 | cnn.relu, 11 | cnn.maxpool] 12 | for i in range(4): 13 | name = 'layer%d' % (i + 1) 14 | layers.append(getattr(cnn, name)) 15 | model = torch.nn.Sequential(*layers) 16 | model.cuda() 17 | model.eval() 18 | return model 19 | 20 | 21 | class VLT5Geo(VLT5): 22 | def __init__(self, config): 23 | super().__init__(config) 24 | self.resnet = build_model() 25 | 26 | def train_step(self, batch): 27 | device = next(self.parameters()).device 28 | image = batch['image_list'].to(device) 29 | 30 | with torch.no_grad(): 31 | vis_feats = self.resnet(image) 32 | 33 | N, C, H, W = vis_feats.shape 34 | vis_feats = vis_feats.reshape(N, C, -1).permute(0, 2, 1) 35 | 36 | input_ids = batch['input_ids'].to(device) 37 | vis_pos = batch['boxes'].to(device) 38 | 39 | vis_attention_mask = batch['vis_attention_mask'].to(device) 40 | 41 | lm_labels = batch["target_ids"].to(device) 42 | 43 | output = self( 44 | input_ids=input_ids, 45 | vis_inputs=(vis_feats, vis_pos), 46 | vis_attention_mask=vis_attention_mask, 47 | labels=lm_labels, 48 | reduce_loss=True, 49 | return_dict=True 50 | ) 51 | 52 | loss = output['loss'] 53 | 54 | result = { 55 | 'loss': loss 56 | } 57 | return result 58 | 59 | def test_step(self, batch, **kwargs): 60 | device = next(self.parameters()).device 61 | image = batch['image_list'].to(device) 62 | 63 | with torch.no_grad(): 64 | vis_feats = self.resnet(image) 65 | 66 | N, C, H, W = vis_feats.shape 67 | vis_feats = vis_feats.reshape(N, C, -1).permute(0, 2, 1) 68 | 69 | input_ids = batch['input_ids'].to(device) 70 | vis_pos = batch['boxes'].to(device) 71 | 72 | vis_attention_mask = batch['vis_attention_mask'].to(device) 73 | 74 | output = self.generate( 75 | input_ids=input_ids, 76 | vis_inputs=(vis_feats, vis_pos), 77 | vis_attention_mask=vis_attention_mask, 78 | **kwargs 79 | ) 80 | 81 | generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True) 82 | 83 | result = {} 84 | result['pred'] = generated_sents 85 | 86 | return result 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniGeo 2 | 3 | Jiaqi Chen, Tong Li, Jinghui Qin, Pan Lu, Liang Lin, Chongyu Chen, Xiaodan Liang. 4 | ["UniGeo: Unifying Geometry Logical Reasoning via Reformulating Mathematical Expression"](https://arxiv.org/abs/2212.02746). 5 | Conference on Empirical Methods in Natural Language Processing (EMNLP 2022) 6 | 7 | We construct a large-scale Unified Geometry problem benchmark, UniGeo, which contains 8 | 4,998 calculation problems and 9,543 proving problems. 9 | We also present a unified multitask Geometric Transformer framework, Geoformer, 10 | to tackle calculation and proving problems simultaneously in the form of sequence 11 | generation, which finally shows the reasoning ability can be improved on both two tasks by unifying formulation. 12 | 13 | If you have any questions, please contact me by email: jadgechen@gmail.com 14 | 15 | ## Datasets 16 | Download the UniGeo dataset from [Google Drive](https://drive.google.com/drive/folders/1NifdHLJe5U08u2Zb1sWL6C-8krpV2z2O?usp=share_link). 17 | 18 | Create a path *datasets* and move the UniGeo into it. 19 | 20 | The code structure is shown below: 21 | 22 | ```bash 23 | ./datasets 24 | UniGeo/ 25 | proving_test.pk 26 | proving_train.pk 27 | proving_val.pk 28 | ... 29 | 30 | ./Geoformer 31 | scripts/ 32 | snap/ 33 | src/ 34 | ... 35 | ``` 36 | 37 | ## Setup 38 | ```bash 39 | # Create python environment 40 | conda create -n unigeo python=3.7 41 | source activate unigeo 42 | 43 | # Install python dependencies 44 | pip install -r requirements.txt 45 | 46 | # Download T5 backbone checkpoint 47 | python download_backbones.py 48 | ``` 49 | 50 | 51 | 52 | 53 | 54 | ## Unified Training 55 | Execute this script to train the model. 56 | 57 | ```bash 58 | cd Geoformer 59 | bash scripts/train.sh 1 60 | ``` 61 | 62 | The pretrained checkpoint can be founded here ([pretrained.pth](https://drive.google.com/drive/folders/1NifdHLJe5U08u2Zb1sWL6C-8krpV2z2O?usp=share_link)). 63 | You can modify the following argument to change the path to pre-trained model. 64 | ```bash 65 | --load snap/pretrained 66 | ``` 67 | 68 | ## Pre-training 69 | You can also execute this script to pre-train a new model. 70 | ```bash 71 | cd Geoformer 72 | bash scripts/pretrain.sh 1 73 | ``` 74 | 75 | 76 | ## Evaluation 77 | Execute this script to evaluate the model. 78 | ```bash 79 | cd Geoformer 80 | bash scripts/evaluate.sh 1 81 | ``` 82 | 83 | The model checkpoint of the reported **Geoformer + Pretraining** can be founded here ([geoformer.pth](https://drive.google.com/drive/folders/1NifdHLJe5U08u2Zb1sWL6C-8krpV2z2O?usp=share_link)). 84 | You can modify the following argument to test *geoformer.pth* or your trained model. 85 | ```bash 86 | --load snap/geoformer 87 | ``` 88 | 89 | -------------------------------------------------------------------------------- /Geoformer/src/ManualProgram/eval_equ.py: -------------------------------------------------------------------------------- 1 | from ManualProgram import operators 2 | from inspect import getmembers, isfunction 3 | import itertools 4 | import math 5 | 6 | 7 | constant = [30, 60, 90, 180, 360, math.pi, 0.618] 8 | op_dict = {0: 'g_equal', 1: 'g_double', 2: 'g_half', 3: 'g_add', 4: 'g_minus', 9 | 5: 'g_sin', 6: 'g_cos', 7: 'g_tan', 8: 'g_asin', 9: 'g_acos', 10 | 10: 'gougu_add', 11: 'gougu_minus', 12: 'g_bili', 11 | 13: 'g_mul', 14: 'g_divide', 15: 'cal_circle_area', 16: 'cal_circle_perimeter', 17: 'cal_cone'} 12 | op_list = [op_dict[key] for key in sorted(op_dict.keys())] 13 | 14 | 15 | class Equations: 16 | def __init__(self): 17 | 18 | self.op_list = op_list 19 | self.op_num = {} 20 | self.call_op = {} 21 | self.exp_info = None 22 | self.results = [] 23 | self.max_step = 3 24 | self.max_len = 7 25 | for op in self.op_list: 26 | self.call_op[op] = eval('operators.{}'.format(op)) 27 | # self.call_op[op] = eval(op) 28 | self.op_num[op] = self.call_op[op].__code__.co_argcount 29 | 30 | def str2exp(self, inputs): 31 | inputs = inputs.split(',') 32 | exp = inputs.copy() 33 | for i, s in enumerate(inputs): 34 | if 'n' in s or 'v' in s or 'c' in s: 35 | exp[i] = s.replace('n', 'N_').replace('v', 'V_').replace('c', 'C_') 36 | else: 37 | exp[i] = op_dict[int(s[2:])] 38 | exp[i] = exp[i].strip() 39 | 40 | self.exp = exp 41 | return exp 42 | 43 | def excuate_equation(self, exp, source_nums=None): 44 | 45 | if source_nums is None: 46 | source_nums = self.exp_info['nums'] 47 | vars = [] 48 | idx = 0 49 | while idx < len(exp): 50 | op = exp[idx] 51 | if op not in self.op_list: 52 | return None 53 | op_nums = self.op_num[op] 54 | if idx + op_nums >= len(exp): 55 | return None 56 | excuate_nums = [] 57 | for tmp in exp[idx + 1: idx + 1 + op_nums]: 58 | if tmp[0] == 'N' and int(tmp[-1]) < len(source_nums): 59 | excuate_nums.append(source_nums[int(tmp[-1])]) 60 | elif tmp[0] == 'V' and int(tmp[-1]) < len(vars): 61 | excuate_nums.append(vars[int(tmp[-1])]) 62 | elif tmp[0] == 'C' and int(tmp[-1]) < len(constant): 63 | excuate_nums.append(constant[int(tmp[-1])]) 64 | else: 65 | return None 66 | idx += op_nums + 1 67 | v = self.call_op[op](*excuate_nums) 68 | if v is None: 69 | return None 70 | vars.append(v) 71 | return vars 72 | 73 | 74 | if __name__ == '__main__': 75 | eq = Equations() 76 | 77 | -------------------------------------------------------------------------------- /Geoformer/src/data_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import cv2 as cv 3 | import numpy as np 4 | 5 | 6 | def process_image(img, min_side=224): 7 | # fill the diagram with a white background and resize it 8 | size = img.shape 9 | h, w = size[0], size[1] 10 | 11 | scale = max(w, h) / float(min_side) 12 | new_w, new_h = int(w/scale), int(h/scale) 13 | resize_img = cv.resize(img, (new_w, new_h)) 14 | 15 | top, bottom, left, right = 0, min_side-new_h, 0, min_side-new_w 16 | pad_img = cv.copyMakeBorder(resize_img, int(top), int(bottom), int(left), int(right), 17 | cv.BORDER_CONSTANT, value=[255,255,255]) 18 | pad_img = pad_img / 255 19 | 20 | return pad_img 21 | 22 | 23 | def create_patch(patch_num=7): 24 | bboxes = [] 25 | for i in range(patch_num): 26 | for j in range(patch_num): 27 | box = [1.0 * j / patch_num, 1.0 * i / patch_num, 1.0 * (j + 1) / patch_num, 1.0 * (i + 1) / patch_num] 28 | bboxes.append(box) 29 | bboxes = np.array(bboxes) 30 | return bboxes 31 | 32 | 33 | def split_elements(sequence): 34 | new_sequence = [] 35 | for token in sequence: 36 | if 'N_' in token or 'NS_' in token or 'frac' in token: 37 | new_sequence.append(token) 38 | elif token.istitle(): 39 | new_sequence.append(token) 40 | elif re.search(r'[A-Z]', token): 41 | # split geometry elements with a space: ABC -> A B C 42 | new_sequence.extend(token) 43 | else: 44 | new_sequence.append(token) 45 | 46 | return new_sequence 47 | 48 | 49 | def process_english_text(ori_text): 50 | text = re.split(r'([=≠≈+-/△∠∥⊙☉⊥⟂≌≅▱∽⁀⌒;,:.•?])', ori_text) 51 | text = ' '.join(text) 52 | 53 | text = text.split() 54 | text = split_elements(text) 55 | text = ' '.join(text) 56 | 57 | # The initial version of the calculation problem (GeoQA) is in Chinese. 58 | # The translated English version still contains some Chinese tokens, 59 | # which should be replaced by English words. 60 | replace_dict ={'≠': 'not-equal', '≈': 'approximate', '△': 'triangle', '∠': 'angle', '∥': 'parallel', 61 | '⊙': 'circle', '☉': 'circle', '⊥': 'perpendicular', '⟂': 'perpendicular', '≌': 'congruent', '≅': 'congruent', 62 | '▱': 'parallelogram', '∽': 'similar', '⁀': 'arc', '⌒': 'arc' 63 | } 64 | for k, v in replace_dict.items(): 65 | text = text.replace(k, v) 66 | 67 | return text 68 | 69 | 70 | def process_Chinese_solving(ori_text): 71 | index = ori_text.find('故选') 72 | text = ori_text[:index] 73 | 74 | # delete special tokens 75 | delete_list = ['^{°}', '{', '}', '°', 'cm', 'm', '米', ',', ':', '.', '、', '′', '~', '″', '【', '】', '$'] 76 | for d in delete_list: 77 | text = text.replace(d, ' ') 78 | # delete Chinese tokens 79 | zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') 80 | text1 = re.sub(zh_pattern, ' ', text) 81 | 82 | # split 83 | pattern = re.compile(r'([=≠≈+-]|π|×|\\frac|\\sqrt|\\cdot|\√|[∵∴△∠∥⊙☉⊥⟂≌≅▱∽⁀⌒]|[.,,:;;,:.•?]|\d+\.?\d*|)') 84 | text2 = re.split(pattern, text1) 85 | # split elements: ABC -> A B C 86 | text2 = split_elements(text2) 87 | 88 | # store numbers 89 | text3 = [] 90 | nums = [] 91 | # replace only nums 92 | for t in text2: 93 | if re.search(r'\d', t): # NS: number in solving 94 | if float(t) in nums: 95 | text3.append('NS_'+str(nums.index(float(t)))) 96 | else: 97 | text3.append('NS_'+str(len(nums))) 98 | nums.append(float(t)) 99 | else: 100 | text3.append(t) 101 | 102 | # replace 103 | text4 = ' '.join(text3) 104 | replace_dict = {'≠': 'not-equal', '≈': 'approximate', '△': 'triangle', '∠': 'angle', '∥': 'parallel', 105 | '⊙': 'circle', '☉': 'circle', '⊥': 'perpendicular', '⟂': 'perpendicular', '≌': 'congruent', '≅': 'congruent', 106 | '▱': 'parallelogram', '∽': 'similar', '⁀': 'arc', '⌒': 'arc', 107 | '/ /': 'parallel', '∵': 'because', '∴': 'therefore', '²': 'square', '√': 'root' 108 | } 109 | for k, v in replace_dict.items(): 110 | text4 = text4.replace(k, v) 111 | text4 = text4.split() 112 | 113 | return text4, nums -------------------------------------------------------------------------------- /Geoformer/src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import torch 4 | import collections 5 | import logging 6 | 7 | def get_area(pos): 8 | """ 9 | Args 10 | pos: [B, N, 4] 11 | (x1, x2, y1, y2) 12 | 13 | Return 14 | area : [B, N] 15 | """ 16 | # [B, N] 17 | height = pos[:, :, 3] - pos[:, :, 2] 18 | width = pos[:, :, 1] - pos[:, :, 0] 19 | area = height * width 20 | return area 21 | 22 | def get_relative_distance(pos): 23 | """ 24 | Args 25 | pos: [B, N, 4] 26 | (x1, x2, y1, y2) 27 | 28 | Return 29 | out : [B, N, N, 4] 30 | """ 31 | # B, N = pos.size()[:-1] 32 | 33 | # [B, N, N, 4] 34 | relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2) 35 | 36 | return relative_distance 37 | 38 | 39 | class LossMeter(object): 40 | def __init__(self, maxlen=100): 41 | """Computes and stores the running average""" 42 | self.vals = collections.deque([], maxlen=maxlen) 43 | 44 | def __len__(self): 45 | return len(self.vals) 46 | 47 | def update(self, new_val): 48 | self.vals.append(new_val) 49 | 50 | @property 51 | def val(self): 52 | return sum(self.vals) / len(self.vals) 53 | 54 | def __repr__(self): 55 | return str(self.val) 56 | 57 | 58 | class AverageMeter(object): 59 | """Computes and stores the average and current value""" 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self.val = 0 65 | self.avg = 0 66 | self.sum = 0 67 | self.count = 0 68 | 69 | def update(self, val, n=1): 70 | self.val = val 71 | self.sum += val * n 72 | self.count += n 73 | self.avg = 1.0 * self.sum / self.count 74 | 75 | def get_avg(self): 76 | return self.avg 77 | 78 | def get_count(self): 79 | return self.count 80 | 81 | 82 | def count_parameters(model): 83 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 84 | 85 | 86 | def load_state_dict(state_dict_path, loc='cpu'): 87 | state_dict = torch.load(state_dict_path, map_location=loc) 88 | # Change Multi GPU to single GPU 89 | original_keys = list(state_dict.keys()) 90 | for key in original_keys: 91 | if key.startswith("module."): 92 | new_key = key[len("module."):] 93 | state_dict[new_key] = state_dict.pop(key) 94 | return state_dict 95 | 96 | 97 | def set_global_logging_level(level=logging.ERROR, prefices=[""]): 98 | """ 99 | Override logging levels of different modules based on their name as a prefix. 100 | It needs to be invoked after the modules have been loaded so that their loggers have been initialized. 101 | 102 | Args: 103 | - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR 104 | - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. 105 | Default is `[""]` to match all active loggers. 106 | The match is a case-sensitive `module_name.startswith(prefix)` 107 | """ 108 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') 109 | for name in logging.root.manager.loggerDict: 110 | if re.match(prefix_re, name): 111 | logging.getLogger(name).setLevel(level) 112 | 113 | 114 | def get_iou(anchors, gt_boxes): 115 | """ 116 | anchors: (N, 4) torch floattensor 117 | gt_boxes: (K, 4) torch floattensor 118 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 119 | """ 120 | N = anchors.size(0) 121 | 122 | if gt_boxes.size() == (4,): 123 | gt_boxes = gt_boxes.view(1, 4) 124 | K = gt_boxes.size(0) 125 | 126 | gt_boxes_area = ( 127 | (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * 128 | (gt_boxes[:, 3] - gt_boxes[:, 1] + 1) 129 | ).view(1, K) 130 | 131 | anchors_area = ( 132 | (anchors[:, 2] - anchors[:, 0] + 1) * 133 | (anchors[:, 3] - anchors[:, 1] + 1) 134 | ).view(N, 1) 135 | 136 | boxes = anchors.view(N, 1, 4).expand(N, K, 4) 137 | query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) 138 | 139 | iw = ( 140 | torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) 141 | - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) 142 | + 1 143 | ) 144 | iw[iw < 0] = 0 145 | 146 | ih = ( 147 | torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) 148 | - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) 149 | + 1 150 | ) 151 | ih[ih < 0] = 0 152 | 153 | ua = anchors_area + gt_boxes_area - (iw * ih) 154 | overlaps = iw * ih / ua 155 | 156 | return overlaps 157 | 158 | 159 | def xywh_to_xyxy(boxes): 160 | """Convert [x y w h] box format to [x1 y1 x2 y2] format.""" 161 | return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1)) 162 | -------------------------------------------------------------------------------- /Geoformer/src/tokenization.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer, T5TokenizerFast, PreTrainedTokenizer, PreTrainedTokenizerFast, PreTrainedTokenizerBase 2 | import re 3 | import sentencepiece as spm 4 | 5 | 6 | class VLT5Tokenizer(T5Tokenizer): 7 | def __init__( 8 | self, 9 | vocab_file, 10 | eos_token="", 11 | unk_token="", 12 | pad_token="", 13 | extra_ids=100, 14 | vis_extra_ids=100, 15 | additional_special_tokens=None, 16 | **kwargs 17 | ): 18 | # Add extra_ids to the special token list 19 | if extra_ids > 0 and additional_special_tokens is None: 20 | additional_special_tokens = ["".format(i) for i in range(extra_ids)] 21 | elif extra_ids > 0 and additional_special_tokens is not None: 22 | # Check that we have the right number of extra_id special tokens 23 | extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens))) 24 | if extra_tokens != extra_ids: 25 | raise ValueError( 26 | f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. " 27 | "In this case the additional_special_tokens must include the extra_ids tokens" 28 | ) 29 | 30 | if vis_extra_ids > 0: 31 | additional_special_tokens.extend(["".format(i) for i in range(vis_extra_ids)]) 32 | 33 | PreTrainedTokenizer.__init__( 34 | self, 35 | eos_token=eos_token, 36 | unk_token=unk_token, 37 | pad_token=pad_token, 38 | extra_ids=extra_ids, 39 | additional_special_tokens=additional_special_tokens, 40 | **kwargs, 41 | ) 42 | 43 | self.vocab_file = vocab_file 44 | self._extra_ids = extra_ids 45 | self._vis_extra_ids = vis_extra_ids 46 | 47 | self.sp_model = spm.SentencePieceProcessor() 48 | self.sp_model.Load(vocab_file) 49 | 50 | @property 51 | def vocab_size(self): 52 | return self.sp_model.get_piece_size() + self._extra_ids + self._vis_extra_ids 53 | 54 | def get_vocab(self): 55 | vocab = {self.convert_ids_to_tokens( 56 | i): i for i in range(self.vocab_size)} 57 | vocab.update(self.added_tokens_encoder) 58 | return vocab 59 | 60 | def _convert_token_to_id(self, token): 61 | """ Converts a token (str) in an id using the vocab. """ 62 | if token.startswith("", token) 64 | num = int(match.group(1)) 65 | return self.vocab_size - num - 1 - self._vis_extra_ids 66 | elif token.startswith("", token) 68 | num = int(match.group(1)) 69 | return self.vocab_size - num - 1 70 | return self.sp_model.piece_to_id(token) 71 | 72 | def _convert_id_to_token(self, index): 73 | """Converts an index (integer) in a token (str) using the vocab.""" 74 | if index < self.sp_model.get_piece_size(): 75 | token = self.sp_model.IdToPiece(index) 76 | else: 77 | if index > self.sp_model.get_piece_size() + self._extra_ids - 1: 78 | token = "".format(self.vocab_size - 1 - index) 79 | else: 80 | token = "".format(self.vocab_size - self._vis_extra_ids - 1 - index) 81 | return token 82 | 83 | 84 | # Below are for Rust-based Fast Tokenizer 85 | 86 | from transformers.convert_slow_tokenizer import SpmConverter 87 | from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors 88 | from typing import Any, Dict, List, Optional, Tuple, Union 89 | 90 | 91 | class VLT5Converter(SpmConverter): 92 | def vocab(self, proto): 93 | vocab = [(piece.piece, piece.score) for piece in proto.pieces] 94 | num_extra_ids = self.original_tokenizer._extra_ids 95 | vocab += [("".format(i), 0.0) 96 | for i in range(num_extra_ids - 1, -1, -1)] 97 | 98 | num_vis_extra_ids = self.original_tokenizer._vis_extra_ids 99 | vocab += [("".format(i), 0.0) 100 | for i in range(num_vis_extra_ids - 1, -1, -1)] 101 | 102 | return vocab 103 | 104 | def post_processor(self): 105 | return processors.TemplateProcessing( 106 | single=["$A", ""], 107 | pair=["$A", "", "$B", ""], 108 | special_tokens=[ 109 | ("", self.original_tokenizer.convert_tokens_to_ids("")), 110 | ], 111 | ) 112 | 113 | 114 | def convert_slow_vlt5tokenizer(vlt5tokenizer): 115 | return VLT5Converter(vlt5tokenizer).converted() 116 | 117 | 118 | class VLT5TokenizerFast(T5TokenizerFast): 119 | 120 | slow_tokenizer_class = VLT5Tokenizer 121 | 122 | prefix_tokens: List[int] = [] 123 | 124 | def __init__( 125 | self, 126 | vocab_file, 127 | tokenizer_file=None, 128 | eos_token="", 129 | unk_token="", 130 | pad_token="", 131 | extra_ids=100, 132 | vis_extra_ids=100, 133 | additional_special_tokens=None, 134 | **kwargs 135 | ): 136 | # Add extra_ids to the special token list 137 | if extra_ids > 0 and additional_special_tokens is None: 138 | additional_special_tokens = ["".format(i) for i in range(extra_ids)] 139 | elif extra_ids > 0 and additional_special_tokens is not None: 140 | # Check that we have the right number of extra_id special tokens 141 | extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens))) 142 | if extra_tokens != extra_ids: 143 | raise ValueError( 144 | f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. " 145 | "In this case the additional_special_tokens must include the extra_ids tokens" 146 | ) 147 | 148 | if vis_extra_ids > 0: 149 | additional_special_tokens.extend(["".format(i) for i in range(vis_extra_ids)]) 150 | 151 | slow_tokenizer = self.slow_tokenizer_class( 152 | vocab_file, 153 | tokenizer_file=tokenizer_file, 154 | eos_token=eos_token, 155 | unk_token=unk_token, 156 | pad_token=pad_token, 157 | extra_ids=extra_ids, 158 | vis_extra_ids=vis_extra_ids, 159 | # additional_special_tokens=additional_special_tokens, 160 | **kwargs 161 | ) 162 | fast_tokenizer = convert_slow_vlt5tokenizer(slow_tokenizer) 163 | self._tokenizer = fast_tokenizer 164 | 165 | PreTrainedTokenizerBase.__init__( 166 | self, 167 | tokenizer_file=tokenizer_file, 168 | eos_token=eos_token, 169 | unk_token=unk_token, 170 | pad_token=pad_token, 171 | extra_ids=extra_ids, 172 | vis_extra_ids=vis_extra_ids, 173 | additional_special_tokens=additional_special_tokens, 174 | **kwargs, 175 | ) 176 | 177 | self.vocab_file = vocab_file 178 | self._extra_ids = extra_ids 179 | self._vis_extra_ids = vis_extra_ids 180 | -------------------------------------------------------------------------------- /Geoformer/src/param.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import pprint 8 | import yaml 9 | import os 10 | 11 | 12 | def str2bool(v): 13 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 14 | return True 15 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 16 | return False 17 | else: 18 | raise argparse.ArgumentTypeError('Boolean value expected.') 19 | 20 | 21 | def is_interactive(): 22 | import __main__ as main 23 | return not hasattr(main, '__file__') 24 | 25 | 26 | def get_optimizer(optim, verbose=False): 27 | # Bind the optimizer 28 | if optim == 'rms': 29 | if verbose: 30 | print("Optimizer: Using RMSProp") 31 | optimizer = torch.optim.RMSprop 32 | elif optim == 'adam': 33 | if verbose: 34 | print("Optimizer: Using Adam") 35 | optimizer = torch.optim.Adam 36 | elif optim == 'adamw': 37 | if verbose: 38 | print("Optimizer: Using AdamW") 39 | # optimizer = torch.optim.AdamW 40 | optimizer = 'adamw' 41 | elif optim == 'adamax': 42 | if verbose: 43 | print("Optimizer: Using Adamax") 44 | optimizer = torch.optim.Adamax 45 | elif optim == 'sgd': 46 | if verbose: 47 | print("Optimizer: SGD") 48 | optimizer = torch.optim.SGD 49 | else: 50 | assert False, "Please add your optimizer %s in the list." % optim 51 | 52 | return optimizer 53 | 54 | 55 | def parse_args(parse=True, **optional_kwargs): 56 | parser = argparse.ArgumentParser() 57 | 58 | parser.add_argument('--seed', type=int, default=9595, help='random seed') 59 | 60 | # Data Splits 61 | parser.add_argument("--train", default='train') 62 | parser.add_argument("--valid", default='valid') 63 | parser.add_argument("--test", default=None) 64 | parser.add_argument('--test_only', action='store_true') 65 | 66 | parser.add_argument('--submit', action='store_true') 67 | 68 | # Quick experiments 69 | parser.add_argument('--train_topk', type=int, default=-1) 70 | parser.add_argument('--valid_topk', type=int, default=-1) 71 | 72 | # Checkpoint 73 | parser.add_argument('--output', type=str, default='snap/test') 74 | parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).') 75 | parser.add_argument('--from_scratch', action='store_true') 76 | 77 | # CPU/GPU 78 | parser.add_argument("--multiGPU", action='store_const', default=False, const=True) 79 | parser.add_argument('--fp16', action='store_true') 80 | parser.add_argument("--distributed", action='store_true') 81 | parser.add_argument("--num_workers", default=0, type=int) 82 | parser.add_argument('--local_rank', type=int, default=-1) 83 | 84 | # Model Config 85 | parser.add_argument('--backbone', type=str, default='t5-base') 86 | parser.add_argument('--tokenizer', type=str, default=None) 87 | 88 | parser.add_argument('--feat_dim', type=float, default=2048) 89 | parser.add_argument('--pos_dim', type=float, default=4) 90 | 91 | parser.add_argument('--use_vision', default=True, type=str2bool) 92 | parser.add_argument('--use_vis_order_embedding', default=True, type=str2bool) 93 | parser.add_argument('--use_vis_layer_norm', default=True, type=str2bool) 94 | parser.add_argument('--individual_vis_layer_norm', default=True, type=str2bool) 95 | parser.add_argument('--share_vis_lang_layer_norm', action='store_true') 96 | 97 | parser.add_argument('--n_boxes', type=int, default=49) 98 | parser.add_argument('--max_n_boxes', type=int, default=49) 99 | parser.add_argument('--max_text_length', type=int, default=20) 100 | 101 | # Training 102 | parser.add_argument('--batch_size', type=int, default=256) 103 | parser.add_argument('--valid_batch_size', type=int, default=None) 104 | parser.add_argument('--optim', default='adamw') 105 | parser.add_argument('--warmup_ratio', type=float, default=0.05) 106 | parser.add_argument('--weight_decay', type=float, default=0.01) 107 | parser.add_argument('--clip_grad_norm', type=float, default=5) 108 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1) 109 | parser.add_argument('--lr', type=float, default=1e-4) 110 | parser.add_argument('--adam_eps', type=float, default=1e-6) 111 | parser.add_argument('--adam_beta1', type=float, default=0.9) 112 | parser.add_argument('--adam_beta2', type=float, default=0.999) 113 | parser.add_argument('--epochs', type=int, default=12) 114 | parser.add_argument('--dropout', type=float, default=0.1) 115 | 116 | parser.add_argument("--losses", default='lm,obj,attr,feat', type=str) 117 | 118 | parser.add_argument('--log_train_accuracy', action='store_true') 119 | 120 | parser.add_argument('--n_ground', type=int, default=1) 121 | parser.add_argument("--wordMaskRate", dest='word_mask_rate', default=0.15, type=float) 122 | parser.add_argument("--objMaskRate", dest='obj_mask_rate',default=0.15, type=float) 123 | 124 | # Inference 125 | parser.add_argument('--num_beams', type=int, default=1) 126 | parser.add_argument('--gen_max_length', type=int, default=20) 127 | 128 | # Data 129 | parser.add_argument('--caption_only', action='store_true') 130 | parser.add_argument('--coco_only', action='store_true') 131 | parser.add_argument('--caption_cocoonly', default=True, type=str2bool) 132 | 133 | parser.add_argument('--do_lower_case', default=True, type=str2bool) 134 | parser.add_argument('--oscar_tags', action='store_true') 135 | 136 | parser.add_argument('--prefix', type=str, default=None) 137 | 138 | # Pretraining 139 | parser.add_argument('--ground_upsample', type=int, default=1) 140 | parser.add_argument('--ground_weight', type=int, default=1) 141 | parser.add_argument('--itm_cocoonly', default=True, type=str2bool) 142 | parser.add_argument('--single_vqa_prefix', action='store_true') 143 | 144 | # COCO Caption 145 | parser.add_argument('--no_prefix', action='store_true') 146 | 147 | # VQA 148 | parser.add_argument("--raw_label", action='store_true') 149 | parser.add_argument("--answer_normalize", action='store_true') 150 | parser.add_argument("--classifier", action='store_true') 151 | parser.add_argument("--test_answerable", action='store_true') 152 | 153 | # RefCOCOg 154 | parser.add_argument('--RefCOCO_GT', action='store_true') 155 | parser.add_argument('--RefCOCO_BUTD', action='store_true') 156 | parser.add_argument("--shuffle_boxes", action='store_true') 157 | 158 | # Multitask 159 | parser.add_argument("--multitask_sampling", type=str, default='roundrobin') 160 | parser.add_argument("--tasks", type=str, default='') 161 | 162 | # Etc. 163 | parser.add_argument('--comment', type=str, default='') 164 | parser.add_argument("--dry", action='store_true') 165 | 166 | # Parse the arguments. 167 | if parse: 168 | args = parser.parse_args() 169 | # For interative engironmnet (ex. jupyter) 170 | else: 171 | args = parser.parse_known_args()[0] 172 | 173 | # Namespace => Dictionary 174 | kwargs = vars(args) 175 | kwargs.update(optional_kwargs) 176 | 177 | args = Config(**kwargs) 178 | 179 | # Bind optimizer class. 180 | verbose = False 181 | args.optimizer = get_optimizer(args.optim, verbose=verbose) 182 | 183 | # Set seeds 184 | torch.manual_seed(args.seed) 185 | random.seed(args.seed) 186 | np.random.seed(args.seed) 187 | 188 | return args 189 | 190 | 191 | class Config(object): 192 | def __init__(self, **kwargs): 193 | """Configuration Class: set kwargs as class attributes with setattr""" 194 | for k, v in kwargs.items(): 195 | setattr(self, k, v) 196 | 197 | @property 198 | def config_str(self): 199 | return pprint.pformat(self.__dict__) 200 | 201 | def __repr__(self): 202 | """Pretty-print configurations in alphabetical order""" 203 | config_str = 'Configurations\n' 204 | config_str += self.config_str 205 | return config_str 206 | 207 | def save(self, path): 208 | with open(path, 'w') as f: 209 | yaml.dump(self.__dict__, f, default_flow_style=False) 210 | 211 | @classmethod 212 | def load(cls, path): 213 | with open(path, 'r') as f: 214 | kwargs = yaml.load(f) 215 | 216 | return Config(**kwargs) 217 | 218 | 219 | if __name__ == '__main__': 220 | args = parse_args(True) 221 | -------------------------------------------------------------------------------- /Geoformer/src/trainer_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from packaging import version 4 | import torch 5 | import torch.nn as nn 6 | import logging 7 | from pprint import pprint 8 | 9 | from utils import load_state_dict, LossMeter, set_global_logging_level 10 | 11 | proj_dir = Path(__file__).resolve().parent.parent 12 | 13 | _use_native_amp = False 14 | _use_apex = False 15 | 16 | # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex 17 | if version.parse(torch.__version__) < version.parse("1.6"): 18 | from transformers.file_utils import is_apex_available 19 | if is_apex_available(): 20 | from apex import amp 21 | _use_apex = True 22 | else: 23 | _use_native_amp = True 24 | from torch.cuda.amp import autocast 25 | 26 | 27 | class TrainerBase(object): 28 | def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True): 29 | self.args = args 30 | 31 | self.train_loader = train_loader 32 | self.val_loader = val_loader 33 | self.test_loader = test_loader 34 | 35 | self.verbose = True 36 | if self.args.distributed: 37 | if self.args.gpu != 0: 38 | self.verbose = False 39 | 40 | if self.args.tokenizer is None: 41 | self.args.tokenizer = self.args.backbone 42 | 43 | if not self.verbose: 44 | set_global_logging_level(logging.ERROR, ["transformers"]) 45 | 46 | def create_config(self): 47 | from transformers import T5Config, BartConfig 48 | 49 | if 't5' in self.args.backbone: 50 | config_class = T5Config 51 | elif 'bart' in self.args.backbone: 52 | config_class = BartConfig 53 | else: 54 | return None 55 | 56 | config = config_class.from_pretrained(self.args.backbone) 57 | 58 | args = self.args 59 | 60 | config.feat_dim = args.feat_dim 61 | config.pos_dim = args.pos_dim 62 | config.n_images = 2 63 | 64 | config.use_vis_order_embedding = args.use_vis_order_embedding 65 | 66 | config.dropout_rate = args.dropout 67 | config.dropout = args.dropout 68 | config.attention_dropout = args.dropout 69 | config.activation_dropout = args.dropout 70 | 71 | config.use_vis_layer_norm = args.use_vis_layer_norm 72 | config.individual_vis_layer_norm = args.individual_vis_layer_norm 73 | config.losses = args.losses 74 | 75 | config.share_vis_lang_layer_norm = args.share_vis_lang_layer_norm 76 | config.classifier = args.classifier 77 | 78 | return config 79 | 80 | 81 | def create_model(self, model_class, config=None, **kwargs): 82 | print(f'Building Model at GPU {self.args.gpu}') 83 | 84 | model_name = self.args.backbone 85 | 86 | model = model_class.from_pretrained( 87 | model_name, 88 | config=config, 89 | **kwargs 90 | ) 91 | return model 92 | 93 | def create_tokenizer(self, **kwargs): 94 | from transformers import T5Tokenizer, BartTokenizer, T5TokenizerFast, BartTokenizerFast 95 | from tokenization import VLT5Tokenizer, VLT5TokenizerFast 96 | 97 | if 't5' in self.args.tokenizer: 98 | if self.args.use_vision: 99 | # tokenizer_class = VLT5Tokenizer 100 | tokenizer_class = VLT5TokenizerFast 101 | else: 102 | # tokenizer_class = T5Tokenizer 103 | tokenizer_class = T5TokenizerFast 104 | elif 'bart' in self.args.tokenizer: 105 | tokenizer_class = BartTokenizer 106 | # tokenizer_class = BartTokenizerFast 107 | 108 | tokenizer_name = self.args.backbone 109 | 110 | tokenizer = tokenizer_class.from_pretrained( 111 | tokenizer_name, 112 | max_length=self.args.max_text_length, 113 | do_lower_case=self.args.do_lower_case, 114 | **kwargs 115 | ) 116 | 117 | return tokenizer 118 | 119 | def create_optimizer_and_scheduler(self): 120 | if self.verbose: 121 | print('Building Optimizer') 122 | 123 | lr_scheduler = None 124 | 125 | if 'adamw' in self.args.optim: 126 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup 127 | batch_per_epoch = len(self.train_loader) 128 | t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs 129 | warmup_ratio = self.args.warmup_ratio 130 | warmup_iters = int(t_total * warmup_ratio) 131 | if self.verbose: 132 | print("Batch per epoch: %d" % batch_per_epoch) 133 | print("Total Iters: %d" % t_total) 134 | print('Warmup ratio:', warmup_ratio) 135 | print("Warm up Iters: %d" % warmup_iters) 136 | 137 | no_decay = ["bias", "LayerNorm.weight"] 138 | optimizer_grouped_parameters = [ 139 | { 140 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 141 | "weight_decay": self.args.weight_decay, 142 | }, 143 | { 144 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 145 | "weight_decay": 0.0, 146 | }, 147 | ] 148 | 149 | optim = AdamW(optimizer_grouped_parameters, 150 | lr=self.args.lr, eps=self.args.adam_eps) 151 | lr_scheduler = get_linear_schedule_with_warmup( 152 | optim, warmup_iters, t_total) 153 | 154 | else: 155 | optim = self.args.optimizer( 156 | list(self.model.parameters()), self.args.lr) 157 | 158 | return optim, lr_scheduler 159 | 160 | def load_checkpoint(self, ckpt_path): 161 | state_dict = load_state_dict(ckpt_path, 'cpu') 162 | 163 | original_keys = list(state_dict.keys()) 164 | for key in original_keys: 165 | if key.startswith("vis_encoder."): 166 | new_key = 'encoder.' + key[len("vis_encoder."):] 167 | state_dict[new_key] = state_dict.pop(key) 168 | 169 | if key.startswith("model.vis_encoder."): 170 | new_key = 'model.encoder.' + key[len("model.vis_encoder."):] 171 | state_dict[new_key] = state_dict.pop(key) 172 | 173 | results = self.model.load_state_dict(state_dict, strict=False) 174 | if self.verbose: 175 | print('Model loaded from ', ckpt_path) 176 | pprint(results) 177 | 178 | def init_weights(self): 179 | 180 | def init_bert_weights(module): 181 | """ Initialize the weights.""" 182 | if isinstance(module, (nn.Linear, nn.Embedding)): 183 | # Slightly different from the TF version which uses truncated_normal for initialization 184 | # cf https://github.com/pytorch/pytorch/pull/5617 185 | module.weight.data.normal_(mean=0.0, std=1) 186 | elif isinstance(module, nn.LayerNorm): 187 | module.bias.data.zero_() 188 | module.weight.data.fill_(1.0) 189 | if isinstance(module, nn.Linear) and module.bias is not None: 190 | module.bias.data.zero_() 191 | self.model.apply(init_bert_weights) 192 | self.model.init_weights() 193 | 194 | def predict(self): 195 | pass 196 | 197 | def evaluate(self): 198 | pass 199 | 200 | def save(self, name): 201 | if not os.path.isdir(self.args.output): 202 | os.makedirs(self.args.output, exist_ok=True) 203 | torch.save(self.model.state_dict(), os.path.join(self.args.output, "%s.pth" % name)) 204 | 205 | def load(self, path, loc=None): 206 | if loc is None and hasattr(self.args, 'gpu'): 207 | loc = f'cuda:{self.args.gpu}' 208 | state_dict = torch.load("%s.pth" % path, map_location=loc) 209 | 210 | original_keys = list(state_dict.keys()) 211 | for key in original_keys: 212 | if key.startswith("module.vis_encoder."): 213 | new_key = 'module.encoder.' + key[len("module.vis_encoder."):] 214 | state_dict[new_key] = state_dict.pop(key) 215 | 216 | if key.startswith("module.model.vis_encoder."): 217 | new_key = 'module.model.encoder.' + key[len("module.model.vis_encoder."):] 218 | state_dict[new_key] = state_dict.pop(key) 219 | 220 | results = self.model.load_state_dict(state_dict, strict=False) 221 | if self.verbose: 222 | print('Model loaded from ', path) 223 | pprint(results) 224 | -------------------------------------------------------------------------------- /Geoformer/src/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import functools 8 | import logging 9 | import numpy as np 10 | import pickle 11 | import torch.distributed as dist 12 | import torch 13 | 14 | _LOCAL_PROCESS_GROUP = None 15 | """ 16 | A torch process group which only includes processes that on the same machine as the current process. 17 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 18 | """ 19 | 20 | 21 | def get_world_size() -> int: 22 | if not dist.is_available(): 23 | return 1 24 | if not dist.is_initialized(): 25 | return 1 26 | return dist.get_world_size() 27 | 28 | 29 | def get_rank() -> int: 30 | if not dist.is_available(): 31 | return 0 32 | if not dist.is_initialized(): 33 | return 0 34 | return dist.get_rank() 35 | 36 | 37 | def get_local_rank() -> int: 38 | """ 39 | Returns: 40 | The rank of the current process within the local (per-machine) process group. 41 | """ 42 | if not dist.is_available(): 43 | return 0 44 | if not dist.is_initialized(): 45 | return 0 46 | assert _LOCAL_PROCESS_GROUP is not None 47 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 48 | 49 | 50 | def get_local_size() -> int: 51 | """ 52 | Returns: 53 | The size of the per-machine process group, 54 | i.e. the number of processes per machine. 55 | """ 56 | if not dist.is_available(): 57 | return 1 58 | if not dist.is_initialized(): 59 | return 1 60 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 61 | 62 | 63 | def is_main_process() -> bool: 64 | return get_rank() == 0 65 | 66 | 67 | def synchronize(): 68 | """ 69 | Helper function to synchronize (barrier) among all processes when 70 | using distributed training 71 | """ 72 | if not dist.is_available(): 73 | return 74 | if not dist.is_initialized(): 75 | return 76 | world_size = dist.get_world_size() 77 | if world_size == 1: 78 | return 79 | dist.barrier() 80 | 81 | 82 | @functools.lru_cache() 83 | def _get_global_gloo_group(): 84 | """ 85 | Return a process group based on gloo backend, containing all the ranks 86 | The result is cached. 87 | """ 88 | if dist.get_backend() == "nccl": 89 | return dist.new_group(backend="gloo") 90 | else: 91 | return dist.group.WORLD 92 | 93 | 94 | def _serialize_to_tensor(data, group): 95 | backend = dist.get_backend(group) 96 | assert backend in ["gloo", "nccl"] 97 | device = torch.device("cpu" if backend == "gloo" else "cuda") 98 | 99 | buffer = pickle.dumps(data) 100 | if len(buffer) > 1024 ** 3: 101 | logger = logging.getLogger(__name__) 102 | logger.warning( 103 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 104 | get_rank(), len(buffer) / (1024 ** 3), device 105 | ) 106 | ) 107 | storage = torch.ByteStorage.from_buffer(buffer) 108 | tensor = torch.ByteTensor(storage).to(device=device) 109 | return tensor 110 | 111 | 112 | def _pad_to_largest_tensor(tensor, group): 113 | """ 114 | Returns: 115 | list[int]: size of the tensor, on each rank 116 | Tensor: padded tensor that has the max size 117 | """ 118 | world_size = dist.get_world_size(group=group) 119 | assert ( 120 | world_size >= 1 121 | ), "comm.gather/all_gather must be called from ranks within the given group!" 122 | local_size = torch.tensor( 123 | [tensor.numel()], dtype=torch.int64, device=tensor.device) 124 | size_list = [ 125 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 126 | for _ in range(world_size) 127 | ] 128 | dist.all_gather(size_list, local_size, group=group) 129 | size_list = [int(size.item()) for size in size_list] 130 | 131 | max_size = max(size_list) 132 | 133 | # we pad the tensor because torch all_gather does not support 134 | # gathering tensors of different shapes 135 | if local_size != max_size: 136 | padding = torch.zeros( 137 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 138 | ) 139 | tensor = torch.cat((tensor, padding), dim=0) 140 | return size_list, tensor 141 | 142 | 143 | def all_gather(data, group=None): 144 | """ 145 | Run all_gather on arbitrary picklable data (not necessarily tensors). 146 | Args: 147 | data: any picklable object 148 | group: a torch process group. By default, will use a group which 149 | contains all ranks on gloo backend. 150 | Returns: 151 | list[data]: list of data gathered from each rank 152 | """ 153 | if get_world_size() == 1: 154 | return [data] 155 | if group is None: 156 | group = _get_global_gloo_group() 157 | if dist.get_world_size(group) == 1: 158 | return [data] 159 | 160 | tensor = _serialize_to_tensor(data, group) 161 | 162 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 163 | max_size = max(size_list) 164 | 165 | # receiving Tensor from all ranks 166 | tensor_list = [ 167 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 168 | for _ in size_list 169 | ] 170 | dist.all_gather(tensor_list, tensor, group=group) 171 | 172 | data_list = [] 173 | for size, tensor in zip(size_list, tensor_list): 174 | buffer = tensor.cpu().numpy().tobytes()[:size] 175 | data_list.append(pickle.loads(buffer)) 176 | 177 | return data_list 178 | 179 | 180 | def gather(data, dst=0, group=None): 181 | """ 182 | Run gather on arbitrary picklable data (not necessarily tensors). 183 | Args: 184 | data: any picklable object 185 | dst (int): destination rank 186 | group: a torch process group. By default, will use a group which 187 | contains all ranks on gloo backend. 188 | Returns: 189 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 190 | an empty list. 191 | """ 192 | if get_world_size() == 1: 193 | return [data] 194 | if group is None: 195 | group = _get_global_gloo_group() 196 | if dist.get_world_size(group=group) == 1: 197 | return [data] 198 | rank = dist.get_rank(group=group) 199 | 200 | tensor = _serialize_to_tensor(data, group) 201 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 202 | 203 | # receiving Tensor from all ranks 204 | if rank == dst: 205 | max_size = max(size_list) 206 | tensor_list = [ 207 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 208 | for _ in size_list 209 | ] 210 | dist.gather(tensor, tensor_list, dst=dst, group=group) 211 | 212 | data_list = [] 213 | for size, tensor in zip(size_list, tensor_list): 214 | buffer = tensor.cpu().numpy().tobytes()[:size] 215 | data_list.append(pickle.loads(buffer)) 216 | return data_list 217 | else: 218 | dist.gather(tensor, [], dst=dst, group=group) 219 | return [] 220 | 221 | 222 | def shared_random_seed(): 223 | """ 224 | Returns: 225 | int: a random number that is the same across all workers. 226 | If workers need a shared RNG, they can use this shared seed to 227 | create one. 228 | All workers must call this function, otherwise it will deadlock. 229 | """ 230 | ints = np.random.randint(2 ** 31) 231 | all_ints = all_gather(ints) 232 | return all_ints[0] 233 | 234 | 235 | def reduce_dict(input_dict, average=True): 236 | """ 237 | Reduce the values in the dictionary from all processes so that process with rank 238 | 0 has the reduced results. 239 | Args: 240 | input_dict (dict): inputs to be reduced. (values not necessarily tensors). 241 | average (bool): whether to do average or sum 242 | Returns: 243 | a dict with the same keys as input_dict, after reduction. 244 | """ 245 | 246 | world_size = get_world_size() 247 | if world_size < 2: 248 | return input_dict 249 | 250 | with torch.no_grad(): 251 | 252 | # Convert to CUDA Tensor for dist.reduce() 253 | input_dict_cuda_vals = {} 254 | for k, v in input_dict.items(): 255 | if type(v) == torch.Tensor: 256 | input_dict_cuda_vals[k] = v.to('cuda') 257 | else: 258 | input_dict_cuda_vals[k] = torch.tensor(v, device='cuda') 259 | 260 | names = [] 261 | values = [] 262 | for k, v in sorted(input_dict_cuda_vals.items()): 263 | names.append(k) 264 | values.append(v) 265 | values = torch.stack(values, dim=0) 266 | dist.reduce(values, dst=0) # reduce to gpu 0 267 | 268 | if dist.get_rank() == 0 and average: 269 | # only main process gets accumulated, so only divide by 270 | # world_size in this case 271 | values /= world_size 272 | reduced_dict = {k: v for k, v in zip(names, values)} 273 | return reduced_dict 274 | -------------------------------------------------------------------------------- /Geoformer/src/pretrain_data.py: -------------------------------------------------------------------------------- 1 | import sacrebleu 2 | from torch.utils.data import DataLoader, Dataset, Sampler 3 | from pathlib import Path 4 | import random 5 | import pickle 6 | import torch 7 | import numpy as np 8 | from torch.utils.data.distributed import DistributedSampler 9 | from tokenization import VLT5TokenizerFast 10 | import preprocess 11 | 12 | from data_utils import process_image, create_patch, process_english_text, process_Chinese_solving 13 | 14 | project_dir = Path(__file__).resolve().parent.parent 15 | workspace_dir = project_dir.parent 16 | 17 | dataset_dir = workspace_dir.joinpath('datasets/').resolve() 18 | geo_dir = dataset_dir.joinpath('UniGeo') 19 | 20 | 21 | class GeoDataset(Dataset): 22 | def __init__(self, split='train', raw_dataset=None, verbose=True, args=None, mode='train'): 23 | super().__init__() 24 | 25 | self.raw_dataset = raw_dataset 26 | self.verbose = verbose 27 | self.args = args 28 | self.mode = mode 29 | 30 | # Loading datasets to data 31 | self.source = split.split(',') 32 | if self.verbose: 33 | print('Data source: ', self.source) 34 | 35 | if self.args.tokenizer is None: 36 | self.args.tokenizer = self.args.backbone 37 | 38 | self.tokenizer = VLT5TokenizerFast.from_pretrained( 39 | args.backbone, 40 | do_lower_case=self.args.do_lower_case) 41 | 42 | # geo dataset 43 | target_text_list = [] 44 | source_text_list = [] 45 | image_list = [] 46 | source_nums_list = [] 47 | choice_nums_list = [] 48 | label_list = [] 49 | problem_form_list = [] 50 | 51 | for source in self.source: 52 | with open(geo_dir.joinpath(f'{source}.pk'), "rb") as f: 53 | dataset = pickle.load(f) 54 | for sample in dataset: 55 | r = random.random() 56 | 57 | if r > 0.5: 58 | prefix = 'solving prediction: ' 59 | problem_with_space = process_english_text(sample['English_problem']) 60 | problem_with_space = prefix + problem_with_space 61 | source_text_list.append(problem_with_space) 62 | 63 | ori_solving = sample['answer'] 64 | solving, solving_nums = process_Chinese_solving(ori_solving) 65 | 66 | text_i = " ".join(solving) 67 | target_text_list.append(text_i) 68 | 69 | else: 70 | prefix = 'denoise text: ' 71 | problem_with_space = process_english_text(sample['English_problem']) 72 | source_text, target_text = preprocess.corrupt_bart( 73 | problem_with_space, mask_ratio=self.args.word_mask_rate, prefix=prefix) 74 | 75 | source_text_list.append(source_text) 76 | target_text_list.append(target_text) 77 | 78 | image = sample['image'] 79 | image = process_image(image) 80 | img_rgb = np.zeros((3, image.shape[0], image.shape[1])) 81 | for i in range(3): 82 | img_rgb[i, :, :] = image 83 | image_list.append(img_rgb) 84 | 85 | source_nums_list.append(sample["numbers"]) 86 | choice_nums_list.append(sample["choice_nums"]) 87 | label_list.append(sample["label"]) 88 | 89 | problem_form_list.append('calculation') 90 | 91 | assert len(source_text_list) == len(target_text_list) 92 | 93 | data = [] 94 | for source_text, target_text, image, source_nums, choice_nums, label, problem_form in zip(source_text_list, target_text_list, image_list, source_nums_list, choice_nums_list, label_list, problem_form_list): 95 | datum = { 96 | 'image': image, 97 | 'source_text': source_text.strip(), 98 | 'target_text': target_text.strip(), 99 | 'source_nums': source_nums, 100 | 'choice_nums': choice_nums, 101 | 'label': label, 102 | 'problem_form': problem_form, 103 | } 104 | data.append(datum) 105 | 106 | if self.verbose: 107 | print(f"Loaded {len(data)} data from", split) 108 | 109 | self.n_gpus = torch.cuda.device_count() 110 | 111 | self.data = data 112 | 113 | if self.verbose: 114 | print("# all sentences:", len(self.data)) 115 | 116 | def __len__(self): 117 | return len(self.data) 118 | 119 | def __getitem__(self, idx): 120 | 121 | out_dict = {} 122 | out_dict['args'] = self.args 123 | 124 | datum = self.data[idx] 125 | 126 | ###### Image ###### 127 | image = datum['image'] 128 | out_dict['image'] = image 129 | boxes = create_patch(patch_num=7) 130 | boxes = torch.from_numpy(boxes) 131 | boxes.clamp_(min=0.0, max=1.0) 132 | n_boxes = len(boxes) 133 | 134 | # n_boxes = min(n_boxes, self.args.max_n_boxes) 135 | out_dict['n_boxes'] = n_boxes 136 | out_dict['boxes'] = boxes[:n_boxes] 137 | 138 | input_text = datum['source_text'] 139 | 140 | if 't5' in self.args.tokenizer: 141 | input_ids = self.tokenizer.encode( 142 | input_text, 143 | max_length=self.args.max_text_length, truncation=True) 144 | else: 145 | input_ids = self.tokenizer.convert_tokens_to_ids( 146 | self.tokenizer.tokenize(input_text)[:self.args.max_text_length - 1] + ['[SEP]']) 147 | 148 | out_dict['input_text'] = input_text 149 | out_dict['input_ids'] = torch.LongTensor(input_ids) 150 | out_dict['input_length'] = len(input_ids) 151 | target_text = datum['target_text'] 152 | 153 | if 't5' in self.args.tokenizer: 154 | target_ids = self.tokenizer.encode( 155 | target_text, max_length=self.args.gen_max_length, truncation=True) 156 | else: 157 | target_ids = self.tokenizer.convert_tokens_to_ids( 158 | self.tokenizer.tokenize(target_text)[:self.args.gen_max_length - 1] + ['[SEP]']) 159 | 160 | assert len(target_ids) <= self.args.gen_max_length, len(target_ids) 161 | out_dict['target_text'] = target_text 162 | out_dict['target_ids'] = torch.LongTensor(target_ids) 163 | out_dict['target_length'] = len(target_ids) 164 | 165 | out_dict['choice_nums'] = datum["choice_nums"] 166 | out_dict['source_nums'] = datum["source_nums"] 167 | out_dict['label'] = datum["label"] 168 | 169 | out_dict['problem_form'] = datum["problem_form"] 170 | 171 | return out_dict 172 | 173 | 174 | def collate_fn(self, batch): 175 | batch_entry = {} 176 | B = len(batch) 177 | 178 | S_W_L = max(entry['input_length'] for entry in batch) 179 | input_ids = torch.ones(B, S_W_L, dtype=torch.long) * self.tokenizer.pad_token_id 180 | 181 | V_L = max(entry['n_boxes'] for entry in batch) 182 | boxes = torch.zeros(B, V_L, 4, dtype=torch.float) 183 | vis_attention_mask = torch.zeros(B, V_L, dtype=torch.float) 184 | 185 | if 'target_ids' in batch[0]: 186 | T_W_L = max(entry['target_length'] for entry in batch) 187 | target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id 188 | 189 | img_paths = [] 190 | input_text = [] 191 | target_text = [] 192 | image_lists = [] 193 | 194 | source_nums_list = [] 195 | choice_nums_list = [] 196 | label_list = [] 197 | 198 | problem_form_list = [] 199 | 200 | for i, entry in enumerate(batch): 201 | input_ids[i, :entry['input_length']] = entry['input_ids'] 202 | 203 | n_boxes = entry['n_boxes'] 204 | boxes[i, :n_boxes] = entry['boxes'] 205 | 206 | vis_attention_mask[i, :n_boxes] = 1 207 | image_lists.append(entry['image']) 208 | 209 | if 'target_ids' in entry: 210 | target_ids[i, :entry['target_length']] = entry['target_ids'] 211 | 212 | if 'input_text' in entry: 213 | input_text.append(entry['input_text']) 214 | 215 | if 'target_text' in entry: 216 | target_text.append(entry['target_text']) 217 | 218 | source_nums_list.append(entry["source_nums"]) 219 | choice_nums_list.append(entry["choice_nums"]) 220 | label_list.append(entry["label"]) 221 | problem_form_list.append(entry['problem_form']) 222 | 223 | 224 | batch_entry['input_ids'] = input_ids 225 | if 'target_ids' in batch[0]: 226 | word_mask = target_ids != self.tokenizer.pad_token_id 227 | target_ids[~word_mask] = -100 228 | batch_entry['target_ids'] = target_ids 229 | 230 | batch_entry['boxes'] = boxes 231 | batch_entry['vis_attention_mask'] = vis_attention_mask 232 | batch_entry['image_list'] = torch.Tensor(image_lists) 233 | batch_entry['img_paths'] = img_paths 234 | 235 | batch_entry['input_text'] = input_text 236 | batch_entry['target_text'] = target_text 237 | 238 | batch_entry['source_nums'] = source_nums_list 239 | batch_entry['choice_nums'] = choice_nums_list 240 | batch_entry['label'] = label_list 241 | 242 | batch_entry['problem_form'] = problem_form_list 243 | 244 | batch_entry['task'] = 'geo' 245 | 246 | return batch_entry 247 | 248 | 249 | def get_loader(args, split='train', mode='train', 250 | batch_size=32, workers=4, distributed=False, gpu=0 251 | ): 252 | 253 | verbose = (gpu == 0) 254 | 255 | dataset = GeoDataset( 256 | split, 257 | verbose=verbose, 258 | args=args, 259 | mode=mode) 260 | 261 | if distributed and mode == 'train': 262 | train_sampler = DistributedSampler(dataset) 263 | 264 | else: 265 | train_sampler = None 266 | if mode == 'train': 267 | loader = DataLoader( 268 | dataset, batch_size=batch_size, shuffle=(train_sampler is None), 269 | num_workers=workers, pin_memory=True, sampler=train_sampler, 270 | collate_fn=dataset.collate_fn) 271 | else: 272 | loader = DataLoader( 273 | dataset, 274 | batch_size=batch_size, shuffle=False, 275 | num_workers=workers, pin_memory=True, 276 | sampler=None, 277 | collate_fn=dataset.collate_fn, 278 | drop_last=False) 279 | 280 | if verbose: 281 | loader.evaluator = GeoEvaluator() 282 | 283 | loader.task = 'geo' 284 | 285 | return loader 286 | 287 | 288 | class GeoEvaluator: 289 | def __init__(self): 290 | pass 291 | 292 | def evaluate(self, predicts, answers): 293 | 294 | try: 295 | bleu = sacrebleu.corpus_bleu(predicts, answers, 296 | lowercase=True) 297 | except EOFError: 298 | print('# preds', len(predicts)) 299 | print('# tgts', len(answers)) 300 | exit() 301 | return { 302 | 'BLEU': bleu.score 303 | } 304 | -------------------------------------------------------------------------------- /Geoformer/src/geo_data.py: -------------------------------------------------------------------------------- 1 | import sacrebleu 2 | from torch.utils.data import DataLoader, Dataset 3 | from pathlib import Path 4 | import pickle 5 | import torch 6 | import numpy as np 7 | import os 8 | 9 | from torch.utils.data.distributed import DistributedSampler 10 | from tokenization import VLT5TokenizerFast 11 | 12 | from data_utils import process_image, create_patch, process_english_text 13 | 14 | project_dir = Path(__file__).resolve().parent.parent 15 | workspace_dir = project_dir.parent 16 | 17 | dataset_dir = workspace_dir.joinpath('datasets/').resolve() 18 | geo_dir = dataset_dir.joinpath('UniGeo') 19 | 20 | 21 | class GeoDataset(Dataset): 22 | def __init__(self, split='train', raw_dataset=None, verbose=True, args=None, mode='train'): 23 | super().__init__() 24 | 25 | self.raw_dataset = raw_dataset 26 | self.verbose = verbose 27 | self.args = args 28 | self.mode = mode 29 | 30 | # Loading datasets to data 31 | self.source = split.split(',') 32 | if self.verbose: 33 | print('Data source: ', self.source) 34 | 35 | if self.args.tokenizer is None: 36 | self.args.tokenizer = self.args.backbone 37 | 38 | self.tokenizer = VLT5TokenizerFast.from_pretrained( 39 | args.backbone, 40 | do_lower_case=self.args.do_lower_case) 41 | 42 | sub_dict_path = os.path.join(geo_dir, "sub_dataset_dict.pk") # problems type 43 | with open(sub_dict_path, 'rb') as file: 44 | subset_dict = pickle.load(file) 45 | self.subset_dict = subset_dict 46 | 47 | # geo dataset 48 | target_text_list = [] 49 | source_text_list = [] 50 | image_list = [] 51 | 52 | source_nums_list = [] 53 | choice_nums_list = [] 54 | label_list = [] 55 | 56 | problem_form_list = [] 57 | problem_type_list = [] 58 | 59 | for source in self.source: 60 | with open(geo_dir.joinpath(f'{source}.pk'), "rb") as f: 61 | dataset = pickle.load(f) 62 | for sample in dataset: 63 | if 'calculation' in source: 64 | problem_with_space = process_english_text(sample['English_problem']) 65 | problem_with_space = 'Calculation: ' + problem_with_space 66 | source_text_list.append(problem_with_space) 67 | 68 | text_i = " ".join(sample["manual_program"]) 69 | target_text_list.append(text_i) 70 | 71 | image = sample['image'] 72 | image = process_image(image) 73 | img_rgb = np.zeros((3, image.shape[0], image.shape[1])) 74 | for i in range(3): 75 | img_rgb[i, :, :] = image 76 | image_list.append(img_rgb) 77 | 78 | source_nums_list.append(sample["numbers"]) 79 | choice_nums_list.append(sample["choice_nums"]) 80 | label_list.append(sample["label"]) 81 | 82 | problem_form_list.append('calculation') 83 | type = self.subset_dict[sample['id']] 84 | problem_type_list.append(type) 85 | 86 | else: 87 | assert 'proving' in source 88 | problem_with_space = sample['input_text'] 89 | problem_with_space = 'Proving: ' + problem_with_space 90 | 91 | source_text_list.append(problem_with_space) 92 | 93 | text_i = ' '.join(sample['proving_sequence']) 94 | target_text_list.append(text_i) 95 | 96 | image = sample['img'] 97 | image = process_image(image) 98 | image = image.transpose(2, 0, 1) 99 | image_list.append(image) 100 | 101 | source_nums_list.append(None) 102 | choice_nums_list.append(None) 103 | label_list.append(None) 104 | 105 | problem_form_list.append('proving') 106 | problem_type_list.append(sample['problem_type']) 107 | 108 | 109 | assert len(source_text_list) == len(target_text_list) 110 | 111 | data = [] 112 | for source_text, target_text, image, source_nums, choice_nums, label, problem_form, problem_type in \ 113 | zip(source_text_list, target_text_list, image_list, source_nums_list, choice_nums_list, label_list, problem_form_list, problem_type_list): 114 | datum = { 115 | 'image': image, 116 | 'source_text': source_text.strip(), 117 | 'target_text': target_text.strip(), 118 | 'source_nums': source_nums, 119 | 'choice_nums': choice_nums, 120 | 'label': label, 121 | 'problem_form': problem_form, 122 | 'problem_type': problem_type 123 | } 124 | data.append(datum) 125 | 126 | if self.verbose: 127 | print(f"Loaded {len(data)} data from", split) 128 | 129 | self.n_gpus = torch.cuda.device_count() 130 | 131 | self.data = data 132 | 133 | if self.verbose: 134 | print("# all sentences:", len(self.data)) 135 | 136 | def __len__(self): 137 | return len(self.data) 138 | 139 | def __getitem__(self, idx): 140 | 141 | out_dict = {} 142 | out_dict['args'] = self.args 143 | 144 | datum = self.data[idx] 145 | 146 | ###### Image ###### 147 | image = datum['image'] 148 | out_dict['image'] = image 149 | boxes = create_patch(patch_num=7) 150 | boxes = torch.from_numpy(boxes) 151 | boxes.clamp_(min=0.0, max=1.0) 152 | n_boxes = len(boxes) 153 | 154 | out_dict['n_boxes'] = n_boxes 155 | out_dict['boxes'] = boxes[:n_boxes] 156 | 157 | 158 | input_text = datum['source_text'] 159 | 160 | if 't5' in self.args.tokenizer: 161 | input_ids = self.tokenizer.encode( 162 | input_text, 163 | max_length=self.args.max_text_length, truncation=True) 164 | else: 165 | input_ids = self.tokenizer.convert_tokens_to_ids( 166 | self.tokenizer.tokenize(input_text)[:self.args.max_text_length - 1] + ['[SEP]']) 167 | 168 | out_dict['input_text'] = input_text 169 | out_dict['input_ids'] = torch.LongTensor(input_ids) 170 | out_dict['input_length'] = len(input_ids) 171 | 172 | target_text = datum['target_text'] 173 | if 't5' in self.args.tokenizer: 174 | target_ids = self.tokenizer.encode( 175 | target_text, max_length=self.args.gen_max_length, truncation=True) 176 | else: 177 | target_ids = self.tokenizer.convert_tokens_to_ids( 178 | self.tokenizer.tokenize(target_text)[:self.args.gen_max_length - 1] + ['[SEP]']) 179 | 180 | assert len(target_ids) <= self.args.gen_max_length, len(target_ids) 181 | out_dict['target_text'] = target_text 182 | out_dict['target_ids'] = torch.LongTensor(target_ids) 183 | out_dict['target_length'] = len(target_ids) 184 | 185 | out_dict['choice_nums'] = datum["choice_nums"] 186 | out_dict['source_nums'] = datum["source_nums"] 187 | out_dict['label'] = datum["label"] 188 | 189 | out_dict['problem_form'] = datum["problem_form"] 190 | out_dict['problem_type'] = datum["problem_type"] 191 | 192 | return out_dict 193 | 194 | 195 | def collate_fn(self, batch): 196 | batch_entry = {} 197 | B = len(batch) 198 | 199 | S_W_L = max(entry['input_length'] for entry in batch) 200 | input_ids = torch.ones(B, S_W_L, dtype=torch.long) * self.tokenizer.pad_token_id 201 | 202 | V_L = max(entry['n_boxes'] for entry in batch) 203 | boxes = torch.zeros(B, V_L, 4, dtype=torch.float) 204 | vis_attention_mask = torch.zeros(B, V_L, dtype=torch.float) 205 | 206 | if 'target_ids' in batch[0]: 207 | T_W_L = max(entry['target_length'] for entry in batch) 208 | target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id 209 | 210 | img_paths = [] 211 | input_text = [] 212 | target_text = [] 213 | image_lists = [] 214 | 215 | source_nums_list = [] 216 | choice_nums_list = [] 217 | label_list = [] 218 | 219 | problem_form_list = [] 220 | problem_type_list = [] 221 | 222 | for i, entry in enumerate(batch): 223 | input_ids[i, :entry['input_length']] = entry['input_ids'] 224 | 225 | n_boxes = entry['n_boxes'] 226 | boxes[i, :n_boxes] = entry['boxes'] 227 | 228 | vis_attention_mask[i, :n_boxes] = 1 229 | image_lists.append(entry['image']) 230 | 231 | if 'target_ids' in entry: 232 | target_ids[i, :entry['target_length']] = entry['target_ids'] 233 | 234 | if 'input_text' in entry: 235 | input_text.append(entry['input_text']) 236 | 237 | if 'target_text' in entry: 238 | target_text.append(entry['target_text']) 239 | 240 | source_nums_list.append(entry["source_nums"]) 241 | choice_nums_list.append(entry["choice_nums"]) 242 | label_list.append(entry["label"]) 243 | problem_form_list.append(entry['problem_form']) 244 | problem_type_list.append(entry['problem_type']) 245 | 246 | 247 | batch_entry['input_ids'] = input_ids 248 | if 'target_ids' in batch[0]: 249 | word_mask = target_ids != self.tokenizer.pad_token_id 250 | target_ids[~word_mask] = -100 251 | batch_entry['target_ids'] = target_ids 252 | 253 | batch_entry['boxes'] = boxes 254 | batch_entry['vis_attention_mask'] = vis_attention_mask 255 | batch_entry['image_list'] = torch.Tensor(image_lists) 256 | batch_entry['img_paths'] = img_paths 257 | 258 | batch_entry['input_text'] = input_text 259 | batch_entry['target_text'] = target_text 260 | 261 | batch_entry['source_nums'] = source_nums_list 262 | batch_entry['choice_nums'] = choice_nums_list 263 | batch_entry['label'] = label_list 264 | 265 | batch_entry['problem_form'] = problem_form_list 266 | batch_entry['problem_type'] = problem_type_list 267 | 268 | batch_entry['task'] = 'geo' 269 | 270 | return batch_entry 271 | 272 | 273 | def get_loader(args, split='train', mode='train', 274 | batch_size=32, workers=4, distributed=False, gpu=0, 275 | ): 276 | 277 | verbose = (gpu == 0) 278 | 279 | dataset = GeoDataset( 280 | split, 281 | verbose=verbose, 282 | args=args, 283 | mode=mode) 284 | 285 | if distributed and mode == 'train': 286 | train_sampler = DistributedSampler(dataset) 287 | 288 | else: 289 | train_sampler = None 290 | if mode == 'train': 291 | loader = DataLoader( 292 | dataset, batch_size=batch_size, shuffle=(train_sampler is None), 293 | num_workers=workers, pin_memory=True, sampler=train_sampler, 294 | collate_fn=dataset.collate_fn) 295 | else: 296 | loader = DataLoader( 297 | dataset, 298 | batch_size=batch_size, shuffle=False, 299 | num_workers=workers, pin_memory=True, 300 | sampler=None, 301 | collate_fn=dataset.collate_fn, 302 | drop_last=False) 303 | 304 | if verbose: 305 | loader.evaluator = GeoEvaluator() 306 | 307 | loader.task = 'geo' 308 | 309 | return loader 310 | 311 | 312 | class GeoEvaluator: 313 | def __init__(self): 314 | pass 315 | 316 | def evaluate(self, predicts, answers): 317 | 318 | try: 319 | bleu = sacrebleu.corpus_bleu(predicts, answers, 320 | lowercase=True) 321 | except EOFError: 322 | print('# preds', len(predicts)) 323 | print('# tgts', len(answers)) 324 | exit() 325 | return { 326 | 'BLEU': bleu.score 327 | } 328 | -------------------------------------------------------------------------------- /Geoformer/src/pretrain.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | import torch.distributed as dist 3 | from torch.nn.parallel import DistributedDataParallel as DDP 4 | import os 5 | from pathlib import Path 6 | from packaging import version 7 | from tqdm import tqdm 8 | import torch 9 | import logging 10 | from param import parse_args 11 | from pretrain_data import get_loader 12 | from utils import load_state_dict, LossMeter, set_global_logging_level, AverageMeter 13 | from pprint import pformat 14 | from ManualProgram.eval_equ import Equations 15 | import math 16 | 17 | set_global_logging_level(logging.ERROR, ["transformers"]) 18 | proj_dir = Path(__file__).resolve().parent.parent 19 | 20 | 21 | _use_native_amp = False 22 | _use_apex = False 23 | 24 | # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex 25 | if version.parse(torch.__version__) < version.parse("1.6"): 26 | from transormers.file_utils import is_apex_available 27 | if is_apex_available(): 28 | from apex import amp 29 | _use_apex = True 30 | else: 31 | _use_native_amp = True 32 | from torch.cuda.amp import autocast 33 | 34 | from trainer_base import TrainerBase 35 | 36 | class Trainer(TrainerBase): 37 | def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True): 38 | super().__init__( 39 | args, 40 | train_loader=train_loader, 41 | val_loader=val_loader, 42 | test_loader=test_loader, 43 | train=train) 44 | 45 | from geo_model import VLT5Geo 46 | model_class = VLT5Geo 47 | config = self.create_config() 48 | self.tokenizer = self.create_tokenizer() 49 | self.model = self.create_model(model_class, config) 50 | self.model.resize_token_embeddings(self.tokenizer.vocab_size) 51 | self.model.tokenizer = self.tokenizer 52 | 53 | # Load Checkpoint 54 | self.start_epoch = None 55 | if args.load is not None: 56 | ckpt_path = args.load + '.pth' 57 | self.load_checkpoint(ckpt_path) 58 | 59 | # GPU Options 60 | print(f'Model Launching at GPU {self.args.gpu}') 61 | if self.verbose: 62 | from time import time 63 | start = time() 64 | self.model = self.model.to(args.gpu) 65 | 66 | # Optimizer 67 | if train: 68 | self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler() 69 | 70 | if self.args.fp16 and _use_native_amp: 71 | self.scaler = torch.cuda.amp.GradScaler() 72 | elif _use_apex: 73 | self.model, self.optim = amp.initialize( 74 | self.model, self.optim, opt_level='O1', verbosity=self.verbose) 75 | 76 | if args.multiGPU: 77 | if args.distributed: 78 | self.model = DDP(self.model, device_ids=[args.gpu], 79 | find_unused_parameters=True 80 | ) 81 | if self.verbose: 82 | print(f'It took {time() - start:.1f}s') 83 | 84 | self._equ = Equations() 85 | self.calculation_acc = AverageMeter() 86 | self.calculation_no_result = AverageMeter() 87 | self.proving_acc = AverageMeter() 88 | self.proving_no_result = AverageMeter() 89 | 90 | def train(self): 91 | if self.verbose: 92 | loss_meter = LossMeter() 93 | best_valid = 0. 94 | best_epoch = 0 95 | 96 | if self.args.distributed: 97 | dist.barrier() 98 | 99 | global_step = 0 100 | for epoch in range(self.args.epochs): 101 | if self.start_epoch is not None: 102 | epoch += self.start_epoch 103 | self.model.train() 104 | if self.args.distributed: 105 | self.train_loader.sampler.set_epoch(epoch) 106 | if self.verbose: 107 | pbar = tqdm(total=len(self.train_loader), ncols=120) 108 | 109 | epoch_results = { 110 | 'loss': 0., 111 | } 112 | 113 | for step_i, batch in enumerate(self.train_loader): 114 | if self.args.fp16 and _use_native_amp: 115 | with autocast(): 116 | if self.args.distributed: 117 | results = self.model.module.train_step(batch) 118 | else: 119 | results = self.model.train_step(batch) 120 | else: 121 | if self.args.distributed: 122 | results = self.model.module.train_step(batch) 123 | else: 124 | results = self.model.train_step(batch) 125 | 126 | loss = results['loss'] 127 | 128 | if self.args.fp16 and _use_native_amp: 129 | self.scaler.scale(loss).backward() 130 | elif self.args.fp16 and _use_apex: 131 | with amp.scale_loss(loss, self.optim) as scaled_loss: 132 | scaled_loss.backward() 133 | else: 134 | loss.backward() 135 | 136 | loss = loss.detach() 137 | 138 | # Update Parameters 139 | if self.args.clip_grad_norm > 0: 140 | if self.args.fp16 and _use_native_amp: 141 | self.scaler.unscale_(self.optim) 142 | torch.nn.utils.clip_grad_norm_( 143 | self.model.parameters(), self.args.clip_grad_norm) 144 | elif self.args.fp16 and _use_apex: 145 | torch.nn.utils.clip_grad_norm_(amp.master_params( 146 | self.optim), self.args.clip_grad_norm) 147 | else: 148 | torch.nn.utils.clip_grad_norm_( 149 | self.model.parameters(), self.args.clip_grad_norm) 150 | 151 | update = True 152 | if self.args.gradient_accumulation_steps > 1: 153 | if step_i == 0: 154 | update = False 155 | elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1: 156 | update = True 157 | else: 158 | update = False 159 | 160 | if update: 161 | if self.args.fp16 and _use_native_amp: 162 | self.scaler.step(self.optim) 163 | self.scaler.update() 164 | else: 165 | self.optim.step() 166 | 167 | if self.lr_scheduler: 168 | self.lr_scheduler.step() 169 | for param in self.model.parameters(): 170 | param.grad = None 171 | global_step += 1 172 | 173 | for k, v in results.items(): 174 | if k in epoch_results: 175 | epoch_results[k] += v.item() 176 | 177 | if self.lr_scheduler: 178 | if version.parse(torch.__version__) >= version.parse("1.4"): 179 | lr = self.lr_scheduler.get_last_lr()[0] 180 | else: 181 | lr = self.lr_scheduler.get_lr()[0] 182 | else: 183 | try: 184 | lr = self.optim.get_lr()[0] 185 | except AttributeError: 186 | lr = self.args.lr 187 | 188 | if self.verbose: 189 | loss_meter.update(loss.item()) 190 | desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step}' 191 | desc_str += f' | Loss {loss_meter.val:4f}' 192 | pbar.set_description(desc_str) 193 | pbar.update(1) 194 | 195 | if self.verbose: 196 | pbar.close() 197 | 198 | # Validation 199 | valid_results = self.evaluate(self.val_loader) 200 | 201 | valid_score = valid_results['BLEU'] 202 | 203 | if valid_score > best_valid: 204 | best_valid = valid_score 205 | best_epoch = epoch 206 | self.save("BEST") 207 | if epoch >= 20: 208 | self.save(f'Epoch{epoch}') 209 | 210 | log_str = '' 211 | # log_str += pformat(valid_results) 212 | log_str += "Epoch %d: Valid BLEU %0.4f" % (epoch, valid_score) 213 | log_str += "\nEpoch %d: Best BLEU %0.4f\n" % (best_epoch, best_valid) 214 | 215 | print(log_str) 216 | print('save path:', self.args.output) 217 | 218 | self.calculation_acc.reset() 219 | self.calculation_no_result.reset() 220 | self.proving_acc.reset() 221 | self.proving_no_result.reset() 222 | 223 | if self.args.distributed: 224 | dist.barrier() 225 | 226 | # Test Set 227 | if self.verbose: 228 | self.save("LAST") 229 | best_path = os.path.join(self.args.output, 'BEST') 230 | self.load(best_path) 231 | self.test() 232 | 233 | 234 | def test(self): 235 | 236 | if self.args.distributed: 237 | dist.barrier() 238 | 239 | if isinstance(self.test_loader, list): 240 | test_loaders = self.test_loader 241 | else: 242 | test_loaders = [self.test_loader] 243 | 244 | for loader in test_loaders: 245 | split = loader.dataset.source 246 | test_results = self.evaluate(loader) 247 | 248 | log_str = f'{split} set results\n' 249 | log_str += pformat(test_results) 250 | 251 | if 'calculation' in split[0]: 252 | print('Calculation BLEU Acc: ', test_results, self.calculation_acc.get_avg(), self.calculation_no_result.get_avg()) 253 | if 'proving' in split[0]: 254 | print('Proving BLEU Acc: ', test_results, self.proving_acc.get_avg(), self.proving_no_result.get_avg()) 255 | 256 | 257 | def predict(self, loader, dump_path=None): 258 | self.model.eval() 259 | with torch.no_grad(): 260 | 261 | predictions = [] 262 | targets = [[]] 263 | 264 | gen_kwargs = {} 265 | gen_kwargs['num_beams'] = self.args.num_beams 266 | gen_kwargs['num_return_sequences'] = self.args.num_beams 267 | gen_kwargs['max_length'] = self.args.gen_max_length 268 | 269 | for i, batch in enumerate(tqdm(loader, ncols=120, desc=f"Prediction {loader.dataset.source}")): 270 | 271 | if self.args.distributed: 272 | results = self.model.module.test_step( 273 | batch, 274 | **gen_kwargs) 275 | else: 276 | results = self.model.test_step( 277 | batch, 278 | **gen_kwargs) 279 | 280 | results_with_beams = results['pred'][0:len(results['pred']):self.args.num_beams] 281 | predictions.extend(results_with_beams) 282 | targets[0].extend(batch['target_text']) 283 | 284 | assert len(targets) == 1 285 | 286 | results = { 287 | 'predictions': predictions, 288 | 'targets': targets 289 | } 290 | 291 | if dump_path is not None: 292 | print('Dumping prediction') 293 | with open(dump_path, 'w') as f: 294 | for i, pred in enumerate(predictions): 295 | f.write(pred.lower().strip()) 296 | if i+1 < len(predictions): 297 | f.write('\n') 298 | 299 | return results 300 | 301 | def evaluate(self, loader, dump_path=None): 302 | evaluator = loader.evaluator 303 | results = self.predict(loader, dump_path) 304 | 305 | predictions = results['predictions'] 306 | targets = results['targets'] 307 | eval_results = evaluator.evaluate(predictions, targets) 308 | return eval_results 309 | 310 | def geo_evaluation(self, pred, batch): 311 | 312 | source_nums = batch['source_nums'] 313 | choice_nums = batch['choice_nums'] 314 | label = batch['label'] 315 | problem_form = batch['problem_form'] 316 | target = batch['target_text'] 317 | 318 | batch_size = len(source_nums) 319 | num_beam = self.args.num_beams 320 | for b in range(batch_size): 321 | if problem_form[b] == 'calculation': 322 | choice = self.evaluate_calculation(pred[b*num_beam:(b+1)*num_beam], choice_nums[b], source_nums[b]) 323 | if choice is None: 324 | self.calculation_acc.update(0) 325 | self.calculation_no_result.update(1.0) 326 | elif choice == label[b]: 327 | self.calculation_acc.update(1.0) 328 | self.calculation_no_result.update(0) 329 | else: 330 | self.calculation_acc.update(0) 331 | self.calculation_no_result.update(0) 332 | 333 | else: 334 | assert problem_form[b] == 'proving' 335 | success = self.evaluate_proving(pred[b*num_beam:(b+1)*num_beam], target[b]) 336 | if success is None: 337 | self.proving_acc.update(0) 338 | self.proving_no_result.update(1.0) 339 | else: 340 | self.proving_acc.update(1.0) 341 | self.proving_no_result.update(0) 342 | 343 | def evaluate_calculation(self, top_k_predictions, choice_nums, source_nums): 344 | choice = None 345 | for i in range(self.args.num_beams): 346 | if choice is not None: 347 | break 348 | hypo = top_k_predictions[i].split() 349 | try: 350 | res = self._equ.excuate_equation(hypo, source_nums) 351 | except: 352 | res = None 353 | if res is not None and len(res) > 0: 354 | for j in range(4): 355 | if choice_nums[j] is not None and math.fabs(res[-1] - choice_nums[j]) < 0.001: 356 | choice = j 357 | 358 | return choice 359 | 360 | def evaluate_proving(self, top_k_predictions, target): 361 | success = None 362 | target = target.split() 363 | for i in range(self.args.num_beams): 364 | if success is not None: 365 | break 366 | hypo = top_k_predictions[i].split() 367 | 368 | if hypo == target: 369 | success = True 370 | 371 | return success 372 | 373 | def main_worker(gpu, args): 374 | # GPU is assigned 375 | args.gpu = gpu 376 | args.rank = gpu 377 | print(f'Process Launching at GPU {gpu}') 378 | 379 | if args.distributed: 380 | torch.cuda.set_device(args.gpu) 381 | dist.init_process_group(backend='nccl') 382 | 383 | if args.valid_batch_size is not None: 384 | valid_batch_size = args.valid_batch_size 385 | else: 386 | valid_batch_size = args.batch_size 387 | 388 | print(f'Building train loader at GPU {gpu}') 389 | train_loader = get_loader( 390 | args, 391 | split=args.train, mode='train', batch_size=args.batch_size, 392 | distributed=args.distributed, gpu=args.gpu, 393 | workers=args.num_workers, 394 | ) 395 | 396 | val_loader = test_loader = None 397 | if gpu == 0: 398 | print(f'Building val loader at GPU {gpu}') 399 | val_loader = get_loader( 400 | args, 401 | split=args.valid, mode='val', batch_size=valid_batch_size, 402 | distributed=False, gpu=args.gpu, 403 | workers=4, 404 | ) 405 | 406 | print(f'Building test loader at GPU {gpu}') 407 | if len(args.test.split(',')) == 1: 408 | test_loader = get_loader( 409 | args, 410 | split=args.test, mode='test', batch_size=valid_batch_size, 411 | distributed=False, gpu=args.gpu, 412 | workers=4, 413 | ) 414 | 415 | elif len(args.test.split(',')) > 1: 416 | test_loader = [] 417 | 418 | for test_split in args.test.split(','): 419 | test_loader.append(get_loader( 420 | args, 421 | split=test_split, mode='test', batch_size=valid_batch_size, 422 | distributed=False, gpu=args.gpu, 423 | workers=4, 424 | )) 425 | 426 | trainer = Trainer(args, train_loader, val_loader, test_loader, train=True) 427 | 428 | if not args.test_only: 429 | trainer.train() 430 | else: 431 | trainer.test() 432 | 433 | 434 | if __name__ == "__main__": 435 | cudnn.benchmark = True 436 | args = parse_args() 437 | ngpus_per_node = torch.cuda.device_count() 438 | args.world_size = ngpus_per_node 439 | if args.local_rank in [0, -1]: 440 | print(args) 441 | 442 | comments = [] 443 | if args.load is not None: 444 | ckpt_str = "_".join(args.load.split('/')[-3:]) 445 | comments.append(ckpt_str) 446 | if args.comment != '': 447 | comments.append(args.comment) 448 | comment = '_'.join(comments) 449 | 450 | from datetime import datetime 451 | current_time = datetime.now().strftime('%b%d_%H-%M') 452 | 453 | run_name = f'{current_time}_GPU{args.world_size}' 454 | if len(comments) > 0: 455 | run_name += f'_{comment}' 456 | 457 | args.run_name = run_name 458 | 459 | if args.distributed: 460 | main_worker(args.local_rank, args) 461 | -------------------------------------------------------------------------------- /Geoformer/src/geo.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | import torch.distributed as dist 3 | from torch.nn.parallel import DistributedDataParallel as DDP 4 | import os 5 | from pathlib import Path 6 | from packaging import version 7 | from tqdm import tqdm 8 | import torch 9 | import logging 10 | from param import parse_args 11 | from geo_data import get_loader 12 | from utils import load_state_dict, LossMeter, set_global_logging_level, AverageMeter 13 | from pprint import pformat 14 | 15 | from ManualProgram.eval_equ import Equations 16 | import math 17 | 18 | set_global_logging_level(logging.ERROR, ["transformers"]) 19 | proj_dir = Path(__file__).resolve().parent.parent 20 | 21 | 22 | _use_native_amp = False 23 | _use_apex = False 24 | 25 | # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex 26 | if version.parse(torch.__version__) < version.parse("1.6"): 27 | from transormers.file_utils import is_apex_available 28 | if is_apex_available(): 29 | from apex import amp 30 | _use_apex = True 31 | else: 32 | _use_native_amp = True 33 | from torch.cuda.amp import autocast 34 | 35 | from trainer_base import TrainerBase 36 | 37 | class Trainer(TrainerBase): 38 | def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True): 39 | super().__init__( 40 | args, 41 | train_loader=train_loader, 42 | val_loader=val_loader, 43 | test_loader=test_loader, 44 | train=train) 45 | 46 | from geo_model import VLT5Geo 47 | model_class = VLT5Geo 48 | config = self.create_config() 49 | self.tokenizer = self.create_tokenizer() 50 | self.model = self.create_model(model_class, config) 51 | self.model.resize_token_embeddings(self.tokenizer.vocab_size) 52 | self.model.tokenizer = self.tokenizer 53 | 54 | # Load Checkpoint 55 | self.start_epoch = None 56 | if args.load is not None: 57 | ckpt_path = args.load + '.pth' 58 | self.load_checkpoint(ckpt_path) 59 | 60 | # GPU Options 61 | print(f'Model Launching at GPU {self.args.gpu}') 62 | if self.verbose: 63 | from time import time 64 | start = time() 65 | self.model = self.model.to(args.gpu) 66 | 67 | # Optimizer 68 | if train: 69 | self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler() 70 | 71 | if self.args.fp16 and _use_native_amp: 72 | self.scaler = torch.cuda.amp.GradScaler() 73 | elif _use_apex: 74 | self.model, self.optim = amp.initialize( 75 | self.model, self.optim, opt_level='O1', verbosity=self.verbose) 76 | 77 | if args.multiGPU: 78 | if args.distributed: 79 | self.model = DDP(self.model, device_ids=[args.gpu], 80 | find_unused_parameters=True 81 | ) 82 | if self.verbose: 83 | print(f'It took {time() - start:.1f}s') 84 | 85 | self._equ = Equations() 86 | self.calculation_acc = AverageMeter() 87 | self.calculation_no_result = AverageMeter() 88 | self.proving_acc = AverageMeter() 89 | self.proving_no_result = AverageMeter() 90 | 91 | self.cal_angle = AverageMeter() 92 | self.cal_length = AverageMeter() 93 | self.cal_other = AverageMeter() 94 | 95 | self.prove_parallel = AverageMeter() 96 | self.prove_triangle = AverageMeter() 97 | self.prove_quadrilateral = AverageMeter() 98 | self.prove_congruent = AverageMeter() 99 | self.prove_similarity = AverageMeter() 100 | 101 | def train(self): 102 | if self.verbose: 103 | loss_meter = LossMeter() 104 | best_valid = 0. 105 | best_epoch = 0 106 | 107 | if self.args.distributed: 108 | dist.barrier() 109 | 110 | global_step = 0 111 | for epoch in range(self.args.epochs): 112 | if self.start_epoch is not None: 113 | epoch += self.start_epoch 114 | self.model.train() 115 | if self.args.distributed: 116 | self.train_loader.sampler.set_epoch(epoch) 117 | if self.verbose: 118 | pbar = tqdm(total=len(self.train_loader), ncols=120) 119 | 120 | epoch_results = { 121 | 'loss': 0., 122 | } 123 | 124 | for step_i, batch in enumerate(self.train_loader): 125 | if self.args.fp16 and _use_native_amp: 126 | with autocast(): 127 | if self.args.distributed: 128 | results = self.model.module.train_step(batch) 129 | else: 130 | results = self.model.train_step(batch) 131 | else: 132 | if self.args.distributed: 133 | results = self.model.module.train_step(batch) 134 | else: 135 | results = self.model.train_step(batch) 136 | 137 | loss = results['loss'] 138 | 139 | if self.args.fp16 and _use_native_amp: 140 | self.scaler.scale(loss).backward() 141 | elif self.args.fp16 and _use_apex: 142 | with amp.scale_loss(loss, self.optim) as scaled_loss: 143 | scaled_loss.backward() 144 | else: 145 | loss.backward() 146 | 147 | loss = loss.detach() 148 | 149 | # Update Parameters 150 | if self.args.clip_grad_norm > 0: 151 | if self.args.fp16 and _use_native_amp: 152 | self.scaler.unscale_(self.optim) 153 | torch.nn.utils.clip_grad_norm_( 154 | self.model.parameters(), self.args.clip_grad_norm) 155 | elif self.args.fp16 and _use_apex: 156 | torch.nn.utils.clip_grad_norm_(amp.master_params( 157 | self.optim), self.args.clip_grad_norm) 158 | else: 159 | torch.nn.utils.clip_grad_norm_( 160 | self.model.parameters(), self.args.clip_grad_norm) 161 | 162 | update = True 163 | if self.args.gradient_accumulation_steps > 1: 164 | if step_i == 0: 165 | update = False 166 | elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1: 167 | update = True 168 | else: 169 | update = False 170 | 171 | if update: 172 | if self.args.fp16 and _use_native_amp: 173 | self.scaler.step(self.optim) 174 | self.scaler.update() 175 | else: 176 | self.optim.step() 177 | 178 | if self.lr_scheduler: 179 | self.lr_scheduler.step() 180 | for param in self.model.parameters(): 181 | param.grad = None 182 | global_step += 1 183 | 184 | for k, v in results.items(): 185 | if k in epoch_results: 186 | epoch_results[k] += v.item() 187 | 188 | if self.lr_scheduler: 189 | if version.parse(torch.__version__) >= version.parse("1.4"): 190 | lr = self.lr_scheduler.get_last_lr()[0] 191 | else: 192 | lr = self.lr_scheduler.get_lr()[0] 193 | else: 194 | try: 195 | lr = self.optim.get_lr()[0] 196 | except AttributeError: 197 | lr = self.args.lr 198 | 199 | if self.verbose: 200 | loss_meter.update(loss.item()) 201 | desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step}' 202 | desc_str += f' | Loss {loss_meter.val:4f}' 203 | pbar.set_description(desc_str) 204 | pbar.update(1) 205 | 206 | if self.verbose: 207 | pbar.close() 208 | 209 | # Validation 210 | valid_results = self.evaluate(self.val_loader) 211 | 212 | valid_score = valid_results['BLEU'] 213 | 214 | if valid_score > best_valid: 215 | best_valid = valid_score 216 | best_epoch = epoch 217 | self.save("BEST") 218 | if epoch >= 20: 219 | self.save(f'Epoch{epoch}') 220 | 221 | log_str = '' 222 | # log_str += pformat(valid_results) 223 | log_str += "Epoch %d: Valid BLEU %0.4f" % (epoch, valid_score) 224 | log_str += "\nEpoch %d: Best BLEU %0.4f\n" % (best_epoch, best_valid) 225 | 226 | print(log_str) 227 | print('save path:', self.args.output) 228 | 229 | self.calculation_acc.reset() 230 | self.calculation_no_result.reset() 231 | self.proving_acc.reset() 232 | self.proving_no_result.reset() 233 | 234 | self.cal_angle.reset() 235 | self.cal_length.reset() 236 | self.cal_other.reset() 237 | 238 | self.prove_parallel.reset() 239 | self.prove_triangle.reset() 240 | self.prove_quadrilateral.reset() 241 | self.prove_congruent.reset() 242 | self.prove_similarity.reset() 243 | 244 | if self.args.distributed: 245 | dist.barrier() 246 | 247 | # Test Set 248 | if self.verbose: 249 | self.save("LAST") 250 | best_path = os.path.join(self.args.output, 'BEST') 251 | self.load(best_path) 252 | self.test() 253 | 254 | 255 | def test(self): 256 | if self.args.distributed: 257 | dist.barrier() 258 | 259 | if isinstance(self.test_loader, list): 260 | test_loaders = self.test_loader 261 | else: 262 | test_loaders = [self.test_loader] 263 | 264 | for loader in test_loaders: 265 | split = loader.dataset.source 266 | test_results = self.evaluate(loader) 267 | 268 | log_str = f'{split} set results\n' 269 | log_str += pformat(test_results) 270 | 271 | if 'calculation' in split[0]: 272 | print('Calculation %s Acc %.4f %.4f' % (test_results, self.calculation_acc.get_avg(), self.calculation_no_result.get_avg())) 273 | print('Subsets: ', self.cal_angle.get_avg(), self.cal_length.get_avg(), self.cal_other.get_avg()) 274 | if 'proving' in split[0]: 275 | print('Proving %s Acc %.4f %.4f ' % (test_results, self.proving_acc.get_avg(), self.proving_no_result.get_avg())) 276 | print('Subsets: %.4f %.4f %.4f %.4f %.4f' % (self.prove_parallel.get_avg(), self.prove_triangle.get_avg(), 277 | self.prove_quadrilateral.get_avg(), self.prove_congruent.get_avg(), 278 | self.prove_similarity.get_avg())) 279 | 280 | def predict(self, loader, dump_path=None): 281 | self.model.eval() 282 | with torch.no_grad(): 283 | 284 | predictions = [] 285 | targets = [[]] 286 | 287 | gen_kwargs = {} 288 | gen_kwargs['num_beams'] = self.args.num_beams 289 | gen_kwargs['num_return_sequences'] = self.args.num_beams 290 | gen_kwargs['max_length'] = self.args.gen_max_length 291 | 292 | for i, batch in enumerate(tqdm(loader, ncols=120, desc=f"Prediction {loader.dataset.source}")): 293 | 294 | if self.args.distributed: 295 | results = self.model.module.test_step( 296 | batch, 297 | **gen_kwargs) 298 | else: 299 | results = self.model.test_step( 300 | batch, 301 | **gen_kwargs) 302 | 303 | results_with_beams = results['pred'][0:len(results['pred']):self.args.num_beams] 304 | predictions.extend(results_with_beams) 305 | targets[0].extend(batch['target_text']) 306 | 307 | self.geo_evaluation(results['pred'], batch) 308 | 309 | assert len(targets) == 1 310 | 311 | results = { 312 | 'predictions': predictions, 313 | 'targets': targets 314 | } 315 | 316 | if dump_path is not None: 317 | print('Dumping prediction') 318 | with open(dump_path, 'w') as f: 319 | for i, pred in enumerate(predictions): 320 | f.write(pred.lower().strip()) 321 | if i+1 < len(predictions): 322 | f.write('\n') 323 | 324 | return results 325 | 326 | def evaluate(self, loader, dump_path=None): 327 | evaluator = loader.evaluator 328 | results = self.predict(loader, dump_path) 329 | 330 | predictions = results['predictions'] 331 | targets = results['targets'] 332 | eval_results = evaluator.evaluate(predictions, targets) 333 | return eval_results 334 | 335 | def geo_evaluation(self, pred, batch): 336 | 337 | source_nums = batch['source_nums'] 338 | choice_nums = batch['choice_nums'] 339 | label = batch['label'] 340 | problem_form = batch['problem_form'] 341 | target = batch['target_text'] 342 | problem_type = batch['problem_type'] 343 | 344 | batch_size = len(source_nums) 345 | num_beam = self.args.num_beams 346 | for b in range(batch_size): 347 | if problem_form[b] == 'calculation': 348 | choice = self.evaluate_calculation(pred[b*num_beam:(b+1)*num_beam], choice_nums[b], source_nums[b]) 349 | if choice is None: 350 | self.calculation_acc.update(0) 351 | self.calculation_no_result.update(1.0) 352 | elif choice == label[b]: 353 | self.calculation_acc.update(1.0) 354 | self.calculation_no_result.update(0) 355 | else: 356 | self.calculation_acc.update(0) 357 | self.calculation_no_result.update(0) 358 | 359 | flag = 1.0 if choice == label[b] else 0 360 | if problem_type[b] == 'angle': 361 | self.cal_angle.update(flag) 362 | elif problem_type[b] == 'length': 363 | self.cal_length.update(flag) 364 | else: 365 | self.cal_other.update(flag) 366 | 367 | else: 368 | assert problem_form[b] == 'proving' 369 | success = self.evaluate_proving(pred[b*num_beam:(b+1)*num_beam], target[b]) 370 | if success is None: 371 | self.proving_acc.update(0) 372 | self.proving_no_result.update(1.0) 373 | else: 374 | self.proving_acc.update(1.0) 375 | self.proving_no_result.update(0) 376 | 377 | flag = 0 if success is None else 1.0 378 | if problem_type[b] == 'parallel': 379 | self.prove_parallel.update(flag) 380 | elif problem_type[b] == 'triangle': 381 | self.prove_triangle.update(flag) 382 | elif problem_type[b] == 'quadrilateral': 383 | self.prove_quadrilateral.update(flag) 384 | elif problem_type[b] == 'congruent': 385 | self.prove_congruent.update(flag) 386 | elif problem_type[b] == 'similarity': 387 | self.prove_similarity.update(flag) 388 | else: 389 | assert problem_type[b] == 'proportions' 390 | # The proportion problems are also related to triangle 391 | self.prove_triangle.update(flag) 392 | 393 | def evaluate_calculation(self, top_k_predictions, choice_nums, source_nums): 394 | choice = None 395 | for i in range(self.args.num_beams): 396 | if choice is not None: 397 | break 398 | hypo = top_k_predictions[i].split() 399 | try: 400 | res = self._equ.excuate_equation(hypo, source_nums) 401 | except: 402 | res = None 403 | if res is not None and len(res) > 0: 404 | for j in range(4): 405 | if choice_nums[j] is not None and math.fabs(res[-1] - choice_nums[j]) < 0.001: 406 | choice = j 407 | 408 | return choice 409 | 410 | def evaluate_proving(self, top_k_predictions, target): 411 | success = None 412 | target = target.split() 413 | for i in range(self.args.num_beams): 414 | if success is not None: 415 | break 416 | hypo = top_k_predictions[i].split() 417 | 418 | if hypo == target: 419 | success = True 420 | 421 | return success 422 | 423 | def main_worker(gpu, args): 424 | # GPU is assigned 425 | args.gpu = gpu 426 | args.rank = gpu 427 | print(f'Process Launching at GPU {gpu}') 428 | 429 | if args.distributed: 430 | torch.cuda.set_device(args.gpu) 431 | dist.init_process_group(backend='nccl') 432 | 433 | if args.valid_batch_size is not None: 434 | valid_batch_size = args.valid_batch_size 435 | else: 436 | valid_batch_size = args.batch_size 437 | 438 | print(f'Building train loader at GPU {gpu}') 439 | train_loader = get_loader( 440 | args, 441 | split=args.train, mode='train', batch_size=args.batch_size, 442 | distributed=args.distributed, gpu=args.gpu, 443 | workers=args.num_workers, 444 | ) 445 | 446 | val_loader = test_loader = None 447 | if gpu == 0: 448 | print(f'Building val loader at GPU {gpu}') 449 | val_loader = get_loader( 450 | args, 451 | split=args.valid, mode='val', batch_size=valid_batch_size, 452 | distributed=False, gpu=args.gpu, 453 | workers=4, 454 | ) 455 | 456 | print(f'Building test loader at GPU {gpu}') 457 | if len(args.test.split(',')) == 1: 458 | test_loader = get_loader( 459 | args, 460 | split=args.test, mode='test', batch_size=valid_batch_size, 461 | distributed=False, gpu=args.gpu, 462 | workers=4, 463 | ) 464 | 465 | elif len(args.test.split(',')) > 1: 466 | test_loader = [] 467 | 468 | for test_split in args.test.split(','): 469 | test_loader.append(get_loader( 470 | args, 471 | split=test_split, mode='test', batch_size=valid_batch_size, 472 | distributed=False, gpu=args.gpu, 473 | workers=4, 474 | )) 475 | 476 | trainer = Trainer(args, train_loader, val_loader, test_loader, train=True) 477 | 478 | if not args.test_only: 479 | trainer.train() 480 | else: 481 | trainer.test() 482 | 483 | 484 | 485 | if __name__ == "__main__": 486 | cudnn.benchmark = True 487 | args = parse_args() 488 | ngpus_per_node = torch.cuda.device_count() 489 | args.world_size = ngpus_per_node 490 | if args.local_rank in [0, -1]: 491 | print(args) 492 | 493 | comments = [] 494 | if args.load is not None: 495 | ckpt_str = "_".join(args.load.split('/')[-3:]) 496 | comments.append(ckpt_str) 497 | if args.comment != '': 498 | comments.append(args.comment) 499 | comment = '_'.join(comments) 500 | 501 | from datetime import datetime 502 | current_time = datetime.now().strftime('%b%d_%H-%M') 503 | 504 | run_name = f'{current_time}_GPU{args.world_size}' 505 | if len(comments) > 0: 506 | run_name += f'_{comment}' 507 | 508 | args.run_name = run_name 509 | 510 | if args.distributed: 511 | main_worker(args.local_rank, args) 512 | -------------------------------------------------------------------------------- /Geoformer/src/modeling_t5.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from transformers.models.t5.modeling_t5 import ( 4 | T5Stack, T5Block, T5LayerNorm, T5LayerSelfAttention, T5LayerFF, T5LayerCrossAttention, 5 | T5PreTrainedModel, T5ForConditionalGeneration 6 | ) 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import CrossEntropyLoss 11 | 12 | from typing import Any, Dict, List, Optional, Tuple 13 | import copy 14 | 15 | from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput 16 | from transformers.utils import logging 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | 21 | class VisualEmbedding(nn.Module): 22 | def __init__(self, config, obj_order_embedding): 23 | super().__init__() 24 | self.config = config 25 | feat_dim = config.feat_dim 26 | pos_dim = config.pos_dim 27 | n_images = config.n_images 28 | 29 | if self.config.individual_vis_layer_norm: 30 | 31 | # Object feature encoding 32 | feat_embedding = [nn.Linear(feat_dim, config.d_model)] 33 | if self.config.use_vis_layer_norm: 34 | feat_embedding.append(T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)) 35 | self.feat_embedding = nn.Sequential(*feat_embedding) 36 | 37 | absolute_vis_pos_embedding = [nn.Linear(pos_dim + 1, config.d_model)] 38 | if self.config.use_vis_layer_norm: 39 | absolute_vis_pos_embedding.append(T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)) 40 | self.absolute_vis_pos_embedding = nn.Sequential(*absolute_vis_pos_embedding) 41 | 42 | if self.config.use_vis_order_embedding: 43 | self.obj_order_embedding = obj_order_embedding 44 | self.img_order_embedding = nn.Embedding(n_images, config.d_model) 45 | 46 | else: 47 | # Object feature encoding 48 | feat_embedding = [nn.Linear(feat_dim, config.d_model)] 49 | self.feat_embedding = nn.Sequential(*feat_embedding) 50 | 51 | absolute_vis_pos_embedding = [nn.Linear(pos_dim + 1, config.d_model)] 52 | self.absolute_vis_pos_embedding = nn.Sequential(*absolute_vis_pos_embedding) 53 | 54 | if self.config.use_vis_order_embedding: 55 | self.obj_order_embedding = obj_order_embedding 56 | self.img_order_embedding = nn.Embedding(n_images, config.d_model) 57 | 58 | if self.config.use_vis_layer_norm: 59 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 60 | 61 | def get_area(self, pos): 62 | """ 63 | Args 64 | pos: [B, N, 4] 65 | (x1, x2, y1, y2) 66 | Return 67 | area : [B, N] 68 | """ 69 | # [B, N] 70 | height = pos[:, :, 3] - pos[:, :, 2] 71 | width = pos[:, :, 1] - pos[:, :, 0] 72 | area = height * width 73 | return area 74 | 75 | 76 | def forward(self, feats, pos, img_order_ids=None, obj_order_ids=None): 77 | """ 78 | Args 79 | feats: [B, N, feat_dim] 80 | pos: [B, N, 4] 81 | (x1, x2, y1, y2) 82 | Return 83 | relative_vis_pos_embedding: [B, N, N, n_heads] 84 | absolute_vis_pos_embedding: # [B, N, d_model] 85 | """ 86 | 87 | B, N, _ = feats.size() 88 | assert pos.size() == (B, N, 4) 89 | 90 | feat_embedding = self.feat_embedding(feats) 91 | 92 | device = feats.device 93 | 94 | area = self.get_area(pos).unsqueeze(2) # [B, N, 1] 95 | pos = torch.cat([pos, area], dim=2) # [B, N, 5] 96 | 97 | # [B, N, d_model] 98 | absolute_vis_pos_embedding = self.absolute_vis_pos_embedding(pos) 99 | 100 | if self.config.use_vis_order_embedding: 101 | if img_order_ids is None: 102 | img_order_ids = torch.zeros(N, dtype=torch.long, device=device) 103 | img_order_ids = img_order_ids.unsqueeze(0) #.expand(B, -1) 104 | img_order_embedding = self.img_order_embedding(img_order_ids) 105 | 106 | if obj_order_ids is None: 107 | obj_order_ids = torch.arange(N, dtype=torch.long, device=device) 108 | obj_order_ids = obj_order_ids.unsqueeze(0) #.expand(B,-1) 109 | obj_order_ids = self.obj_order_embedding.num_embeddings - obj_order_ids - 1 110 | obj_order_embedding = self.obj_order_embedding(obj_order_ids) 111 | 112 | vis_embedding = feat_embedding + absolute_vis_pos_embedding + \ 113 | img_order_embedding + obj_order_embedding 114 | 115 | else: 116 | vis_embedding = feat_embedding + absolute_vis_pos_embedding 117 | 118 | if not self.config.individual_vis_layer_norm: 119 | if self.config.use_vis_layer_norm: 120 | vis_embedding = self.layer_norm(vis_embedding) 121 | 122 | return vis_embedding 123 | 124 | 125 | class JointEncoder(T5Stack): 126 | def __init__(self, config, embed_tokens=None): 127 | super(T5Stack, self).__init__(config) 128 | self.config = config 129 | 130 | self.embed_tokens = embed_tokens 131 | self.is_decoder = self.config.is_decoder 132 | assert self.config.is_decoder is False 133 | 134 | self.visual_embedding = VisualEmbedding(self.config, embed_tokens) 135 | 136 | self.block = nn.ModuleList( 137 | [T5Block(config, has_relative_attention_bias=(i == 0)) 138 | for i in range(config.num_layers)] 139 | ) 140 | self.final_layer_norm = T5LayerNorm( 141 | config.d_model, eps=config.layer_norm_epsilon) 142 | self.dropout = nn.Dropout(config.dropout_rate) 143 | 144 | self.init_weights() 145 | self.model_parallel = False 146 | self.device_map = None 147 | 148 | def set_input_embeddings(self, new_embeddings): 149 | self.embed_tokens = new_embeddings 150 | self.visual_embedding.obj_order_embedding = new_embeddings 151 | 152 | def forward( 153 | self, 154 | input_ids=None, 155 | attention_mask=None, 156 | 157 | vis_inputs=None, 158 | vis_attention_mask=None, 159 | 160 | inputs_embeds=None, 161 | head_mask=None, 162 | past_key_values=None, 163 | use_cache=None, 164 | output_attentions=None, 165 | output_hidden_states=None, 166 | return_dict=None, 167 | ): 168 | 169 | if inputs_embeds is None: 170 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 171 | inputs_embeds = self.embed_tokens(input_ids) 172 | 173 | B, L = inputs_embeds.size()[:-1] 174 | 175 | vis_feats = vis_inputs[0] 176 | boxes = vis_inputs[1] 177 | img_order_ids = None 178 | obj_order_ids = None 179 | if len(vis_inputs) >= 3: 180 | img_order_ids = vis_inputs[2] 181 | if len(vis_inputs) == 4: 182 | obj_order_ids = vis_inputs[3] 183 | 184 | vis_embeds = self.visual_embedding( 185 | vis_feats, boxes, img_order_ids, obj_order_ids) 186 | 187 | V_L = vis_embeds.size(1) 188 | 189 | inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1) 190 | 191 | if attention_mask is None: 192 | attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) 193 | 194 | if vis_attention_mask is None: 195 | vis_attention_mask = attention_mask.new_ones(B, V_L) 196 | 197 | attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1) 198 | 199 | # ourselves in which case we just need to make it broadcastable to all heads. 200 | extended_attention_mask = self.get_extended_attention_mask( 201 | attention_mask, 202 | (B, L+V_L), 203 | inputs_embeds.device) 204 | 205 | # initialize past_key_values with `None` if past does not exist 206 | if past_key_values is None: 207 | past_key_values = [None] * len(self.block) 208 | 209 | # Prepare head mask if needed 210 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 211 | present_key_value_states = () if use_cache else None 212 | all_hidden_states = () if output_hidden_states else None 213 | all_attentions = () if output_attentions else None 214 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None 215 | 216 | hidden_states = self.dropout(inputs_embeds) 217 | 218 | if self.config.num_layers > 0: 219 | 220 | assert self.block[0].layer[0].SelfAttention.has_relative_attention_bias 221 | 222 | seq_length = L + V_L 223 | q_len = seq_length 224 | k_len = seq_length 225 | 226 | # [1, n_heads, Q_len, K_len] 227 | text_position_bias = self.block[0].layer[0].SelfAttention.compute_bias( 228 | L, L) 229 | num_heads = text_position_bias.size(1) 230 | position_bias = text_position_bias.new_zeros( 231 | 1, num_heads, seq_length, seq_length) 232 | position_bias[:, :, :L, :L] = text_position_bias 233 | 234 | position_bias = position_bias + extended_attention_mask 235 | 236 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): 237 | 238 | layer_outputs = layer_module( 239 | hidden_states, 240 | attention_mask=extended_attention_mask, 241 | position_bias=position_bias, 242 | encoder_hidden_states=None, 243 | encoder_attention_mask=None, 244 | encoder_decoder_position_bias=None, 245 | head_mask=head_mask[i], 246 | past_key_value=past_key_value, 247 | use_cache=use_cache, 248 | output_attentions=output_attentions, 249 | ) 250 | # layer_outputs is a tuple with: 251 | # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) 252 | hidden_states, present_key_value_state = layer_outputs[:2] 253 | 254 | # We share the position biases between the layers - the first layer store them 255 | # layer_outputs = hidden-states, key-value-states (self-attention weights), 256 | # (self-attention position bias), (cross-attention weights), (cross-attention position bias) 257 | position_bias = layer_outputs[2] 258 | 259 | # append next layer key value states 260 | if use_cache: 261 | present_key_value_states = present_key_value_states + \ 262 | (present_key_value_state,) 263 | 264 | hidden_states = self.final_layer_norm(hidden_states) 265 | hidden_states = self.dropout(hidden_states) 266 | 267 | # Add last layer 268 | if output_hidden_states: 269 | all_hidden_states = all_hidden_states + (hidden_states,) 270 | 271 | if not return_dict: 272 | return tuple( 273 | v 274 | for v in [ 275 | hidden_states, 276 | present_key_value_states, 277 | all_hidden_states, 278 | all_attentions, 279 | all_cross_attentions, 280 | ] 281 | if v is not None 282 | ) 283 | return BaseModelOutputWithPastAndCrossAttentions( 284 | last_hidden_state=hidden_states, 285 | past_key_values=present_key_value_states, 286 | hidden_states=all_hidden_states, 287 | attentions=all_attentions, 288 | cross_attentions=all_cross_attentions, 289 | ) 290 | 291 | 292 | class VLT5(T5ForConditionalGeneration): 293 | _keys_to_ignore_on_load_missing = [ 294 | r"encoder\.embed_tokens\.weight", 295 | r"decoder\.embed_tokens\.weight", 296 | r"lm_head\.weight", 297 | ] 298 | _keys_to_ignore_on_load_unexpected = [ 299 | r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", 300 | ] 301 | 302 | def __init__(self, config): 303 | super(T5ForConditionalGeneration, self).__init__(config) 304 | 305 | self.config = config 306 | 307 | self.model_dim = config.d_model 308 | 309 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 310 | 311 | encoder_config = copy.deepcopy(config) 312 | encoder_config.is_decoder = False 313 | encoder_config.use_cache = False 314 | encoder_config.is_encoder_decoder = False 315 | 316 | #---- Modified ----# 317 | # self.encoder = T5Stack(encoder_config, self.shared) 318 | self.encoder = JointEncoder(encoder_config, self.shared) 319 | #------------------# 320 | 321 | decoder_config = copy.deepcopy(config) 322 | decoder_config.is_decoder = True 323 | decoder_config.is_encoder_decoder = False 324 | 325 | self.decoder = T5Stack(decoder_config, self.shared) 326 | 327 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 328 | 329 | self.init_weights() 330 | 331 | # Model parallel 332 | self.model_parallel = False 333 | self.device_map = None 334 | 335 | def set_input_embeddings(self, new_embeddings): 336 | self.shared = new_embeddings 337 | self.encoder.set_input_embeddings(new_embeddings) 338 | self.decoder.set_input_embeddings(new_embeddings) 339 | 340 | def extend_vocab(self, vocab_size): 341 | 342 | new_shared = nn.Embedding(vocab_size, self.config.d_model) 343 | old_weight = self.shared.weight.data.detach().clone() 344 | old_vocab_size = old_weight.size(0) 345 | new_shared.weight.data[:old_vocab_size, :] = old_weight 346 | self.shared = new_shared 347 | 348 | new_lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False) 349 | old_weight = self.lm_head.weight.data.detach().clone() 350 | old_vocab_size = old_weight.size(0) 351 | new_lm_head.weight.data[:old_vocab_size, :] = old_weight 352 | self.lm_head = new_lm_head 353 | 354 | self.vis_encoder.visual_embedding.obj_order_embedding = self.shared 355 | 356 | self.encoder.embed_tokens = self.shared 357 | self.decoder.embed_tokens = self.shared 358 | 359 | self.lm_head.weight = self.shared.weight 360 | 361 | self.config.vocab_size = vocab_size 362 | self.encoder.config.vocab_size = vocab_size 363 | self.vis_encoder.config.vocab_size = vocab_size 364 | self.decoder.config.vocab_size = vocab_size 365 | 366 | 367 | # @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) 368 | # @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 369 | def forward( 370 | self, 371 | input_ids=None, 372 | attention_mask=None, 373 | encoder_outputs=None, 374 | 375 | vis_inputs=None, 376 | vis_attention_mask=None, 377 | 378 | decoder_input_ids=None, 379 | decoder_attention_mask=None, 380 | past_key_values=None, 381 | use_cache=None, 382 | labels=None, 383 | inputs_embeds=None, 384 | decoder_inputs_embeds=None, 385 | head_mask=None, 386 | output_attentions=None, 387 | output_hidden_states=None, 388 | return_dict=None, 389 | reduce_loss=False, 390 | 391 | return_hidden_state=False, 392 | 393 | **kwargs, 394 | ): 395 | 396 | use_cache = use_cache if use_cache is not None else self.config.use_cache 397 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 398 | 399 | if encoder_outputs is None: 400 | 401 | encoder_outputs = self.encoder( 402 | input_ids=input_ids, 403 | attention_mask=attention_mask, 404 | inputs_embeds=inputs_embeds, 405 | 406 | vis_inputs=vis_inputs, 407 | vis_attention_mask=vis_attention_mask, 408 | 409 | head_mask=head_mask, 410 | output_attentions=output_attentions, 411 | output_hidden_states=output_hidden_states, 412 | return_dict=return_dict, 413 | ) 414 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 415 | encoder_outputs = BaseModelOutput( 416 | last_hidden_state=encoder_outputs[0], 417 | hidden_states=encoder_outputs[1] if len( 418 | encoder_outputs) > 1 else None, 419 | attentions=encoder_outputs[2] if len( 420 | encoder_outputs) > 2 else None, 421 | ) 422 | 423 | hidden_states = encoder_outputs[0] 424 | 425 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: 426 | # get decoder inputs from shifting lm labels to the right 427 | decoder_input_ids = self._shift_right(labels) 428 | 429 | # If decoding with past key value states, only the last tokens 430 | # should be given as an input 431 | if past_key_values is not None: 432 | assert labels is None, "Decoder should not use cached key value states when training." 433 | if decoder_input_ids is not None: 434 | decoder_input_ids = decoder_input_ids[:, -1:] 435 | if decoder_inputs_embeds is not None: 436 | decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] 437 | 438 | if attention_mask is None: 439 | attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=hidden_states.dtype, device=hidden_states.device) 440 | if vis_attention_mask is None: 441 | B, L = attention_mask.size() 442 | V_L = encoder_outputs[0].size(1) - L 443 | vis_attention_mask = attention_mask.new_ones(B, V_L) 444 | encoder_attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1) 445 | 446 | # Decode 447 | decoder_outputs = self.decoder( 448 | input_ids=decoder_input_ids, 449 | attention_mask=decoder_attention_mask, 450 | inputs_embeds=decoder_inputs_embeds, 451 | past_key_values=past_key_values, 452 | 453 | encoder_hidden_states=hidden_states, 454 | encoder_attention_mask=encoder_attention_mask, 455 | 456 | head_mask=head_mask, 457 | use_cache=use_cache, 458 | output_attentions=output_attentions, 459 | output_hidden_states=output_hidden_states, 460 | return_dict=return_dict, 461 | ) 462 | 463 | sequence_output = decoder_outputs[0] 464 | 465 | assert self.config.tie_word_embeddings is True 466 | 467 | if self.config.tie_word_embeddings: 468 | # Rescale output before projecting on vocab 469 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 470 | sequence_output = sequence_output * (self.model_dim ** -0.5) 471 | 472 | if return_hidden_state: 473 | return sequence_output 474 | 475 | lm_logits = self.lm_head(sequence_output) 476 | 477 | loss = None 478 | if labels is not None: 479 | if reduce_loss: 480 | loss_fct = CrossEntropyLoss(ignore_index=-100) 481 | else: 482 | loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none') 483 | loss = loss_fct( 484 | lm_logits.view(-1, lm_logits.size(-1)), 485 | labels.view(-1)) 486 | 487 | return VLSeq2SeqLMOutput( 488 | loss=loss, 489 | logits=lm_logits, 490 | past_key_values=decoder_outputs.past_key_values, 491 | decoder_last_hidden_state=decoder_outputs.last_hidden_state, 492 | decoder_hidden_states=decoder_outputs.hidden_states, 493 | ) 494 | 495 | def prepare_inputs_for_generation( 496 | self, input_ids, past=None, attention_mask=None, use_cache=None, 497 | encoder_outputs=None, 498 | **kwargs): 499 | 500 | # cut decoder_input_ids if past is used 501 | if past is not None: 502 | input_ids = input_ids[:, -1:] 503 | 504 | output = { 505 | "decoder_input_ids": input_ids, 506 | "past_key_values": past, 507 | "encoder_outputs": encoder_outputs, 508 | "attention_mask": attention_mask, 509 | "use_cache": use_cache, 510 | } 511 | 512 | if 'vis_attention_mask' in kwargs: 513 | output['vis_attention_mask'] = kwargs['vis_attention_mask'] 514 | 515 | return output 516 | 517 | @staticmethod 518 | def _expand_inputs_for_generation( 519 | input_ids: torch.LongTensor, 520 | expand_size: int = 1, 521 | is_encoder_decoder: bool = False, 522 | attention_mask: torch.LongTensor = None, 523 | encoder_outputs: ModelOutput = None, 524 | **model_kwargs 525 | ) -> Tuple[torch.LongTensor, Dict[str, Any]]: 526 | expanded_return_idx = ( 527 | torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, 528 | expand_size).view(-1).to(input_ids.device) 529 | ) 530 | input_ids = input_ids.index_select(0, expanded_return_idx) 531 | 532 | if "token_type_ids" in model_kwargs: 533 | token_type_ids = model_kwargs["token_type_ids"] 534 | model_kwargs["token_type_ids"] = token_type_ids.index_select( 535 | 0, expanded_return_idx) 536 | 537 | if attention_mask is not None: 538 | model_kwargs["attention_mask"] = attention_mask.index_select( 539 | 0, expanded_return_idx) 540 | 541 | if model_kwargs.get("vis_attention_mask", None) is not None: 542 | model_kwargs['vis_attention_mask'] = model_kwargs['vis_attention_mask'].index_select( 543 | 0, expanded_return_idx) 544 | 545 | if is_encoder_decoder: 546 | assert encoder_outputs is not None 547 | encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 548 | 0, expanded_return_idx 549 | ) 550 | model_kwargs["encoder_outputs"] = encoder_outputs 551 | 552 | return input_ids, model_kwargs 553 | 554 | 555 | @dataclass 556 | class VLSeq2SeqLMOutput(ModelOutput): 557 | """ 558 | Base class for sequence-to-sequence language models outputs. 559 | 560 | Args: 561 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): 562 | Languaged modeling loss. 563 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): 564 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 565 | past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): 566 | List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape 567 | :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). 568 | 569 | Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be 570 | used (see ``past_key_values`` input) to speed up sequential decoding. 571 | decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): 572 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 573 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 574 | 575 | Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. 576 | decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 577 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 578 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 579 | 580 | Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the 581 | self-attention heads. 582 | encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 583 | Sequence of hidden-states at the output of the last layer of the encoder of the model. 584 | encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): 585 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 586 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 587 | 588 | Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. 589 | encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 590 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 591 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 592 | 593 | Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the 594 | self-attention heads. 595 | """ 596 | 597 | loss: Optional[torch.FloatTensor] = None 598 | logits: torch.FloatTensor = None 599 | past_key_values: Optional[List[torch.FloatTensor]] = None 600 | decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None 601 | decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 602 | decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None 603 | encoder_last_hidden_state: Optional[torch.FloatTensor] = None 604 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 605 | encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None 606 | 607 | vis_encoder_last_hidden_state: Optional[torch.FloatTensor] = None 608 | vis_encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 609 | vis_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None 610 | --------------------------------------------------------------------------------