├── cal_aer.sh ├── run_align.sh ├── train.sh ├── train_utils.py ├── README.md ├── aligner ├── word_align.py └── sent_aligner.py ├── self_training_modeling_adapter.py ├── aer.py └── train_alignment_adapter.py /cal_aer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | 6 | export WORKLOC=/home/acc_align 7 | datadir=$WORKLOC/data/deen 8 | 9 | ref_align=$datadir/deen.talp 10 | reftype='--oneRef' 11 | 12 | 13 | 14 | for LayerNum in `seq 1 12`; do 15 | echo "=====AER shifted for de-en layer=${LayerNum}..." 16 | python $WORKLOC/aer.py ${ref_align} $WORKLOC/xxx/de2en.align.$LayerNum --fAlpha 0.5 $reftype 17 | done 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /run_align.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | WorkLOC=/home/acc_align #yours 5 | 6 | SRC=$WorkLOC/xxx/roen/roen.src 7 | TGT=$WorkLOC/xxx/roen/roen.tgt 8 | 9 | OUTPUT_DIR=$WorkLOC/xxx/infer_output 10 | ADAPTER=$WorkLOC/xxx/adapter 11 | Model=$WorkLOC/xxx/LaBSE 12 | 13 | 14 | 15 | python $WorkLOC/github_open/aligner/train_alignment_adapter.py \ 16 | --infer_path $OUTPUT_DIR \ 17 | --adapter_path $ADAPTER \ 18 | --model_name_or_path $Model \ 19 | --extraction 'softmax' \ 20 | --infer_data_file_src $SRC \ 21 | --infer_data_file_tgt $TGT \ 22 | --per_gpu_train_batch_size 40 \ 23 | --gradient_accumulation_steps 1 \ 24 | --align_layer 6 \ 25 | --softmax_threshold 0.1 \ 26 | --do_test \ 27 | 28 | 29 | exit 30 | 31 | 32 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | WorkLOC=/home/acc_align #yours 5 | 6 | 7 | 8 | 9 | TRAIN_FILE_SRC=$WorkLOC/xxx/train.src 10 | TRAIN_FILE_TGT=$WorkLOC/xxx/train.tgt 11 | TRAIN_FILE_ALIGN=$WorkLOC/xxx/train.talp 12 | 13 | EVAL_FILE_SRC=$WorkLOC/xxx/dev.src 14 | EVAL_FILE_TGT=$WorkLOC/xxx/dev.tgt 15 | Eval_gold_file=$WorkLOC/xxx/dev.talp 16 | 17 | OUTPUT_DIR_ADAPTER=$WorkLOC/adapter_output 18 | Model=$WorkLOC/models/LaBSE 19 | 20 | EVAL_RES=$WorkLOC/xxx/eval_result 21 | 22 | 23 | CUDA_VISIBLE_DEVICES=0 python $WorkLOC/github_open/train_alignment_adapter.py \ 24 | --output_dir_adapter $OUTPUT_DIR_ADAPTER \ 25 | --eval_res_dir $EVAL_RES \ 26 | --model_name_or_path $Model \ 27 | --extraction 'softmax' \ 28 | --train_so \ 29 | --do_train \ 30 | --do_eval \ 31 | --train_data_file_src $TRAIN_FILE_SRC \ 32 | --train_data_file_tgt $TRAIN_FILE_TGT \ 33 | --eval_data_file_src $EVAL_FILE_SRC \ 34 | --eval_data_file_tgt $EVAL_FILE_TGT \ 35 | --per_gpu_train_batch_size 40 \ 36 | --gradient_accumulation_steps 1 \ 37 | --num_train_epochs 20 \ 38 | --learning_rate 1e-4 \ 39 | --save_steps 100 \ 40 | --max_steps 1200 \ 41 | --align_layer 6 \ 42 | --logging_steps 50 \ 43 | --eval_gold_file $Eval_gold_file \ 44 | --gold_one_index \ 45 | --softmax_threshold 0.1 \ 46 | --train_gold_file $TRAIN_FILE_ALIGN \ 47 | #--output_dir $OUTPUT_DIR \ 48 | 49 | exit 50 | 51 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Modifications copyright (C) 2020 Zi-Yi Dou 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | WEIGHTS_NAME = "pytorch_model.bin" 18 | 19 | import os 20 | import glob 21 | import re 22 | import shutil 23 | from typing import Dict, List, Tuple 24 | import logging 25 | import math 26 | import torch 27 | from torch.optim import Optimizer 28 | from torch.optim.lr_scheduler import LambdaLR 29 | 30 | import logging 31 | from typing import Text 32 | import os 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | def get_logger(name: Text, filename: Text = None, level: int = logging.DEBUG) -> logging.Logger: 38 | logger = logging.getLogger(name) 39 | logger.setLevel(level) 40 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 41 | 42 | ch = logging.StreamHandler() 43 | ch.setLevel(level) 44 | ch.setFormatter(formatter) 45 | logger.addHandler(ch) 46 | 47 | if filename is not None: 48 | fh = logging.FileHandler(filename) 49 | fh.setLevel(level) 50 | fh.setFormatter(formatter) 51 | logger.addHandler(fh) 52 | 53 | return logger 54 | 55 | def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]: 56 | ordering_and_checkpoint_path = [] 57 | 58 | glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix))) 59 | 60 | for path in glob_checkpoints: 61 | if use_mtime: 62 | ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) 63 | else: 64 | regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path) 65 | if regex_match and regex_match.groups(): 66 | ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) 67 | 68 | checkpoints_sorted = sorted(ordering_and_checkpoint_path) 69 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] 70 | return checkpoints_sorted 71 | 72 | def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None: 73 | if not args.save_total_limit: 74 | return 75 | if args.save_total_limit <= 0: 76 | return 77 | 78 | # Check if we should delete older checkpoint(s) 79 | checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime) 80 | if len(checkpoints_sorted) <= args.save_total_limit: 81 | return 82 | 83 | number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) 84 | checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] 85 | for checkpoint in checkpoints_to_be_deleted: 86 | logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) 87 | shutil.rmtree(checkpoint) 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multilingual Sentence Transformer as A Multilingual Word Aligner 2 | 3 | 4 | 5 | 6 | 7 | ## Requirements 8 | 9 | Trained with Python 3.7, adapter-transformers 4.16.2, Torch 1.9.0, tqdm 4.62.3. 10 | 11 | 12 | ## Data Format 13 | 14 | - source 15 | ``` 16 | Wir glauben nicht , daß wir nur Rosinen herauspicken sollten . 17 | Das stimmt nicht ! 18 | ``` 19 | 20 | - target 21 | ``` 22 | We do not believe that we should cherry-pick . 23 | But this is not what happens . 24 | ``` 25 | 26 | - gold alignment 27 | ``` 28 | 9-8 8-8 7-8 6-6 1-1 2-2 4-5 5-5 3-3 11-9 10-7 2-4 29 | 3-4 2-3 2-6 4-7 2-5 1-2 30 | ``` 31 | ## Directly extract alignments 32 | 33 | ```shell 34 | bash run_align.sh 35 | ``` 36 | 37 | 38 | ## Fine-tuning on training data 39 | 40 | ```shell 41 | bash train.sh 42 | ``` 43 | 44 | ## Calculate AER 45 | 46 | ```shell 47 | bash cal_aer.sh 48 | ``` 49 | 50 | ## Data 51 | 52 | Links to the test set used in the paper are here: 53 | 54 | 55 | | Language Pair | Type |Link | 56 | | ------------- | ------------- | ------------- | 57 | | En-De | Gold Alignment | www-i6.informatik.rwth-aachen.de/goldAlignment/ | 58 | | En-Fr | Gold Alignment | http://web.eecs.umich.edu/~mihalcea/wpt/ | 59 | | En-Ro | Gold Alignment | http://web.eecs.umich.edu/~mihalcea/wpt05/ | 60 | | En-Fa | Gold Alignment | https://ece.ut.ac.ir/en/web/nlp/resources | 61 | | En-Zh | Gold Alignment | https://nlp.csai.tsinghua.edu.cn/~ly/systems/TsinghuaAligner/TsinghuaAligner.html | 62 | | En-Ja | Gold Alignment | http://www.phontron.com/kftt | 63 | | En-Sv | Gold Alignment | https://www.ida.liu.se/divisions/hcs/nlplab/resources/ges/ | 64 | 65 | Links to the training set and validation set used in the paper are here [here](https://drive.google.com/file/d/19X0mhTx6-EhgILm7_mtVWrT2qal-o-uV/view?usp=share_link) 66 | 67 | ## LaBSE 68 | 69 | You can access to LaBSE model [here](https://huggingface.co/sentence-transformers/LaBSE) . 70 | 71 | ## Adapter Checkpoints 72 | 73 | The multilingual adapter checkpoint is [here](https://drive.google.com/open?id=1eB8aWd4iM6DSQWJZOA5so4rB4MCQQyQf&usp=drive_copy) . 74 | 75 | ## Citation 76 | 77 | ``` 78 | @inproceedings{wang-etal-2022-multilingual, 79 | title = "Multilingual Sentence Transformer as A Multilingual Word Aligner", 80 | author = "Wang, Weikang and 81 | Chen, Guanhua and 82 | Wang, Hanqing and 83 | Han, Yue and 84 | Chen, Yun", 85 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2022", 86 | month = dec, 87 | year = "2022", 88 | address = "Abu Dhabi, United Arab Emirates", 89 | publisher = "Association for Computational Linguistics", 90 | url = "https://aclanthology.org/2022.findings-emnlp.215", 91 | pages = "2952--2963", 92 | abstract = "Multilingual pretrained language models (mPLMs) have shown their effectiveness in multilingual word alignment induction. However, these methods usually start from mBERT or XLM-R. In this paper, we investigate whether multilingual sentence Transformer LaBSE is a strong multilingual word aligner. This idea is non-trivial as LaBSE is trained to learn language-agnostic sentence-level embeddings, while the alignment extraction task requires the more fine-grained word-level embeddings to be language-agnostic. We demonstrate that the vanilla LaBSE outperforms other mPLMs currently used in the alignment task, and then propose to finetune LaBSE on parallel corpus for further improvement. Experiment results on seven language pairs show that our best aligner outperforms previous state-of-the-art models of all varieties. In addition, our aligner supports different language pairs in a single model, and even achieves new state-of-the-art on zero-shot language pairs that does not appear in the finetuning process.", 93 | } 94 | 95 | ``` 96 | -------------------------------------------------------------------------------- /aligner/word_align.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import logging 5 | from typing import Dict, List, Tuple, Union 6 | 7 | import numpy as np 8 | import torch 9 | from transformers import BertModel, BertTokenizer, XLMModel, XLMTokenizer, RobertaModel, RobertaTokenizer, \ 10 | XLMRobertaModel, XLMRobertaTokenizer, AutoConfig, AutoModel, AutoTokenizer 11 | import torch.nn as nn 12 | from train_utils import get_logger 13 | 14 | LOG = get_logger(__name__) 15 | 16 | 17 | def return_extended_attention_mask(attention_mask, dtype): 18 | if attention_mask.dim() == 3: 19 | extended_attention_mask = attention_mask[:, None, :, :] 20 | elif attention_mask.dim() == 2: 21 | extended_attention_mask = attention_mask[:, None, None, :] 22 | else: 23 | raise ValueError( 24 | "Wrong shape for input_ids or attention_mask" 25 | ) 26 | extended_attention_mask = extended_attention_mask.to(dtype=dtype) 27 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 28 | return extended_attention_mask 29 | 30 | 31 | 32 | class SentenceAligner_word(object): 33 | def __init__(self, args, model): 34 | 35 | self.guide = None 36 | self.softmax_threshold = args.softmax_threshold 37 | self.embed_loader = model 38 | 39 | def transpose_for_scores(self, x): 40 | new_x_shape = x.size()[:-1] + (1, x.size(-1)) 41 | x = x.view(*new_x_shape) 42 | return x.permute(0, 2, 1, 3) 43 | 44 | def get_subword_matrix(self, args, inputs_src, inputs_tgt, PAD_ID, CLS_ID, SEP_ID, output_prob=False): 45 | 46 | output_src,output_tgt = self.embed_loader( 47 | inputs_src=inputs_src, inputs_tgt=inputs_tgt, attention_mask_src=(inputs_src != PAD_ID), 48 | attention_mask_tgt=(inputs_tgt != PAD_ID), guide=None, align_layer=args.align_layer, 49 | extraction=args.extraction, softmax_threshold=args.softmax_threshold, 50 | train_so=args.train_so, train_co=args.train_co, do_infer=True, 51 | ) 52 | 53 | align_matrix_all_layers = {} 54 | 55 | for layer_id in range(1, len(output_src.hidden_states)): 56 | 57 | hidden_states_src = output_src.hidden_states[layer_id] 58 | hidden_states_tgt = output_tgt.hidden_states[layer_id] 59 | # mask 60 | attention_mask_src = ((inputs_src == PAD_ID) + (inputs_src == CLS_ID) + (inputs_src == SEP_ID)).float() 61 | attention_mask_tgt = ((inputs_tgt == PAD_ID) + (inputs_tgt == CLS_ID) + (inputs_tgt == SEP_ID)).float() 62 | len_src = torch.sum(1 - attention_mask_src, -1) 63 | len_tgt = torch.sum(1 - attention_mask_tgt, -1) 64 | attention_mask_src = return_extended_attention_mask(1 - attention_mask_src, hidden_states_src.dtype) 65 | attention_mask_tgt = return_extended_attention_mask(1 - attention_mask_tgt, hidden_states_tgt.dtype) 66 | 67 | # qkv 68 | query_src = self.transpose_for_scores(hidden_states_src) 69 | query_tgt = self.transpose_for_scores(hidden_states_tgt) 70 | key_src = query_src 71 | key_tgt = query_tgt 72 | value_src = query_src 73 | value_tgt = query_tgt 74 | 75 | # att 76 | attention_scores = torch.matmul(query_src, key_tgt.transpose(-1, -2)) 77 | attention_scores_src = attention_scores + attention_mask_tgt 78 | attention_scores_tgt = attention_scores + attention_mask_src.transpose(-1, -2) 79 | 80 | attention_probs_src = nn.Softmax(dim=-1)( 81 | attention_scores_src) # if extraction == 'softmax' else entmax15(attention_scores_src, dim=-1) 82 | attention_probs_tgt = nn.Softmax(dim=-2)( 83 | attention_scores_tgt) # if extraction == 'softmax' else entmax15(attention_scores_tgt, dim=-2) 84 | 85 | if self.guide is None: 86 | # threshold = softmax_threshold if extraction == 'softmax' else 0 87 | threshold = self.softmax_threshold 88 | align_matrix = (attention_probs_src > threshold) * (attention_probs_tgt > threshold) 89 | 90 | if not output_prob: 91 | # return align_matrix 92 | align_matrix_all_layers[layer_id] = align_matrix 93 | # A heuristic of generating the alignment probability 94 | """ 95 | attention_probs_src = nn.Softmax(dim=-1)(attention_scores_src/torch.sqrt(len_tgt.view(-1, 1, 1, 1))) 96 | attention_probs_tgt = nn.Softmax(dim=-2)(attention_scores_tgt/torch.sqrt(len_src.view(-1, 1, 1, 1))) 97 | align_prob = (2*attention_probs_src*attention_probs_tgt)/(attention_probs_src+attention_probs_tgt+1e-9) 98 | return align_matrix, align_prob 99 | """ 100 | 101 | return align_matrix_all_layers 102 | 103 | def get_aligned_word(self, args, inputs_src, inputs_tgt, bpe2word_map_src, bpe2word_map_tgt, PAD_ID, CLS_ID, SEP_ID, 104 | output_prob=False): 105 | 106 | attention_probs_inter_all_layers = self.get_subword_matrix(args, inputs_src, inputs_tgt, PAD_ID, CLS_ID, SEP_ID, 107 | output_prob) 108 | if output_prob: 109 | attention_probs_inter, alignment_probs = attention_probs_inter 110 | alignment_probs = alignment_probs[:, 0, 1:-1, 1:-1] 111 | 112 | word_aligns_all_layers = {} 113 | 114 | for layer_id in attention_probs_inter_all_layers: 115 | 116 | attention_probs_inter = attention_probs_inter_all_layers[layer_id].float() 117 | 118 | word_aligns = [] 119 | attention_probs_inter = attention_probs_inter[:, 0, 1:-1, 1:-1] 120 | 121 | for idx, (attention, b2w_src, b2w_tgt) in enumerate( 122 | zip(attention_probs_inter, bpe2word_map_src, bpe2word_map_tgt)): 123 | aligns = set() if not output_prob else dict() 124 | non_zeros = torch.nonzero(attention) 125 | for i, j in non_zeros: 126 | word_pair = (b2w_src[i], b2w_tgt[j]) 127 | if output_prob: 128 | prob = alignment_probs[idx, i, j] 129 | if not word_pair in aligns: 130 | aligns[word_pair] = prob 131 | else: 132 | aligns[word_pair] = max(aligns[word_pair], prob) 133 | else: 134 | aligns.add(word_pair) 135 | word_aligns.append(aligns) 136 | 137 | word_aligns_all_layers[layer_id] = word_aligns 138 | return word_aligns_all_layers 139 | -------------------------------------------------------------------------------- /self_training_modeling_adapter.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from transformers import AutoModel 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss 5 | import torch 6 | import numpy as np 7 | from transformers import PreTrainedModel 8 | 9 | 10 | PAD_ID=0 11 | CLS_ID=101 12 | SEP_ID=102 13 | 14 | def return_extended_attention_mask(attention_mask, dtype): 15 | if attention_mask.dim() == 3: 16 | extended_attention_mask = attention_mask[:, None, :, :] 17 | elif attention_mask.dim() == 2: 18 | extended_attention_mask = attention_mask[:, None, None, :] 19 | else: 20 | raise ValueError( 21 | "Wrong shape for input_ids or attention_mask" 22 | ) 23 | extended_attention_mask = extended_attention_mask.to(dtype=dtype) 24 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 25 | return extended_attention_mask 26 | 27 | 28 | class ModelGuideHead(nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | 32 | def transpose_for_scores(self, x): 33 | new_x_shape = x.size()[:-1] + (1, x.size(-1)) 34 | x = x.view(*new_x_shape) 35 | return x.permute(0, 2, 1, 3) 36 | 37 | def forward( 38 | self, 39 | hidden_states_src, hidden_states_tgt, 40 | inputs_src, inputs_tgt, 41 | guide=None, 42 | extraction='softmax', softmax_threshold=0.1, 43 | output_prob=False, 44 | ): 45 | #mask 46 | attention_mask_src = ( (inputs_src==PAD_ID) + (inputs_src==CLS_ID) + (inputs_src==SEP_ID) ).float() 47 | attention_mask_tgt = ( (inputs_tgt==PAD_ID) + (inputs_tgt==CLS_ID) + (inputs_tgt==SEP_ID) ).float() 48 | len_src = torch.sum(1-attention_mask_src, -1) 49 | len_tgt = torch.sum(1-attention_mask_tgt, -1) 50 | attention_mask_src = return_extended_attention_mask(1-attention_mask_src, hidden_states_src.dtype) 51 | attention_mask_tgt = return_extended_attention_mask(1-attention_mask_tgt, hidden_states_src.dtype) 52 | 53 | #qkv 54 | query_src = self.transpose_for_scores(hidden_states_src) 55 | query_tgt = self.transpose_for_scores(hidden_states_tgt) 56 | key_src = query_src 57 | key_tgt = query_tgt 58 | value_src = query_src 59 | value_tgt = query_tgt 60 | 61 | #att 62 | attention_scores = torch.matmul(query_src, key_tgt.transpose(-1, -2)) 63 | attention_scores_src = attention_scores + attention_mask_tgt 64 | attention_scores_tgt = attention_scores + attention_mask_src.transpose(-1, -2) 65 | 66 | 67 | attention_probs_src = nn.Softmax(dim=-1)(attention_scores_src) #if extraction == 'softmax' else entmax15(attention_scores_src, dim=-1) 68 | attention_probs_tgt = nn.Softmax(dim=-2)(attention_scores_tgt) #if extraction == 'softmax' else entmax15(attention_scores_tgt, dim=-2) 69 | 70 | 71 | 72 | if guide is None: 73 | 74 | threshold = softmax_threshold if extraction == 'softmax' else 0 75 | align_matrix = (attention_probs_src>threshold)*(attention_probs_tgt>threshold) 76 | if not output_prob: 77 | return align_matrix 78 | # A heuristic of generating the alignment probability 79 | attention_probs_src = nn.Softmax(dim=-1)(attention_scores_src/torch.sqrt(len_tgt.view(-1, 1, 1, 1))) 80 | attention_probs_tgt = nn.Softmax(dim=-2)(attention_scores_tgt/torch.sqrt(len_src.view(-1, 1, 1, 1))) 81 | align_prob = (2*attention_probs_src*attention_probs_tgt)/(attention_probs_src+attention_probs_tgt+1e-9) 82 | return align_matrix, align_prob 83 | 84 | 85 | 86 | 87 | so_loss_src = torch.sum(torch.sum (attention_probs_src*guide, -1), -1).view(-1) 88 | so_loss_tgt = torch.sum(torch.sum (attention_probs_tgt*guide, -1), -1).view(-1) 89 | 90 | so_loss = so_loss_src/len_src + so_loss_tgt/len_tgt 91 | so_loss = -torch.mean(so_loss) 92 | 93 | 94 | 95 | return so_loss 96 | 97 | 98 | 99 | 100 | 101 | class BertForSO(PreTrainedModel): 102 | def __init__(self, args, config, model_adapter): 103 | super().__init__(config) 104 | self.model = model_adapter 105 | self.guide_layer = ModelGuideHead() 106 | 107 | def forward( 108 | self, 109 | inputs_src, 110 | inputs_tgt=None, 111 | labels_src=None, 112 | labels_tgt=None, 113 | attention_mask_src=None, 114 | attention_mask_tgt=None, 115 | align_layer=6, 116 | guide=None, 117 | extraction='softmax', softmax_threshold=0.1, 118 | position_ids1=None, 119 | position_ids2=None, 120 | do_infer=False, 121 | ): 122 | 123 | loss_fct =CrossEntropyLoss(reduction='none') 124 | batch_size = inputs_src.size(0) 125 | 126 | output_src = self.model( 127 | inputs_src, 128 | attention_mask=attention_mask_src, 129 | position_ids=position_ids1, 130 | ) 131 | 132 | 133 | output_tgt = self.model( 134 | inputs_tgt, 135 | attention_mask=attention_mask_tgt, 136 | position_ids=position_ids2, 137 | ) 138 | if do_infer: 139 | return output_src, output_tgt 140 | 141 | if guide is None: 142 | raise ValueError('must specify labels for the self-trianing objective') 143 | 144 | 145 | 146 | hidden_states_src = output_src.hidden_states[align_layer] 147 | hidden_states_tgt = output_tgt.hidden_states[align_layer] 148 | 149 | 150 | sco_loss = self.guide_layer(hidden_states_src, hidden_states_tgt, inputs_src, inputs_tgt, guide=guide, 151 | extraction=extraction, softmax_threshold=softmax_threshold) 152 | return sco_loss 153 | 154 | def save_adapter(self, save_directory, adapter_name): 155 | self.model.save_adapter(save_directory, adapter_name) 156 | 157 | def get_aligned_word(self, inputs_src, inputs_tgt, bpe2word_map_src, bpe2word_map_tgt, device, src_len, tgt_len, 158 | align_layer=6, extraction='softmax', softmax_threshold=0.1, test=False, output_prob=False, 159 | word_aligns=None, pairs_len=None): 160 | batch_size = inputs_src.size(0) 161 | bpelen_src, bpelen_tgt = inputs_src.size(1) - 2, inputs_tgt.size(1) - 2 162 | if word_aligns is None: 163 | inputs_src = inputs_src.to(dtype=torch.long, device=device).clone() 164 | inputs_tgt = inputs_tgt.to(dtype=torch.long, device=device).clone() 165 | 166 | with torch.no_grad(): 167 | outputs_src = self.model( 168 | inputs_src, 169 | attention_mask=(inputs_src != PAD_ID), 170 | ) 171 | outputs_tgt = self.model( 172 | inputs_tgt, 173 | attention_mask=(inputs_tgt != PAD_ID), 174 | ) 175 | 176 | 177 | hidden_states_src = outputs_src.hidden_states[align_layer] 178 | hidden_states_tgt = outputs_tgt.hidden_states[align_layer] 179 | 180 | attention_probs_inter = self.guide_layer(hidden_states_src, hidden_states_tgt, inputs_src, inputs_tgt, 181 | extraction=extraction, softmax_threshold=softmax_threshold, 182 | output_prob=output_prob) 183 | if output_prob: 184 | attention_probs_inter, alignment_probs = attention_probs_inter 185 | alignment_probs = alignment_probs[:, 0, 1:-1, 1:-1] 186 | attention_probs_inter = attention_probs_inter.float() 187 | 188 | word_aligns = [] 189 | attention_probs_inter = attention_probs_inter[:, 0, 1:-1, 1:-1] 190 | 191 | for idx, (attention, b2w_src, b2w_tgt) in enumerate( 192 | zip(attention_probs_inter, bpe2word_map_src, bpe2word_map_tgt)): 193 | aligns = set() if not output_prob else dict() 194 | non_zeros = torch.nonzero(attention) 195 | for i, j in non_zeros: 196 | word_pair = (b2w_src[i], b2w_tgt[j]) 197 | if output_prob: 198 | prob = alignment_probs[idx, i, j] 199 | if not word_pair in aligns: 200 | aligns[word_pair] = prob 201 | else: 202 | aligns[word_pair] = max(aligns[word_pair], prob) 203 | else: 204 | aligns.add(word_pair) 205 | word_aligns.append(aligns) 206 | 207 | if test: 208 | 209 | return word_aligns 210 | 211 | 212 | 213 | guide = torch.zeros(batch_size, 1, src_len, tgt_len) 214 | for idx, (word_align, b2w_src, b2w_tgt) in enumerate(zip(word_aligns, bpe2word_map_src, bpe2word_map_tgt)): 215 | len_src = min(bpelen_src, len(b2w_src)) 216 | len_tgt = min(bpelen_tgt, len(b2w_tgt)) 217 | 218 | for i in range(len_src): 219 | for j in range(len_tgt): 220 | if (b2w_src[i], b2w_tgt[j]) in word_align: 221 | guide[idx, 0, i + 1, j + 1] = 1.0 222 | 223 | 224 | 225 | return guide 226 | -------------------------------------------------------------------------------- /aer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import itertools 5 | from collections import Counter 6 | 7 | 8 | PUNCTUATION_MARKS = {".", ",", "!", "?", ";", ":", "(", ")"} 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser("Calculates Alignment Error Rate, output format: AER (Precision, Recall, Alginment-Links-Hypothesis)") 13 | parser.add_argument("reference", help="path of reference alignment, e.g. '10-9 11p42'") 14 | parser.add_argument("hypothesis", help="path to hypothesis alignment") 15 | 16 | parser.add_argument("--reverseRef", help="reverse reference alignment", action='store_true') 17 | parser.add_argument("--reverseHyp", help="reverse hypothesis alignment", action='store_true') 18 | 19 | parser.add_argument("--oneRef", help="reference indices start at index 1", action='store_true') 20 | parser.add_argument("--oneHyp", help="hypothesis indices start at index 1", action='store_true') 21 | parser.add_argument("--allSure", help="treat all alignments in the reference as sure alignments", action='store_true') 22 | parser.add_argument("--ignorePossible", help="Ignore all possible links", action='store_true') 23 | parser.add_argument("--fAlpha", help="alpha parameter used to calculate f measure (has to be set to a value >= 0.0 to report the f-measure)", default=-1.0, type=float) 24 | 25 | parser.add_argument("--source", default="", help="the source sentence, used for an error analysis") 26 | parser.add_argument("--target", default="", help="the target sentence, used for an error analysis") 27 | parser.add_argument("--cleanPunctuation", action="store_true", help="Removes alignments including punctuation marks, that are not aligned to the same punctuation mark (e.g. ','-'that')") 28 | parser.add_argument("--most_common_errors", default=10, type=int) 29 | 30 | return parser.parse_args() 31 | 32 | 33 | def calculate_internal_jumps(alignments): 34 | """ Count number of times the set of source word indices aligned to a target word index are not adjacent 35 | Each non adjacent set of source word indices counts only once 36 | >>> calculate_internal_jumps([{1,2,4}, {42}]) 37 | 1 38 | >>> calculate_internal_jumps([{1,2,3,4}]) 39 | 0 40 | >>> calculate_internal_jumps([set()]) 41 | 0 42 | """ 43 | def contiguous(s): 44 | if len(s) <= 1: 45 | return True 46 | else: 47 | elements_in_contiguous_set = max(s) - min(s) + 1 48 | return elements_in_contiguous_set == len(s) 49 | 50 | return [contiguous(s) for s in alignments].count(False) 51 | 52 | 53 | def calculate_external_jumps(alignments): 54 | """ Count number of times the (smallest) source index aligned to target word x is not adjacent or identical to any source word index aligned to the next target word index x+1 55 | Target words which do not have any source word aligned to it are ignored 56 | >>> calculate_external_jumps([set(), {1,2,4}, {2}, {4}, set()]) 57 | 1 58 | """ 59 | 60 | jumps = 0 61 | 62 | for prev, current in zip(alignments, alignments[1:]): 63 | if len(prev) > 0 and len(current) > 0: 64 | src = sorted(prev)[0] 65 | if src in current or src+1 in current or src-1 in current: 66 | pass 67 | else: 68 | jumps += 1 69 | return jumps 70 | 71 | 72 | def to_list(A): 73 | """ converts set of src-tgt alignments to a list containing a set of aligned source word for each target position 74 | >>> to_list({(2,1)}) 75 | [set(), {2}] 76 | """ 77 | max_tgt_idx = max({y for x, y in A}) if len(A) > 0 else 0 78 | lst = [set() for _ in range(max_tgt_idx+1)] 79 | for x, y in A: 80 | lst[y].add(x) 81 | return lst 82 | 83 | 84 | def calculate_metrics(array_sure, array_possible, array_hypothesis, f_alpha, source_sentences=(), target_sentences=(), clean_punctuation=False): 85 | """ Calculates precision, recall and alignment error rate as described in "A Systematic Comparison of Various 86 | Statistical Alignment Models" (https://www.aclweb.org/anthology/J/J03/J03-1002.pdf) in chapter 5 87 | 88 | 89 | Args: 90 | array_sure: array of sure alignment links 91 | array_possible: array of possible alignment links 92 | array_hypothesis: array of hypothesis alignment links 93 | """ 94 | 95 | number_of_sentences = len(array_sure) 96 | assert number_of_sentences == len(array_possible) 97 | assert number_of_sentences == len(array_hypothesis) 98 | 99 | errors = Counter() 100 | 101 | sum_a_intersect_p, sum_a_intersect_s, sum_s, sum_a, aligned_source_words, aligned_target_words = 6 * [0.0] 102 | sum_source_words, sum_target_words = map(lambda s: max(1.0, sum(len(x) for x in s)), [source_sentences, target_sentences]) 103 | internal_jumps, external_jumps = 0, 0 104 | 105 | for S, P, A, source, target in itertools.zip_longest(array_sure, array_possible, array_hypothesis, source_sentences, target_sentences): 106 | if clean_punctuation: 107 | A = {(s, t) for (s, t) in A if not ((source[s] in PUNCTUATION_MARKS or target[t] in PUNCTUATION_MARKS) and source[s] != target[t])} 108 | sum_a += len(A) 109 | sum_s += len(S) 110 | sum_a_intersect_p += len(A.intersection(P)) 111 | sum_a_intersect_s += len(A.intersection(S)) 112 | aligned_source_words += len({x for x, y in A}) 113 | aligned_target_words += len({y for x, y in A}) 114 | al = to_list(A) 115 | internal_jumps += calculate_internal_jumps(al) 116 | external_jumps += calculate_external_jumps(al) 117 | 118 | if source and target: 119 | for src_pos, tgt_pos in A: 120 | if not src_pos < len(source): 121 | print(source, len(source), src_pos) 122 | if not tgt_pos < len(target): 123 | print(target, len(target), tgt_pos) 124 | if (src_pos, tgt_pos) not in P: 125 | errors[source[src_pos], target[tgt_pos]] += 1 126 | 127 | precision = sum_a_intersect_p / sum_a 128 | recall = sum_a_intersect_s / sum_s 129 | aer = 1.0 - ((sum_a_intersect_p + sum_a_intersect_s) / (sum_a + sum_s)) 130 | 131 | if f_alpha < 0.0: 132 | f_measure = 0.0 133 | else: 134 | f_divident = f_alpha / precision 135 | f_divident += (1.0 - f_alpha) / recall 136 | f_measure = 1.0 / f_divident 137 | 138 | source_coverage = aligned_source_words / sum_source_words 139 | target_coverage = aligned_target_words / sum_target_words 140 | 141 | return precision, recall, aer, f_measure, errors, source_coverage, target_coverage, internal_jumps, external_jumps 142 | 143 | 144 | def parse_single_alignment(string, reverse=False, one_indexed=False): 145 | assert ('-' in string or 'p' in string) and 'Bad Alignment separator' 146 | 147 | a, b = string.replace('p', '-').split('-') 148 | a, b = int(a), int(b) 149 | 150 | if one_indexed: 151 | a = a - 1 152 | b = b - 1 153 | 154 | #if reverse: 155 | # a, b = b, a 156 | 157 | return a, b 158 | 159 | 160 | def read_text(path): 161 | if path == "": 162 | return [] 163 | with open(path, "r", encoding="utf-8") as f: 164 | return [l.split() for l in f] 165 | 166 | 167 | if __name__ == "__main__": 168 | args = parse_args() 169 | sure, possible, hypothesis = [], [], [] 170 | 171 | source, target = map(read_text, [args.source, args.target]) 172 | 173 | assert len(source) == len(target), "Length of source and target does not match" 174 | assert (not args.cleanPunctuation) or len(source) > 0, "To clean punctuation alignments, specify a source and target text file" 175 | # print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$') 176 | # print(args.reverseRef) #False 177 | # print(args.oneRef) #True 178 | # print(args.cleanPunctuation) #False 179 | # print(args.allSure)#False 180 | with open(args.reference, 'r') as f: 181 | for line in f: 182 | sure.append(set()) 183 | possible.append(set()) 184 | 185 | for alignment_string in line.split(): 186 | 187 | sure_alignment = True if '-' in alignment_string else False 188 | alignment_tuple = parse_single_alignment(alignment_string, args.reverseRef, args.oneRef) 189 | 190 | if sure_alignment or args.allSure: 191 | sure[-1].add(alignment_tuple) 192 | if sure_alignment or not args.ignorePossible: 193 | possible[-1].add(alignment_tuple) 194 | with open(args.hypothesis, 'r') as f: 195 | for line in f: 196 | hypothesis.append(set()) 197 | 198 | for alignment_string in line.split(): 199 | alignment_tuple = parse_single_alignment(alignment_string, args.reverseHyp, args.oneHyp) 200 | hypothesis[-1].add(alignment_tuple) 201 | 202 | precision, recall, aer, f_measure, errors, source_coverage, target_coverage, internal_jumps, external_jumps = calculate_metrics(sure, possible, hypothesis, args.fAlpha, source, target, args.cleanPunctuation) 203 | print("{0}: {1:.1f}% ({2:.1f}%/{3:.1f}%/{4})".format(args.hypothesis, 204 | aer * 100.0, precision * 100.0, recall * 100.0, sum([len(x) for x in hypothesis]))) 205 | #print("=======aer========",aer * 100.0) 206 | if args.fAlpha >= 0.0: 207 | print("F-Measure: {:.3f}".format(f_measure)) 208 | 209 | if args.source: 210 | assert args.target and args.most_common_errors > 0, "To output the most common errors, define a source and target file and the number of errors to output" 211 | print(errors.most_common(args.most_common_errors)) 212 | print("Internal Jumps: {}, External Jumps: {}".format(internal_jumps, external_jumps)) 213 | print("Source Coverage: {:.1f}%, Target Coverage: {:.1f}%".format(source_coverage * 100.0, target_coverage * 100.0)) 214 | -------------------------------------------------------------------------------- /aligner/sent_aligner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Modifications copyright (C) 2020 Zi-Yi Dou 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import argparse 19 | import random 20 | import itertools 21 | import os 22 | 23 | import numpy as np 24 | import torch 25 | from tqdm import trange 26 | from torch.nn.utils.rnn import pad_sequence 27 | from torch.utils.data import DataLoader, IterableDataset 28 | 29 | from transformers import AutoTokenizer, AutoConfig, AutoModel 30 | from aligner.word_align import SentenceAligner_word 31 | from tqdm import tqdm 32 | 33 | 34 | 35 | 36 | class LineByLineTextDataset(IterableDataset): 37 | def __init__(self, tokenizer, file_path_src, file_path_tgt, offsets=None): 38 | assert os.path.isfile(file_path_src) 39 | assert os.path.isfile(file_path_tgt) 40 | print('Loading the dataset...') 41 | self.examples = [] 42 | self.tokenizer = tokenizer 43 | self.file_path_src = file_path_src 44 | self.file_path_tgt = file_path_tgt 45 | self.offsets = offsets 46 | 47 | def process_line(self, line_src, line_tgt): 48 | """ 49 | if len(line) == 0 or line.isspace() or not len(line.split(' ||| ')) == 2: 50 | return None 51 | """ 52 | if len(line_src) == 0 or len(line_tgt) == 0: 53 | return None 54 | 55 | sent_src, sent_tgt = line_src.strip().split(), line_tgt.strip().split() 56 | token_src, token_tgt = [self.tokenizer.tokenize(word) for word in sent_src], [self.tokenizer.tokenize(word) for 57 | word in sent_tgt] 58 | wid_src, wid_tgt = [self.tokenizer.convert_tokens_to_ids(x) for x in token_src], [ 59 | self.tokenizer.convert_tokens_to_ids(x) for x in token_tgt] 60 | 61 | ids_src, ids_tgt = self.tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', 62 | max_length=512)['input_ids'], \ 63 | self.tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', 64 | max_length=512)['input_ids'] 65 | 66 | 67 | bpe2word_map_src = [] 68 | for i, word_list in enumerate(token_src): 69 | bpe2word_map_src += [i for x in word_list] 70 | bpe2word_map_tgt = [] 71 | for i, word_list in enumerate(token_tgt): 72 | bpe2word_map_tgt += [i for x in word_list] 73 | return (ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sent_src, sent_tgt) 74 | 75 | def __iter__(self): 76 | 77 | f_src = open(self.file_path_src, encoding="utf-8") 78 | f_tgt = open(self.file_path_tgt, encoding="utf-8") 79 | i = 0 80 | for line_src, line_tgt in zip(f_src, f_tgt): 81 | i = i+1 82 | if line_src and line_tgt: 83 | 84 | processed = self.process_line(line_src, line_tgt) 85 | if processed is None: 86 | print( 87 | f'Line "{line_src.strip()}" (offset in bytes: {f_src.tell()}) is not in the correct format. Skipping...') 88 | empty_tensor = torch.tensor([self.tokenizer.cls_token_id, 999, self.tokenizer.sep_token_id]) 89 | empty_sent = '' 90 | yield (empty_tensor, empty_tensor, [-1], [-1], empty_sent, empty_sent) 91 | else: 92 | yield processed 93 | 94 | 95 | 96 | 97 | 98 | 99 | def word_align(args, tokenizer, model, folder_path, src_path, tgt_path): 100 | 101 | device = torch.device('cuda:1') 102 | def collate(examples): 103 | ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = zip(*examples) 104 | ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id) 105 | ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) 106 | return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt 107 | 108 | 109 | dataset = LineByLineTextDataset(tokenizer, file_path_src=src_path, file_path_tgt=tgt_path) 110 | dataloader = DataLoader( 111 | dataset, batch_size=args.per_gpu_train_batch_size, collate_fn=collate 112 | ) 113 | 114 | tqdm_iterator = trange(0, desc="Extracting") 115 | model_sentence = SentenceAligner_word(args, model) 116 | 117 | 118 | word_aligns_list_all_layer_dic = {} 119 | model.eval() 120 | for batch in tqdm(dataloader): 121 | with torch.no_grad(): 122 | ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = batch 123 | 124 | ids_src, ids_tgt = ids_src, ids_tgt 125 | word_aligns_list_all_layer_dic_one_batch = model_sentence.get_aligned_word(args, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id, output_prob = False) 126 | 127 | for layer_id in word_aligns_list_all_layer_dic_one_batch: 128 | 129 | if layer_id not in word_aligns_list_all_layer_dic: 130 | word_aligns_list_all_layer_dic[layer_id] = word_aligns_list_all_layer_dic_one_batch[layer_id] 131 | else: 132 | word_aligns_list_all_layer_dic[layer_id] = word_aligns_list_all_layer_dic[layer_id] + word_aligns_list_all_layer_dic_one_batch[layer_id] 133 | 134 | 135 | 136 | for layer_id in word_aligns_list_all_layer_dic: 137 | with open(os.path.join(folder_path, f'{"XX2XX.align"}.{str(layer_id)}'),'w', encoding='utf-8') as writers: 138 | for word_aligns in word_aligns_list_all_layer_dic[layer_id]: 139 | output_str = [] 140 | for word_align in word_aligns: 141 | if word_align[0] != -1: 142 | output_str.append(f'{word_align[0]}-{word_align[1]}') 143 | writers.write(' '.join(output_str) + '\n') 144 | 145 | 146 | # def main(): 147 | # parser = argparse.ArgumentParser() 148 | # 149 | # # Required parameters 150 | # parser.add_argument( 151 | # "--data_file_src", default="/nfsshare/home/wangweikang/en2ces_101/output/en2Ces.src", type=str, 152 | # help="The input data file (a text file)." 153 | # ) 154 | # parser.add_argument( 155 | # "--data_file_tgt", default="/nfsshare/home/wangweikang/en2ces_101/output/en2Ces.tgt", type=str, 156 | # help="The input data file (a text file)." 157 | # ) 158 | # parser.add_argument( 159 | # "--output_file", 160 | # default='/nfsshare/home/wangweikang/my_alignment/valid_output/full/800', 161 | # type=str, 162 | # help="The output file." 163 | # ) 164 | # parser.add_argument("--align_layer", type=int, default=6, help="layer for alignment extraction") 165 | # parser.add_argument( 166 | # "--extraction", default='softmax', type=str, help='softmax or others' 167 | # ) 168 | # parser.add_argument( 169 | # "--softmax_threshold", type=float, default=0.1 170 | # ) 171 | # parser.add_argument( 172 | # "--output_prob_file", default=None, type=str, help='The output probability file.' 173 | # ) 174 | # parser.add_argument( 175 | # "--output_word_file", default=None, type=str, help='The output word file.' 176 | # ) 177 | # parser.add_argument( 178 | # "--model_name_or_path", 179 | # default="/nfsshare/home/wangweikang/my_alignment/ckpt_output/full_fune/checkpoint-800", 180 | # type=str, 181 | # help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.", 182 | # ) 183 | # parser.add_argument( 184 | # "--adapter_path", 185 | # default="/nfsshare/home/wangweikang/my_alignment/ckpt_output/full_fune/checkpoint-800", 186 | # type=str, 187 | # help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.", 188 | # ) 189 | # parser.add_argument("--batch_size", default=32, type=int) 190 | # 191 | # parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 192 | # parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading") 193 | # parser.add_argument("--tokenizer_name_or_path", type=str, default="xlm-roberta-base") 194 | # args = parser.parse_args() 195 | # device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 196 | # args.device = device 197 | # 198 | # 199 | # tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 200 | # model = AutoModel.from_pretrained(args.model_name_or_path) 201 | # if args.adapter_path: 202 | # model.load_adapter(args.adapter_path) 203 | # model.set_active_adapters('alignment_adapter') 204 | # 205 | # 206 | # word_align(args, tokenizer, model, args.output_file, args.data_file_src, args.data_file_tgt) 207 | # 208 | # 209 | # if __name__ == "__main__": 210 | # main() 211 | 212 | -------------------------------------------------------------------------------- /train_alignment_adapter.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Modifications copyright (C) 2020, Zi-Yi Dou 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | import argparse 20 | import glob 21 | import logging 22 | import os 23 | import random 24 | import re 25 | from typing import Dict, List, Tuple 26 | from tqdm import tqdm 27 | import copy 28 | 29 | import numpy as np 30 | import torch 31 | from torch.nn.utils.rnn import pad_sequence 32 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 33 | from torch.utils.data.distributed import DistributedSampler 34 | from tqdm import tqdm, trange 35 | from self_training_modeling_adapter import BertForSO 36 | from transformers import AutoTokenizer, AutoConfig, AutoModel, AdamW, get_linear_schedule_with_warmup, HoulsbyConfig 37 | from train_utils import _sorted_checkpoints, _rotate_checkpoints, WEIGHTS_NAME 38 | from transformers import AdapterConfig 39 | 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | import itertools 44 | from aligner.sent_aligner import word_align 45 | 46 | 47 | class LineByLineTextDataset(Dataset): 48 | def __init__(self, tokenizer, args, file_path_src, file_path_tgt, gold_path): 49 | assert os.path.isfile(file_path_src) 50 | assert os.path.isfile(file_path_tgt) 51 | logger.info("Creating features from dataset file at %s", file_path_src) 52 | 53 | assert file_path_src != file_path_tgt 54 | 55 | #cache_fn = f'{file_path_src}.cache' if gold_path is None else f'{file_path}.gold.cache' 56 | if args.cache_data and os.path.isfile(cache_fn) and not args.overwrite_cache: 57 | logger.info("Loading cached data from %s", cache_fn) 58 | self.examples = torch.load(cache_fn) 59 | else: 60 | # Loading text data 61 | self.examples = [] 62 | with open(file_path_src, encoding="utf-8") as fs: 63 | lines_src = fs.readlines() 64 | with open(file_path_tgt, encoding="utf-8") as ft: 65 | lines_tgt = ft.readlines() 66 | 67 | # Loading gold data 68 | if gold_path is not None: 69 | assert os.path.isfile(gold_path) 70 | logger.info("Loading gold alignments at %s", gold_path) 71 | with open(gold_path, encoding="utf-8") as f: 72 | gold_lines = f.readlines() 73 | assert len(gold_lines) == len(lines_src) 74 | 75 | i = 0 76 | for line_id, (line_src, line_tgt) in tqdm(enumerate(zip(lines_src, lines_tgt))): 77 | i = i + 1 78 | if line_src and line_tgt: 79 | sent_src, sent_tgt = line_src.strip().split(), line_tgt.strip().split() 80 | token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) 81 | for word in sent_tgt] 82 | wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [ 83 | tokenizer.convert_tokens_to_ids(x) for x in token_tgt] 84 | ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', 85 | max_length=args.max_len)['input_ids'], \ 86 | tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', 87 | max_length=args.max_len)['input_ids'] 88 | 89 | 90 | if len(ids_src) == 2 or len(ids_tgt) == 2: 91 | #logger.info("Skipping instance src %s", line_src) 92 | #logger.info("Skipping instance tgt %s", lines_tgt) 93 | continue 94 | 95 | 96 | 97 | bpe2word_map_src = [] 98 | for i, word_list in enumerate(token_src): 99 | bpe2word_map_src += [i for x in word_list] 100 | bpe2word_map_tgt = [] 101 | for i, word_list in enumerate(token_tgt): 102 | bpe2word_map_tgt += [i for x in word_list] 103 | 104 | if gold_path is not None: 105 | try: 106 | gold_line = gold_lines[line_id].strip().split() 107 | gold_word_pairs = [] 108 | for src_tgt in gold_line: 109 | if 'p' in src_tgt: 110 | if args.ignore_possible_alignments: 111 | continue 112 | wsrc, wtgt = src_tgt.split('p') 113 | else: 114 | wsrc, wtgt = src_tgt.split('-') 115 | wsrc, wtgt = (int(wsrc), int(wtgt)) if not args.gold_one_index else ( 116 | int(wsrc) - 1, int(wtgt) - 1) 117 | gold_word_pairs.append((wsrc, wtgt)) 118 | self.examples.append( 119 | (ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, gold_word_pairs,[len(ids_src)-2, len(ids_tgt)-2])) 120 | except: 121 | logger.info("Error when processing the gold alignment %s, skipping", 122 | gold_lines[line_id].strip()) 123 | continue 124 | else: 125 | self.examples.append((ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, None, [len(ids_src)-2, len(ids_tgt)-2])) 126 | 127 | if args.cache_data: 128 | logger.info("Saving cached data to %s", cache_fn) 129 | torch.save(self.examples, cache_fn) 130 | 131 | 132 | 133 | def __len__(self): 134 | return len(self.examples) 135 | 136 | def __getitem__(self, i): 137 | neg_i = random.randint(0, len(self.examples) - 1) 138 | while neg_i == i: 139 | neg_i = random.randint(0, len(self.examples) - 1) 140 | return tuple(list(self.examples[i]) + list(self.examples[neg_i][:2])) 141 | 142 | 143 | def load_and_cache_examples(args, tokenizer, evaluate=False): 144 | file_path_src = args.eval_data_file_src if evaluate else args.train_data_file_src 145 | file_path_tgt = args.eval_data_file_tgt if evaluate else args.train_data_file_tgt 146 | gold_path = args.eval_gold_file if evaluate else args.train_gold_file 147 | return LineByLineTextDataset(tokenizer, args, file_path_src=file_path_src, file_path_tgt=file_path_tgt, gold_path=gold_path) 148 | 149 | 150 | def set_seed(args): 151 | if args.seed >= 0: 152 | random.seed(args.seed) 153 | np.random.seed(args.seed) 154 | torch.manual_seed(args.seed) 155 | torch.cuda.manual_seed_all(args.seed) 156 | 157 | 158 | 159 | 160 | 161 | def train(args, train_dataset, model, tokenizer) -> Tuple[int, float]: 162 | """ Train the model """ 163 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 164 | 165 | 166 | 167 | def collate(examples): 168 | #model_init.eval() 169 | model.eval() 170 | examples_src, examples_tgt, examples_srctgt, examples_tgtsrc, langid_srctgt, langid_tgtsrc, psi_examples_srctgt, psi_labels = [], [], [], [], [], [], [], [] 171 | src_len = tgt_len = 0 172 | bpe2word_map_src, bpe2word_map_tgt = [], [] 173 | word_aligns = [] 174 | pairs_len = [] 175 | for example in examples: 176 | end_id = example[0][-1].view(-1) 177 | 178 | src_id = example[0][:args.block_size] 179 | src_id = torch.cat([src_id[:-1], end_id]) 180 | tgt_id = example[1][:args.block_size] 181 | tgt_id = torch.cat([tgt_id[:-1], end_id]) 182 | 183 | examples_src.append(src_id) 184 | examples_tgt.append(tgt_id) 185 | src_len = max(src_len, len(src_id)) 186 | tgt_len = max(tgt_len, len(tgt_id)) 187 | 188 | bpe2word_map_src.append(example[2]) 189 | bpe2word_map_tgt.append(example[3]) 190 | word_aligns.append(example[4]) 191 | 192 | pairs_len.append(example[5]) 193 | 194 | examples_src = pad_sequence(examples_src, batch_first=True, padding_value=tokenizer.pad_token_id) 195 | examples_tgt = pad_sequence(examples_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) 196 | 197 | 198 | if word_aligns[0] is None: 199 | word_aligns = None 200 | if args.n_gpu > 1 or args.local_rank != -1: 201 | 202 | guides = model.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, 203 | args.device, src_len, tgt_len, align_layer=args.align_layer, 204 | extraction=args.extraction, softmax_threshold=args.softmax_threshold, 205 | word_aligns=word_aligns, pairs_len=pairs_len) 206 | else: 207 | guides = model.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 208 | src_len, tgt_len, align_layer=args.align_layer, extraction=args.extraction, 209 | softmax_threshold=args.softmax_threshold, word_aligns=word_aligns, pairs_len=pairs_len) 210 | 211 | return examples_src, examples_tgt, guides 212 | 213 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 214 | train_dataloader = DataLoader( 215 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate 216 | ) 217 | 218 | 219 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 220 | 221 | if args.max_steps > 0 and args.max_steps < t_total: 222 | t_total = args.max_steps 223 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 224 | 225 | # Prepare optimizer and schedule (linear warmup and decay) 226 | no_decay = ["bias", "LayerNorm.weight"] 227 | 228 | optimizer_grouped_parameters = [ 229 | { 230 | "params": [p for n, p in model.named_parameters() if (not (any(nd in n for nd in no_decay)))], 231 | "weight_decay": args.weight_decay, 232 | }, 233 | {"params": [p for n, p in model.named_parameters() if ((any(nd in n for nd in no_decay)))], 234 | "weight_decay": 0.0}, 235 | ] 236 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 237 | scheduler = get_linear_schedule_with_warmup( 238 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 239 | ) 240 | 241 | if args.fp16: 242 | try: 243 | from apex import amp 244 | except ImportError: 245 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 246 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 247 | 248 | # multi-gpu training (should be after apex fp16 initialization) 249 | if args.n_gpu > 1: 250 | model = torch.nn.DataParallel(model) 251 | 252 | # Distributed training (should be after apex fp16 initialization) 253 | if args.local_rank != -1: 254 | model = torch.nn.parallel.DistributedDataParallel( 255 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True 256 | ) 257 | 258 | # Train! 259 | logger.info("***** Running training *****") 260 | logger.info(" Num examples = %d", len(train_dataset)) 261 | logger.info(" Num Epochs = %d", args.num_train_epochs) 262 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 263 | logger.info( 264 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 265 | args.train_batch_size 266 | * args.gradient_accumulation_steps 267 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 268 | ) 269 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 270 | logger.info(" Total optimization steps = %d", t_total) 271 | 272 | global_step = 0 273 | # Check if continuing training from a checkpoint 274 | tr_loss, logging_loss = 0.0, 0.0 275 | 276 | model.zero_grad() 277 | set_seed(args) # Added here for reproducibility 278 | 279 | #writer_train = open(args.train_res_dir, "w") 280 | 281 | 282 | def backward_loss(loss, tot_loss): 283 | if args.n_gpu > 1: 284 | loss = loss.mean() # mean() to average on multi-gpu parallel training 285 | if args.gradient_accumulation_steps > 1: 286 | loss = loss / args.gradient_accumulation_steps 287 | 288 | tot_loss += loss.item() 289 | if args.fp16: 290 | with amp.scale_loss(loss, optimizer) as scaled_loss: 291 | scaled_loss.backward() 292 | else: 293 | loss.backward() 294 | return tot_loss 295 | 296 | tqdm_iterator = trange(int(t_total), desc="Iteration", disable=args.local_rank not in [-1, 0]) 297 | for _ in range(int(args.num_train_epochs)): 298 | for step, batch in enumerate(train_dataloader): 299 | model.train() 300 | 301 | if args.train_so: 302 | inputs_src, inputs_tgt = batch[0].clone(), batch[1].clone() 303 | inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device) 304 | attention_mask_src, attention_mask_tgt = (inputs_src != 0), (inputs_tgt != 0) 305 | guide = batch[2].to(args.device) 306 | loss = model(inputs_src=inputs_src, inputs_tgt=inputs_tgt, attention_mask_src=attention_mask_src, 307 | attention_mask_tgt=attention_mask_tgt, guide=guide, align_layer=args.align_layer, 308 | extraction=args.extraction, softmax_threshold=args.softmax_threshold, 309 | ) 310 | tr_loss = backward_loss(loss, tr_loss) 311 | 312 | if (step + 1) % args.gradient_accumulation_steps == 0: 313 | if args.fp16: 314 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 315 | else: 316 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 317 | optimizer.step() 318 | scheduler.step() # Update learning rate schedule 319 | model.zero_grad() 320 | global_step += 1 321 | tqdm_iterator.update() 322 | 323 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 324 | logger.info(" Step %s. Training loss = %s", str(global_step), 325 | str((tr_loss - logging_loss) / args.logging_steps)) 326 | 327 | 328 | logger.info("***** Training results {} *****".format(global_step)) 329 | #for key in sorted(result.keys()): 330 | logger.info(" %s = %s", str(global_step)+' steps', str((tr_loss - logging_loss) / args.logging_steps)) 331 | #writer_train.write("%s = %s\n" % (str(global_step)+' steps', str((tr_loss - logging_loss) / args.logging_steps))) 332 | 333 | logging_loss = tr_loss 334 | 335 | evaluate(args, model, tokenizer, global_step, prefix='') 336 | 337 | 338 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 339 | checkpoint_prefix = "checkpoint" 340 | 341 | 342 | 343 | output_dir_adapter = os.path.join(args.output_dir_adapter, "{}-{}".format(checkpoint_prefix, global_step)) 344 | 345 | model.save_adapter(output_dir_adapter, "alignment_adapter") 346 | logger.info("Saving adapters to %s", output_dir_adapter) 347 | 348 | if global_step > t_total: 349 | break 350 | 351 | if global_step > t_total: 352 | break 353 | 354 | return global_step, tr_loss / global_step 355 | 356 | 357 | def evaluate(args, model, tokenizer, global_step, prefix="") -> Dict: 358 | # Loop to handle MNLI double evaluation (matched, mis-matched) 359 | eval_output_dir = args.eval_res_dir 360 | 361 | eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) 362 | 363 | if args.local_rank in [-1, 0]: 364 | os.makedirs(eval_output_dir, exist_ok=True) 365 | 366 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 367 | 368 | # Note that DistributedSampler samples randomly 369 | def collate(examples): 370 | model.eval() 371 | examples_src, examples_tgt, examples_srctgt, examples_tgtsrc, langid_srctgt, langid_tgtsrc, psi_examples_srctgt, psi_labels = [], [], [], [], [], [], [], [] 372 | src_len = tgt_len = 0 373 | bpe2word_map_src, bpe2word_map_tgt = [], [] 374 | word_aligns = [] 375 | pairs_len = [] 376 | for example in examples: 377 | end_id = example[0][-1].view(-1) 378 | 379 | src_id = example[0][:args.block_size] 380 | src_id = torch.cat([src_id[:-1], end_id]) 381 | tgt_id = example[1][:args.block_size] 382 | tgt_id = torch.cat([tgt_id[:-1], end_id]) 383 | 384 | examples_src.append(src_id) 385 | examples_tgt.append(tgt_id) 386 | src_len = max(src_len, len(src_id)) 387 | tgt_len = max(tgt_len, len(tgt_id)) 388 | 389 | bpe2word_map_src.append(example[2]) 390 | bpe2word_map_tgt.append(example[3]) 391 | word_aligns.append(example[4]) 392 | pairs_len.append(example[5]) 393 | 394 | examples_src = pad_sequence(examples_src, batch_first=True, padding_value=tokenizer.pad_token_id) 395 | examples_tgt = pad_sequence(examples_tgt, batch_first=True, padding_value=tokenizer.pad_token_id) 396 | 397 | if word_aligns[0] is None: 398 | word_aligns = None 399 | 400 | guides = model.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, 401 | src_len, tgt_len, align_layer=args.align_layer, extraction=args.extraction, 402 | softmax_threshold=args.softmax_threshold, test=False, word_aligns=word_aligns,pairs_len=pairs_len) 403 | 404 | return examples_src, examples_tgt, guides, bpe2word_map_src, bpe2word_map_tgt 405 | 406 | eval_sampler = SequentialSampler(eval_dataset) 407 | eval_dataloader = DataLoader( 408 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate 409 | ) 410 | 411 | # multi-gpu evaluate 412 | if args.n_gpu > 1: 413 | model = torch.nn.DataParallel(model) 414 | 415 | # Eval! 416 | logger.info("***** Running evaluation {} *****".format(prefix)) 417 | logger.info(" Num examples = %d", len(eval_dataset)) 418 | logger.info(" Batch size = %d", args.eval_batch_size) 419 | eval_loss = 0.0 420 | nb_eval_steps = 0 421 | model.eval() 422 | set_seed(args) # Added here for seeds 423 | 424 | 425 | 426 | folder_path = os.path.join(args.eval_res_dir, str(global_step) + '_steps') 427 | if not os.path.exists(folder_path): 428 | os.makedirs(folder_path) 429 | output_file = os.path.join(folder_path, "dev.align.6") 430 | 431 | writers = open(output_file, 'w', encoding='utf-8') 432 | 433 | 434 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 435 | with torch.no_grad(): 436 | inputs_src, inputs_tgt = batch[0].clone(), batch[1].clone() 437 | inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device) 438 | attention_mask_src, attention_mask_tgt = (inputs_src != 0), (inputs_tgt != 0) 439 | 440 | bpe2word_map_src, bpe2word_map_tgt = batch[3], batch[4] 441 | 442 | 443 | word_aligns_list_batch = model.get_aligned_word(inputs_src, inputs_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, [], 444 | [], 445 | align_layer=6, extraction='softmax', softmax_threshold=0.1, test=True, 446 | output_prob=False, 447 | word_aligns=None, pairs_len=None) 448 | 449 | 450 | for aligns_set in word_aligns_list_batch: 451 | output_str = [] 452 | for word_align in aligns_set: 453 | if word_align[0] != -1: 454 | output_str.append(f'{word_align[0]}-{word_align[1]}') 455 | writers.write(' '.join(output_str) + '\n') 456 | 457 | 458 | 459 | 460 | 461 | def main(): 462 | parser = argparse.ArgumentParser() 463 | 464 | # Required parameters 465 | parser.add_argument( 466 | "--eval_data_file_src", default='', type=str, help="The input evluating data file (a text file)." 467 | ) 468 | parser.add_argument( 469 | "--eval_data_file_tgt", default="", type=str, help="The input evaluating data file (a text file)." 470 | ) 471 | 472 | parser.add_argument( 473 | "--train_data_file_src", default='', type=str, 474 | help="The input training data file (a text file)." 475 | ) 476 | parser.add_argument( 477 | "--train_data_file_tgt", default="", type=str, 478 | help="The input training data file (a text file)." 479 | ) 480 | 481 | parser.add_argument( 482 | "--infer_data_file_src", default='', type=str, help="The input evluating data file (a text file)." 483 | ) 484 | parser.add_argument( 485 | "--infer_data_file_tgt", default="", type=str, help="The input evaluating data file (a text file)." 486 | ) 487 | 488 | 489 | 490 | 491 | parser.add_argument( 492 | "--output_dir", 493 | default='', 494 | type=str, 495 | #required=True, 496 | help="The output directory where the model predictions and checkpoints will be written.", 497 | ) 498 | parser.add_argument( 499 | "--output_dir_adapter", 500 | default='', 501 | type=str, 502 | # required=True, 503 | help="The output directory where the adapters will be written.", 504 | ) 505 | parser.add_argument( 506 | "--train_res_dir", 507 | default='', 508 | type=str, 509 | # required=True, 510 | help="The output directory where the training loss will be written.", 511 | ) 512 | parser.add_argument( 513 | "--eval_res_dir", 514 | default='', 515 | type=str, 516 | # required=True, 517 | help="The output directory where the eval loss will be written.", 518 | ) 519 | parser.add_argument( 520 | "--infer_path", 521 | default='', 522 | type=str, 523 | # required=True, 524 | help="The output directory where the inference results will be written.", 525 | ) 526 | 527 | 528 | 529 | parser.add_argument("--train_so", action="store_true") 530 | # Supervised settings 531 | parser.add_argument( 532 | "--train_gold_file", default=None, type=str, help="Gold alignment for training data" 533 | ) 534 | parser.add_argument( 535 | "--eval_gold_file", default=None, type=str, help="Gold alignment for evaluation data" 536 | ) 537 | parser.add_argument( 538 | "--ignore_possible_alignments", action="store_true", help="Whether to ignore possible gold alignments" 539 | ) 540 | parser.add_argument( 541 | "--gold_one_index", action="store_true", help="Whether the gold alignment files are one-indexed" 542 | ) 543 | # Other parameters 544 | parser.add_argument("--cache_data", action="store_true", help='if cache the dataset') 545 | parser.add_argument("--align_layer", type=int, default=6, help="layer for alignment extraction") 546 | parser.add_argument( 547 | "--extraction", default='softmax', type=str, choices=['softmax', 'entmax'], help='softmax or entmax' 548 | ) 549 | parser.add_argument( 550 | "--softmax_threshold", type=float, default=0.1 551 | ) 552 | 553 | parser.add_argument( 554 | "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir" 555 | ) 556 | parser.add_argument( 557 | "--model_name_or_path", 558 | default=None, 559 | type=str, 560 | help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.", 561 | ) 562 | 563 | parser.add_argument( 564 | "--adapter_path", 565 | default=None, 566 | type=str, 567 | help="The adapter checkpoint for weights initialization. Leave None if you want to train a model from scratch.", 568 | ) 569 | # parser.add_argument( 570 | # "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" 571 | # ) 572 | parser.add_argument( 573 | "--config_name", 574 | default=None, 575 | type=str, 576 | help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.", 577 | ) 578 | parser.add_argument( 579 | "--tokenizer_name", 580 | default=None, 581 | type=str, 582 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.", 583 | ) 584 | parser.add_argument( 585 | "--cache_dir", 586 | default=None, 587 | type=str, 588 | help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)", 589 | ) 590 | parser.add_argument( 591 | "--block_size", 592 | default=-1, 593 | type=int, 594 | help="Optional input sequence length after tokenization." 595 | "The training dataset will be truncated in block of this size for training." 596 | "Default to the model max input length for single sentence inputs (take into account special tokens).", 597 | ) 598 | parser.add_argument( 599 | "--max_len", 600 | default=512, 601 | type=int, 602 | help="max sequence length" 603 | ) 604 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 605 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") 606 | parser.add_argument("--do_test", action="store_true", help="Whether to run infer on the dev set.") 607 | parser.add_argument("--per_gpu_train_batch_size", default=1, type=int, help="Batch size per GPU/CPU for training.") 608 | parser.add_argument( 609 | "--per_gpu_eval_batch_size", default=32, type=int, help="Batch size per GPU/CPU for evaluation." 610 | ) 611 | parser.add_argument( 612 | "--gradient_accumulation_steps", 613 | type=int, 614 | default=4, 615 | help="Number of updates steps to accumulate before performing a backward/update pass.", 616 | ) 617 | parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.") 618 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 619 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 620 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 621 | parser.add_argument( 622 | "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform." 623 | ) 624 | parser.add_argument( 625 | "--max_steps", 626 | default=-1, 627 | type=int, 628 | help="If > 0: set the maximum number of training steps to perform." 629 | ) 630 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 631 | 632 | parser.add_argument("--logging_steps", type=int, default=25, help="Log every X updates steps.") 633 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 634 | parser.add_argument( 635 | "--save_total_limit", 636 | type=int, 637 | default=None, 638 | help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default", 639 | ) 640 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 641 | parser.add_argument( 642 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" 643 | ) 644 | parser.add_argument( 645 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 646 | ) 647 | 648 | parser.add_argument("--reduction_factor", type=int, default=6, help="reduction_factor for adapter") 649 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 650 | parser.add_argument( 651 | "--fp16", 652 | action="store_true", 653 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 654 | ) 655 | parser.add_argument( 656 | "--fp16_opt_level", 657 | type=str, 658 | default="O1", 659 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 660 | "See details at https://nvidia.github.io/apex/amp.html", 661 | ) 662 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 663 | args = parser.parse_args() 664 | 665 | if args.eval_data_file_src is None and args.do_eval: 666 | raise ValueError( 667 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 668 | "or remove the --do_eval argument." 669 | ) 670 | if args.should_continue: 671 | sorted_checkpoints = _sorted_checkpoints(args) 672 | if len(sorted_checkpoints) == 0: 673 | raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.") 674 | else: 675 | args.model_name_or_path = sorted_checkpoints[-1] 676 | 677 | if ( 678 | os.path.exists(args.output_dir_adapter) 679 | and os.listdir(args.output_dir_adapter) 680 | and args.do_train 681 | and not args.overwrite_output_dir 682 | ): 683 | raise ValueError( 684 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 685 | args.output_dir_adapter 686 | ) 687 | ) 688 | 689 | # Setup CUDA, GPU & distributed training 690 | if args.local_rank == -1 or args.no_cuda: 691 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 692 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 693 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 694 | torch.cuda.set_device(args.local_rank) 695 | device = torch.device("cuda", args.local_rank) 696 | torch.distributed.init_process_group(backend="nccl") 697 | args.n_gpu = 1 698 | args.device = device 699 | 700 | # Setup logging 701 | logging.basicConfig( 702 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 703 | datefmt="%m/%d/%Y %H:%M:%S", 704 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 705 | ) 706 | logger.warning( 707 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 708 | args.local_rank, 709 | device, 710 | args.n_gpu, 711 | bool(args.local_rank != -1), 712 | args.fp16, 713 | ) 714 | 715 | # Set seed 716 | set_seed(args) 717 | 718 | # Load pretrained model and tokenizer 719 | if args.local_rank not in [-1, 0]: 720 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab 721 | 722 | #config_class, model_class, tokenizer_class = BertConfig, BertForSO, BertTokenizer 723 | modelforALING, tokenizer_class = BertForSO, AutoTokenizer 724 | config = AutoConfig.from_pretrained(args.model_name_or_path) 725 | 726 | if args.tokenizer_name: 727 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir) 728 | elif args.model_name_or_path: 729 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 730 | else: 731 | raise ValueError( 732 | "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it," 733 | "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__) 734 | ) 735 | 736 | if args.block_size <= 0: 737 | args.block_size = args.max_len 738 | # Our input block size will be the max possible for the model 739 | else: 740 | args.block_size = min(args.block_size, args.max_len) 741 | 742 | 743 | labse_model = AutoModel.from_pretrained(args.model_name_or_path, output_hidden_states=True) 744 | 745 | if args.do_train: 746 | 747 | config_align = HoulsbyConfig(reduction_factor=6) 748 | 749 | labse_model.add_adapter("alignment_adapter", config=config_align) 750 | labse_model.train_adapter("alignment_adapter") 751 | labse_model.set_active_adapters("alignment_adapter") 752 | 753 | 754 | model = modelforALING(args, config, labse_model) 755 | 756 | model.to(args.device) 757 | 758 | if args.local_rank == 0: 759 | torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab 760 | 761 | logger.info("Training/evaluation parameters %s", args) 762 | 763 | if args.do_train: 764 | if args.local_rank not in [-1, 0]: 765 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache 766 | 767 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) 768 | 769 | 770 | 771 | if args.local_rank == 0: 772 | torch.distributed.barrier() 773 | 774 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 775 | 776 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 777 | 778 | 779 | 780 | 781 | 782 | 783 | if args.do_test: 784 | # extract word alignment for all layers 785 | if args.adapter_path: 786 | labse_model.load_adapter(args.adapter_path) 787 | labse_model.set_active_adapters('alignment_adapter') 788 | model = modelforALING(args, config, labse_model) 789 | model.to(args.device) 790 | 791 | # folder_path= os.path.join(args.infer_path, f'{}2{}' + '_steps') 792 | folder_path = args.infer_path 793 | if not os.path.exists(folder_path): 794 | os.makedirs(folder_path) 795 | word_align(args, tokenizer, model, folder_path, args.infer_data_file_src, 796 | args.infer_data_file_tgt) 797 | 798 | 799 | 800 | 801 | if __name__ == "__main__": 802 | main() 803 | 804 | 805 | 806 | 807 | 808 | 809 | 810 | --------------------------------------------------------------------------------