├── .DS_Store ├── LICENSE ├── README.md ├── fs_plugins ├── .DS_Store ├── __init__.py ├── criterions │ ├── __init__.py │ ├── nat_pcfg_loss.py │ └── utilities.py ├── custom_ops │ ├── __init__.py │ ├── logsoftmax_gather.cu │ ├── pcfg_best_tree.cu │ ├── pcfg_loss.cpp │ ├── pcfg_loss.cu │ ├── pcfg_loss.py │ ├── pcfg_loss_backward.cu │ ├── pcfg_viterbi.cu │ └── utilities.h ├── models │ ├── __init__.py │ ├── glat_decomposed_with_link_two_hands_tri_pcfg.py │ └── lemon_tree.py └── tasks │ ├── __init__.py │ └── translation_lev_modified.py ├── test_scripts ├── test_pcfg_viterbi_wmt14_deen.sh ├── test_pcfg_viterbi_wmt14_ende.sh ├── test_pcfg_viterbi_wmt16_enro.sh ├── test_pcfg_viterbi_wmt16_roen.sh ├── test_pcfg_viterbi_wmt17_enzh.sh └── test_pcfg_viterbi_wmt17_zhen.sh └── train_scripts ├── train_wmt14_deen_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh ├── train_wmt14_ende_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh ├── train_wmt16_enro_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh ├── train_wmt16_roen_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh ├── train_wmt17_enzh_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh └── train_wmt17_zhen_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/PCFG-NAT/30c6174320f62a9b3d559155c63f82f8515d3d66/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ICTNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Readme 2 | 3 | This repository contains the code for the paper [**Non-autoregressive Machine Translation with Probabilistic Context-free Grammar**](https://neurips.cc/virtual/2023/poster/71942). 4 | 5 | 6 | 7 | This project is based on [fairseq](https://github.com/facebookresearch/fairseq) and [DA-Transformer](https://github.com/thu-coai/DA-Transformer). 8 | 9 | 10 | #### PCFG-NAT files 11 | 12 | We provide the fairseq plugins in the directory fs_plugins/, some of them (custom_ops/, utilities.py, translation_lev_moditied.py) are copied from the original [DA-Transformer](https://github.com/thu-coai/DA-Transformer). 13 | 14 | 15 | ``` 16 | DASpeech 17 | ├── __init__.py 18 | ├── criterions 19 | │ ├── __init__.py 20 | │ ├── nat_pcfg_loss.py ## PCFG-NAT loss 21 | │ └── utilities.py 22 | ├── custom_ops ## CUDA implementations 23 | │ ├── __init__.py 24 | │ ├── pcfg_best_tree.cu ## best alignment for glat training 25 | │ ├── pcfg_loss.cpp ## cpp wrapper of PCFG-NAT loss 26 | │ ├── pcfg_loss.cu ## forward of PCFG-NAT loss 27 | │ ├── pcfg_loss_backward.cu ## backward of PCFG-NAT loss 28 | │ ├── pcfg_viterbi.cu ## viterbi algorithm of PCFG-NAT inference 29 | │ ├── pcfg_loss.py ## python wrapper of PCFG-NAT loss 30 | │ ├── logsoftmax_gather.cu ## logsoftmax gather 31 | │ └── utilities.h 32 | ├── models 33 | │ ├── __init__.py 34 | │ └── glat_decomposed_with_link_two_hands_tri_pcfg.py ## PCFG-NAT model 35 | │ └── lemon_tree.py ## support tree structure of PCFG-NAT 36 | └── tasks 37 | ├── __init__.py 38 | ├── translation_lev_modified.py ## PCFG-NAT translation task 39 | ``` 40 | 41 | #### Requirements and Installation 42 | 43 | * Python >= 3.7 44 | * Pytorch == 1.10.1 (tested with cuda == 11.3) 45 | * gcc >= 7.0.0 46 | * Install fairseq via `pip install -e fairseq/.` 47 | 48 | #### Preparing Data 49 | Fairseq provides the preprocessed raw datasets here. Please build the binarized dataset by the following script: 50 | 51 | ```bash 52 | input_dir=path/to/raw_data # directory of raw text data 53 | data_dir=path/to/binarized_data # directory of the generated binarized data 54 | src=en # source language id 55 | tgt=de # target language id 56 | fairseq-preprocess --source-lang ${src} --target-lang ${tgt} \ 57 | --trainpref ${input_dir}/train.${src}-${tgt} --validpref ${input_dir}/valid.${src}-${tgt} --testpref ${input_dir}/test.${src}-${tgt} \ 58 | --src-dict ${input_dir}/dict.${src}.txt --tgt-dict ${input_dir}/dict.${tgt}.txt \ 59 | --destdir ${data_dir} --workers 32 60 | ``` 61 | 62 | #### Training 63 | 64 | Here we provide the training script of PCFG-NAT on WMT-14 En-De, and the training scripts of PCFG-NAT on WMT17 En-Zh and WMT-16 En-Ro are in `train_scripts/`. 65 | ```bash 66 | exp=exp_name 67 | root=fairseq 68 | data_dir=data_dir 69 | checkpoint_dir=checkpoint_dir 70 | user_dir=fs_plugins 71 | fairseq-train ${data_dir} \ 72 | --user-dir $user_dir \ 73 | --task translation_lev_modified --noise full_mask \ 74 | --arch glat_decomposed_with_link_two_hands_tri_pcfg_base \ 75 | --decoder-learned-pos --encoder-learned-pos \ 76 | --share-all-embeddings --activation-fn gelu \ 77 | --apply-bert-init \ 78 | --links-feature feature:position --decode-strategy lookahead \ 79 | --max-source-positions 128 --max-target-positions 1030 --src-upsample-scale 4.0 \ 80 | --left-tree-layer 1 \ 81 | --criterion nat_pcfg_loss \ 82 | --length-loss-factor 0 --max-transition-length 99999 \ 83 | --glat-p 0.5:0.1@200k --glance-strategy number-random \ 84 | --no-force-emit \ 85 | --optimizer adam --adam-betas '(0.9,0.999)' \ 86 | --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.1 \ 87 | --lr-scheduler inverse_sqrt --warmup-updates 10000 \ 88 | --clip-norm 0.1 --lr 0.0005 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \ 89 | --min-loss-scale 0 --ddp-backend c10d \ 90 | --max-tokens 2730 --update-freq 3 --grouped-shuffling \ 91 | --max-update 300000 --max-tokens-valid 1024 \ 92 | --save-interval 1 --save-interval-updates 10000 \ 93 | --seed 0 --fp16 \ 94 | --validate-interval 1 --validate-interval-updates 10000 \ 95 | --skip-invalid-size-inputs-valid-test \ 96 | --fixed-validation-seed 7 \ 97 | --best-checkpoint-metric loss \ 98 | --keep-last-epochs 32 \ 99 | --keep-best-checkpoints 10 --save-dir ${checkpoint_dir} \ 100 | --log-format 'simple' --log-interval 100 101 | ``` 102 | Most the command line arguments are the same as [fairseq](https://github.com/facebookresearch/fairseq) and [DA-Transformer](https://github.com/thu-coai/DA-Transformer). 103 | `--left-tree-layer 1 \` means the local prefix tree in support tree only has one layer. 104 | 105 | 106 | #### Evaluation 107 | 108 | * Average the best 5 checkpoints. 109 | * Here we provide the decoding script of PCFG-NAT on WMT-14 En-De, and the evaluation scripts of PCFG-NAT on WMT17 En-Zh and WMT-16 En-Ro are in `test_scripts/`. 110 | 111 | ```bash 112 | exp=exp_name 113 | root=fairseq 114 | data_dir=data_dir 115 | checkpoint_dir=checkpoint_dir 116 | user_dir=fs_plugins 117 | 118 | fairseq-generate ${data_dir} \ 119 | --gen-subset test --user-dir $user_dir --task translation_lev_modified \ 120 | --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \ 121 | --remove-bpe --batch-size 1 --seed 0 \ 122 | --model-overrides "{\"decode_strategy\":\"viterbi\", \"decode_viterbibeta\":1.0}" \ 123 | --path $checkpoint_dir/average_best_5.pt 124 | ``` 125 | 126 | #### Citation 127 | 128 | If this repository is useful for you, please cite as: 129 | ``` 130 | @inproceedings{ 131 | gui2023pcfg, 132 | title={Non-autoregressive Machine Translation with Probabilistic Context-free Grammar}, 133 | author={Gui, Shangtong and Shao, Chenze and Ma, Zhengrui and Zhang, Xishan and Chen, Yunji and Feng, Yang}, 134 | booktitle={Advances in Neural Information Processing Systems}, 135 | year={2023}, 136 | } 137 | ``` 138 | 139 | -------------------------------------------------------------------------------- /fs_plugins/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/PCFG-NAT/30c6174320f62a9b3d559155c63f82f8515d3d66/fs_plugins/.DS_Store -------------------------------------------------------------------------------- /fs_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterions import * 2 | from .models import * 3 | from .tasks import * 4 | 5 | print("fairseq plugins loaded...") -------------------------------------------------------------------------------- /fs_plugins/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.criterions." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/criterions/nat_pcfg_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | import logging 4 | from functools import reduce 5 | import numpy as np 6 | from typing import Union, Tuple, Optional 7 | import sys 8 | 9 | import torch 10 | from torch import Tensor 11 | import torch.nn.functional as F 12 | from fairseq import metrics, utils 13 | from fairseq.criterions import FairseqCriterion, register_criterion 14 | from torch.autograd import Function 15 | from ..custom_ops import cuda_pcfg_loss, cuda_pcfg_best_tree 16 | from ..custom_ops import dag_logsoftmax_gather_inplace 17 | 18 | from .utilities import parse_anneal_argument, get_anneal_value 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | ########### gpu use tracker ########### 23 | # import inspect 24 | SHOW_MEMORY_USE=False 25 | if SHOW_MEMORY_USE: 26 | from fairseq.gpu_mem_track import MemTracker 27 | gpu_tracker = MemTracker() 28 | ######################################## 29 | 30 | @register_criterion("nat_pcfg_loss") 31 | class NATPCFGoss(FairseqCriterion): 32 | 33 | def __init__(self, cfg, task): 34 | super().__init__(task) 35 | self.cfg = cfg 36 | assert cfg.label_smoothing == 0, "pcfg does not support label smoothing" 37 | self.glance_strategy = cfg.glance_strategy 38 | self._glat_p_anneal_params = parse_anneal_argument(cfg.glat_p) 39 | 40 | self.set_update_num(0) 41 | 42 | @staticmethod 43 | def add_args(parser): 44 | """Add criterion-specific arguments to the parser.""" 45 | parser.add_argument("--label-smoothing", type=float, default=0, help="DA-Transformer does not use label smoothing for now") 46 | parser.add_argument("--glat-p", type=str, default="0", help="Glancing probability. 0.5:0.1@200k indicates annealing p from 0.5 to 0.1 in 200k steps.") 47 | parser.add_argument("--glance-strategy", type=str, default=None, help='Glancing strategy. Possible values: "number-random" or "None" or "CMLM"') 48 | parser.add_argument("--no-force-emit", action="store_true", help="If true, do not fix the position of glance tokens in the second forward pass") 49 | 50 | parser.add_argument("--torch-pfcg-logsoftmax-gather", action="store_true", help="Use torch implementation for logsoftmax-gather, which supports GPU and CPU device. (Cuda implementation only supports GPU)") 51 | parser.add_argument("--torch-pfcg-best-alignment", action="store_true", help="Use torch implementation for pfcg-best-alignment, which supports GPU and CPU device. (Cuda implementation only supports GPU)") 52 | 53 | def _compute_loss(self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0): 54 | """ 55 | outputs: batch x len x d_model 56 | targets: batch x len 57 | masks: batch x len 58 | 59 | policy_logprob: if there is some policy 60 | depends on the likelihood score as rewards. 61 | """ 62 | 63 | def mean_ds(x: Tensor, dim=None) -> Tensor: 64 | return ( 65 | x.float().mean().type_as(x) 66 | if dim is None 67 | else x.float().mean(dim).type_as(x) 68 | ) 69 | 70 | if masks is not None: 71 | outputs, targets = outputs[masks], targets[masks] 72 | 73 | if masks is not None and not masks.any(): 74 | nll_loss = torch.tensor(0) 75 | loss = nll_loss 76 | else: 77 | logits = utils.log_softmax(outputs, dim=-1) 78 | if targets.dim() == 1: 79 | losses = F.nll_loss(logits, targets.to(logits.device), reduction="none") 80 | 81 | else: # soft-labels 82 | losses = F.kl_div(logits, targets.to(logits.device), reduction="none") 83 | losses = losses.sum(-1) 84 | 85 | nll_loss = mean_ds(losses) 86 | if label_smoothing > 0: 87 | loss = ( 88 | nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing 89 | ) 90 | else: 91 | loss = nll_loss 92 | 93 | loss_nofactor = loss 94 | loss = loss * factor 95 | 96 | return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor, "ntokens": outputs.shape[0], "loss_nofactor": loss_nofactor} 97 | 98 | def _compute_pfcg_loss(self, outputs, output_masks, targets, target_masks, links, label_smoothing=0.0, name="loss", 99 | factor=1.0, matchmask=None, keep_word_mask=None, model=None): 100 | 101 | batch_size = outputs.shape[0] 102 | prelen = outputs.shape[1] 103 | tarlen = targets.shape[1] 104 | 105 | output_length = output_masks.sum(dim=-1) 106 | target_length = target_masks.sum(dim=-1) 107 | 108 | 109 | outputs, match_all = dag_logsoftmax_gather_inplace(outputs, targets.unsqueeze(1).expand(-1, prelen, -1)) 110 | 111 | nvalidtokens = output_masks.sum() 112 | eos_mask = torch.zeros_like(match_all).to(match_all).bool() 113 | eos_mask[range(batch_size), :, target_length-1] = True 114 | eos_mask[range(batch_size), output_length-1, target_length-1] = False 115 | match_all_eos_masked = match_all.masked_fill(eos_mask, float("-inf")) 116 | if matchmask is not None and not self.cfg.no_force_emit: 117 | glat_prev_mask = keep_word_mask.unsqueeze(-1) 118 | 119 | matchmask = matchmask.transpose(1,2) 120 | # print(match_all_eos_masked.size(), matchmask.size(), glat_prev_mask.size()) 121 | match_all_eos_masked = match_all_eos_masked.masked_fill(glat_prev_mask, 0) + match_all_eos_masked.masked_fill(~matchmask, float("-inf")).masked_fill(~glat_prev_mask, 0).detach() 122 | loss_result = cuda_pcfg_loss(match_all_eos_masked, links, output_length, target_length) 123 | invalid_masks = loss_result.isinf().logical_or(loss_result.isnan()) 124 | # if loss_result.isinf().any(): 125 | # print(loss_result, match_all_eos_masked.size()) 126 | # np.savez('/data/guishangtong/PCFG-NAT/fs_plugins/numba_tests/test.npz', match_all=match_all_eos_masked.detach().cpu().numpy(), links=links.detach().cpu().numpy(), output_length=output_length.detach().cpu().numpy(), target_length=target_length.detach().cpu().numpy()) 127 | # assert(False) 128 | # print(loss_result) 129 | # if loss_result.isinf().any(): 130 | # print(loss_result) 131 | # assert(False) 132 | loss_result.masked_fill_(invalid_masks, 0) 133 | invalid_nsentences = invalid_masks.sum().detach() 134 | 135 | loss = -(loss_result / target_length).mean() 136 | 137 | nll_loss = loss.detach() 138 | nsentences, ntokens = targets.shape[0], targets.ne(self.task.tgt_dict.pad()).sum() 139 | 140 | loss_nofactor = loss 141 | loss = loss * factor 142 | 143 | return {"name": name, "loss": loss, "nll_loss": nll_loss, 144 | "factor": factor, "ntokens": ntokens, "nvalidtokens": nvalidtokens, "nsentences": nsentences, 145 | "loss_nofactor": loss_nofactor, "invalid_nsentences": invalid_nsentences} 146 | 147 | def _custom_loss(self, loss, name="loss", factor=1.0): 148 | return {"name": name, "loss": loss, "factor": factor} 149 | 150 | def set_update_num(self, update_num): 151 | self.glat_p = get_anneal_value(self._glat_p_anneal_params, update_num) 152 | 153 | def forward(self, model, sample, reduce=True): 154 | """Compute the loss for the given sample. 155 | Returns a tuple with three elements: 156 | 1) the loss 157 | 2) the sample size, which is used as the denominator for the gradient 158 | 3) logging outputs to display while training 159 | """ 160 | 161 | # import gc 162 | # gc.collect() 163 | # if SHOW_MEMORY_USE: 164 | # print(torch.cuda.memory_reserved() / 1024 / 1024, file=sys.stderr, flush=True) 165 | # gpu_tracker.clear_cache() 166 | if SHOW_MEMORY_USE: 167 | gpu_tracker.track() 168 | 169 | # B x T 170 | src_tokens, src_lengths = ( 171 | sample["net_input"]["src_tokens"], 172 | sample["net_input"]["src_lengths"], 173 | ) 174 | tgt_tokens = sample["target"] 175 | 176 | # if SHOW_MEMORY_USE: 177 | # print(sample["net_input"]["src_tokens"].shape[0], sample["net_input"]["src_tokens"].shape[1], tgt_tokens.shape[1], file=sys.stderr, end=" ") 178 | 179 | if sample.get("update_num", None) is not None: # in training 180 | self.set_update_num(sample['update_num']) 181 | 182 | prev_output_tokens = model.initialize_output_tokens_by_tokens(src_tokens, tgt_tokens) 183 | 184 | 185 | if self.glat_p == 0: 186 | glat = None 187 | else: 188 | glat = { 189 | "context_p": max(self.glat_p, 0), 190 | "require_glance_grad": False 191 | } 192 | 193 | def glat_function(model, word_ins_out, tgt_tokens, prev_output_tokens, glat, links=None): 194 | batch_size, prelen, _, _ = links.shape 195 | tarlen = tgt_tokens.shape[1] 196 | nonpad_positions = ~tgt_tokens.eq(model.pad) 197 | target_length = (nonpad_positions).sum(1) 198 | output_length = prev_output_tokens.ne(model.pad).sum(1) 199 | pred_tokens = word_ins_out.argmax(-1) 200 | word_ins_out, match = dag_logsoftmax_gather_inplace(word_ins_out, tgt_tokens.unsqueeze(1).expand(-1, prelen, -1)) 201 | eos_mask = torch.zeros_like(match).to(match).bool() 202 | # for batch_id in range(batch_size): 203 | # eos_mask[batch_id, :, target_length[batch_id]-1] = True 204 | # eos_mask[batch_id, output_length[batch_id]-1, target_length[batch_id]-1] = False 205 | eos_mask[range(batch_size), :, target_length-1] = True 206 | eos_mask[range(batch_size), output_length-1, target_length-1] = False 207 | match_all_eos_masked = match.masked_fill(eos_mask, float("-inf")) 208 | 209 | 210 | path, max_tree_lprob = cuda_pcfg_best_tree(match_all_eos_masked, links, output_length, target_length) # batch * prelen 211 | # if max_tree_lprob.isinf().any(): 212 | # np.savez('/data/guishangtong/PCFG-NAT/fs_plugins/numba_tests/test.npz', match_all=match_all_eos_masked.detach().cpu().numpy(), links=links.detach().cpu().numpy(), output_length=output_length.detach().cpu().numpy(), target_length=target_length.detach().cpu().numpy()) 213 | # assert(False) 214 | invalid_masks = max_tree_lprob.isinf().logical_or(max_tree_lprob.isnan()) 215 | max_tree_lprob.masked_fill_(invalid_masks, 0) 216 | predict_align_mask = path >= 0 217 | path[path+1>=(tarlen+1)] = -1 218 | path[path+1<0] = -1 219 | # if not ((path+1) < (tarlen+1)).all(): 220 | # torch.set_printoptions(profile="full") 221 | # print(tarlen+1) 222 | # print(path[(path+1) >= (tarlen+1)]) 223 | # print(path+1) 224 | # assert(((path+1) < (tarlen+1)).all()) 225 | matchmask = torch.zeros(batch_size, tarlen + 1, prelen, device=match.device, dtype=torch.bool).scatter_(1, path.unsqueeze(1) + 1, 1)[:, 1:] 226 | 227 | oracle = tgt_tokens.gather(-1, path.clip(min=0)) # bsz * prelen 228 | # print(pred_tokens.size(), oracle.size(), matchmask.size()) 229 | 230 | same_num = ((pred_tokens == oracle) & predict_align_mask).sum(1) 231 | 232 | if self.glance_strategy is None: 233 | keep_prob = ((target_length - same_num) / target_length * glat['context_p']).unsqueeze(-1) * predict_align_mask.float() 234 | 235 | elif self.glance_strategy in ['number-random']: 236 | prob = torch.randn(oracle.shape, device=tgt_tokens.device, dtype=torch.float) 237 | prob.masked_fill_(~predict_align_mask, -100) 238 | glance_nums = ((target_length - same_num) * glat['context_p'] + 0.5).to(torch.long) 239 | #prob_thresh = prob.topk(glance_nums.max().clip(min=1))[0].gather(-1, (glance_nums - 1).clip(min=0).unsqueeze(-1)).squeeze(-1) 240 | prob_thresh = prob.sort(descending=True)[0].gather(-1, (glance_nums - 1).clip(min=0).unsqueeze(-1)).squeeze(-1) 241 | prob_thresh.masked_fill_(glance_nums == 0, 100) 242 | keep_prob = (prob >= prob_thresh.unsqueeze(-1)).to(prob.dtype) 243 | 244 | elif self.glance_strategy == "cmlm": 245 | prob = torch.randn(oracle.shape, device=tgt_tokens.device, dtype=torch.float) 246 | prob.masked_fill_(~predict_align_mask, -100) 247 | glance_nums = (target_length * torch.rand_like(target_length, dtype=torch.float) + 0.5).to(torch.long) 248 | #prob_thresh = prob.topk(glance_nums.max().clip(min=1))[0].gather(-1, (glance_nums - 1).clip(min=0).unsqueeze(-1)).squeeze(-1) 249 | prob_thresh = prob.sort(descending=True)[0].gather(-1, (glance_nums - 1).clip(min=0).unsqueeze(-1)).squeeze(-1) 250 | prob_thresh.masked_fill_(glance_nums == 0, 100) 251 | keep_prob = (prob >= prob_thresh.unsqueeze(-1)).to(prob.dtype) 252 | 253 | keep_word_mask = (torch.rand(prev_output_tokens.shape, device=prev_output_tokens.device) < keep_prob).bool() 254 | 255 | glat_prev_output_tokens = prev_output_tokens.masked_fill(keep_word_mask, 0) + oracle.masked_fill(~keep_word_mask, 0) 256 | output_length = prev_output_tokens.ne(model.pad).sum(dim=-1) 257 | glat_tgt_tokens = tgt_tokens 258 | 259 | glat_info = { 260 | "glat_accu": (same_num.sum() / target_length.sum()).detach(), 261 | "glat_context_p": glat['context_p'], 262 | "glat_keep": keep_prob.mean().detach(), 263 | "matchmask": matchmask, 264 | "keep_word_mask": keep_word_mask, 265 | "glat_prev_output_tokens": glat_prev_output_tokens, 266 | "max_tree_lprob": -(max_tree_lprob / target_length).mean(), 267 | } 268 | 269 | return glat_prev_output_tokens, glat_tgt_tokens, glat_info 270 | outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens, glat, glat_function) 271 | 272 | losses = [] 273 | 274 | # PCFG loss 275 | _losses = self._compute_pfcg_loss( 276 | outputs["word_ins"].get("out"), 277 | prev_output_tokens.ne(self.task.tgt_dict.pad()), 278 | outputs["word_ins"].get("tgt"), 279 | outputs["word_ins"].get("mask", None), 280 | outputs["links"], 281 | name="pcfg-loss", 282 | factor=1, 283 | matchmask=outputs.get('matchmask', None), 284 | keep_word_mask=outputs.get('keep_word_mask', None), 285 | model=model 286 | ) 287 | 288 | losses += [_losses] 289 | pcfg_nll_loss = _losses.get("nll_loss", 0.0) 290 | nsentences = _losses["nsentences"] 291 | ntokens = _losses["ntokens"] 292 | nvalidtokens = _losses["nvalidtokens"] 293 | invalid_nsentences = _losses["invalid_nsentences"] 294 | 295 | #length 296 | _losses = self._compute_loss( 297 | outputs["length"].get("out"), 298 | outputs["length"].get("tgt"), 299 | None, 300 | 0, 301 | name="length-loss", 302 | factor=outputs["length"]["factor"], ) 303 | losses += [_losses] 304 | length_nll_loss = _losses.get("nll_loss", 0.0) 305 | 306 | loss = sum(l["loss"] for l in losses) 307 | 308 | sample_size = 1 309 | logging_output = { 310 | "loss": loss.data, 311 | "pcfg_nll-loss": pcfg_nll_loss.data, 312 | "length_nll-loss": length_nll_loss.data, 313 | "ntokens": ntokens, 314 | "nvalidtokens": nvalidtokens, 315 | "nsentences": nsentences, 316 | "invalid_nsentences": invalid_nsentences, 317 | "sample_size": sample_size, 318 | "glat_acc": outputs.get("glat_accu", 0), 319 | "glat_keep": outputs.get("glat_keep", 0), 320 | "max_tree_lprob": outputs.get("max_tree_lprob", 0), 321 | } 322 | 323 | for l in losses: 324 | logging_output[l["name"]] = ( 325 | utils.item(l["loss_nofactor"]) 326 | if reduce 327 | else l["loss_nofactor"] 328 | ) 329 | if SHOW_MEMORY_USE: 330 | gpu_tracker.track() 331 | return loss, sample_size, logging_output 332 | 333 | @staticmethod 334 | def reduce_metrics(logging_outputs) -> None: 335 | """Aggregate logging outputs from data parallel training.""" 336 | sample_size = utils.item( 337 | sum(log.get("sample_size", 0) for log in logging_outputs) 338 | ) # each batch is 1 339 | loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) # token-level loss 340 | 341 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 342 | nvalidtokens = sum(log.get('nvalidtokens', 0) for log in logging_outputs) 343 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 344 | invalid_nsentences = sum(log.get('invalid_nsentences', 0) for log in logging_outputs) 345 | loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) # token-level loss 346 | glat_acc = utils.item(sum(log.get("glat_acc", 0) for log in logging_outputs)) 347 | glat_keep = utils.item(sum(log.get("glat_keep", 0) for log in logging_outputs)) 348 | max_tree_lprob = utils.item(sum(log.get("max_tree_lprob", 0) for log in logging_outputs)) 349 | 350 | res = { 351 | "ntokens": utils.item(ntokens), 352 | "nsentences": utils.item(nsentences), 353 | "nvalidtokens": utils.item(nvalidtokens), 354 | "invalid_nsentences": utils.item(invalid_nsentences), 355 | 'tokens_perc': utils.item(nvalidtokens / ntokens), 356 | 'sentences_perc': 1 - utils.item(invalid_nsentences / nsentences), 357 | } 358 | res["loss"] = loss / sample_size 359 | res["glat_acc"] = glat_acc / sample_size 360 | res["glat_keep"] = glat_keep / sample_size 361 | res["max_tree_lprob"] = max_tree_lprob / sample_size 362 | 363 | for key, value in res.items(): 364 | metrics.log_scalar( 365 | key, value, sample_size, round=3 366 | ) 367 | 368 | for key in logging_outputs[0]: 369 | if key[-5:] == "-loss": 370 | val = utils.item(sum(log.get(key, 0) for log in logging_outputs)) 371 | metrics.log_scalar( 372 | key[:-5], 373 | val / sample_size if sample_size > 0 else 0.0, 374 | sample_size, 375 | round=3, 376 | ) 377 | 378 | @staticmethod 379 | def logging_outputs_can_be_summed() -> bool: 380 | """ 381 | Whether the logging outputs returned by `forward` can be summed 382 | across workers prior to calling `reduce_metrics`. Setting this 383 | to True will improves distributed training speed. 384 | """ 385 | return True 386 | -------------------------------------------------------------------------------- /fs_plugins/criterions/utilities.py: -------------------------------------------------------------------------------- 1 | def parse_anneal_argument(anneal_str): 2 | def parse_value_pos(value_str): 3 | if "@" in value_str: 4 | value, pos = value_str.split("@") 5 | else: 6 | value = value_str 7 | pos = "0" 8 | return float(value), float(pos.replace("k", "000")) 9 | 10 | res = [] 11 | for value_str in anneal_str.split(":"): 12 | res.append(parse_value_pos(value_str)) 13 | return res 14 | 15 | def get_anneal_value(anneal_params, update_num): 16 | last_value, last_pos = anneal_params[0][0], 0 17 | for value, pos in anneal_params: 18 | if update_num < pos: 19 | return last_value + (value - last_value) * (update_num - last_pos) / (pos - last_pos + 1) 20 | last_value, last_pos = value, pos 21 | return anneal_params[-1][0] 22 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .pcfg_loss import cuda_pcfg_loss, cuda_pcfg_best_tree, cuda_pcfg_viterbi, dag_logsoftmax_gather_inplace, viterbi_decoding -------------------------------------------------------------------------------- /fs_plugins/custom_ops/logsoftmax_gather.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CPU(x) TORCH_CHECK(x.type().is_cpu(), #x " must be a CPU tensor") 21 | 22 | #define MY_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ 23 | case enum_type: { \ 24 | using HINT = type; \ 25 | return __VA_ARGS__(); \ 26 | } 27 | 28 | #define MY_DISPATCH_FLOATING_TYPES_AND_HALF_WITH_HINT(TYPE, NAME, HINT, ...) \ 29 | [&] { \ 30 | const auto& the_type = TYPE; \ 31 | /* don't use TYPE again in case it is an expensive or side-effect op */ \ 32 | at::ScalarType _st = ::detail::scalar_type(the_type); \ 33 | switch (_st) { \ 34 | MY_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Double, double, HINT, __VA_ARGS__) \ 35 | MY_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Float, float, HINT, __VA_ARGS__) \ 36 | MY_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Half, at::Half, HINT, __VA_ARGS__) \ 37 | default: \ 38 | AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ 39 | } \ 40 | }() 41 | 42 | #define MY_PRIVATE_VALUE(val, HINT, ...) \ 43 | case val: { \ 44 | const int HINT = val; \ 45 | return __VA_ARGS__(); \ 46 | } 47 | 48 | #define MY_DISPATCH_VALUE(VAL, NAME, HINT, ...) \ 49 | [&] { \ 50 | switch (VAL) { \ 51 | MY_PRIVATE_VALUE(1, HINT, __VA_ARGS__) \ 52 | MY_PRIVATE_VALUE(2, HINT, __VA_ARGS__) \ 53 | default: \ 54 | AT_ERROR(#NAME, " not implemented for this value"); \ 55 | } \ 56 | }() 57 | 58 | #define MY_DISPATCH_BOOL(VAL, NAME, HINT, ...) \ 59 | [&] { \ 60 | if (VAL) { \ 61 | const bool HINT = true; \ 62 | return __VA_ARGS__(); \ 63 | }else{ \ 64 | const bool HINT = false; \ 65 | return __VA_ARGS__(); \ 66 | } \ 67 | }() 68 | 69 | 70 | template 71 | struct SumOp { 72 | __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; } 73 | }; 74 | 75 | template 76 | struct MaxOp { 77 | __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); } 78 | }; 79 | 80 | template class ReductionOp, typename T, int block_size> 81 | __inline__ __device__ T BlockAllReduce(T val) { 82 | typedef cub::BlockReduce BlockReduce; 83 | __shared__ typename BlockReduce::TempStorage temp_storage; 84 | __shared__ T result_broadcast; 85 | T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); 86 | if (threadIdx.x == 0) { result_broadcast = result; } 87 | __syncthreads(); 88 | return result_broadcast; 89 | } 90 | 91 | template 92 | __inline__ __device__ T Inf(); 93 | 94 | template<> 95 | __inline__ __device__ float Inf() { 96 | return CUDART_INF_F; 97 | } 98 | 99 | template<> 100 | __inline__ __device__ double Inf() { 101 | return CUDART_INF; 102 | } 103 | 104 | template 105 | __inline__ __device__ T Exp(T x); 106 | 107 | template<> 108 | __inline__ __device__ float Exp(float x) { 109 | #ifdef OF_SOFTMAX_USE_FAST_MATH 110 | return __expf(x); 111 | #else 112 | return exp(x); 113 | #endif 114 | } 115 | 116 | template<> 117 | __inline__ __device__ double Exp(double x) { 118 | return exp(x); 119 | } 120 | 121 | template 122 | __inline__ __device__ T Div(T a, T b); 123 | 124 | template<> 125 | __inline__ __device__ float Div(float a, float b) { 126 | #ifdef OF_SOFTMAX_USE_FAST_MATH 127 | return __fdividef(a, b); 128 | #else 129 | return a / b; 130 | #endif 131 | } 132 | 133 | template<> 134 | __inline__ __device__ double Div(double a, double b) { 135 | return a / b; 136 | } 137 | 138 | template 139 | __inline__ __device__ T Log(T x); 140 | 141 | template<> 142 | __inline__ __device__ float Log(float x) { 143 | #ifdef OF_SOFTMAX_USE_FAST_MATH 144 | return __logf(x); 145 | #else 146 | return log(x); 147 | #endif 148 | } 149 | template<> 150 | __inline__ __device__ double Log(double x) { 151 | return log(x); 152 | } 153 | 154 | inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves, 155 | int* num_blocks) { 156 | int dev; 157 | { 158 | cudaError_t err = cudaGetDevice(&dev); 159 | if (err != cudaSuccess) { return err; } 160 | } 161 | int sm_count; 162 | { 163 | cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); 164 | if (err != cudaSuccess) { return err; } 165 | } 166 | int tpm; 167 | { 168 | cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); 169 | if (err != cudaSuccess) { return err; } 170 | } 171 | *num_blocks = 172 | std::max(1, std::min(max_blocks, sm_count * tpm / block_size * waves)); 173 | return cudaSuccess; 174 | } 175 | 176 | template 177 | struct DefaultComputeType{ 178 | using type = T; 179 | }; 180 | 181 | template<> 182 | struct DefaultComputeType { 183 | using type = float; 184 | }; 185 | 186 | 187 | template 188 | struct GetPackType { 189 | using type = typename std::aligned_storage::type; 190 | }; 191 | 192 | template 193 | using PackType = typename GetPackType::type; 194 | 195 | template 196 | union Pack { 197 | static_assert(sizeof(PackType) == sizeof(T) * N, ""); 198 | __device__ Pack() { 199 | // do nothing 200 | } 201 | PackType storage; 202 | T elem[N]; 203 | }; 204 | 205 | template 206 | struct DirectLoad { 207 | DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {} 208 | template 209 | __device__ void load(DST* dst, int64_t row, int64_t col) const { 210 | Pack pack; 211 | const int64_t offset = (row * row_size + col) / N; 212 | pack.storage = *(reinterpret_cast*>(src) + offset); 213 | #pragma unroll 214 | for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); } 215 | } 216 | const SRC* src; 217 | int64_t row_size; 218 | }; 219 | 220 | template 221 | struct DirectStore { 222 | DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {} 223 | template 224 | __device__ void store(const SRC* src, int64_t row, int64_t col) { 225 | Pack pack; 226 | const int64_t offset = (row * row_size + col) / N; 227 | #pragma unroll 228 | for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast(src[i]); } 229 | *(reinterpret_cast*>(dst) + offset) = pack.storage; 230 | } 231 | DST* dst; 232 | int64_t row_size; 233 | }; 234 | 235 | 236 | 237 | 238 | template 240 | __global__ void logsoftmax_gather_kernel( 241 | LOAD load, STORE store, int64_t rows, int64_t cols, 242 | Accessor word_ins_out, Accessor2 selected_result, Accessor3 select_idx, 243 | int bsz, int prelen, int slen, int vocabsize) 244 | { 245 | const int tid = threadIdx.x; 246 | // assert(cols % pack_size == 0); 247 | static_assert(pack_size == 1, "pack_size should not be 1"); 248 | const int num_packs = cols / pack_size; 249 | 250 | for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { 251 | ComputeType thread_max = -Inf(); 252 | for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { 253 | ComputeType pack[pack_size]; 254 | load.template load(pack, row, pack_id * pack_size); 255 | 256 | #pragma unroll 257 | for (int i = 0; i < pack_size; ++i) { thread_max = max(thread_max, pack[i]); } 258 | } 259 | 260 | const ComputeType row_max = BlockAllReduce(thread_max); 261 | ComputeType thread_sum = 0; 262 | for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { 263 | ComputeType pack[pack_size]; 264 | load.template load(pack, row, pack_id * pack_size); 265 | 266 | #pragma unroll 267 | for (int i = 0; i < pack_size; ++i) { thread_sum += Exp(pack[i] - row_max); } 268 | } 269 | 270 | const ComputeType row_sum = BlockAllReduce(thread_sum); 271 | int batch_id = row / prelen; 272 | int prepos = row % prelen; 273 | for(int sid = tid; sid < slen; sid += block_size){ 274 | int64_t target_idx = select_idx[batch_id][prepos][sid]; 275 | selected_result[batch_id][prepos][sid] = (static_cast(word_ins_out[batch_id][prepos][target_idx]) - row_max) - Log(row_sum); 276 | } 277 | 278 | if (require_grad){ 279 | __syncthreads(); 280 | for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { 281 | ComputeType pack[pack_size]; 282 | load.template load(pack, row, pack_id * pack_size); 283 | 284 | #pragma unroll 285 | for (int i = 0; i < pack_size; ++i) { pack[i] = Div(Exp(pack[i] - row_max), row_sum); } 286 | 287 | store.template store(pack, row, pack_id * pack_size); 288 | } 289 | } 290 | 291 | } 292 | } 293 | 294 | 295 | torch::Tensor logsoftmax_gather(torch::Tensor word_ins_out, const torch::Tensor &select_idx, bool require_gradient) 296 | { 297 | CHECK_CUDA(word_ins_out); // bsz * prelen * vocabsize 298 | CHECK_CUDA(select_idx); // bsz * prelen * slen 299 | TORCH_CHECK(word_ins_out.dim() == 3, "word_ins_out dim != 3"); 300 | TORCH_CHECK(select_idx.dim() == 3, "select_idx dim != 3"); 301 | 302 | auto bsz = word_ins_out.size(0); 303 | auto prelen = word_ins_out.size(1); 304 | auto vocabsize = word_ins_out.size(2); 305 | auto slen = select_idx.size(2); 306 | TORCH_CHECK(select_idx.size(0) == bsz, "batch size not match"); 307 | TORCH_CHECK(select_idx.size(1) == prelen, "prelen size not match"); 308 | TORCH_CHECK(select_idx.scalar_type() == at::kLong, "select_idx should be long"); 309 | TORCH_CHECK(word_ins_out.is_contiguous(), "word_ins_out is not contiguous"); 310 | 311 | constexpr int block_size = 1024; 312 | constexpr int waves = 32; 313 | int grid_dim_x; 314 | { 315 | cudaError_t err = GetNumBlocks(block_size, bsz * prelen, waves, &grid_dim_x); 316 | assert(err == cudaSuccess); 317 | } 318 | 319 | torch::Tensor selected_result; 320 | cudaStream_t stream = 0; 321 | 322 | MY_DISPATCH_FLOATING_TYPES_AND_HALF_WITH_HINT( 323 | word_ins_out.scalar_type(), "logsoftmax_gather_kernel_scalar_t", scalar_t, [&] { 324 | using ComputeType = typename DefaultComputeType::type; 325 | if (std::is_same::value){ 326 | selected_result = at::zeros({bsz, prelen, slen}, word_ins_out.options().dtype(at::kFloat)); 327 | }else{ 328 | selected_result = at::zeros({bsz, prelen, slen}, word_ins_out.options().dtype(at::kDouble)); 329 | } 330 | using Load = DirectLoad; 331 | using Store = DirectStore; 332 | Load load(word_ins_out.data_ptr(), vocabsize); 333 | Store store(word_ins_out.data_ptr(), vocabsize); 334 | int64_t cols = vocabsize; 335 | int64_t rows = bsz * prelen; 336 | const int PackSize = 1; 337 | // MY_DISPATCH_VALUE( 338 | // pack_size, "GatherVocabLogitsKernel_pack_size", PackSize, [&]{ 339 | MY_DISPATCH_BOOL( 340 | require_gradient, "logsoftmax_gather_kernel_require_gradient", RequireGrad, [&]{ 341 | logsoftmax_gather_kernel 342 | <<>> 343 | ( 344 | load, store, rows, cols, 345 | word_ins_out.packed_accessor64(), 346 | selected_result.packed_accessor64(), 347 | select_idx.packed_accessor64(), 348 | bsz, prelen, slen, vocabsize 349 | ); 350 | assert(cudaPeekAtLastError() == cudaSuccess); 351 | } 352 | ); 353 | // } 354 | // ); 355 | } 356 | ); 357 | 358 | return selected_result; 359 | } 360 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/pcfg_best_tree.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include "utilities.h" 17 | 18 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 19 | #define CHECK_CPU(x) TORCH_CHECK(x.type().is_cpu(), #x " must be a CPU tensor") 20 | 21 | #define MAIN_CHAIN_I2E(x, max_left) ((x)*((max_left)+1)+1) 22 | #define MAIN_CHAIN_E2I(x, max_left) (((x)-1)/((max_left)+1)) 23 | #define LOCAL_TREE_I2E(x, max_left) (((x)==0)?0:((x) + ((x)-1) / (max_left) + 1)) 24 | #define LOCAL_TREE_E2I(x, max_left) ((x)==0?0:((x) - ((x) - 1) / ((max_left)+1) - 1)) 25 | 26 | #define BLOCK_BUCKET 16 27 | 28 | // Define this to turn on error checking 29 | #define CUDA_ERROR_CHECK 30 | 31 | #define CudaSafeCall( err ) __cudaSafeCall( err, __FILE__, __LINE__ ) 32 | #define CudaCheckError() __cudaCheckError( __FILE__, __LINE__ ) 33 | 34 | inline void __cudaSafeCall( cudaError err, const char *file, const int line ) 35 | { 36 | #ifdef CUDA_ERROR_CHECK 37 | if ( cudaSuccess != err ) 38 | { 39 | fprintf( stderr, "cudaSafeCall() failed at %s:%i : %s\n", 40 | file, line, cudaGetErrorString( err ) ); 41 | exit( -1 ); 42 | } 43 | #endif 44 | 45 | return; 46 | } 47 | 48 | inline void __cudaCheckError( const char *file, const int line ) 49 | { 50 | #ifdef CUDA_ERROR_CHECK 51 | cudaError err = cudaGetLastError(); 52 | if ( cudaSuccess != err ) 53 | { 54 | fprintf( stderr, "cudaCheckError() failed at %s:%i : %s\n", 55 | file, line, cudaGetErrorString( err ) ); 56 | exit( -1 ); 57 | } 58 | #endif 59 | 60 | return; 61 | } 62 | 63 | 64 | template 65 | __global__ void calculate_S_trace_kernel( 66 | volatile int *bucket_queue, volatile int *accomplish_queue, volatile int *start_queue, 67 | Accessor1 S_trace, 68 | Accessor2 C_trace, 69 | Accessor3 tree, 70 | Accessor1 S, 71 | Accessor2 C, 72 | Accessor1 match_all, 73 | Accessor2 links, 74 | Accessor4 output_length, 75 | Accessor4 target_length, 76 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 77 | { 78 | __shared__ volatile int task_id; 79 | __shared__ volatile int seg_id; 80 | __shared__ volatile int start; 81 | 82 | bool main_thread = threadIdx.x == 0 && threadIdx.y == 0; 83 | 84 | int ticket_no = blockIdx.y; 85 | int batch_id = ticket_no % bsz; 86 | 87 | 88 | int m = output_length[batch_id]; 89 | int n = target_length[batch_id]; 90 | __threadfence(); 91 | __syncthreads(); 92 | unsigned shfl_mask = (1 << TRANS_BLOCK_SIZE) - 1; 93 | shfl_mask = shfl_mask << (threadIdx.y % (32 / TRANS_BLOCK_SIZE) * TRANS_BLOCK_SIZE); 94 | while(start_queue[batch_id]< (n+2)*n_seg){ 95 | if(main_thread){ 96 | task_id = atomicAdd((int*)start_queue + batch_id, 1); 97 | // printf("batch_id:%d task_id:%d addr:%d\n", batch_id, task_id, (int*)start_queue + batch_id); 98 | seg_id = task_id % n_seg; 99 | start = task_id/n_seg; 100 | bool done = false; 101 | while(!done){ 102 | done = true; 103 | for(int i=0; i= start); 105 | } 106 | } 107 | } 108 | __threadfence(); 109 | __syncthreads(); 110 | int a_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 111 | int a = MAIN_CHAIN_I2E(a_id, max_left); 112 | if(start == 0){ 113 | if(seg_id == 0 && main_thread){ 114 | S_trace[batch_id][0][0] = 1; 115 | // printf("S_trace: %d 1 0 %f\n", batch_id, S_trace[batch_id][1][0]); 116 | } 117 | } 118 | else{ 119 | int _start = start - 1; 120 | if (a > 0 && a < m && S_trace[batch_id][a_id][_start]!=0){ 121 | int max_c = -1, max_j = 0, max_b = -1; 122 | scalar_t maxval = -std::numeric_limits::infinity(); 123 | for(int c_id = a_id + threadIdx.x+1; MAIN_CHAIN_I2E(c_id, max_left) ::infinity(); 133 | if(temp > maxval){ maxval = temp; max_c = c_id; max_j = j; max_b = b_id; } 134 | } 135 | } 136 | 137 | 138 | } 139 | 140 | __syncwarp(shfl_mask); 141 | if_constexpr (TRANS_BLOCK_SIZE > 16) { 142 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 16, TRANS_BLOCK_SIZE); 143 | int next_c = __shfl_down_sync(shfl_mask, max_c, 16, TRANS_BLOCK_SIZE); 144 | int next_j = __shfl_down_sync(shfl_mask, max_j, 16, TRANS_BLOCK_SIZE); 145 | int next_b = __shfl_down_sync(shfl_mask, max_b, 16, TRANS_BLOCK_SIZE); 146 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_j = next_j; max_b = next_b; }} 147 | if_constexpr (TRANS_BLOCK_SIZE > 8) { 148 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 8, TRANS_BLOCK_SIZE); 149 | int next_c = __shfl_down_sync(shfl_mask, max_c, 8, TRANS_BLOCK_SIZE); 150 | int next_j = __shfl_down_sync(shfl_mask, max_j, 8, TRANS_BLOCK_SIZE); 151 | int next_b = __shfl_down_sync(shfl_mask, max_b, 8, TRANS_BLOCK_SIZE); 152 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_j = next_j; max_b = next_b; }} 153 | if_constexpr (TRANS_BLOCK_SIZE > 4) { 154 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 4, TRANS_BLOCK_SIZE); 155 | int next_c = __shfl_down_sync(shfl_mask, max_c, 4, TRANS_BLOCK_SIZE); 156 | int next_j = __shfl_down_sync(shfl_mask, max_j, 4, TRANS_BLOCK_SIZE); 157 | int next_b = __shfl_down_sync(shfl_mask, max_b, 4, TRANS_BLOCK_SIZE); 158 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_j = next_j; max_b = next_b; }} 159 | if_constexpr (TRANS_BLOCK_SIZE > 2) { 160 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 2, TRANS_BLOCK_SIZE); 161 | int next_c = __shfl_down_sync(shfl_mask, max_c, 2, TRANS_BLOCK_SIZE); 162 | int next_j = __shfl_down_sync(shfl_mask, max_j, 2, TRANS_BLOCK_SIZE); 163 | int next_b = __shfl_down_sync(shfl_mask, max_b, 2, TRANS_BLOCK_SIZE); 164 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_j = next_j; max_b = next_b; }} 165 | if_constexpr (TRANS_BLOCK_SIZE > 1) { 166 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 1, TRANS_BLOCK_SIZE); 167 | int next_c = __shfl_down_sync(shfl_mask, max_c, 1, TRANS_BLOCK_SIZE); 168 | int next_j = __shfl_down_sync(shfl_mask, max_j, 1, TRANS_BLOCK_SIZE); 169 | int next_b = __shfl_down_sync(shfl_mask, max_b, 1, TRANS_BLOCK_SIZE); 170 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_j = next_j; max_b = next_b; }} 171 | maxval = __shfl_sync(shfl_mask, maxval, 0, TRANS_BLOCK_SIZE); 172 | max_c = __shfl_sync(shfl_mask, max_c, 0, TRANS_BLOCK_SIZE); 173 | max_j = __shfl_sync(shfl_mask, max_j, 0, TRANS_BLOCK_SIZE); 174 | max_b = __shfl_sync(shfl_mask, max_b, 0, TRANS_BLOCK_SIZE); 175 | 176 | if(threadIdx.x == 0){ 177 | // printf("new a:%d start:%d max_b:%d max_j:%d max_c:%d\n", a ,_start, max_b, max_j, max_c); 178 | if(max_c!=-1 && max_b!=-1){ 179 | S_trace[batch_id][max_c][_start+max_j+1] = 1; 180 | C_trace[batch_id][max_b][_start][max_j] = 1; 181 | } 182 | tree[batch_id][a] = _start+max_j; 183 | // if(a==1){ 184 | // printf("%d %d\n", _start, max_j); 185 | // } 186 | } 187 | } 188 | 189 | } 190 | __threadfence(); 191 | __syncthreads(); 192 | if (main_thread){ 193 | atomicAdd((int*)accomplish_queue + batch_id*n_seg + seg_id, 1); 194 | } 195 | } 196 | 197 | } 198 | 199 | 200 | template 201 | void invoke_calculate_S_trace(cudaStream_t stream, torch::Tensor &S_trace, torch::Tensor &C_trace, torch::Tensor &tree, const torch::Tensor &S, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, \ 202 | int bsz, int prelen, int tarlen, int max_left) 203 | { 204 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 205 | int local_tree_size = prelen - main_chain_size; 206 | int n_seg = (main_chain_size - 1) / SEQ_BLOCK_SIZE + 1; 207 | dim3 dimGrid(1, 2 * n_seg * bsz); 208 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 209 | // assert(n_seg <= BLOCK_BUCKET); 210 | int *bucket_queue, *accomplish_queue, *start_queue; 211 | auto tmp_tensor = at::zeros({BLOCK_BUCKET + bsz * n_seg}, match_all.options().dtype(at::kInt)); 212 | // auto tmp_tensor = at::zeros({BLOCK_BUCKET + bsz}, match_all.options().dtype(at::kInt)); 213 | bucket_queue = tmp_tensor.data_ptr(); 214 | accomplish_queue = bucket_queue + BLOCK_BUCKET; 215 | auto tmp_tensor3 = at::zeros({bsz}, match_all.options().dtype(at::kInt)); 216 | start_queue = tmp_tensor3.data_ptr(); 217 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 218 | AT_DISPATCH_FLOATING_TYPES( 219 | match_all.scalar_type(), "invoke_calculate_S_trace", [&] { 220 | tree.fill_(-1); 221 | calculate_S_trace_kernel<<>>( 222 | bucket_queue, accomplish_queue, start_queue, 223 | S_trace.packed_accessor64(), 224 | C_trace.packed_accessor64(), 225 | tree.packed_accessor64(), 226 | S.packed_accessor64(), 227 | C.packed_accessor64(), 228 | match_all.packed_accessor64(), 229 | links.packed_accessor64(), 230 | output_length.packed_accessor64(), 231 | target_length.packed_accessor64(), 232 | bsz, prelen, tarlen, max_left, n_seg 233 | ); 234 | } 235 | ); 236 | } 237 | 238 | 239 | template 240 | __global__ void calculate_C_kernel_trace( 241 | Accessor1 C_trace, 242 | Accessor2 tree, 243 | Accessor1 C, 244 | Accessor3 match_all, 245 | Accessor1 links, 246 | Accessor4 output_length, 247 | Accessor4 target_length, 248 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 249 | { 250 | bool main_thread = threadIdx.x == 0 && threadIdx.y == 0; 251 | 252 | 253 | // int ticket_no = bucket_no * BLOCK_BUCKET + bucket_idx; 254 | int ticket_no = blockIdx.y; 255 | int batch_id = ticket_no % bsz; 256 | int seg_id = ticket_no / bsz; 257 | int a_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y + 1; 258 | int a = LOCAL_TREE_I2E(a_id, max_left); 259 | int max_left_a = ((a-1) / (max_left+1)) * (max_left+1) + 1; 260 | int max_right_a = ((a-1) / (max_left+1) +1) * (max_left+1) + 1; 261 | int m = output_length[batch_id]; 262 | int n = target_length[batch_id]; 263 | // if(main_thread){ 264 | // printf("batch_id: %d, seg_id: %d, started\n", batch_id, seg_id); 265 | // } 266 | 267 | 268 | for(int gap = max_left+1; gap >= 1; gap--){ 269 | 270 | if (a > 0 && a < m){ 271 | for (int i=threadIdx.x; i::infinity(); 275 | for(int j=0;j::infinity(); 285 | if(temp > maxval){maxval = temp; max_j=j; max_b=b_id, max_c=c_id;} 286 | } 287 | } 288 | } 289 | if(max_b!=-1 && max_c!=-1){ 290 | C_trace[batch_id][max_b][i][max_j] = 1; 291 | C_trace[batch_id][max_c][i+max_j+1][gap-max_j-1] = 1; 292 | } 293 | tree[batch_id][a] = i+max_j; 294 | 295 | 296 | } 297 | } 298 | __threadfence(); 299 | __syncthreads(); 300 | } 301 | } 302 | 303 | 304 | template 305 | void invoke_calculate_C_trace(cudaStream_t stream, torch::Tensor &C_trace, torch::Tensor &tree, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, \ 306 | int bsz, int prelen, int tarlen, int max_left) 307 | { 308 | if (max_left==0) return; 309 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 310 | int local_tree_size = prelen - main_chain_size; 311 | int n_seg = (local_tree_size - 1) / SEQ_BLOCK_SIZE + 1; 312 | dim3 dimGrid(1, n_seg * bsz); 313 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 314 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 315 | 316 | AT_DISPATCH_FLOATING_TYPES( 317 | match_all.scalar_type(), "invoke_calculate_C_trace", [&] { 318 | calculate_C_kernel_trace<<>>( 319 | C_trace.packed_accessor64(), 320 | tree.packed_accessor64(), 321 | C.packed_accessor64(), 322 | match_all.packed_accessor64(), 323 | links.packed_accessor64(), 324 | output_length.packed_accessor64(), 325 | target_length.packed_accessor64(), 326 | bsz, prelen, tarlen, max_left, n_seg 327 | ); 328 | } 329 | ); 330 | 331 | } 332 | 333 | torch::Tensor pcfg_best_tree(const torch::Tensor &S, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, 334 | const torch::Tensor &output_length, const torch::Tensor &target_length, 335 | int config) 336 | { 337 | 338 | auto bsz = match_all.size(0); 339 | auto prelen = match_all.size(1); 340 | auto tarlen = match_all.size(2); 341 | auto max_left = links.size(2); 342 | max_left = max_left - 1; 343 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 344 | int local_tree_size = prelen - main_chain_size; 345 | torch::Tensor S_trace = at::zeros({bsz, main_chain_size, tarlen+2}, match_all.options()); 346 | torch::Tensor C_trace = at::zeros({bsz, local_tree_size, tarlen+2, max_left+2}, match_all.options()); 347 | 348 | torch::Tensor tree = at::zeros({bsz, prelen}, output_length.options()); 349 | cudaStream_t current_stream = 0; 350 | switch(config){ 351 | case 1: invoke_calculate_S_trace<4, 128>(current_stream, S_trace, C_trace, tree, S, C, match_all, links, output_length, target_length, bsz, prelen, tarlen, max_left); break; 352 | default: TORCH_CHECK(config <= 4 && config >= 1, "config should be 1~4"); 353 | } 354 | switch(config){ 355 | case 1: invoke_calculate_C_trace<4, 128>(current_stream, C_trace, tree, C, match_all, links, output_length, target_length, bsz, prelen, tarlen, max_left); break; 356 | default: TORCH_CHECK(config <= 4 && config >= 1, "config should be 1~4"); 357 | } 358 | 359 | return tree; 360 | } 361 | 362 | 363 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/pcfg_loss.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | using namespace std; 5 | 6 | std::tuple pcfg_loss(const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, int config); 7 | std::tuple pcfg_loss_backward(const torch::Tensor &grad_output, const torch::Tensor &S, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, int config); 8 | torch::Tensor pcfg_best_tree(const torch::Tensor &S, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, int config); 9 | std::tuple pcfg_viterbi(const torch::Tensor &ob_lprob, const torch::Tensor &links, const torch::Tensor &output_length,int config); 10 | torch::Tensor logsoftmax_gather(torch::Tensor word_ins_out, const torch::Tensor &select_idx, bool require_gradient); 11 | 12 | torch::Tensor viterbi_decoding(torch::Tensor pred_length, torch::Tensor output_length, 13 | torch::Tensor L_trace, 14 | torch::Tensor R_trace, 15 | torch::Tensor M_trace, 16 | torch::Tensor unreduced_tokens, 17 | torch::Tensor left_tree_mask, 18 | int pad_index) 19 | { 20 | auto batch_size = pred_length.size(0); 21 | 22 | vector > unpad_output_tokens; 23 | for (int i = 0; i < batch_size; i++) 24 | { 25 | 26 | int pred_len = pred_length[i].item(); 27 | pair now = make_pair(1, pred_len); 28 | vector > stack; 29 | vector res; 30 | int max_h = output_length[i].item(); 31 | int last = -1; 32 | while ((now.first < max_h && now.first != 0) || stack.size() > 0) 33 | { 34 | while (now.first < max_h && now.first != 0) 35 | { 36 | stack.push_back(now); 37 | auto links_left_idx = L_trace.index({i, now.first, now.second}).item(); 38 | links_left_idx = links_left_idx == now.first ? 0 : links_left_idx; 39 | if (left_tree_mask.index({now.first, links_left_idx}).item()) 40 | { 41 | break; 42 | } 43 | auto now_length = M_trace.index({i, now.first, now.second}).item(); 44 | now = make_pair(links_left_idx, now_length); 45 | } 46 | now = stack.back(); 47 | stack.pop_back(); 48 | auto now_token = unreduced_tokens.index({i, now.first}).item(); 49 | if (now_token != pad_index && now_token != last) 50 | { 51 | last = now_token; 52 | res.push_back(now_token); 53 | } 54 | auto links_right_idx = R_trace.index({i, now.first, now.second}).item(); 55 | auto now_length = now.second - M_trace.index({i, now.first, now.second}).item() - 1; 56 | now = make_pair(links_right_idx, now_length); 57 | } 58 | unpad_output_tokens.push_back(res); 59 | } 60 | int output_seqlen = 0; 61 | for (int i = 0; i < batch_size; i++) 62 | { 63 | output_seqlen = max(output_seqlen, (int)unpad_output_tokens[i].size()); 64 | } 65 | torch::Tensor output_tokens_tensor = torch::empty({batch_size, output_seqlen}).fill_(pad_index); 66 | for (int i = 0; i < batch_size; i++) 67 | { 68 | output_tokens_tensor.index_put_({i}, torch::tensor(unpad_output_tokens[i])); 69 | } 70 | 71 | return output_tokens_tensor; 72 | 73 | } 74 | 75 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 76 | m.def("pcfg_loss", &pcfg_loss, "PCFG Loss"); 77 | m.def("pcfg_loss_backward", &pcfg_loss_backward, "PCFG Loss Backward"); 78 | m.def("pcfg_best_tree", &pcfg_best_tree, "PCFG Best Tree"); 79 | m.def("pcfg_viterbi", &pcfg_viterbi, "PCFG Viterbi"); 80 | m.def("logsoftmax_gather", &logsoftmax_gather, "logsoftmax + gather"); 81 | m.def("viterbi_decoding", &viterbi_decoding, "Viterbi"); 82 | } 83 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/pcfg_loss.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include "utilities.h" 17 | 18 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 19 | #define CHECK_CPU(x) TORCH_CHECK(x.type().is_cpu(), #x " must be a CPU tensor") 20 | 21 | #define MAIN_CHAIN_I2E(x, max_left) (x)*((max_left)+1)+1 22 | #define MAIN_CHAIN_E2I(x, max_left) ((x)-1)/((max_left)+1) 23 | #define LOCAL_TREE_I2E(x, max_left) (((x)==0)?0:((x) + ((x)-1) / (max_left) + 1)) 24 | #define LOCAL_TREE_E2I(x, max_left) (((x)==0)?0:((x) - ((x) - 1) / ((max_left)+1) - 1)) 25 | 26 | #define BLOCK_BUCKET 16 27 | 28 | // Define this to turn on error checking 29 | #define CUDA_ERROR_CHECK 30 | 31 | #define CudaSafeCall( err ) __cudaSafeCall( err, __FILE__, __LINE__ ) 32 | #define CudaCheckError() __cudaCheckError( __FILE__, __LINE__ ) 33 | 34 | inline void __cudaSafeCall( cudaError err, const char *file, const int line ) 35 | { 36 | #ifdef CUDA_ERROR_CHECK 37 | if ( cudaSuccess != err ) 38 | { 39 | fprintf( stderr, "cudaSafeCall() failed at %s:%i : %s\n", 40 | file, line, cudaGetErrorString( err ) ); 41 | exit( -1 ); 42 | } 43 | #endif 44 | 45 | return; 46 | } 47 | 48 | inline void __cudaCheckError( const char *file, const int line ) 49 | { 50 | #ifdef CUDA_ERROR_CHECK 51 | cudaError err = cudaDeviceSynchronize(); 52 | if( cudaSuccess != err ) 53 | { 54 | fprintf( stderr, "cudaCheckError() with sync failed at %s:%i : %s\n", 55 | file, line, cudaGetErrorString( err ) ); 56 | exit( -1 ); 57 | } 58 | 59 | err = cudaGetLastError(); 60 | if ( cudaSuccess != err ) 61 | { 62 | fprintf( stderr, "last cudaCheckError() failed at %s:%i : %s\n", 63 | file, line, cudaGetErrorString( err ) ); 64 | exit( -1 ); 65 | } 66 | 67 | err = cudaPeekAtLastError(); 68 | if ( cudaSuccess != err ) 69 | { 70 | fprintf( stderr, "peek last cudaCheckError() failed at %s:%i : %s\n", 71 | file, line, cudaGetErrorString( err ) ); 72 | exit( -1 ); 73 | } 74 | 75 | // More careful checking. However, this will affect performance. 76 | // Comment away if needed. 77 | 78 | #endif 79 | 80 | return; 81 | } 82 | 83 | 84 | template 85 | __global__ void calculate_S_kernel( 86 | volatile int *bucket_queue, volatile int *accomplish_queue, volatile int *start_queue, 87 | Accessor1 S, 88 | Accessor2 C, 89 | Accessor3 match_all, 90 | Accessor2 links, 91 | Accessor4 output_length, 92 | Accessor4 target_length, 93 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 94 | { 95 | int bucket_idx = blockIdx.y % BLOCK_BUCKET; 96 | __shared__ volatile int task_id; 97 | __shared__ volatile int seg_id; 98 | __shared__ volatile int start; 99 | 100 | bool main_thread = threadIdx.x == 0 && threadIdx.y == 0; 101 | 102 | 103 | int ticket_no = blockIdx.y; 104 | int batch_id = ticket_no % bsz; 105 | 106 | int m = output_length[batch_id]; 107 | int n = target_length[batch_id]; 108 | __threadfence(); 109 | __syncthreads(); 110 | unsigned shfl_mask = (1 << TRANS_BLOCK_SIZE) - 1; 111 | 112 | shfl_mask = shfl_mask << (threadIdx.y % (32 / TRANS_BLOCK_SIZE) * TRANS_BLOCK_SIZE); 113 | int src_line = threadIdx.y % (32 / TRANS_BLOCK_SIZE)* TRANS_BLOCK_SIZE; 114 | src_line = 0; 115 | int a_id = 0; 116 | int a = 0; 117 | bool done = false; 118 | scalar_t maxval = -std::numeric_limits::infinity(); 119 | scalar_t temp = -std::numeric_limits::infinity(); 120 | scalar_t sumval = 0; 121 | while(start_queue[batch_id] < n*n_seg){ 122 | 123 | 124 | if(main_thread){ 125 | task_id = atomicAdd((int*)start_queue + batch_id, 1); 126 | seg_id = n_seg -1 - (task_id % n_seg); 127 | start = n-1 - (task_id / n_seg); 128 | 129 | 130 | done = false; 131 | while(!done){ 132 | done = true; 133 | for(int i=seg_id; i= n-1-start); 135 | } 136 | } 137 | } 138 | __threadfence(); 139 | __syncthreads(); 140 | a_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 141 | a = MAIN_CHAIN_I2E(a_id, max_left); 142 | if (start == n-1){ 143 | if(main_thread && seg_id==0){ 144 | int last_id = MAIN_CHAIN_E2I(m-1, max_left); 145 | S[batch_id][last_id][n-1] = match_all[batch_id][m-1][n-1]; 146 | } 147 | } 148 | else{ 149 | if (a > 0 && a < m){ 150 | maxval = -std::numeric_limits::infinity(); 151 | temp = -std::numeric_limits::infinity(); 152 | for(int c_id = a_id + threadIdx.x; c_id*(max_left+1)+1 ::infinity(); 163 | if(temp > maxval) maxval = temp; 164 | 165 | } 166 | } 167 | 168 | 169 | } 170 | // if(a==241 && start == 30){printf("ended x:%d %f\n", threadIdx.x, maxval);} 171 | 172 | 173 | __syncwarp(shfl_mask); 174 | if_constexpr (TRANS_BLOCK_SIZE > 16) {scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 16, TRANS_BLOCK_SIZE); if(nextval > maxval) maxval = nextval;} 175 | if_constexpr (TRANS_BLOCK_SIZE > 8) {scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 8, TRANS_BLOCK_SIZE); if(nextval > maxval) maxval = nextval;} 176 | if_constexpr (TRANS_BLOCK_SIZE > 4) {scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 4, TRANS_BLOCK_SIZE); if(nextval > maxval) maxval = nextval;} 177 | if_constexpr (TRANS_BLOCK_SIZE > 2) {scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 2, TRANS_BLOCK_SIZE); if(nextval > maxval) maxval = nextval;} 178 | if_constexpr (TRANS_BLOCK_SIZE > 1) {scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 1, TRANS_BLOCK_SIZE); if(nextval > maxval) maxval = nextval;} 179 | maxval = __shfl_sync(shfl_mask, maxval, src_line, TRANS_BLOCK_SIZE); 180 | 181 | // shfl_mask = __ballot_sync(shfl_mask, !isinf(maxval)); 182 | float res; 183 | if (isinf(maxval)){ 184 | res = maxval; 185 | } 186 | else{ 187 | sumval = 0; 188 | for(int c_id = a_id + threadIdx.x; c_id*(max_left+1)+1 ::infinity(); 198 | sumval += exp(temp); 199 | } 200 | } 201 | 202 | } 203 | __syncwarp(shfl_mask); 204 | if_constexpr (TRANS_BLOCK_SIZE > 16) sumval += __shfl_down_sync(shfl_mask, sumval, 16, TRANS_BLOCK_SIZE); 205 | if_constexpr (TRANS_BLOCK_SIZE > 8) sumval += __shfl_down_sync(shfl_mask, sumval, 8, TRANS_BLOCK_SIZE); 206 | if_constexpr (TRANS_BLOCK_SIZE > 4) sumval += __shfl_down_sync(shfl_mask, sumval, 4, TRANS_BLOCK_SIZE); 207 | if_constexpr (TRANS_BLOCK_SIZE > 2) sumval += __shfl_down_sync(shfl_mask, sumval, 2, TRANS_BLOCK_SIZE); 208 | if_constexpr (TRANS_BLOCK_SIZE > 1) sumval += __shfl_down_sync(shfl_mask, sumval, 1, TRANS_BLOCK_SIZE); 209 | res = log(sumval) + maxval; 210 | } 211 | if(threadIdx.x == 0){ 212 | S[batch_id][a_id][start] = res; 213 | } 214 | } 215 | } 216 | __threadfence(); 217 | __syncthreads(); 218 | if (main_thread){ 219 | atomicAdd((int*)accomplish_queue + batch_id*n_seg + seg_id, 1); 220 | } 221 | 222 | } 223 | 224 | } 225 | 226 | 227 | template 228 | void invoke_calculate_S(cudaStream_t stream,torch::Tensor &S, torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, \ 229 | int bsz, int prelen, int tarlen, int max_left) 230 | { 231 | int main_chain_size = (prelen-2) / (max_left+1) + 1; 232 | int n_seg = (main_chain_size - 1) / SEQ_BLOCK_SIZE + 1; 233 | n_seg = 2*n_seg; 234 | dim3 dimGrid(1, n_seg * bsz); 235 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 236 | int *bucket_queue, *accomplish_queue, *start_queue; 237 | auto tmp_tensor = at::zeros({BLOCK_BUCKET}, match_all.options().dtype(at::kInt)); 238 | bucket_queue = tmp_tensor.data_ptr(); 239 | auto tmp_tensor2 = at::zeros({bsz * n_seg}, match_all.options().dtype(at::kInt)); 240 | accomplish_queue = tmp_tensor2.data_ptr(); 241 | auto tmp_tensor3 = at::zeros({bsz}, match_all.options().dtype(at::kInt)); 242 | start_queue = tmp_tensor3.data_ptr(); 243 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 244 | AT_DISPATCH_FLOATING_TYPES( 245 | match_all.scalar_type(), "invoke_calculate_S", [&] { 246 | S.fill_(-std::numeric_limits::infinity()); 247 | calculate_S_kernel<<>>( 248 | bucket_queue, accomplish_queue, start_queue, 249 | S.packed_accessor64(), 250 | C.packed_accessor64(), 251 | match_all.packed_accessor64(), 252 | links.packed_accessor64(), 253 | output_length.packed_accessor64(), 254 | target_length.packed_accessor64(), 255 | bsz, prelen, tarlen, max_left, n_seg 256 | ); 257 | } 258 | ); 259 | // CudaCheckError(); 260 | } 261 | 262 | template 263 | __global__ void calculate_C_kernel( 264 | Accessor1 C, 265 | Accessor2 match_all, 266 | Accessor1 links, 267 | Accessor3 output_length, 268 | Accessor3 target_length, 269 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 270 | { 271 | 272 | bool main_thread = threadIdx.x == 0 && threadIdx.y == 0; 273 | 274 | 275 | // int ticket_no = bucket_no * BLOCK_BUCKET + bucket_idx; 276 | int ticket_no = blockIdx.y; 277 | int batch_id = ticket_no % bsz; 278 | int seg_id = ticket_no / bsz; 279 | int a_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 280 | int a = LOCAL_TREE_I2E(a_id, max_left); 281 | int max_left_a = ((a-1) / (max_left+1)) * (max_left+1) + 1; 282 | int max_right_a = ((a-1) / (max_left+1) +1) * (max_left+1) + 1; 283 | int m = output_length[batch_id]; 284 | int n = target_length[batch_id]; 285 | 286 | 287 | // start = 0 288 | { 289 | if(main_thread && seg_id == 0){ 290 | C[batch_id][0][n][0] = 0; 291 | } 292 | for(int i=threadIdx.x; i 0 && a < m && ((a-max_left_a) % 2 == 1)){ 294 | C[batch_id][a_id][i][1] = match_all[batch_id][a][i]; 295 | } 296 | C[batch_id][0][i][0] = 0; 297 | } 298 | __threadfence(); 299 | __syncthreads(); 300 | } 301 | for(int gap = 2; gap < max_left+2; gap++){ 302 | if (a > 0 && a < m){ 303 | for (int i=threadIdx.x; i::infinity(); 305 | for(int j=0;j::infinity(); 315 | if(temp > maxval) maxval = temp; 316 | } 317 | } 318 | } 319 | float res; 320 | if (isinf(maxval)){ 321 | res = maxval; 322 | } 323 | else{ 324 | scalar_t sumval = 0; 325 | for(int j=0;j::infinity(); 335 | sumval += exp(temp); 336 | } 337 | } 338 | } 339 | res = log(sumval) + maxval; 340 | 341 | } 342 | C[batch_id][a_id][i][gap] = res; 343 | 344 | } 345 | } 346 | __threadfence(); 347 | __syncthreads(); 348 | } 349 | } 350 | 351 | 352 | template 353 | void invoke_calculate_C(cudaStream_t stream, torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, \ 354 | int bsz, int prelen, int tarlen, int max_left) 355 | { 356 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 357 | int local_tree_size = prelen - main_chain_size; 358 | int n_seg = (local_tree_size - 1) / SEQ_BLOCK_SIZE + 1; 359 | dim3 dimGrid(1, n_seg * bsz); 360 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 361 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 362 | 363 | AT_DISPATCH_FLOATING_TYPES( 364 | match_all.scalar_type(), "invoke_calculate_C", [&] { 365 | C.fill_(-std::numeric_limits::infinity()); 366 | calculate_C_kernel<<>>( 367 | C.packed_accessor64(), 368 | match_all.packed_accessor64(), 369 | links.packed_accessor64(), 370 | output_length.packed_accessor64(), 371 | target_length.packed_accessor64(), 372 | bsz, prelen, tarlen, max_left, n_seg 373 | ); 374 | } 375 | ); 376 | 377 | } 378 | 379 | std::tuple pcfg_loss(const torch::Tensor &match_all, const torch::Tensor &links, 380 | const torch::Tensor &output_length, const torch::Tensor &target_length, 381 | int config) 382 | { 383 | 384 | auto bsz = match_all.size(0); 385 | auto prelen = match_all.size(1); 386 | auto tarlen = match_all.size(2); 387 | auto max_left = links.size(2); 388 | max_left = max_left - 1; 389 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 390 | int local_tree_size = prelen - main_chain_size; 391 | 392 | 393 | torch::Tensor S = at::zeros({bsz, main_chain_size, tarlen+2}, match_all.options()); 394 | torch::Tensor C = at::zeros({bsz, local_tree_size, tarlen+2, max_left+2}, match_all.options()); 395 | cudaStream_t current_stream = 0; 396 | 397 | invoke_calculate_C<4, 128>(current_stream, C, match_all, links, output_length, target_length, bsz, prelen, tarlen, max_left); 398 | 399 | invoke_calculate_S<4, 128>(current_stream, S, C, match_all, links, output_length, target_length, bsz, prelen, tarlen, max_left); 400 | return std::make_tuple(S, C); 401 | } 402 | 403 | 404 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/pcfg_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import math 4 | import sys 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | from torch.utils.cpp_extension import load 11 | from torch.utils.checkpoint import checkpoint 12 | from torch import jit 13 | from typing import Any, Dict, List, Optional, Tuple 14 | 15 | module_path = os.path.dirname(__file__) 16 | pcfg_kernel = None 17 | 18 | def get_pcfg_kernel(): 19 | global pcfg_kernel 20 | if not torch.cuda.is_available(): 21 | raise RuntimeError("You need GPU to use the custom cuda operations") 22 | if pcfg_kernel is not None: 23 | return pcfg_kernel 24 | else: 25 | print("Start compiling cuda operations for PCFG...(It usually takes a few minutes for the first time running.)", file=sys.stderr, flush=True) 26 | 27 | if int(torch.version.cuda.split(".")[0]) < 11: 28 | extra_include_paths = [os.path.join(module_path, "../../cub")] 29 | else: 30 | extra_include_paths = None 31 | 32 | pcfg_kernel = load( 33 | "pcfg_loss_fn", 34 | sources=[ 35 | os.path.join(module_path, "pcfg_loss.cpp"), 36 | os.path.join(module_path, "pcfg_loss.cu"), 37 | os.path.join(module_path, "pcfg_loss_backward.cu"), 38 | os.path.join(module_path, "pcfg_best_tree.cu"), 39 | os.path.join(module_path, "pcfg_viterbi.cu"), 40 | os.path.join(module_path, "logsoftmax_gather.cu"), 41 | ], 42 | extra_cflags=['-DOF_SOFTMAX_USE_FAST_MATH', '-O3'], 43 | extra_cuda_cflags=['-DOF_SOFTMAX_USE_FAST_MATH', '-O3', '-lineinfo'], 44 | extra_include_paths=extra_include_paths, 45 | ) 46 | print("PCFG Cuda operations compiled", file=sys.stderr, flush=True) 47 | return pcfg_kernel 48 | 49 | class CUDAPCFGLossFunc(Function): 50 | config = 1 51 | config1 = 1 52 | config2 = 1 53 | 54 | @staticmethod 55 | def forward( 56 | ctx, 57 | match_all, # bsz * tarlen * prelen 58 | links, # bsz * prelen * translen 59 | output_length, # bsz 60 | target_length, # bsz 61 | ): 62 | 63 | batch_size, prelen, tarlen = match_all.shape 64 | _, _, max_left, _ = links.shape 65 | max_left = max_left-1 66 | 67 | require_gradient = ctx.needs_input_grad[0] or ctx.needs_input_grad[1] 68 | match_all = match_all.contiguous() 69 | links = links.contiguous() 70 | S, C = get_pcfg_kernel().pcfg_loss(match_all, links, output_length, target_length, CUDAPCFGLossFunc.config) # bsz * prelen * tarlen 71 | if require_gradient: 72 | match_all_grad, links_grad = get_pcfg_kernel().pcfg_loss_backward(torch.ones((batch_size)).to(C), S, C, match_all, links, output_length, target_length, CUDAPCFGLossFunc.config) 73 | ctx.save_for_backward(match_all_grad, links_grad) 74 | return S[range(batch_size), 0, 0] 75 | 76 | @staticmethod 77 | def backward(ctx, grad_output): 78 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 79 | match_all_grad, links_grad = ctx.saved_tensors 80 | batch_size, _, _ = match_all_grad.shape 81 | return match_all_grad * grad_output.view(batch_size, 1, 1), links_grad * grad_output.view(batch_size, 1, 1, 1), None, None 82 | else: 83 | return None, None, None, None 84 | 85 | cuda_pcfg_loss = CUDAPCFGLossFunc.apply 86 | 87 | 88 | 89 | class CUDAPCFGBestTreeFunc(Function): 90 | config = 1 91 | 92 | @staticmethod 93 | def forward( 94 | ctx, 95 | match_all, # bsz * tarlen * prelen 96 | links, # bsz * prelen * translen 97 | output_length, # bsz 98 | target_length, # bsz 99 | ): 100 | 101 | batch_size, prelen, tarlen = match_all.shape 102 | _, _, max_left, _ = links.shape 103 | max_left = max_left-1 104 | 105 | match_all = match_all.contiguous() 106 | links = links.contiguous() 107 | S, C = get_pcfg_kernel().pcfg_loss(match_all, links, output_length, target_length, CUDAPCFGBestTreeFunc.config) # bsz * prelen * tarlen 108 | 109 | tree = get_pcfg_kernel().pcfg_best_tree(S, C, match_all, links, output_length, target_length, CUDAPCFGBestTreeFunc.config) # bsz * prelen * tarlen 110 | 111 | return tree, S[range(batch_size), 0, 0] 112 | 113 | @staticmethod 114 | def backward(ctx, grad_output): 115 | return None, None, None, None 116 | 117 | cuda_pcfg_best_tree = CUDAPCFGBestTreeFunc.apply 118 | 119 | class CUDAPCFGViterbiFunc(Function): 120 | config = 1 121 | 122 | @staticmethod 123 | def forward( 124 | ctx, 125 | ob_lprob, # bsz * tarlen * prelen 126 | links, # bsz * prelen * translen 127 | output_length, # bsz 128 | ): 129 | 130 | batch_size = ob_lprob.size(0) 131 | ob_lprob = ob_lprob.contiguous() 132 | links = links.contiguous() 133 | S, R, L, M = get_pcfg_kernel().pcfg_viterbi(ob_lprob, links, output_length, CUDAPCFGBestTreeFunc.config) # bsz * prelen * tarlen 134 | # print(S[range(batch_size), 0, :]) 135 | # assert(False) 136 | return S[range(batch_size), 1, :], R, L, M 137 | 138 | @staticmethod 139 | def backward(ctx, grad_output): 140 | return None, None, None, None 141 | 142 | cuda_pcfg_viterbi = CUDAPCFGViterbiFunc.apply 143 | 144 | class DagLogsoftmaxGatherFunc(Function): 145 | 146 | @staticmethod 147 | def forward( 148 | ctx, 149 | word_ins_out, # bsz * prelen * vocabsize 150 | select_idx # bsz * prelen * slen 151 | ): 152 | 153 | require_gradient = ctx.needs_input_grad[0] 154 | selected_result = get_pcfg_kernel().logsoftmax_gather(word_ins_out, select_idx, require_gradient) 155 | # Note: the cuda kernel will modify word_ins_out and then reuse it in backward 156 | ctx.mark_dirty(word_ins_out) 157 | ctx.set_materialize_grads(False) 158 | 159 | if require_gradient: 160 | ctx.save_for_backward(word_ins_out, select_idx) 161 | ctx.has_backward = False 162 | return word_ins_out, selected_result # bsz * prelen * slen 163 | 164 | @staticmethod 165 | def backward(ctx, grad_word_ins_out, grad_output): 166 | if not ctx.needs_input_grad[0]: 167 | return None, None 168 | assert grad_word_ins_out is None, "Cannot reuse word_ins_out after logsoftmax_gather" 169 | if grad_output is None: 170 | return None, None 171 | 172 | assert not ctx.has_backward, "Cannot backward twice in logsoftmax_gather" 173 | ctx.has_backward = True 174 | 175 | grad_input, selected_idx = ctx.saved_tensors 176 | grad_input.mul_(grad_output.sum(-1, keepdim=True).neg_().to(grad_input.dtype)) 177 | grad_input.scatter_add_(-1, selected_idx, grad_output.to(grad_input.dtype)) 178 | 179 | return grad_input, None 180 | 181 | dag_logsoftmax_gather_inplace = DagLogsoftmaxGatherFunc.apply 182 | 183 | def viterbi_decoding(pred_length, output_length, L_trace, R_trace, M_trace, unreduced_tokens, left_tree_mask, pad_index): 184 | return get_pcfg_kernel().viterbi_decoding(pred_length, output_length, L_trace, R_trace, M_trace, unreduced_tokens, left_tree_mask, pad_index) 185 | 186 | if __name__ == "__main__": 187 | get_pcfg_kernel() 188 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/pcfg_loss_backward.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include "utilities.h" 17 | 18 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 19 | #define CHECK_CPU(x) TORCH_CHECK(x.type().is_cpu(), #x " must be a CPU tensor") 20 | 21 | #define BLOCK_BUCKET 16 22 | #define EPSILON 1e-5 23 | #define MAIN_CHAIN_I2E(x, max_left) ((x)*((max_left)+1)+1) 24 | #define MAIN_CHAIN_E2I(x, max_left) (((x)-1)/((max_left)+1)) 25 | #define LOCAL_TREE_I2E(x, max_left) (((x)==0)?0:((x) + ((x)-1) / (max_left) + 1)) 26 | #define LOCAL_TREE_E2I(x, max_left) ((x)==0?0:((x) - ((x) - 1) / ((max_left)+1) - 1)) 27 | 28 | template 29 | __global__ void calculate_S_kernel_grad( 30 | volatile int *bucket_queue, volatile int *accomplish_queue, volatile int *start_queue, 31 | Accessor1 grad_output, 32 | Accessor2 S_grad, 33 | Accessor2 S, 34 | Accessor3 C, 35 | Accessor2 match_all, 36 | Accessor3 links, 37 | Accessor4 output_length, 38 | Accessor4 target_length, 39 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 40 | { 41 | __shared__ volatile int task_id; 42 | __shared__ volatile int seg_id; 43 | __shared__ volatile int start; 44 | 45 | bool main_thread = threadIdx.x == 0 && threadIdx.y == 0; 46 | 47 | int ticket_no = blockIdx.y; 48 | int batch_id = ticket_no % bsz; 49 | 50 | int m = output_length[batch_id]; 51 | int n = target_length[batch_id]; 52 | __threadfence(); 53 | __syncthreads(); 54 | unsigned shfl_mask = (1 << TRANS_BLOCK_SIZE) - 1; 55 | shfl_mask = shfl_mask << (threadIdx.y % (32 / TRANS_BLOCK_SIZE) * TRANS_BLOCK_SIZE); 56 | 57 | 58 | while(start_queue[batch_id]< (n+1)*n_seg){ 59 | if(main_thread){ 60 | task_id = atomicAdd((int*)start_queue + batch_id, 1); 61 | // printf("batch_id:%d task_id:%d addr:%d\n", batch_id, task_id, (int*)start_queue + batch_id); 62 | seg_id = task_id%n_seg; 63 | start = task_id/n_seg; 64 | bool done = false; 65 | while(!done){ 66 | done = true; 67 | for(int i=0; i= start); 69 | } 70 | } 71 | } 72 | __threadfence(); 73 | __syncthreads(); 74 | int c_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 75 | int c = MAIN_CHAIN_I2E(c_id, max_left); 76 | if(start == 0){ 77 | if(seg_id==0 && main_thread){ 78 | S_grad[batch_id][0][0] = grad_output[batch_id]; 79 | // printf("batch_id: %d %f\n", batch_id, S[batch_id][1][0]); 80 | } 81 | } 82 | else{ 83 | int _start = start - 1; 84 | if(c>0 && c 16) sumval += __shfl_down_sync(shfl_mask, sumval, 16, TRANS_BLOCK_SIZE); 103 | if_constexpr (TRANS_BLOCK_SIZE > 8) sumval += __shfl_down_sync(shfl_mask, sumval, 8, TRANS_BLOCK_SIZE); 104 | if_constexpr (TRANS_BLOCK_SIZE > 4) sumval += __shfl_down_sync(shfl_mask, sumval, 4, TRANS_BLOCK_SIZE); 105 | if_constexpr (TRANS_BLOCK_SIZE > 2) sumval += __shfl_down_sync(shfl_mask, sumval, 2, TRANS_BLOCK_SIZE); 106 | if_constexpr (TRANS_BLOCK_SIZE > 1) sumval += __shfl_down_sync(shfl_mask, sumval, 1, TRANS_BLOCK_SIZE); 107 | if(threadIdx.x==0){ 108 | // if(sumval !=0 ) printf("batch_id:%d c:%d _start+j+1:%d\n", batch_id, c, _start+j+1); 109 | S_grad[batch_id][c_id][_start+j+1] += sumval; 110 | } 111 | } 112 | } 113 | } 114 | __threadfence(); 115 | __syncthreads(); 116 | if (main_thread){ 117 | atomicAdd((int*)accomplish_queue + batch_id*n_seg + seg_id, 1); 118 | } 119 | } 120 | 121 | } 122 | 123 | template 124 | void invoke_calculate_S_grad(cudaStream_t stream, const torch::Tensor &grad_output, torch::Tensor &S_grad, const torch::Tensor &S, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, \ 125 | int bsz, int prelen, int tarlen, int max_left) 126 | { 127 | int main_chain_size = (prelen-2) / (max_left+1) + 1; 128 | int n_seg = (main_chain_size - 1) / SEQ_BLOCK_SIZE + 1; 129 | // n_seg = n_seg * 2; 130 | dim3 dimGrid(1, 2 * n_seg * bsz); 131 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 132 | // assert(n_seg <= BLOCK_BUCKET); 133 | int *bucket_queue, *accomplish_queue, *start_queue; 134 | auto tmp_tensor = at::zeros({BLOCK_BUCKET + bsz * n_seg}, match_all.options().dtype(at::kInt)); 135 | // auto tmp_tensor = at::zeros({BLOCK_BUCKET + bsz}, match_all.options().dtype(at::kInt)); 136 | bucket_queue = tmp_tensor.data_ptr(); 137 | accomplish_queue = bucket_queue + BLOCK_BUCKET; 138 | auto tmp_tensor3 = at::zeros({bsz}, match_all.options().dtype(at::kInt)); 139 | start_queue = tmp_tensor3.data_ptr(); 140 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 141 | AT_DISPATCH_FLOATING_TYPES( 142 | match_all.scalar_type(), "invoke_calculate_S_grad", [&] { 143 | calculate_S_kernel_grad<<>>( 144 | bucket_queue, accomplish_queue, start_queue, 145 | grad_output.packed_accessor64(), 146 | S_grad.packed_accessor64(), 147 | S.packed_accessor64(), 148 | C.packed_accessor64(), 149 | match_all.packed_accessor64(), 150 | links.packed_accessor64(), 151 | output_length.packed_accessor64(), 152 | target_length.packed_accessor64(), 153 | bsz, prelen, tarlen, max_left, n_seg 154 | ); 155 | } 156 | ); 157 | } 158 | 159 | template 160 | __global__ void calculate_C_kernel_grad_1( 161 | Accessor1 S_grad, 162 | Accessor2 C_grad, 163 | Accessor1 S, 164 | Accessor2 C, 165 | Accessor1 match_all, 166 | Accessor2 links, 167 | Accessor3 output_length, 168 | Accessor3 target_length, 169 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 170 | { 171 | // int ticket_no = bucket_no * BLOCK_BUCKET + bucket_idx; 172 | int ticket_no = blockIdx.y; 173 | int batch_id = ticket_no % bsz; 174 | int seg_id = ticket_no / bsz; 175 | int b_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y+1; 176 | int b = LOCAL_TREE_I2E(b_id, max_left); 177 | int a = ((b-1) / (max_left+1) +1) * (max_left+1) + 1; 178 | int a_id = MAIN_CHAIN_E2I(a, max_left); 179 | int m = output_length[batch_id]; 180 | int n = target_length[batch_id]; 181 | 182 | for(int start = n-2; start >= 0; start--){ 183 | if (b > 0 && b < m && a>0 && a 16) sumval += __shfl_down_sync(shfl_mask, sumval, 16, TRANS_BLOCK_SIZE); 197 | if_constexpr (TRANS_BLOCK_SIZE > 8) sumval += __shfl_down_sync(shfl_mask, sumval, 8, TRANS_BLOCK_SIZE); 198 | if_constexpr (TRANS_BLOCK_SIZE > 4) sumval += __shfl_down_sync(shfl_mask, sumval, 4, TRANS_BLOCK_SIZE); 199 | if_constexpr (TRANS_BLOCK_SIZE > 2) sumval += __shfl_down_sync(shfl_mask, sumval, 2, TRANS_BLOCK_SIZE); 200 | if_constexpr (TRANS_BLOCK_SIZE > 1) sumval += __shfl_down_sync(shfl_mask, sumval, 1, TRANS_BLOCK_SIZE); 201 | if (threadIdx.x==0){ 202 | 203 | C_grad[batch_id][b_id][start][j] += sumval; 204 | } 205 | } 206 | 207 | } 208 | } 209 | } 210 | 211 | template 212 | __global__ void calculate_C_kernel_grad_2( 213 | Accessor1 C_grad, 214 | Accessor1 C, 215 | Accessor2 match_all, 216 | Accessor1 links, 217 | Accessor3 output_length, 218 | Accessor3 target_length, 219 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 220 | { 221 | 222 | // int ticket_no = bucket_no * BLOCK_BUCKET + bucket_idx; 223 | int ticket_no = blockIdx.y; 224 | int batch_id = ticket_no % bsz; 225 | int seg_id = ticket_no / bsz; 226 | int selected_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 227 | int selected_h = LOCAL_TREE_I2E(selected_id, max_left); 228 | int max_right_a = ((selected_h-1) / (max_left+1) +1) * (max_left+1) + 1; 229 | int max_left_a = ((selected_h-1) / (max_left+1)) * (max_left+1) + 1; 230 | 231 | int m = output_length[batch_id]; 232 | int n = target_length[batch_id]; 233 | 234 | 235 | 236 | for(int gap=max_left+1;gap>=2;gap--){ 237 | if (selected_h > 0 && selected_h < m){ 238 | int b = selected_h; 239 | for (int a=b+1; a 281 | void invoke_calculate_C_grad(cudaStream_t stream, torch::Tensor &S_grad, torch::Tensor &C_grad, const torch::Tensor &S, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, \ 282 | int bsz, int prelen, int tarlen, int max_left) 283 | { 284 | if(max_left == 0) return; 285 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 286 | int local_tree_size = prelen - main_chain_size; 287 | int n_seg = (local_tree_size - 1) / SEQ_BLOCK_SIZE + 1; 288 | dim3 dimGrid(1, n_seg * bsz); 289 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 290 | // assert(n_seg <= BLOCK_BUCKET); 291 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 292 | AT_DISPATCH_FLOATING_TYPES( 293 | match_all.scalar_type(), "invoke_calculate_C_grad_1", [&] { 294 | calculate_C_kernel_grad_1<<>>( 295 | S_grad.packed_accessor64(), 296 | C_grad.packed_accessor64(), 297 | S.packed_accessor64(), 298 | C.packed_accessor64(), 299 | match_all.packed_accessor64(), 300 | links.packed_accessor64(), 301 | output_length.packed_accessor64(), 302 | target_length.packed_accessor64(), 303 | bsz, prelen, tarlen, max_left, n_seg 304 | ); 305 | } 306 | ); 307 | int n_seg2 = (prelen - 1) / SEQ_BLOCK_SIZE + 1; 308 | dim3 dimGrid2(1, n_seg2 * bsz); 309 | AT_DISPATCH_FLOATING_TYPES( 310 | match_all.scalar_type(), "invoke_calculate_C_grad_2", [&] { 311 | calculate_C_kernel_grad_2<<>>( 312 | C_grad.packed_accessor64(), 313 | C.packed_accessor64(), 314 | match_all.packed_accessor64(), 315 | links.packed_accessor64(), 316 | output_length.packed_accessor64(), 317 | target_length.packed_accessor64(), 318 | bsz, prelen, tarlen, max_left, n_seg2 319 | ); 320 | } 321 | ); 322 | 323 | } 324 | 325 | template 326 | __global__ void calculate_match_all_kernel_grad_1( 327 | Accessor1 match_all_grad, 328 | Accessor1 S_grad, 329 | Accessor2 C_grad, 330 | Accessor1 S, 331 | Accessor2 C, 332 | Accessor1 match_all, 333 | Accessor2 links, 334 | Accessor3 output_length, 335 | Accessor3 target_length, 336 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 337 | { 338 | bool main_thread = threadIdx.x == 0 && threadIdx.y == 0; 339 | int ticket_no = blockIdx.y; 340 | int batch_id = ticket_no % bsz; 341 | int m = output_length[batch_id]; 342 | int n = target_length[batch_id]; 343 | 344 | int seg_id = ticket_no / bsz; 345 | int a_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 346 | int a = MAIN_CHAIN_I2E(a_id, max_left); 347 | 348 | for(int start=0;start0 && a 16) sumval += __shfl_down_sync(shfl_mask, sumval, 16, TRANS_BLOCK_SIZE); 369 | if_constexpr (TRANS_BLOCK_SIZE > 8) sumval += __shfl_down_sync(shfl_mask, sumval, 8, TRANS_BLOCK_SIZE); 370 | if_constexpr (TRANS_BLOCK_SIZE > 4) sumval += __shfl_down_sync(shfl_mask, sumval, 4, TRANS_BLOCK_SIZE); 371 | if_constexpr (TRANS_BLOCK_SIZE > 2) sumval += __shfl_down_sync(shfl_mask, sumval, 2, TRANS_BLOCK_SIZE); 372 | if_constexpr (TRANS_BLOCK_SIZE > 1) sumval += __shfl_down_sync(shfl_mask, sumval, 1, TRANS_BLOCK_SIZE); 373 | if(threadIdx.x==0){ 374 | match_all_grad[batch_id][a][start+j] += sumval; 375 | } 376 | } 377 | } 378 | } 379 | if(main_thread && seg_id==0){ 380 | int last_id = MAIN_CHAIN_E2I(m-1, max_left); 381 | match_all_grad[batch_id][m-1][n-1] += S_grad[batch_id][last_id][n-1]; 382 | } 383 | } 384 | 385 | template 386 | __global__ void calculate_match_all_kernel_grad_2( 387 | Accessor1 match_all_grad, 388 | Accessor2 C_grad, 389 | Accessor2 C, 390 | Accessor1 match_all, 391 | Accessor2 links, 392 | Accessor3 output_length, 393 | Accessor3 target_length, 394 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 395 | { 396 | int ticket_no = blockIdx.y; 397 | int batch_id = ticket_no % bsz; 398 | int m = output_length[batch_id]; 399 | int n = target_length[batch_id]; 400 | 401 | int seg_id = ticket_no / bsz; 402 | int a_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 403 | int a = LOCAL_TREE_I2E(a_id, max_left); 404 | int max_left_a = ((a-1) / (max_left+1)) * (max_left+1) + 1; 405 | int max_right_a = ((a-1) / (max_left+1) +1) * (max_left+1) + 1; 406 | 407 | for(int gap=max_left+1;gap>2;gap--){ 408 | if(a>0 && a0 && a 442 | void invoke_calculate_match_all_grad(cudaStream_t stream, torch::Tensor &match_all_grad, torch::Tensor &S_grad, torch::Tensor &C_grad, const torch::Tensor &S, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, \ 443 | int bsz, int prelen, int tarlen, int max_left) 444 | { 445 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 446 | int local_tree_size = prelen - main_chain_size; 447 | 448 | int n_seg = (main_chain_size - 1) / SEQ_BLOCK_SIZE + 1; 449 | 450 | dim3 dimGrid(1, n_seg * bsz); 451 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 452 | // assert(n_seg <= BLOCK_BUCKET); 453 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 454 | AT_DISPATCH_FLOATING_TYPES( 455 | match_all.scalar_type(), "calculate_match_all_kernel_grad_1", [&] { 456 | calculate_match_all_kernel_grad_1<<>>( 457 | match_all_grad.packed_accessor64(), 458 | S_grad.packed_accessor64(), 459 | C_grad.packed_accessor64(), 460 | S.packed_accessor64(), 461 | C.packed_accessor64(), 462 | match_all.packed_accessor64(), 463 | links.packed_accessor64(), 464 | output_length.packed_accessor64(), 465 | target_length.packed_accessor64(), 466 | bsz, prelen, tarlen, max_left, n_seg 467 | ); 468 | } 469 | ); 470 | int n_seg_2 = (local_tree_size - 1) / SEQ_BLOCK_SIZE + 1; 471 | dim3 dimGrid2(1, n_seg_2 * bsz); 472 | AT_DISPATCH_FLOATING_TYPES( 473 | match_all.scalar_type(), "calculate_match_all_kernel_grad_2", [&] { 474 | calculate_match_all_kernel_grad_2<<>>( 475 | match_all_grad.packed_accessor64(), 476 | C_grad.packed_accessor64(), 477 | C.packed_accessor64(), 478 | match_all.packed_accessor64(), 479 | links.packed_accessor64(), 480 | output_length.packed_accessor64(), 481 | target_length.packed_accessor64(), 482 | bsz, prelen, tarlen, max_left, n_seg_2 483 | ); 484 | } 485 | ); 486 | 487 | } 488 | 489 | 490 | template 491 | __global__ void calculate_links_kernel_grad_1( 492 | Accessor1 links_grad, 493 | Accessor2 S_grad, 494 | Accessor1 C_grad, 495 | Accessor2 S, 496 | Accessor1 C, 497 | Accessor2 match_all, 498 | Accessor1 links, 499 | Accessor3 output_length, 500 | Accessor3 target_length, 501 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 502 | { 503 | int ticket_no = blockIdx.y; 504 | int batch_id = ticket_no % bsz; 505 | int m = output_length[batch_id]; 506 | int n = target_length[batch_id]; 507 | 508 | int seg_id = ticket_no / bsz; 509 | int a_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 510 | int a = MAIN_CHAIN_I2E(a_id, max_left); 511 | 512 | for(int start=0;start0 && a 536 | __global__ void calculate_links_kernel_grad_2( 537 | Accessor2 links_grad, 538 | Accessor2 C_grad, 539 | Accessor2 C, 540 | Accessor1 match_all, 541 | Accessor2 links, 542 | Accessor3 output_length, 543 | Accessor3 target_length, 544 | int bsz, int prelen, int tarlen, int max_left, int n_seg) 545 | { 546 | // int ticket_no = bucket_no * BLOCK_BUCKET + bucket_idx; 547 | int ticket_no = blockIdx.y; 548 | int batch_id = ticket_no % bsz; 549 | int seg_id = ticket_no / bsz; 550 | int a_id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y+1; 551 | int a = LOCAL_TREE_I2E(a_id, max_left); 552 | int max_right = (a / (max_left+1) +1) * (max_left+1) + 1; 553 | int m = output_length[batch_id]; 554 | int n = target_length[batch_id]; 555 | 556 | for(int gap=max_left+1;gap>2;gap--){ 557 | if (a > 0 && a < m && (a % (max_left+1)!=1)){ 558 | 559 | for(int c=a+1; c 16) sumval += __shfl_down_sync(shfl_mask, sumval, 16, TRANS_BLOCK_SIZE); 590 | if_constexpr (TRANS_BLOCK_SIZE > 8) sumval += __shfl_down_sync(shfl_mask, sumval, 8, TRANS_BLOCK_SIZE); 591 | if_constexpr (TRANS_BLOCK_SIZE > 4) sumval += __shfl_down_sync(shfl_mask, sumval, 4, TRANS_BLOCK_SIZE); 592 | if_constexpr (TRANS_BLOCK_SIZE > 2) sumval += __shfl_down_sync(shfl_mask, sumval, 2, TRANS_BLOCK_SIZE); 593 | if_constexpr (TRANS_BLOCK_SIZE > 1) sumval += __shfl_down_sync(shfl_mask, sumval, 1, TRANS_BLOCK_SIZE); 594 | if(threadIdx.x==0){ 595 | links_grad[batch_id][a][a-b][c] += sumval; 596 | } 597 | } 598 | 599 | } 600 | } 601 | 602 | } 603 | } 604 | 605 | 606 | template 607 | void invoke_calculate_links_grad(cudaStream_t stream, torch::Tensor &links_grad, torch::Tensor &S_grad, torch::Tensor &C_grad, const torch::Tensor &S, const torch::Tensor &C, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, \ 608 | int bsz, int prelen, int tarlen, int max_left) 609 | { 610 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 611 | int local_tree_size = prelen - main_chain_size; 612 | int n_seg = (main_chain_size - 1) / SEQ_BLOCK_SIZE + 1; 613 | dim3 dimGrid(1, n_seg * bsz); 614 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 615 | // assert(n_seg <= BLOCK_BUCKET); 616 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 617 | AT_DISPATCH_FLOATING_TYPES( 618 | match_all.scalar_type(), "calculate_links_kernel_grad_1", [&] { 619 | calculate_links_kernel_grad_1<<>>( 620 | links_grad.packed_accessor64(), 621 | S_grad.packed_accessor64(), 622 | C_grad.packed_accessor64(), 623 | S.packed_accessor64(), 624 | C.packed_accessor64(), 625 | match_all.packed_accessor64(), 626 | links.packed_accessor64(), 627 | output_length.packed_accessor64(), 628 | target_length.packed_accessor64(), 629 | bsz, prelen, tarlen, max_left, n_seg 630 | ); 631 | } 632 | ); 633 | 634 | int n_seg_2 = (local_tree_size - 1) / SEQ_BLOCK_SIZE + 1; 635 | dim3 dimGrid2(1, n_seg_2 * bsz); 636 | // assert(n_seg <= BLOCK_BUCKET); 637 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 638 | AT_DISPATCH_FLOATING_TYPES( 639 | match_all.scalar_type(), "invoke_calculate_links_grad_2", [&] { 640 | calculate_links_kernel_grad_2<<>>( 641 | links_grad.packed_accessor64(), 642 | C_grad.packed_accessor64(), 643 | C.packed_accessor64(), 644 | match_all.packed_accessor64(), 645 | links.packed_accessor64(), 646 | output_length.packed_accessor64(), 647 | target_length.packed_accessor64(), 648 | bsz, prelen, tarlen, max_left, n_seg_2 649 | ); 650 | } 651 | ); 652 | } 653 | 654 | std::tuple pcfg_loss_backward(const torch::Tensor &grad_output,const torch::Tensor &S, const torch::Tensor &C, 655 | const torch::Tensor &match_all, const torch::Tensor &links, 656 | const torch::Tensor &output_length, const torch::Tensor &target_length, 657 | int config) 658 | { 659 | // CHECK_CUDA(match_all); // bsz * tarlen * prelen 660 | // CHECK_CUDA(links); // bsz * prelen * translen 661 | // CHECK_CUDA(output_length); // bsz 662 | // CHECK_CUDA(target_length); // bsz 663 | // TORCH_CHECK(match_all.dim() == 3, "match_all dim != 3"); 664 | // TORCH_CHECK(links.dim() == 4, "links dim != 3"); 665 | // TORCH_CHECK(output_length.dim() == 1, "output_length dim != 3"); 666 | // TORCH_CHECK(target_length.dim() == 1, "target_length dim != 3"); 667 | 668 | auto bsz = match_all.size(0); 669 | auto prelen = match_all.size(1); 670 | auto tarlen = match_all.size(2); 671 | auto max_left = links.size(2); 672 | max_left = max_left - 1; 673 | 674 | int main_chain_size = (prelen - 2) / (max_left + 1) + 1; 675 | int local_tree_size = prelen - main_chain_size; 676 | 677 | // TORCH_CHECK(links.size(0) == bsz && output_length.size(0) == bsz && target_length.size(0) == bsz, "batch size not match"); 678 | // TORCH_CHECK(links.size(1) == prelen, "prelen not match"); 679 | // TORCH_CHECK(output_length.scalar_type() == at::kLong && target_length.scalar_type() == at::kLong, "length should be long"); 680 | 681 | 682 | // printf("alpha0\n"); 683 | 684 | // calculate alpha 685 | // printf("%d %d %d\n", bsz, tarlen, prelen); 686 | torch::Tensor S_grad = at::zeros({bsz, main_chain_size, tarlen+2}, match_all.options()); 687 | torch::Tensor C_grad = at::zeros({bsz, local_tree_size, tarlen+2, max_left+2}, match_all.options()); 688 | 689 | torch::Tensor match_all_grad = at::zeros({bsz, prelen, tarlen}, match_all.options()); 690 | torch::Tensor links_grad = at::zeros({bsz, prelen, max_left+1, prelen}, match_all.options()); 691 | 692 | cudaStream_t current_stream = 0; 693 | // printf("invoke_calculate_S_grad\n"); 694 | switch(config){ 695 | case 1: invoke_calculate_S_grad<4, 128>(current_stream, grad_output, S_grad, S, C, match_all, links, output_length, target_length, bsz, prelen, tarlen, max_left); break; 696 | default: TORCH_CHECK(config <= 4 && config >= 1, "config should be 1~4"); 697 | } 698 | // cudaDeviceSynchronize(); 699 | // printf("invoke_calculate_C_grad\n"); 700 | switch(config){ 701 | case 1: invoke_calculate_C_grad<4, 128>(current_stream, S_grad, C_grad, S, C, match_all, links, output_length, target_length, bsz, prelen, tarlen, max_left); break; 702 | default: TORCH_CHECK(config <= 4 && config >= 1, "config should be 1~4"); 703 | } 704 | // cudaDeviceSynchronize(); 705 | // printf("invoke_calculate_match_all_grad\n"); 706 | switch(config){ 707 | case 1: invoke_calculate_match_all_grad<4, 128>(current_stream, match_all_grad, S_grad, C_grad, S, C, match_all, links, output_length, target_length, bsz, prelen, tarlen, max_left); break; 708 | default: TORCH_CHECK(config <= 4 && config >= 1, "config should be 1~4"); 709 | } 710 | // cudaDeviceSynchronize(); 711 | // printf("invoke_calculate_links_grad\n"); 712 | switch(config){ 713 | case 1: invoke_calculate_links_grad<4, 128>(current_stream, links_grad, S_grad, C_grad, S, C, match_all, links, output_length, target_length, bsz, prelen, tarlen, max_left); break; 714 | default: TORCH_CHECK(config <= 4 && config >= 1, "config should be 1~4"); 715 | } 716 | 717 | 718 | // printf("alpha4\n"); 719 | return std::make_tuple(match_all_grad, links_grad); 720 | // return std::make_tuple(S_grad, C_grad); 721 | } 722 | 723 | 724 | 725 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/pcfg_viterbi.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include "utilities.h" 16 | 17 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 18 | #define CHECK_CPU(x) TORCH_CHECK(x.type().is_cpu(), #x " must be a CPU tensor") 19 | 20 | #define BLOCK_BUCKET 16 21 | 22 | template 23 | __global__ void calculate_S_kernel( 24 | volatile int *bucket_queue, volatile int *accomplish_queue, 25 | Accessor1 S, 26 | Accessor1 C, 27 | Accessor2 R, 28 | Accessor2 L, 29 | Accessor2 M, 30 | Accessor3 ob_lprob, 31 | Accessor4 links, 32 | Accessor5 output_length, 33 | int bsz, int prelen, int max_left, int n_seg) 34 | { 35 | int bucket_idx = blockIdx.y % BLOCK_BUCKET; 36 | __shared__ volatile int bucket_no; 37 | 38 | bool main_thread = threadIdx.x == 0 && threadIdx.y == 0; 39 | if (main_thread){ 40 | // obtain task id 41 | bucket_no = atomicAdd((int*)bucket_queue + bucket_idx, 1); 42 | } 43 | __syncthreads(); 44 | 45 | int ticket_no = bucket_no * BLOCK_BUCKET + bucket_idx; 46 | int batch_id = ticket_no % bsz; 47 | 48 | 49 | int m = output_length[batch_id]; 50 | 51 | int seg_id = ticket_no / bsz; 52 | int id = seg_id * SEQ_BLOCK_SIZE + threadIdx.y; 53 | int a = id*(max_left+1)+1; 54 | 55 | // start = 0 56 | { 57 | if(seg_id == 0 && main_thread){ 58 | S[batch_id][m-1][1] = ob_lprob[batch_id][m-1]; 59 | } 60 | 61 | __threadfence(); 62 | __syncthreads(); 63 | if(main_thread){ 64 | atomicAdd((int*)accomplish_queue + batch_id, 1); 65 | } 66 | } 67 | for(int length = 2; length < m/4; length++){ 68 | if (main_thread){ 69 | while(accomplish_queue[batch_id] < (length-1)*n_seg); // wait for previous segment to accomplish 70 | } 71 | __syncthreads(); 72 | if (a > 0 && a < m){ 73 | scalar_t maxval = -std::numeric_limits::infinity(); 74 | int max_b=0, max_c=0, max_j=0; 75 | for(int c_id = id + threadIdx.x; c_id*(max_left+1)+1 ::infinity(); 85 | 86 | if(temp > maxval){maxval = temp; max_b=_b; max_c = c; max_j=j;} 87 | } 88 | } 89 | } 90 | unsigned shfl_mask = __activemask(); 91 | if_constexpr (TRANS_BLOCK_SIZE > 16) { 92 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 16, TRANS_BLOCK_SIZE); 93 | int next_c = __shfl_down_sync(shfl_mask, max_c, 16, TRANS_BLOCK_SIZE); 94 | int next_b = __shfl_down_sync(shfl_mask, max_b, 16, TRANS_BLOCK_SIZE); 95 | int next_j = __shfl_down_sync(shfl_mask, max_j, 16, TRANS_BLOCK_SIZE); 96 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_b = next_b; max_j = next_j;}} 97 | if_constexpr (TRANS_BLOCK_SIZE > 8) { 98 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 8, TRANS_BLOCK_SIZE); 99 | int next_c = __shfl_down_sync(shfl_mask, max_c, 8, TRANS_BLOCK_SIZE); 100 | int next_b = __shfl_down_sync(shfl_mask, max_b, 8, TRANS_BLOCK_SIZE); 101 | int next_j = __shfl_down_sync(shfl_mask, max_j, 8, TRANS_BLOCK_SIZE); 102 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_b = next_b; max_j = next_j;}} 103 | if_constexpr (TRANS_BLOCK_SIZE > 4) { 104 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 4, TRANS_BLOCK_SIZE); 105 | int next_c = __shfl_down_sync(shfl_mask, max_c, 4, TRANS_BLOCK_SIZE); 106 | int next_b = __shfl_down_sync(shfl_mask, max_b, 4, TRANS_BLOCK_SIZE); 107 | int next_j = __shfl_down_sync(shfl_mask, max_j, 4, TRANS_BLOCK_SIZE); 108 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_b = next_b; max_j = next_j;}} 109 | if_constexpr (TRANS_BLOCK_SIZE > 2) { 110 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 2, TRANS_BLOCK_SIZE); 111 | int next_c = __shfl_down_sync(shfl_mask, max_c, 2, TRANS_BLOCK_SIZE); 112 | int next_b = __shfl_down_sync(shfl_mask, max_b, 2, TRANS_BLOCK_SIZE); 113 | int next_j = __shfl_down_sync(shfl_mask, max_j, 2, TRANS_BLOCK_SIZE); 114 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_b = next_b; max_j = next_j;}} 115 | if_constexpr (TRANS_BLOCK_SIZE > 1) { 116 | scalar_t nextval = __shfl_down_sync(shfl_mask, maxval, 1, TRANS_BLOCK_SIZE); 117 | int next_c = __shfl_down_sync(shfl_mask, max_c, 1, TRANS_BLOCK_SIZE); 118 | int next_b = __shfl_down_sync(shfl_mask, max_b, 1, TRANS_BLOCK_SIZE); 119 | int next_j = __shfl_down_sync(shfl_mask, max_j, 1, TRANS_BLOCK_SIZE); 120 | if(nextval > maxval){ maxval = nextval; max_c = next_c; max_b = next_b; max_j = next_j;}} 121 | maxval = __shfl_sync(shfl_mask, maxval, 0, TRANS_BLOCK_SIZE); 122 | max_c = __shfl_sync(shfl_mask, max_c, 0, TRANS_BLOCK_SIZE); 123 | max_b = __shfl_sync(shfl_mask, max_b, 0, TRANS_BLOCK_SIZE); 124 | max_j = __shfl_sync(shfl_mask, max_j, 0, TRANS_BLOCK_SIZE); 125 | if(threadIdx.x == 0 && !isinf(maxval)){ 126 | 127 | S[batch_id][a][length] = maxval; 128 | L[batch_id][a][length] = max_b; 129 | R[batch_id][a][length] = max_c; 130 | M[batch_id][a][length] = max_j; 131 | } 132 | } 133 | __threadfence(); 134 | __syncthreads(); 135 | if (main_thread){ 136 | atomicAdd((int*)accomplish_queue + batch_id, 1); 137 | } 138 | } 139 | 140 | } 141 | 142 | 143 | template 144 | void invoke_calculate_S(cudaStream_t stream, torch::Tensor &S, torch::Tensor &C, torch::Tensor &R, torch::Tensor &L, torch::Tensor &M, const torch::Tensor &ob_lprob, const torch::Tensor &links, const torch::Tensor &output_length, \ 145 | int bsz, int prelen, int max_left) 146 | { 147 | int n_seg = ((prelen-2) / (max_left+1) + 1 - 1) / SEQ_BLOCK_SIZE + 1; 148 | 149 | dim3 dimGrid(1, n_seg * bsz); 150 | dim3 dimBlock(TRANS_BLOCK_SIZE, SEQ_BLOCK_SIZE); 151 | // assert(n_seg <= BLOCK_BUCKET); 152 | int *bucket_queue, *accomplish_queue; 153 | auto tmp_tensor = at::zeros({BLOCK_BUCKET + bsz}, ob_lprob.options().dtype(at::kInt)); 154 | bucket_queue = tmp_tensor.data_ptr(); 155 | accomplish_queue = bucket_queue + BLOCK_BUCKET; 156 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 157 | AT_DISPATCH_FLOATING_TYPES( 158 | ob_lprob.scalar_type(), "invoke_calculate_S", [&] { 159 | S.fill_(-std::numeric_limits::infinity()); 160 | calculate_S_kernel<<>>( 161 | bucket_queue, accomplish_queue, 162 | S.packed_accessor64(), 163 | C.packed_accessor64(), 164 | R.packed_accessor64(), 165 | L.packed_accessor64(), 166 | M.packed_accessor64(), 167 | ob_lprob.packed_accessor64(), 168 | links.packed_accessor64(), 169 | output_length.packed_accessor64(), 170 | bsz, prelen, max_left, n_seg 171 | ); 172 | } 173 | ); 174 | } 175 | 176 | 177 | template 178 | __global__ void calculate_C_kernel( 179 | Accessor1 C, 180 | Accessor2 R, 181 | Accessor2 L, 182 | Accessor2 M, 183 | Accessor3 ob_lprob, 184 | Accessor4 links, 185 | Accessor5 output_length, 186 | int bsz, int prelen, int max_left, int n_seg) 187 | { 188 | 189 | bool main_thread = threadIdx.x == 0 && threadIdx.y == 0; 190 | 191 | 192 | // int ticket_no = bucket_no * BLOCK_BUCKET + bucket_idx; 193 | int ticket_no = blockIdx.y; 194 | int batch_id = ticket_no % bsz; 195 | int seg_id = ticket_no / bsz; 196 | int a = seg_id * SEQ_BLOCK_SIZE + threadIdx.y + 1; 197 | int max_left_a = ((a-1) / (max_left+1)) * (max_left+1) + 1; 198 | int max_right_a = ((a-1) / (max_left+1) +1) * (max_left+1) + 1; 199 | int m = output_length[batch_id]; 200 | 201 | // start = 0 202 | { 203 | if(a > 0 && a < m && ((a-max_left_a) % 2 == 1) && threadIdx.x==0){ 204 | C[batch_id][a][1] = ob_lprob[batch_id][a]; 205 | } 206 | if(seg_id==0 && main_thread){ 207 | C[batch_id][0][0] = 0; 208 | } 209 | 210 | 211 | __threadfence(); 212 | __syncthreads(); 213 | } 214 | for(int gap = 2; gap < max_left+1; gap++){ 215 | 216 | if (a > 0 && a < m && (a % (max_left+1)!=1)){ 217 | scalar_t maxval = -std::numeric_limits::infinity(); 218 | int max_b=0, max_c = 0, max_j=0; 219 | for(int j=0;j::infinity(); 227 | if(temp > maxval){maxval = temp; max_b=_b; max_c=c; max_j=j;} 228 | } 229 | } 230 | } 231 | if(!isinf(maxval)){ 232 | C[batch_id][a][gap] = maxval; 233 | R[batch_id][a][gap] = max_c; 234 | L[batch_id][a][gap] = max_b; 235 | M[batch_id][a][gap] = max_j; 236 | } 237 | } 238 | __threadfence(); 239 | __syncthreads(); 240 | } 241 | } 242 | 243 | 244 | template 245 | void invoke_calculate_C(cudaStream_t stream, torch::Tensor &C, torch::Tensor &R, torch::Tensor &L, torch::Tensor &M, const torch::Tensor &ob_lprob, const torch::Tensor &links, const torch::Tensor &output_length, \ 246 | int bsz, int prelen, int max_left) 247 | { 248 | 249 | int n_seg = (prelen - 1) / SEQ_BLOCK_SIZE + 1; 250 | dim3 dimGrid(1, n_seg * bsz); 251 | dim3 dimBlock(1, SEQ_BLOCK_SIZE); 252 | static_assert(TRANS_BLOCK_SIZE <= 32, "TRANS_BLOCK_SIZE should be less than warp size"); 253 | 254 | AT_DISPATCH_FLOATING_TYPES( 255 | ob_lprob.scalar_type(), "invoke_calculate_C", [&] { 256 | C.fill_(-std::numeric_limits::infinity()); 257 | calculate_C_kernel<<>>( 258 | C.packed_accessor64(), 259 | R.packed_accessor64(), 260 | L.packed_accessor64(), 261 | M.packed_accessor64(), 262 | ob_lprob.packed_accessor64(), 263 | links.packed_accessor64(), 264 | output_length.packed_accessor64(), 265 | bsz, prelen, max_left, n_seg 266 | ); 267 | } 268 | ); 269 | 270 | } 271 | 272 | std::tuple pcfg_viterbi(const torch::Tensor &ob_lprob, const torch::Tensor &links, 273 | const torch::Tensor &output_length, 274 | int config) 275 | { 276 | 277 | 278 | auto bsz = ob_lprob.size(0); 279 | auto prelen = ob_lprob.size(1); 280 | auto max_left = links.size(2); 281 | max_left = max_left - 1; 282 | 283 | 284 | 285 | torch::Tensor S = at::zeros({bsz, prelen, prelen/4}, ob_lprob.options()); 286 | torch::Tensor C = at::zeros({bsz, prelen, prelen/4}, ob_lprob.options()); 287 | torch::Tensor R = at::zeros({bsz, prelen, prelen/4}, output_length.options()); 288 | torch::Tensor L = at::zeros({bsz, prelen, prelen/4}, output_length.options()); 289 | torch::Tensor M = at::zeros({bsz, prelen, prelen/4}, output_length.options()); 290 | cudaStream_t current_stream = 0; 291 | switch(config){ 292 | case 1: invoke_calculate_C<1, 256>(current_stream, C, R, L, M, ob_lprob, links, output_length, bsz, prelen, max_left); break; 293 | default: TORCH_CHECK(config <= 4 && config >= 1, "config should be 1~4"); 294 | } 295 | switch(config){ 296 | case 1: invoke_calculate_S<4, 128>(current_stream, S, C, R, L, M, ob_lprob, links, output_length, bsz, prelen, max_left); break; 297 | default: TORCH_CHECK(config <= 4 && config >= 1, "config should be 1~4"); 298 | } 299 | 300 | return std::make_tuple(S, R, L, M); 301 | } 302 | 303 | 304 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/utilities.h: -------------------------------------------------------------------------------- 1 | #define GCC_VERSION (__GNUC__ * 10000 \ 2 | + __GNUC_MINOR__ * 100 \ 3 | + __GNUC_PATCHLEVEL__) 4 | 5 | #if GCC_VERSION >= 70000 6 | #define if_constexpr(expression) if constexpr (expression) 7 | #else 8 | #define if_constexpr(expression) if(expression) 9 | #endif 10 | -------------------------------------------------------------------------------- /fs_plugins/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.models." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/models/glat_decomposed_with_link_two_hands_tri_pcfg.py: -------------------------------------------------------------------------------- 1 | from fairseq.models.nat.fairseq_nat_model import FairseqNATModel 2 | import logging 3 | import random 4 | import copy 5 | import math 6 | from typing import Any, Dict, List, Optional, Tuple 7 | import numpy as np 8 | import torch 9 | from torch import Tensor, nn, jit 10 | import torch.nn.functional as F 11 | from fairseq import utils 12 | from fairseq.iterative_refinement_generator import DecoderOut 13 | from fairseq.models import register_model, register_model_architecture 14 | from fairseq.modules import ( 15 | PositionalEmbedding, 16 | ) 17 | from .lemon_tree import BinaryTreeNode 18 | from fairseq.modules.transformer_sentence_encoder import init_bert_params 19 | from fairseq.models.nat.nonautoregressive_transformer import NATransformerDecoder 20 | from contextlib import contextmanager 21 | from .lemon_tree import * 22 | from ..custom_ops import cuda_pcfg_viterbi, viterbi_decoding 23 | import sys 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | @contextmanager 28 | def torch_seed(seed): 29 | # modified from lunanlp 30 | state = torch.random.get_rng_state() 31 | state_cuda = torch.cuda.random.get_rng_state() 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | try: 35 | yield 36 | finally: 37 | torch.random.set_rng_state(state) 38 | torch.cuda.random.set_rng_state(state_cuda) 39 | 40 | # @jit.script 41 | def logsumexp(x: Tensor, dim: int) -> Tensor: 42 | m, _ = x.max(dim=dim) 43 | mask = m == -float('inf') 44 | 45 | s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim) 46 | return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float('inf')) 47 | 48 | @register_model("glat_decomposed_with_link_two_hands_tri_pcfg") 49 | class GlatDecomposedLinkPCFG(FairseqNATModel): 50 | 51 | def __init__(self, args, encoder, decoder): 52 | super().__init__(args, encoder, decoder) 53 | self.init_beam_search() 54 | self._left_tree_mask = torch.empty(0) 55 | self._right_tree_mask = torch.empty(0) 56 | self._main_chain = torch.empty(0) 57 | self.max_left = 2**self.args.left_tree_layer-1 58 | self.total_length = 0 59 | self.total_main_chain = 0 60 | self.layer_1_count = 0 61 | self.layer_2_count = 0 62 | # self.main_chain_subword = [0,0,0] 63 | # self.left_tree_subword = [0,0,0] 64 | self.main_chain_subword = 0 65 | self.left_tree_subword = 0 66 | self.main_chain_pos_dict = {} 67 | self.left_tree_pos_dict = {} 68 | 69 | 70 | def init_beam_search(self): 71 | if self.args.decode_strategy == "beamsearch": 72 | import dag_search 73 | self.dag_search = dag_search 74 | dag_search.beam_search_init(self.args.decode_max_batchsize, self.args.decode_beamsize, 75 | self.args.decode_top_cand_n, self.decoder.max_positions(), self.tgt_dict, self.args.decode_lm_path) 76 | 77 | @classmethod 78 | def build_decoder(cls, args, tgt_dict, embed_tokens): 79 | decoder = GlatLinkDecoderPCFG(args, tgt_dict, embed_tokens) 80 | if getattr(args, "apply_bert_init", False): 81 | decoder.apply(init_bert_params) 82 | return decoder 83 | 84 | @staticmethod 85 | def add_args(parser): 86 | FairseqNATModel.add_args(parser) 87 | GlatLinkDecoderPCFG.add_args(parser) 88 | 89 | # length prediction 90 | parser.add_argument( 91 | "--src-embedding-copy", 92 | action="store_true", 93 | help="copy encoder word embeddings as the initial input of the decoder", 94 | ) 95 | parser.add_argument( 96 | "--pred-length-offset", 97 | action="store_true", 98 | help="predicting the length difference between the target and source sentences", 99 | ) 100 | parser.add_argument( 101 | "--sg-length-pred", 102 | action="store_true", 103 | help="stop the gradients back-propagated from the length predictor", 104 | ) 105 | parser.add_argument( 106 | "--length-loss-factor", 107 | type=float, 108 | help="weights on the length prediction loss", 109 | ) 110 | 111 | parser.add_argument('--links-feature', type=str, default="feature:position", help="Features used to predict transition.") 112 | parser.add_argument('--max-transition-length', type=int, default=99999, help="Max transition distance. -1 means no limitation, \ 113 | which cannot be used for cuda custom operations. To use cuda operations with no limitation, please use a very large number such as 99999.") 114 | 115 | parser.add_argument("--left-tree-layer", type=int, default=None, help="tree layer size of left sub tree") 116 | 117 | parser.add_argument("--src-upsample-scale", type=float, default=None, help="Specify the graph size with a upsample factor (lambda). Graph Size = \\lambda * src_length") 118 | parser.add_argument("--src-upsample-fixed", type=int, default=None, help="Specify the graph size by a constant. Cannot use together with src-upsample-scale") 119 | parser.add_argument("--length-multiplier", type=float, default=None, help="Deprecated") # does not work now 120 | parser.add_argument('--max-decoder-batch-tokens', type=int, default=None, help="Max tokens for LightSeq Decoder when using --src-upsample-fixed") 121 | 122 | parser.add_argument('--filter-max-length', default=None, type=str, help='Filter the sample that above the max lengths, e.g., "128:256" indicating 128 for source, 256 for target. Default: None, for filtering according max-source-positions and max-target-positions') 123 | parser.add_argument("--filter-ratio", type=float, default=None, help="Deprecated") # does not work now; need support of trainer.py 124 | 125 | parser.add_argument('--decode-strategy', type=str, default="lookahead", help='One of "greedy", "lookahead", "beamsearch"') 126 | 127 | parser.add_argument('--decode-alpha', type=float, default=1.1, help="Used for length penalty. Beam Search finds the sentence maximize: 1 / |Y|^{alpha} [ log P(Y) + gamma log P_{n-gram}(Y)]") 128 | parser.add_argument('--decode-beta', type=float, default=1, help="Scale the score of logits. log P(Y, A) := sum P(y_i|a_i) + beta * sum log(a_i|a_{i-1})") 129 | parser.add_argument('--decode-top-cand-n', type=float, default=5, help="Numbers of top candidates when considering transition") 130 | parser.add_argument('--decode-gamma', type=float, default=0.1, help="Used for n-gram language model score. Beam Search finds the sentence maximize: 1 / |Y|^{alpha} [ log P(Y) + gamma log P_{n-gram}(Y)]") 131 | parser.add_argument('--decode-beamsize', type=float, default=100, help="Beam size") 132 | parser.add_argument('--decode-max-beam-per-length', type=float, default=10, help="Limits the number of beam that has a same length in each step") 133 | parser.add_argument('--decode-top-p', type=float, default=0.9, help="Max probability of top candidates when considering transition") 134 | parser.add_argument('--decode-lm-path', type=str, default=None, help="Path to n-gram language model. None for not using n-gram LM") 135 | parser.add_argument('--decode-max-batchsize', type=int, default=32, help="Should not be smaller than the real batch size (the value is used for memory allocation)") 136 | parser.add_argument('--decode-dedup', type=bool, default=False, help="Use token deduplication in BeamSearch") 137 | 138 | 139 | 140 | def extract_links(self, features, prev_output_tokens, 141 | link_positional, query_linear, key_linear_left, key_linear_right, gate_linear, left_tree_mask, right_tree_mask, main_chain): 142 | 143 | links_feature = vars(self.args).get("links_feature", "feature:position").split(":") 144 | 145 | links_feature_arr = [] 146 | if "feature" in links_feature: 147 | links_feature_arr.append(features) 148 | if "position" in links_feature or "sinposition" in links_feature: 149 | links_feature_arr.append(link_positional(prev_output_tokens)) 150 | 151 | features_withpos = torch.cat(links_feature_arr, dim=-1) 152 | 153 | batch_size = features.shape[0] 154 | seqlen = features.shape[1] 155 | chunk_num = self.args.decoder_attention_heads 156 | chunk_size = self.args.decoder_embed_dim // self.args.decoder_attention_heads 157 | ninf = float("-inf") 158 | target_dtype = torch.float 159 | 160 | query_chunks = query_linear(features_withpos).reshape(batch_size, seqlen, chunk_num, chunk_size) 161 | key_chunks_left = key_linear_left(features_withpos).reshape(batch_size, seqlen, chunk_num, chunk_size) 162 | key_chunks_right = key_linear_right(features_withpos).reshape(batch_size, seqlen, chunk_num, chunk_size) 163 | log_gates = F.log_softmax(gate_linear(features_withpos), dim=-1, dtype=target_dtype) # batch_size * seqlen * chunk_num 164 | log_multi_content_ab = (torch.einsum("bicf,bjcf->bijc", query_chunks.to(dtype=target_dtype), key_chunks_left.to(dtype=target_dtype)) / ((chunk_size) ** 0.5)) 165 | log_multi_content_ac = (torch.einsum("bicf,bjcf->bijc", query_chunks.to(dtype=target_dtype), key_chunks_right.to(dtype=target_dtype)) / ((chunk_size) ** 0.5)) 166 | log_multi_content_bc_folded = (torch.einsum("bicf,bjcf->bijc", key_chunks_left.to(dtype=target_dtype), key_chunks_right.to(dtype=target_dtype)) / ((chunk_size) ** 0.5)) 167 | 168 | link_left_mask = torch.logical_or(prev_output_tokens.eq(self.pad).unsqueeze(1),left_tree_mask.unsqueeze(0)) 169 | link_right_mask = torch.logical_or(prev_output_tokens.eq(self.pad).unsqueeze(1),right_tree_mask.unsqueeze(0)) 170 | 171 | link_right_mask = ~(~link_right_mask & main_chain.unsqueeze(0).unsqueeze(0)) 172 | output_length = prev_output_tokens.ne(self.pad).sum(dim=-1) 173 | # assert((output_length - 2)) 174 | link_left_nouse_mask = (~link_left_mask).sum(dim=2, keepdim=True) == 0 175 | link_right_nouse_mask = (~link_right_mask).sum(dim=2, keepdim=True) == 0 176 | 177 | link_left_mask.masked_fill_(link_left_nouse_mask, False) 178 | link_right_mask.masked_fill_(link_right_nouse_mask, False) 179 | link_nouse_mask = torch.logical_or(link_left_nouse_mask, link_right_nouse_mask) 180 | 181 | log_multi_content_ab = log_multi_content_ab.masked_fill(link_left_mask.view(batch_size, seqlen, seqlen, 1), ninf) 182 | log_multi_content_ac = log_multi_content_ac.masked_fill(link_right_mask.view(batch_size, seqlen, seqlen, 1), ninf) 183 | 184 | index = self._max_left_index[:, :seqlen, :,].to(log_multi_content_ab.device) 185 | log_multi_content_ab = torch.gather(log_multi_content_ab, dim=2, index=index.view(1, seqlen, self.max_left+1, 1).expand(batch_size, -1, -1, chunk_num)) 186 | 187 | log_multi_content_bc = log_multi_content_bc_folded.unfold(1, self.max_left, 1) 188 | 189 | log_multi_content_bc = torch.roll(log_multi_content_bc, 1, 1) 190 | log_multi_content_bc = log_multi_content_bc.permute(0,1,4,2,3) 191 | log_multi_content_bc = torch.cat((log_multi_content_bc_folded[:,0,:,:].view(batch_size, 1, 1, seqlen, chunk_num).expand(-1, seqlen, -1, -1, -1) ,log_multi_content_bc), 2) 192 | 193 | log_multi_content_abc = log_multi_content_ab.unsqueeze(3) + log_multi_content_bc + log_multi_content_ac.unsqueeze(2) 194 | 195 | log_multi_content_abc = F.log_softmax(log_multi_content_abc.reshape(batch_size, seqlen, (self.max_left+1)*seqlen, chunk_num), dim=2) 196 | 197 | links = logsumexp(log_multi_content_abc + log_gates.unsqueeze(2), dim=-1) 198 | links = links.view(batch_size, seqlen, self.max_left+1, seqlen) 199 | links = links.masked_fill(link_nouse_mask.view(batch_size, seqlen, 1, 1), ninf) 200 | 201 | return links 202 | 203 | def buffered_tree_mask(self, tensor): 204 | dim = tensor.size(1) 205 | 206 | # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. 207 | if (self._left_tree_mask.size(0) == 0 or self._left_tree_mask.size(0) < dim): 208 | _left_tree_mask, _right_tree_mask, _main_chain = BinaryTreeNode.get_mask(dim, self.args.left_tree_layer, self.args.src_upsample_scale) 209 | 210 | self._left_tree_mask = _left_tree_mask 211 | self._right_tree_mask = _right_tree_mask 212 | self._main_chain = _main_chain 213 | self._max_left_index = torch.arange(0, dim).unsqueeze(0).unsqueeze(-1) 214 | self._max_left_index = self._max_left_index + torch.zeros(1, dim, self.max_left+1, dtype=torch.int64) 215 | self._max_left_index[:,:,0] = 0 216 | 217 | for i in range(1, self.max_left+1): 218 | self._max_left_index[:,:,i] = torch.where((self._max_left_index[:,:,i] - i) < 0, 0, self._max_left_index[:,:,i] - i) 219 | self._left_tree_mask = self._left_tree_mask.bool() 220 | self._right_tree_mask = self._right_tree_mask.bool() 221 | self._main_chain = self._main_chain.bool() 222 | 223 | self._left_tree_mask = self._left_tree_mask.to(tensor.device) 224 | self._right_tree_mask = self._right_tree_mask.to(tensor.device) 225 | self._main_chain = self._main_chain.to(tensor.device) 226 | return self._left_tree_mask[:dim, :dim], self._right_tree_mask[:dim, :dim], self._main_chain[:dim] 227 | 228 | def extract_features(self, prev_output_tokens, encoder_out, rand_seed, require_links=False): 229 | with torch_seed(rand_seed): 230 | features, _ = self.decoder.extract_features( 231 | prev_output_tokens, 232 | encoder_out=encoder_out, 233 | embedding_copy=False 234 | ) 235 | # word_ins_out = self.decoder.output_layer(features) 236 | left_tree_mask, right_tree_mask, main_chain = self.buffered_tree_mask(features) 237 | word_ins_out = self.decoder.output_projection(features) 238 | links = None 239 | if require_links: 240 | links = self.extract_links(features, \ 241 | prev_output_tokens, \ 242 | self.decoder.link_positional, \ 243 | self.decoder.query_linear, \ 244 | self.decoder.key_linear_left, \ 245 | self.decoder.key_linear_right, \ 246 | self.decoder.gate_linear, \ 247 | left_tree_mask, \ 248 | right_tree_mask, \ 249 | main_chain, \ 250 | ) 251 | 252 | 253 | 254 | return word_ins_out, links, left_tree_mask, right_tree_mask, main_chain 255 | 256 | def forward( 257 | self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, glat=None, glat_function=None, **kwargs 258 | ): 259 | # encoding 260 | encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) 261 | 262 | # length prediction 263 | length_out = self.decoder.forward_length( 264 | normalize=False, encoder_out=encoder_out 265 | ) 266 | length_tgt = self.decoder.forward_length_prediction( 267 | length_out, encoder_out, tgt_tokens 268 | ) 269 | rand_seed = random.randint(0, 19260817) 270 | glat_info = None 271 | if glat and tgt_tokens is not None: 272 | with torch.set_grad_enabled(glat.get('require_glance_grad', False)): 273 | word_ins_out, links, left_tree_mask, right_tree_mask, main_chain = self.extract_features(prev_output_tokens, encoder_out, rand_seed, require_links=True) 274 | prev_output_tokens, tgt_tokens, glat_info = glat_function(self, word_ins_out, tgt_tokens, prev_output_tokens, glat, links=links) 275 | word_ins_out = None 276 | 277 | word_ins_out, links, left_tree_mask, right_tree_mask, main_chain = self.extract_features(prev_output_tokens, encoder_out, rand_seed, require_links=True) 278 | 279 | ret = { 280 | "word_ins": { 281 | "out": word_ins_out, 282 | "tgt": tgt_tokens, 283 | "mask": tgt_tokens.ne(self.pad), 284 | "main_chain": main_chain, 285 | "nll_loss": True, 286 | } 287 | } 288 | ret['links'] = links 289 | ret['left_tree_mask'] = left_tree_mask 290 | ret['right_tree_mask'] = right_tree_mask 291 | ret["length"] = { 292 | "out": length_out, 293 | "tgt": length_tgt, 294 | "factor": self.decoder.length_loss_factor, 295 | } 296 | if glat_info is not None: 297 | ret.update(glat_info) 298 | return ret 299 | 300 | 301 | def initialize_output_tokens_with_length(self, src_tokens, length_tgt): 302 | max_length = length_tgt.max() 303 | if length_tgt.min() < 2: 304 | 305 | print(length_tgt) 306 | assert(False) 307 | idx_length = utils.new_arange(src_tokens, max_length) 308 | 309 | initial_output_tokens = src_tokens.new_zeros( 310 | src_tokens.size(0), max_length 311 | ).fill_(self.pad) 312 | initial_output_tokens.masked_fill_( 313 | idx_length[None, :] < length_tgt[:, None], self.unk 314 | ) 315 | initial_output_tokens[:, 0] = self.bos 316 | initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) 317 | return initial_output_tokens 318 | 319 | def initialize_output_tokens_upsample_by_tokens(self, src_tokens): 320 | if vars(self.args).get("src_upsample_scale", None) is not None: 321 | length_tgt = torch.sum(src_tokens.ne(self.tgt_dict.pad_index), -1) 322 | length_tgt = (length_tgt * self.args.src_upsample_scale * (2**self.args.left_tree_layer)).long().clamp_(min=2) + 2 323 | else: 324 | length_tgt = torch.zeros(src_tokens.shape[0], device=src_tokens.device, dtype=src_tokens.dtype).fill_(self.args.src_upsample_fixed) 325 | return self.initialize_output_tokens_with_length(src_tokens, length_tgt) 326 | 327 | def initialize_output_tokens_multiplier_by_tokens(self, src_tokens, tgt_tokens): 328 | length_tgt = torch.sum(tgt_tokens.ne(self.tgt_dict.pad_index), -1) 329 | length_tgt = (length_tgt * self.args.length_multiplier).long().clamp_(min=2) 330 | return self.initialize_output_tokens_with_length(src_tokens, length_tgt) 331 | 332 | def initialize_output_tokens_by_tokens(self, src_tokens, tgt_tokens): 333 | if vars(self.args).get("src_upsample_scale", None) is not None or vars(self.args).get("src_upsample_fixed", None) is not None: 334 | return self.initialize_output_tokens_upsample_by_tokens(src_tokens) 335 | elif vars(self.args).get("length_multiplier", None) is not None: 336 | return self.initialize_output_tokens_multiplier_by_tokens(src_tokens, tgt_tokens) 337 | 338 | def initialize_output_tokens_upsample(self, encoder_out, src_tokens): 339 | # length prediction 340 | if vars(self.args).get("src_upsample_scale", None) is not None: 341 | length_tgt = torch.sum(src_tokens.ne(self.tgt_dict.pad_index), -1) 342 | length_tgt = (length_tgt * self.args.src_upsample_scale * (2**self.args.left_tree_layer)).long().clamp_(min=2) + 2 343 | else: 344 | length_tgt = torch.zeros(src_tokens.shape[0], device=src_tokens.device, dtype=src_tokens.dtype).fill_(self.args.src_upsample_fixed) 345 | initial_output_tokens = self.initialize_output_tokens_with_length(src_tokens, length_tgt) 346 | 347 | initial_output_scores = initial_output_tokens.new_zeros( 348 | *initial_output_tokens.size() 349 | ).type_as(encoder_out["encoder_out"][0]) 350 | 351 | return DecoderOut( 352 | output_tokens=initial_output_tokens, 353 | output_scores=initial_output_scores, 354 | attn=None, 355 | step=0, 356 | max_step=0, 357 | history=None, 358 | ) 359 | 360 | def initialize_output_tokens_multiplier(self, encoder_out, src_tokens): 361 | # length prediction 362 | length_tgt = self.decoder.forward_length_prediction( 363 | self.decoder.forward_length(normalize=True, encoder_out=encoder_out), 364 | encoder_out=encoder_out, 365 | ) 366 | length_tgt = (length_tgt * self.args.length_multiplier).long().clamp_(min=2) 367 | initial_output_tokens = self.initialize_output_tokens_with_length(src_tokens, length_tgt) 368 | 369 | initial_output_scores = initial_output_tokens.new_zeros( 370 | *initial_output_tokens.size() 371 | ).type_as(encoder_out["encoder_out"][0]) 372 | 373 | return DecoderOut( 374 | output_tokens=initial_output_tokens, 375 | output_scores=initial_output_scores, 376 | attn=None, 377 | step=0, 378 | max_step=0, 379 | history=None, 380 | ) 381 | 382 | def initialize_output_tokens(self, encoder_out, src_tokens): 383 | if vars(self.args).get("src_upsample_scale", None) is not None or vars(self.args).get("src_upsample_fixed", None) is not None: 384 | return self.initialize_output_tokens_upsample(encoder_out, src_tokens) 385 | elif vars(self.args).get("length_multiplier", None) is not None: 386 | return self.initialize_output_tokens_multiplier(encoder_out, src_tokens) 387 | 388 | def max_positions(self): 389 | if vars(self.args).get("filter_max_length", None) is not None: 390 | if ":" not in self.args.filter_max_length: 391 | a = b = int(self.args.filter_max_length) 392 | else: 393 | a, b = self.args.filter_max_length.split(":") 394 | a, b = int(a), int(b) 395 | return (a, b) 396 | else: 397 | if vars(self.args).get("src_upsample_fixed", None) is not None: 398 | return (self.encoder.max_positions(), self.decoder.max_positions()) 399 | elif vars(self.args).get("src_upsample_scale", None) is not None: 400 | return (min(self.encoder.max_positions(), int(self.decoder.max_positions() / self.args.src_upsample_scale)), self.decoder.max_positions()) 401 | else: 402 | return (min(self.encoder.max_positions(), int(self.decoder.max_positions() / self.args.length_multiplier)), self.decoder.max_positions()) 403 | 404 | def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): 405 | step = decoder_out.step 406 | output_tokens = decoder_out.output_tokens 407 | 408 | history = decoder_out.history 409 | rand_seed = random.randint(0, 19260817) 410 | 411 | # execute the decoder 412 | output_logits, links, left_tree_mask, right_tree_mask, main_chain = self.extract_features(output_tokens, encoder_out, rand_seed, require_links=True) 413 | 414 | output_logits_normalized = output_logits.log_softmax(dim=-1) 415 | unreduced_logits, unreduced_tokens_torch = output_logits_normalized.max(dim=-1) 416 | unreduced_tokens = unreduced_tokens_torch.tolist() 417 | bsz, prelen, left_span, _ = links.size() 418 | 419 | if self.args.decode_strategy in ["lookahead", "greedy"]: 420 | if self.args.decode_strategy == "lookahead": 421 | output_length = torch.sum(output_tokens.ne(self.tgt_dict.pad_index), dim=-1).tolist() 422 | links_lookahead_right = links + unreduced_logits.unsqueeze(1).unsqueeze(1) * self.args.decode_beta 423 | links_lookahead_left = torch.zeros(bsz, prelen, left_span).to(links) 424 | 425 | links_lookahead_left = torch.gather(unreduced_logits.unsqueeze(1).repeat(bsz,prelen,1), dim=2, index=self._max_left_index[:, :prelen, :,].to(links.device)) * self.args.decode_beta 426 | 427 | links_lookahead = links_lookahead_right + links_lookahead_left.unsqueeze(-1) 428 | links_idx = links_lookahead.view(bsz, prelen, left_span*prelen).max(dim=-1)[1].cpu().tolist() # batch * prelen 429 | elif self.args.decode_strategy == "greedy": 430 | output_length = torch.sum(output_tokens.ne(self.tgt_dict.pad_index), dim=-1).tolist() 431 | links_idx = links.view(bsz, prelen, -1).max(dim=-1)[1].cpu().tolist() # batch * prelen 432 | 433 | unpad_output_tokens = [] 434 | count_last = 0 435 | for i, length in enumerate(output_length): 436 | 437 | now = 1 438 | now_node = BinaryTreeNode(now) 439 | root_node = now_node 440 | tree_stack = [] 441 | stack = [] 442 | res = [] 443 | path = [] 444 | while (now < length and now != 0) or len(stack) > 0: 445 | 446 | while now < length and now != 0: 447 | stack.append(now) 448 | 449 | tree_stack.append(now_node) 450 | 451 | next_idx = links_idx[i][now] 452 | links_left_idx = next_idx // prelen 453 | links_left_idx = 0 if links_left_idx==0 else now - links_left_idx 454 | now_node.leftChild = BinaryTreeNode(links_left_idx) 455 | now_node = now_node.leftChild 456 | if left_tree_mask[now][links_left_idx]: 457 | break 458 | 459 | now = links_left_idx 460 | 461 | 462 | now = stack.pop() 463 | now_node = tree_stack.pop() 464 | now_token = unreduced_tokens[i][now] 465 | 466 | if now_token != self.tgt_dict.pad_index: 467 | res.append(now_token) 468 | path.append(now) 469 | now_node.order = (now_node.order, self.tgt_dict[now_token]) 470 | 471 | 472 | 473 | next_idx = links_idx[i][now] 474 | links_right_idx = next_idx % prelen 475 | now = links_right_idx 476 | now_node.rightChild = BinaryTreeNode(links_right_idx) 477 | now_node = now_node.rightChild 478 | 479 | unpad_output_tokens.append(res) 480 | output_seqlen = max([len(res) for res in unpad_output_tokens]) 481 | output_tokens = [res + [self.tgt_dict.pad_index] * (output_seqlen - len(res)) for res in unpad_output_tokens] 482 | output_tokens = torch.tensor(output_tokens, device=decoder_out.output_tokens.device, dtype=decoder_out.output_tokens.dtype) 483 | elif self.args.decode_strategy in ["viterbi"]: 484 | output_length = torch.sum(output_tokens.ne(self.tgt_dict.pad_index), dim=-1) 485 | 486 | scores, R_trace, L_trace, M_trace = cuda_pcfg_viterbi(unreduced_logits, links, output_length.cuda(links.get_device()).long()) 487 | # scores, R_trace, L_trace, M_trace = scores.cpu(), R_trace.cpu(), L_trace.cpu(), M_trace.cpu() 488 | lengths = torch.arange(prelen//4).unsqueeze(0) 489 | # print(scores.size(), lengths.size()) 490 | length_penalty = (lengths ** self.args.decode_viterbibeta).cuda(scores.get_device()) 491 | scores = scores / length_penalty 492 | 493 | invalid_masks = scores.isnan() 494 | scores.masked_fill_(invalid_masks, float("-inf")) 495 | 496 | max_score, pred_length = torch.max(scores[:,1:], dim = -1) 497 | pred_length = pred_length+1 498 | 499 | 500 | R_trace, L_trace, M_trace = R_trace.cpu(), L_trace.cpu(), M_trace.cpu() 501 | output_tokens = viterbi_decoding(pred_length.long(), output_length.long(), \ 502 | L_trace.long(), R_trace.long(), M_trace.long(), unreduced_tokens_torch.long(), left_tree_mask.long(), self.tgt_dict.pad_index).to(device=pred_length.device) 503 | if history is not None: 504 | history.append(output_tokens.clone()) 505 | 506 | return decoder_out._replace( 507 | output_tokens=output_tokens, 508 | output_scores=torch.full(output_tokens.size(), 1.0), 509 | attn=None, 510 | history=history, 511 | ) 512 | 513 | class GlatLinkDecoderPCFG(NATransformerDecoder): 514 | 515 | def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): 516 | super().__init__(args, dictionary, embed_tokens, no_encoder_attn) 517 | self.init_link_feature(args) 518 | 519 | def init_link_feature(self, args): 520 | links_feature = self.args.links_feature.split(":") 521 | links_dim = 0 522 | if "feature" in links_feature: 523 | links_dim += args.decoder_embed_dim 524 | if "position" in links_feature: 525 | self.link_positional = PositionalEmbedding(args.max_target_positions, args.decoder_embed_dim, self.padding_idx, True) 526 | links_dim += args.decoder_embed_dim 527 | elif "sinposition" in links_feature: 528 | self.link_positional = PositionalEmbedding(args.max_target_positions, args.decoder_embed_dim, self.padding_idx, False) 529 | links_dim += args.decoder_embed_dim 530 | else: 531 | self.link_positional = None 532 | 533 | 534 | self.query_linear = nn.Linear(links_dim, args.decoder_embed_dim) 535 | self.key_linear_left = nn.Linear(links_dim, args.decoder_embed_dim) 536 | self.key_linear_right = nn.Linear(links_dim, args.decoder_embed_dim) 537 | self.gate_linear = nn.Linear(links_dim, args.decoder_attention_heads) 538 | 539 | @staticmethod 540 | def add_args(parser): 541 | pass 542 | 543 | @register_model_architecture( 544 | "glat_decomposed_with_link_two_hands_tri_pcfg", "glat_decomposed_with_link_two_hands_tri_pcfg_6e6d512" 545 | ) 546 | def base_architecture(args): 547 | args.encoder_embed_path = getattr(args, "encoder_embed_path", None) 548 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) 549 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) 550 | args.encoder_layers = getattr(args, "encoder_layers", 6) 551 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) 552 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) 553 | args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) 554 | args.decoder_embed_path = getattr(args, "decoder_embed_path", None) 555 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) 556 | args.decoder_ffn_embed_dim = getattr( 557 | args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim 558 | ) 559 | args.decoder_layers = getattr(args, "decoder_layers", 6) 560 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) 561 | args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) 562 | args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) 563 | args.attention_dropout = getattr(args, "attention_dropout", 0.0) 564 | args.activation_dropout = getattr(args, "activation_dropout", 0.0) 565 | args.activation_fn = getattr(args, "activation_fn", "relu") 566 | args.dropout = getattr(args, "dropout", 0.1) 567 | args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) 568 | args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) 569 | args.share_decoder_input_output_embed = getattr( 570 | args, "share_decoder_input_output_embed", False 571 | ) 572 | args.share_all_embeddings = getattr(args, "share_all_embeddings", False) 573 | args.no_token_positional_embeddings = getattr( 574 | args, "no_token_positional_embeddings", False 575 | ) 576 | args.adaptive_input = getattr(args, "adaptive_input", False) 577 | args.apply_bert_init = getattr(args, "apply_bert_init", False) 578 | 579 | args.decoder_output_dim = getattr( 580 | args, "decoder_output_dim", args.decoder_embed_dim 581 | ) 582 | args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) 583 | 584 | # --- special arguments --- 585 | args.sg_length_pred = getattr(args, "sg_length_pred", False) 586 | args.pred_length_offset = getattr(args, "pred_length_offset", False) 587 | args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) 588 | args.src_embedding_copy = getattr(args, "src_embedding_copy", False) 589 | 590 | @register_model_architecture( 591 | "glat_decomposed_with_link_two_hands_tri_pcfg", "glat_decomposed_with_link_two_hands_tri_pcfg_base" 592 | ) 593 | def base_architecture2(args): 594 | base_architecture(args) 595 | 596 | @register_model_architecture("glat_decomposed_with_link_two_hands_tri_pcfg", "glat_decomposed_with_link_two_hands_tri_pcfg_iwslt_de_en") 597 | def nonautoregressive_transformer_iwslt_de_en(args): 598 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) 599 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) 600 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) 601 | args.encoder_layers = getattr(args, "encoder_layers", 6) 602 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) 603 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) 604 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) 605 | args.decoder_layers = getattr(args, "decoder_layers", 6) 606 | base_architecture(args) 607 | -------------------------------------------------------------------------------- /fs_plugins/models/lemon_tree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | class BinaryTreeNode: 5 | def __init__(self, order=None): 6 | self.order = order 7 | self.leftChild = None 8 | self.rightChild = None 9 | self.isMainChain = False 10 | self.leftArray = [] 11 | self.rightArray = [] 12 | 13 | def __str__(self, level=0): 14 | ret = " "*level+repr(self.order)+"\n" 15 | if self.leftChild is not None: 16 | ret += self.leftChild.__str__(level+1) 17 | if self.rightChild is not None: 18 | ret += self.rightChild.__str__(level+1) 19 | 20 | return ret 21 | 22 | @staticmethod 23 | def subtree(layer_size): 24 | if layer_size == 0: 25 | return None 26 | root = BinaryTreeNode() 27 | stack = [root] 28 | for i in range(layer_size-1): 29 | stack_new = [] 30 | for node in stack: 31 | left = BinaryTreeNode() 32 | right = BinaryTreeNode() 33 | stack_new.extend([left, right]) 34 | node.leftChild = left 35 | node.rightChild = right 36 | stack = stack_new 37 | 38 | return root 39 | 40 | @staticmethod 41 | def build_tree(dim, left_tree_layer, src_upsample_scale): 42 | left_sub_tree_size = 2**left_tree_layer - 1 43 | root = BinaryTreeNode() 44 | root.leftChild = BinaryTreeNode() 45 | now = root 46 | now.isMainChain = True 47 | main_chain_size = torch.div(dim-2, (2**left_tree_layer), rounding_mode='trunc').long() 48 | # print(main_chain_size) 49 | # print(main_chain_size) 50 | # main_chain_size = (dim-2) // (src_upsample_scale + left_sub_tree_size*src_upsample_scale) * src_upsample_scale 51 | for i in range(main_chain_size): 52 | now.rightChild = BinaryTreeNode() 53 | now = now.rightChild 54 | now.isMainChain = True 55 | now.leftChild = BinaryTreeNode.subtree(left_tree_layer) 56 | # now = now.rightChild 57 | # now.rightChild = BinaryTreeNode() 58 | # now.rightChild.isMainChain = True 59 | return root 60 | 61 | @staticmethod 62 | def inorder(root, array): 63 | 64 | if root is None: 65 | return array 66 | 67 | array = BinaryTreeNode.inorder(root.leftChild, array) 68 | root.order = len(array) 69 | array.append(root) 70 | array = BinaryTreeNode.inorder(root.rightChild, array) 71 | 72 | if root.leftChild is not None: 73 | 74 | root.leftArray.append(root.leftChild.order) 75 | root.leftArray.extend(root.leftChild.leftArray), root.leftArray.extend(root.leftChild.rightArray) 76 | 77 | if root.rightChild is not None: 78 | root.rightArray.append(root.rightChild.order) 79 | 80 | root.rightArray.extend(root.rightChild.leftArray), root.rightArray.extend(root.rightChild.rightArray) 81 | # print(root.order, root.rightArray) 82 | return array 83 | 84 | 85 | @staticmethod 86 | def get_root(dim, left_tree_layer, src_upsample_scale): 87 | root = BinaryTreeNode.build_tree(dim, left_tree_layer, src_upsample_scale) 88 | BinaryTreeNode.inorder(root, []) 89 | return root 90 | 91 | @staticmethod 92 | def get_mask(dim, left_tree_layer, src_upsample_scale): 93 | sys.setrecursionlimit(2048+16) 94 | root = BinaryTreeNode.build_tree(dim, left_tree_layer, src_upsample_scale) 95 | BinaryTreeNode.inorder(root, []) 96 | _left_tree_mask = torch.ones([dim, dim]) 97 | _right_tree_mask = torch.ones([dim, dim]) 98 | _stop_mask = torch.zeros(dim) 99 | _main_chain = torch.zeros(dim) 100 | stack = [root] 101 | while len(stack) > 0: 102 | now = stack.pop() 103 | if now is None: 104 | continue 105 | _left_tree_mask[now.order, now.leftArray] = 0 106 | _right_tree_mask[now.order, now.rightArray] = 0 107 | stack.append(now.rightChild) 108 | stack.append(now.leftChild) 109 | if now.rightChild is None and now.leftChild is None: 110 | _stop_mask[now.order] = 1 111 | if now.isMainChain: 112 | _main_chain[now.order] = 1 113 | # if now.order+1`_. 35 | """ 36 | 37 | cfg: TranslationLevenshteinConfig 38 | 39 | def load_dataset(self, split, epoch=1, combine=False, **kwargs): 40 | """Load a given dataset split. 41 | 42 | Args: 43 | split (str): name of the split (e.g., train, valid, test) 44 | """ 45 | paths = utils.split_paths(self.cfg.data) 46 | assert len(paths) > 0 47 | data_path = paths[(epoch - 1) % len(paths)] 48 | 49 | # infer langcode 50 | src, tgt = self.cfg.source_lang, self.cfg.target_lang 51 | 52 | self.datasets[split] = load_langpair_dataset( 53 | data_path, 54 | split, 55 | src, 56 | self.src_dict, 57 | tgt, 58 | self.tgt_dict, 59 | combine=combine, 60 | dataset_impl=self.cfg.dataset_impl, 61 | upsample_primary=self.cfg.upsample_primary, 62 | left_pad_source=self.cfg.left_pad_source, 63 | left_pad_target=self.cfg.left_pad_target, 64 | max_source_positions=self.cfg.max_source_positions, 65 | max_target_positions=self.cfg.max_target_positions, 66 | prepend_bos=True, 67 | ) 68 | 69 | def inject_noise(self, target_tokens): 70 | def _random_delete(target_tokens): 71 | pad = self.tgt_dict.pad() 72 | bos = self.tgt_dict.bos() 73 | eos = self.tgt_dict.eos() 74 | 75 | max_len = target_tokens.size(1) 76 | target_mask = target_tokens.eq(pad) 77 | target_score = target_tokens.clone().float().uniform_() 78 | target_score.masked_fill_( 79 | target_tokens.eq(bos) | target_tokens.eq(eos), 0.0 80 | ) 81 | target_score.masked_fill_(target_mask, 1) 82 | target_score, target_rank = target_score.sort(1) 83 | target_length = target_mask.size(1) - target_mask.float().sum( 84 | 1, keepdim=True 85 | ) 86 | 87 | # do not delete and (we assign 0 score for them) 88 | target_cutoff = ( 89 | 2 90 | + ( 91 | (target_length - 2) 92 | * target_score.new_zeros(target_score.size(0), 1).uniform_() 93 | ).long() 94 | ) 95 | target_cutoff = target_score.sort(1)[1] >= target_cutoff 96 | 97 | prev_target_tokens = ( 98 | target_tokens.gather(1, target_rank) 99 | .masked_fill_(target_cutoff, pad) 100 | .gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1]) 101 | ) 102 | prev_target_tokens = prev_target_tokens[ 103 | :, : prev_target_tokens.ne(pad).sum(1).max() 104 | ] 105 | 106 | return prev_target_tokens 107 | 108 | def _random_mask(target_tokens): 109 | pad = self.tgt_dict.pad() 110 | bos = self.tgt_dict.bos() 111 | eos = self.tgt_dict.eos() 112 | unk = self.tgt_dict.unk() 113 | 114 | target_masks = ( 115 | target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) 116 | ) 117 | target_score = target_tokens.clone().float().uniform_() 118 | target_score.masked_fill_(~target_masks, 2.0) 119 | target_length = target_masks.sum(1).float() 120 | target_length = target_length * target_length.clone().uniform_() 121 | target_length = target_length + 1 # make sure to mask at least one token. 122 | 123 | _, target_rank = target_score.sort(1) 124 | target_cutoff = new_arange(target_rank) < target_length[:, None].long() 125 | prev_target_tokens = target_tokens.masked_fill( 126 | target_cutoff.scatter(1, target_rank, target_cutoff), unk 127 | ) 128 | return prev_target_tokens 129 | 130 | def _full_mask(target_tokens): 131 | pad = self.tgt_dict.pad() 132 | bos = self.tgt_dict.bos() 133 | eos = self.tgt_dict.eos() 134 | unk = self.tgt_dict.unk() 135 | 136 | target_mask = ( 137 | target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad) 138 | ) 139 | return target_tokens.masked_fill(~target_mask, unk) 140 | 141 | if self.cfg.noise == "random_delete": 142 | return _random_delete(target_tokens) 143 | elif self.cfg.noise == "random_mask": 144 | return _random_mask(target_tokens) 145 | elif self.cfg.noise == "full_mask": 146 | return _full_mask(target_tokens) 147 | elif self.cfg.noise == "no_noise": 148 | return target_tokens 149 | else: 150 | raise NotImplementedError 151 | 152 | def build_generator(self, models, args, **unused): 153 | # add models input to match the API for SequenceGenerator 154 | from fairseq.iterative_refinement_generator import IterativeRefinementGenerator 155 | 156 | return IterativeRefinementGenerator( 157 | self.target_dictionary, 158 | eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0), 159 | max_iter=getattr(args, "iter_decode_max_iter", 10), 160 | beam_size=getattr(args, "iter_decode_with_beam", 1), 161 | reranking=getattr(args, "iter_decode_with_external_reranker", False), 162 | decoding_format=getattr(args, "decoding_format", None), 163 | adaptive=not getattr(args, "iter_decode_force_max_iter", False), 164 | retain_history=getattr(args, "retain_iter_history", False), 165 | ) 166 | 167 | def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): 168 | if constraints is not None: 169 | # Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ 170 | raise NotImplementedError( 171 | "Constrained decoding with the translation_lev task is not supported" 172 | ) 173 | 174 | return LanguagePairDataset( 175 | src_tokens, src_lengths, self.source_dictionary, append_bos=True 176 | ) 177 | 178 | def train_step( 179 | self, sample, model, criterion, optimizer, update_num, ignore_grad=False 180 | ): 181 | model.train() 182 | # print(update_num) 183 | sample['update_num'] = update_num 184 | sample["prev_target"] = self.inject_noise(sample["target"]) 185 | if ignore_grad: 186 | sample['dummy'] = True 187 | with torch.autograd.profiler.record_function("forward"): 188 | loss, sample_size, logging_output = criterion(model, sample) 189 | if ignore_grad: 190 | loss *= 0 191 | with torch.autograd.profiler.record_function("backward"): 192 | optimizer.backward(loss) 193 | return loss, sample_size, logging_output 194 | 195 | def valid_step(self, sample, model, criterion): 196 | model.eval() 197 | with torch.no_grad(): 198 | sample["prev_target"] = self.inject_noise(sample["target"]) 199 | loss, sample_size, logging_output = criterion(model, sample) 200 | EVAL_BLEU_ORDER = 4 201 | if self.cfg.eval_bleu: 202 | bleu = self._inference_with_bleu(self.sequence_generator, sample, model) 203 | logging_output["_bleu_sys_len"] = bleu.sys_len 204 | logging_output["_bleu_ref_len"] = bleu.ref_len 205 | # we split counts into separate entries so that they can be 206 | # summed efficiently across workers using fast-stat-sync 207 | assert len(bleu.counts) == EVAL_BLEU_ORDER 208 | for i in range(EVAL_BLEU_ORDER): 209 | logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] 210 | logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] 211 | return loss, sample_size, logging_output 212 | 213 | 214 | def _inference_with_bleu(self, generator, sample, model): 215 | import sacrebleu 216 | 217 | def decode(toks, escape_unk=False): 218 | s = self.tgt_dict.string( 219 | toks.int().cpu(), 220 | self.cfg.eval_bleu_remove_bpe, 221 | # The default unknown string in fairseq is ``, but 222 | # this is tokenized by sacrebleu as `< unk >`, inflating 223 | # BLEU scores. Instead, we use a somewhat more verbose 224 | # alternative that is unlikely to appear in the real 225 | # reference, but doesn't get split into multiple tokens. 226 | unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), 227 | ) 228 | if self.tokenizer: 229 | s = self.tokenizer.decode(s) 230 | return s 231 | 232 | gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) 233 | hyps, refs = [], [] 234 | for i in range(len(gen_out)): 235 | hyps.append(decode(gen_out[i][0]["tokens"])) 236 | refs.append( 237 | decode( 238 | utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), 239 | escape_unk=True, # don't count as matches to the hypo 240 | ) 241 | ) 242 | if self.cfg.eval_bleu_print_samples: 243 | logger.info("example hypothesis: " + hyps[0]) 244 | logger.info("example reference: " + refs[0]) 245 | if self.cfg.eval_tokenized_bleu: 246 | return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") 247 | else: 248 | if self.cfg.target_lang == "ja": 249 | return sacrebleu.corpus_bleu(hyps, [refs], tokenize="ja-mecab") 250 | elif self.cfg.target_lang == "zh": 251 | return sacrebleu.corpus_bleu(hyps, [refs], tokenize="zh") 252 | else: 253 | return sacrebleu.corpus_bleu(hyps, [refs]) 254 | -------------------------------------------------------------------------------- /test_scripts/test_pcfg_viterbi_wmt14_deen.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | 7 | fairseq-generate ${data_dir} \ 8 | --source-lang de \ 9 | --target-lang en \ 10 | --gen-subset test --user-dir $user_dir --task translation_lev_modified \ 11 | --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \ 12 | --remove-bpe --max-tokens 1024 --seed 0 \ 13 | --model-overrides "{\"decode_strategy\":\"viterbi\", \"decode_viterbibeta\":1.0}" \ 14 | --path $checkpoint_dir/average_best_5.pt 15 | 16 | -------------------------------------------------------------------------------- /test_scripts/test_pcfg_viterbi_wmt14_ende.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | 7 | fairseq-generate ${data_dir} \ 8 | --gen-subset test --user-dir $user_dir --task translation_lev_modified \ 9 | --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \ 10 | --remove-bpe --batch-size 1 --seed 0 \ 11 | --model-overrides "{\"decode_strategy\":\"viterbi\", \"decode_viterbibeta\":1.0}" \ 12 | --path $checkpoint_dir/average_best_5.pt 13 | 14 | -------------------------------------------------------------------------------- /test_scripts/test_pcfg_viterbi_wmt16_enro.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | 7 | fairseq-generate ${data_dir} \ 8 | --gen-subset test --user-dir $user_dir --task translation_lev_modified \ 9 | --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \ 10 | --remove-bpe --max-tokens 1024 --seed 0 \ 11 | --model-overrides "{\"decode_strategy\":\"viterbi\", \"decode_viterbibeta\":1.0}" \ 12 | --path ${checkpoint_path} 13 | 14 | -------------------------------------------------------------------------------- /test_scripts/test_pcfg_viterbi_wmt16_roen.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | 7 | fairseq-generate ${data_dir} \ 8 | --source-lang ro \ 9 | --target-lang en \ 10 | --gen-subset test --user-dir $user_dir --task translation_lev_modified \ 11 | --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \ 12 | --remove-bpe --max-tokens 1024 --seed 0 \ 13 | --model-overrides "{\"decode_strategy\":\"viterbi\", \"decode_viterbibeta\":1.0}" \ 14 | --path ${checkpoint_path} 15 | 16 | -------------------------------------------------------------------------------- /test_scripts/test_pcfg_viterbi_wmt17_enzh.sh: -------------------------------------------------------------------------------- 1 | root=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | 7 | fairseq-generate ${data_dir} \ 8 | --gen-subset test --user-dir $user_dir --task translation_lev_modified \ 9 | --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \ 10 | --remove-bpe --max-tokens 1024 --seed 0 \ 11 | --model-overrides "{\"decode_strategy\":\"viterbi\", \"decode_viterbibeta\":1.0}" \ 12 | --source-lang en --target-lang zh --tokenizer moses --scoring sacrebleu --sacrebleu-tokenizer zh \ 13 | --path $checkpoint_dir/average_best_5.pt 14 | 15 | -------------------------------------------------------------------------------- /test_scripts/test_pcfg_viterbi_wmt17_zhen.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | 7 | fairseq-generate ${data_dir} \ 8 | --gen-subset test --user-dir $user_dir --task translation_lev_modified \ 9 | --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \ 10 | --remove-bpe --max-tokens 1024 --seed 0 \ 11 | --model-overrides "{\"decode_strategy\":\"viterbi\", \"decode_viterbibeta\":1.0}" \ 12 | --path $checkpoint_dir/average_best_5.pt 13 | 14 | -------------------------------------------------------------------------------- /train_scripts/train_wmt14_deen_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | fairseq-train ${data_dir} \ 7 | --user-dir $user_dir \ 8 | --source-lang de \ 9 | --target-lang en \ 10 | --task translation_lev_modified --noise full_mask \ 11 | --arch glat_decomposed_with_link_two_hands_tri_pcfg_base \ 12 | --decoder-learned-pos --encoder-learned-pos \ 13 | --share-all-embeddings --activation-fn gelu \ 14 | --apply-bert-init \ 15 | --links-feature feature:position --decode-strategy lookahead \ 16 | --max-source-positions 128 --max-target-positions 1030 --src-upsample-scale 4.0 \ 17 | --left-tree-layer 1 \ 18 | --criterion nat_pcfg_loss \ 19 | --length-loss-factor 0 --max-transition-length 99999 \ 20 | --glat-p 0.5:0.1@200k --glance-strategy number-random \ 21 | --no-force-emit \ 22 | --optimizer adam --adam-betas '(0.9,0.999)' \ 23 | --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.1 \ 24 | --lr-scheduler inverse_sqrt --warmup-updates 10000 \ 25 | --clip-norm 0.1 --lr 0.0005 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \ 26 | --min-loss-scale 0 --ddp-backend c10d \ 27 | --max-tokens 2730 --update-freq 3 --grouped-shuffling \ 28 | --max-update 300000 --max-tokens-valid 1024 \ 29 | --save-interval 1 --save-interval-updates 10000 \ 30 | --seed 0 --fp16 \ 31 | --validate-interval 1 --validate-interval-updates 10000 \ 32 | --skip-invalid-size-inputs-valid-test \ 33 | --fixed-validation-seed 7 \ 34 | --best-checkpoint-metric loss \ 35 | --keep-last-epochs 32 \ 36 | --keep-best-checkpoints 10 --save-dir ${checkpoint_dir} \ 37 | --log-format 'simple' --log-interval 100 -------------------------------------------------------------------------------- /train_scripts/train_wmt14_ende_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | fairseq-train ${data_dir} \ 7 | --user-dir $user_dir \ 8 | --task translation_lev_modified --noise full_mask \ 9 | --arch glat_decomposed_with_link_two_hands_tri_pcfg_base \ 10 | --decoder-learned-pos --encoder-learned-pos \ 11 | --share-all-embeddings --activation-fn gelu \ 12 | --apply-bert-init \ 13 | --links-feature feature:position --decode-strategy lookahead \ 14 | --max-source-positions 128 --max-target-positions 1030 --src-upsample-scale 4.0 \ 15 | --left-tree-layer 1 \ 16 | --criterion nat_pcfg_loss \ 17 | --length-loss-factor 0 --max-transition-length 99999 \ 18 | --glat-p 0.5:0.1@200k --glance-strategy number-random \ 19 | --no-force-emit \ 20 | --optimizer adam --adam-betas '(0.9,0.999)' \ 21 | --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.1 \ 22 | --lr-scheduler inverse_sqrt --warmup-updates 10000 \ 23 | --clip-norm 0.1 --lr 0.0005 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \ 24 | --min-loss-scale 0 --ddp-backend c10d \ 25 | --max-tokens 2730 --update-freq 3 --grouped-shuffling \ 26 | --max-update 300000 --max-tokens-valid 1024 \ 27 | --save-interval 1 --save-interval-updates 10000 \ 28 | --seed 0 --fp16 \ 29 | --validate-interval 1 --validate-interval-updates 10000 \ 30 | --skip-invalid-size-inputs-valid-test \ 31 | --fixed-validation-seed 7 \ 32 | --best-checkpoint-metric loss \ 33 | --keep-last-epochs 32 \ 34 | --keep-best-checkpoints 10 --save-dir ${checkpoint_dir} \ 35 | --log-format 'simple' --log-interval 100 -------------------------------------------------------------------------------- /train_scripts/train_wmt16_enro_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | fairseq-train ${data_dir} \ 7 | --user-dir $user_dir \ 8 | --task translation_lev_modified --noise full_mask \ 9 | --arch glat_decomposed_with_link_two_hands_tri_pcfg_base \ 10 | --decoder-learned-pos --encoder-learned-pos \ 11 | --share-all-embeddings --activation-fn gelu \ 12 | --apply-bert-init \ 13 | --links-feature feature:position --decode-strategy lookahead \ 14 | --max-source-positions 256 --max-target-positions 2048 --src-upsample-scale 4.0 \ 15 | --left-tree-layer 1 \ 16 | --criterion nat_pcfg_loss \ 17 | --length-loss-factor 0 --max-transition-length 99999 \ 18 | --glat-p 0.5:0.1@30k --glance-strategy number-random \ 19 | --optimizer adam --adam-betas '(0.9,0.999)' \ 20 | --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.3 \ 21 | --lr-scheduler inverse_sqrt --warmup-updates 10000 \ 22 | --clip-norm 0.1 --lr 0.0007 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \ 23 | --min-loss-scale 0 --ddp-backend c10d \ 24 | --max-tokens 4096 --update-freq 4 --grouped-shuffling \ 25 | --max-update 300000 --max-tokens-valid 1024 \ 26 | --save-interval 1 --save-interval-updates 10000 \ 27 | --patience 32 \ 28 | --seed 0 --fp16 \ 29 | --validate-interval 1 --validate-interval-updates 10000 \ 30 | --fixed-validation-seed 7 \ 31 | --best-checkpoint-metric loss \ 32 | --keep-last-epochs 32 \ 33 | --keep-best-checkpoints 5 --save-dir ${checkpoint_dir} \ 34 | --log-format 'simple' --log-interval 100 -------------------------------------------------------------------------------- /train_scripts/train_wmt16_roen_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | fairseq-train ${data_dir} \ 7 | --user-dir $user_dir \ 8 | --source-lang ro \ 9 | --target-lang en \ 10 | --task translation_lev_modified --noise full_mask \ 11 | --arch glat_decomposed_with_link_two_hands_tri_pcfg_base \ 12 | --decoder-learned-pos --encoder-learned-pos \ 13 | --share-all-embeddings --activation-fn gelu \ 14 | --apply-bert-init \ 15 | --links-feature feature:position --decode-strategy lookahead \ 16 | --max-source-positions 256 --max-target-positions 2048 --src-upsample-scale 4.0 \ 17 | --left-tree-layer 1 \ 18 | --criterion nat_pcfg_loss \ 19 | --length-loss-factor 0 --max-transition-length 99999 \ 20 | --glat-p 0.5:0.1@30k --glance-strategy number-random \ 21 | --optimizer adam --adam-betas '(0.9,0.999)' \ 22 | --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.3 \ 23 | --lr-scheduler inverse_sqrt --warmup-updates 10000 \ 24 | --clip-norm 0.1 --lr 0.0007 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \ 25 | --min-loss-scale 0 --ddp-backend c10d \ 26 | --max-tokens 4096 --update-freq 4 --grouped-shuffling \ 27 | --max-update 300000 --max-tokens-valid 1024 \ 28 | --save-interval 1 --save-interval-updates 10000 \ 29 | --patience 32 \ 30 | --seed 0 --fp16 \ 31 | --validate-interval 1 --validate-interval-updates 10000 \ 32 | --fixed-validation-seed 7 \ 33 | --best-checkpoint-metric loss \ 34 | --keep-last-epochs 32 \ 35 | --keep-best-checkpoints 5 --save-dir ${checkpoint_dir} \ 36 | --log-format 'simple' --log-interval 100 -------------------------------------------------------------------------------- /train_scripts/train_wmt17_enzh_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | fairseq-train ${data_dir} \ 7 | --user-dir $user_dir \ 8 | --task translation_lev_modified --noise full_mask \ 9 | --arch glat_decomposed_with_link_two_hands_tri_pcfg_base \ 10 | --decoder-learned-pos --encoder-learned-pos \ 11 | --share-decoder-input-output-embed --activation-fn gelu \ 12 | --apply-bert-init \ 13 | --links-feature feature:position --decode-strategy lookahead \ 14 | --max-source-positions 128 --max-target-positions 1030 --src-upsample-scale 4.0 \ 15 | --left-tree-layer 1 \ 16 | --criterion nat_pcfg_loss \ 17 | --length-loss-factor 0 --max-transition-length 99999 \ 18 | --glat-p 0.5:0.1@200k --glance-strategy number-random \ 19 | --no-force-emit \ 20 | --optimizer adam --adam-betas '(0.9,0.999)' \ 21 | --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.1 \ 22 | --lr-scheduler inverse_sqrt --warmup-updates 10000 \ 23 | --clip-norm 0.1 --lr 0.0005 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \ 24 | --min-loss-scale 0 --ddp-backend c10d \ 25 | --max-tokens 2730 --update-freq 3 --grouped-shuffling \ 26 | --max-update 300000 --max-tokens-valid 1024 \ 27 | --save-interval 1 --save-interval-updates 10000 \ 28 | --seed 0 --fp16 \ 29 | --validate-interval 1 --validate-interval-updates 10000 \ 30 | --skip-invalid-size-inputs-valid-test \ 31 | --fixed-validation-seed 7 \ 32 | --best-checkpoint-metric loss \ 33 | --keep-last-epochs 32 \ 34 | --keep-best-checkpoints 10 --save-dir ${checkpoint_dir} \ 35 | --log-format 'simple' --log-interval 100 -------------------------------------------------------------------------------- /train_scripts/train_wmt17_zhen_pcfg_two_hands_tri_layer_1_glat_0.5_0.1.sh: -------------------------------------------------------------------------------- 1 | exp=exp_name 2 | root=fairseq 3 | data_dir=data_dir 4 | checkpoint_dir=checkpoint_dir 5 | user_dir=fs_plugins 6 | fairseq-train ${data_dir} \ 7 | --user-dir $user_dir \ 8 | --task translation_lev_modified --noise full_mask \ 9 | --arch glat_decomposed_with_link_two_hands_tri_pcfg_base \ 10 | --decoder-learned-pos --encoder-learned-pos \ 11 | --share-decoder-input-output-embed --activation-fn gelu \ 12 | --apply-bert-init \ 13 | --links-feature feature:position --decode-strategy lookahead \ 14 | --max-source-positions 128 --max-target-positions 1030 --src-upsample-scale 4.0 \ 15 | --left-tree-layer 1 \ 16 | --criterion nat_pcfg_loss \ 17 | --length-loss-factor 0 --max-transition-length 99999 \ 18 | --glat-p 0.5:0.1@200k --glance-strategy number-random \ 19 | --no-force-emit \ 20 | --optimizer adam --adam-betas '(0.9,0.999)' \ 21 | --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.1 \ 22 | --lr-scheduler inverse_sqrt --warmup-updates 10000 \ 23 | --clip-norm 0.1 --lr 0.0005 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \ 24 | --min-loss-scale 0 --ddp-backend c10d \ 25 | --max-tokens 2730 --update-freq 3 --grouped-shuffling \ 26 | --max-update 300000 --max-tokens-valid 1024 \ 27 | --save-interval 1 --save-interval-updates 10000 \ 28 | --seed 0 --fp16 \ 29 | --validate-interval 1 --validate-interval-updates 10000 \ 30 | --skip-invalid-size-inputs-valid-test \ 31 | --fixed-validation-seed 7 \ 32 | --best-checkpoint-metric loss \ 33 | --keep-last-epochs 32 \ 34 | --keep-best-checkpoints 10 --save-dir ${checkpoint_dir} \ 35 | --log-format 'simple' --log-interval 100 --------------------------------------------------------------------------------