├── utility ├── simultaneous_translation │ ├── modules │ │ ├── __init__.py │ │ ├── waitk_transformer_layers.py │ │ └── sinkhorn_attention.py │ ├── __init__.py │ ├── tasks │ │ ├── __init__.py │ │ ├── inference_config.py │ │ └── translation_infer.py │ ├── criterion │ │ ├── __init__.py │ │ └── label_smoothed_ctc_criterion.py │ ├── models │ │ ├── __init__.py │ │ ├── causal_encoder.py │ │ ├── nat_utils.py │ │ └── sinkhorn_waitk.py │ └── eval │ │ ├── lm │ │ └── train_kenlm.sh │ │ ├── run_all_simuleval.sh │ │ ├── anticipation │ │ ├── run_aligner.sh │ │ ├── avg_reorder_distance.py │ │ └── count_anticipation.py │ │ ├── run_wmt15_bleueval.sh │ │ ├── simuleval_fullsentence.sh │ │ ├── oracle_order │ │ ├── run_cwmt_oracle.sh │ │ └── nat_utils.py │ │ ├── simuleval.sh │ │ ├── run_cwmt_bleueval.sh │ │ ├── run_cwmt_bleueval copy.sh │ │ └── agents │ │ └── simul_t2t_waitk.py ├── README.md └── scripts │ ├── run_aligner.sh │ ├── avg_reorder_distance.py │ ├── count_anticipation.py │ ├── reorder.py │ └── average_checkpoints.py ├── requirements.txt ├── train ├── cwmt-enzh │ ├── infer_mt.yaml │ ├── data_path.sh │ ├── 1-vanilla_wait_k.sh │ ├── 0-teacher.sh │ ├── 2s-proc_generate.py │ ├── 2s-encode_test.sh │ ├── 1s-wait_k.sh │ ├── 0-distill_enzh.sh │ ├── 0-distill_enzh_mono.sh │ ├── 2-test_model_full.sh │ └── 2-test_model.sh ├── wmt21-enja │ ├── infer_mt.yaml │ ├── data_path.sh │ ├── 1-vanilla_wait_k.sh │ ├── 2s-proc_generate.py │ ├── 0-teacher.sh │ ├── 2s-encode_test.sh │ ├── 0-distill_enja.sh │ ├── 0-distill_enja_mono.sh │ ├── 1s-wait_k.sh │ ├── 2-test_model_full.sh │ └── 2-test_model.sh ├── get_score.py └── test_test.py ├── data ├── data_path.sh ├── 2-train_kenlm.sh ├── 2-train_align.sh ├── 2-fast_align.sh ├── 2-k_anticipation.sh ├── 3s-generate_subset.sh ├── 3s-generate_raw.sh ├── 2-get_uncertainty.py ├── 0-get_en_mono.sh ├── 1s-preprocess_tokenizer.sh ├── 0-get_en_mono_scaling.sh ├── 1-preprocess_distill.py └── 0-get_data_cwmt.sh ├── .gitignore ├── README.md ├── extra_installation.md └── LICENSE /utility/simultaneous_translation/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .waitk_transformer_layers import * 2 | from .sinkhorn_attention import SinkhornAttention 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8.1 2 | editdistance 3 | sacrebleu==1.5.1 4 | sacremoses 5 | sentencepiece 6 | tqdm 7 | tensorboard 8 | tensorboardX 9 | wandb 10 | kaggle 11 | seaborn 12 | ninja 13 | fast_align 14 | mosesdecoder 15 | SimulEval 16 | -------------------------------------------------------------------------------- /utility/README.md: -------------------------------------------------------------------------------- 1 | Requirements: 2 | 3 | - fairseq 4 | - SimulEval 5 | - fast_align 6 | - mosesdecoder 7 | - kenlm 8 | 9 | Folder simultaneous_translation was adobt from [Anticipation-free Training for Simultaneous Translation](https://github.com/George0828Zhang/sinkhorn-simultrans). 10 | -------------------------------------------------------------------------------- /train/cwmt-enzh/infer_mt.yaml: -------------------------------------------------------------------------------- 1 | eval_bleu: true 2 | generation_args: 3 | beam: 1 4 | max_len_a: 1.2 5 | max_len_b: 10 6 | post_process: sentencepiece 7 | print_samples: true 8 | eval_bleu_args: 9 | sacrebleu_tokenizer: zh 10 | sacrebleu_lowercase: false 11 | sacrebleu_char_level: false 12 | -------------------------------------------------------------------------------- /train/wmt21-enja/infer_mt.yaml: -------------------------------------------------------------------------------- 1 | eval_bleu: true 2 | generation_args: 3 | beam: 1 4 | max_len_a: 1.2 5 | max_len_b: 10 6 | post_process: sentencepiece 7 | print_samples: true 8 | eval_bleu_args: 9 | sacrebleu_tokenizer: ja-mecab 10 | sacrebleu_lowercase: false 11 | sacrebleu_char_level: false 12 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import models 7 | from . import tasks 8 | from . import criterion 9 | -------------------------------------------------------------------------------- /data/data_path.sh: -------------------------------------------------------------------------------- 1 | export SRC=en 2 | export TGT=zh 3 | export BASE=/root/Mono4SiM/data/cwmt-${SRC}${TGT} 4 | export MONO=/root/Mono4SiM/data/mono 5 | export DATA=/root/Mono4SiM/generate/teacher_cwmt_mono 6 | export FAIRSEQ=/root/Mono4SiM/utility/fairseq 7 | export PYTHONPATH="$FAIRSEQ:$PYTHONPATH" 8 | # . ~/envs/apex/bin/activate 9 | -------------------------------------------------------------------------------- /train/cwmt-enzh/data_path.sh: -------------------------------------------------------------------------------- 1 | export SRC=en 2 | export TGT=zh 3 | export DATASET=cwmt-${SRC}${TGT} 4 | export BASE=/root/Mono4SiM/data/${DATASET} 5 | export DATA=/root/Mono4SiM/generate/teacher_cwmt_mono/data-bin 6 | export FAIRSEQ=/root/Mono4SiM/utility/fairseq 7 | export USERDIR=/root/Mono4SiM/utility/simultaneous_translation 8 | export PYTHONPATH="$FAIRSEQ:$PYTHONPATH" 9 | # . ~/envs/apex/bin/activate 10 | -------------------------------------------------------------------------------- /train/wmt21-enja/data_path.sh: -------------------------------------------------------------------------------- 1 | export SRC=en 2 | export TGT=ja 3 | export DATASET=wmt21-${SRC}${TGT} 4 | export BASE=/root/Mono4SiM/data/${DATASET} 5 | export DATA=/root/Mono4SiM/data/teacher_wmt21_mono/data-bin 6 | export FAIRSEQ=/root/Mono4SiM/utility/fairseq 7 | export USERDIR=/root/Mono4SiM/utility/simultaneous_translation 8 | export PYTHONPATH="$FAIRSEQ:$PYTHONPATH" 9 | # . ~/envs/apex/bin/activate 10 | -------------------------------------------------------------------------------- /utility/scripts/run_aligner.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ALIGNOUT=$1 3 | ALIGNOUT_R=$2 4 | cd /root/Mono4SiM/utility/scripts 5 | for k in {1..9}; do 6 | if [ -f "${ALIGNOUT}" ]; then 7 | echo "calculating $k anticipation" 8 | python count_anticipation.py -k $k < ${ALIGNOUT} 9 | elif [ -f "${ALIGNOUT_R}" ]; then 10 | echo "calculating $k anticipation rev" 11 | python count_anticipation.py -k $k -r < ${ALIGNOUT_R} 12 | fi 13 | done 14 | -------------------------------------------------------------------------------- /data/2-train_kenlm.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | KENLMBIN=/root/Mono4SiM/utility/kenlm/build/bin 4 | CORPUS=${BASE}/ready/train.clean.${TGT} 5 | LMDATA=${BASE}/score 6 | NGRAM=3 7 | 8 | mkdir -p ${LMDATA} 9 | 10 | export PATH="/root/.local/bin:$PATH" 11 | # estimate ngram 12 | ${KENLMBIN}/lmplz -o ${NGRAM} -S 50% < ${CORPUS} > ${LMDATA}/$(basename $BASE)_$TGT.arpa 13 | 14 | # binarize 15 | ${KENLMBIN}/build_binary -s ${LMDATA}/$(basename $BASE)_$TGT.arpa ${LMDATA}/$(basename $BASE)_$TGT.bin 16 | -------------------------------------------------------------------------------- /train/cwmt-enzh/1-vanilla_wait_k.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | DATA=$1 4 | SUBSET=$2 5 | GPU=$3 6 | if [ ! $3 ]; then 7 | GPU='0,1,2,3,4,5,6,7' 8 | fi 9 | 10 | for i in 1 3 5 7 9; do 11 | TASK=wait_${i}_${SRC}${TGT}_distill 12 | if [ "$DATA" == "raw" ]; then 13 | TASK=wait_${i}_${SRC}${TGT} 14 | fi 15 | echo ">> Begin generating ${TASK}_${SUBSET}" 16 | bash 1s-wait_k.sh \ 17 | ${i} ${DATA} ${TASK} ${GPU} \ 18 | ${SUBSET} >> log/${TASK}_${SUBSET}.log 19 | done 20 | -------------------------------------------------------------------------------- /train/wmt21-enja/1-vanilla_wait_k.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | DATA=$1 4 | SUBSET=$2 5 | GPU=$3 6 | if [ ! $3 ]; then 7 | GPU='0,1,2,3,4,5,6,7' 8 | fi 9 | 10 | for i in 1 3 5 7 9; do 11 | TASK=wait_${i}_${SRC}${TGT}_distill 12 | if [ "$DATA" == "raw" ]; then 13 | TASK=wait_${i}_${SRC}${TGT} 14 | fi 15 | echo ">> Begin training ${TASK}_${SUBSET}" 16 | bash 1s-wait_k.sh \ 17 | ${i} ${DATA} ${TASK} ${GPU} \ 18 | ${SUBSET} >> log/${TASK}_${SUBSET}.log 19 | done 20 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | for file in sorted(os.listdir(os.path.dirname(__file__))): 11 | if file.endswith(".py") and not file.startswith("_"): 12 | file_name = file[: file.find(".py")] 13 | importlib.import_module( 14 | "simultaneous_translation.tasks." + file_name 15 | ) 16 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/criterion/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | for file in sorted(os.listdir(os.path.dirname(__file__))): 11 | if file.endswith(".py") and not file.startswith("_"): 12 | file_name = file[: file.find(".py")] 13 | importlib.import_module( 14 | "simultaneous_translation.criterion." + file_name 15 | ) 16 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | ignores = ["simple_nat.py", "toy_transformer.py"] 10 | for file in sorted(os.listdir(os.path.dirname(__file__))): 11 | if file.endswith(".py") and not file.startswith("_") and file not in ignores: 12 | model_name = file[: file.find(".py")] 13 | importlib.import_module( 14 | "simultaneous_translation.models." + model_name 15 | ) 16 | -------------------------------------------------------------------------------- /data/2-train_align.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | OUTDIR=${BASE}/score 4 | PREFIX=${BASE}/ready/train.clean 5 | mkdir -p ${OUTDIR} 6 | 7 | SRCTOK=$(mktemp) 8 | TGTTOK=$(mktemp) 9 | CORPUS=${OUTDIR}/align_process.${SRC}-${TGT} 10 | 11 | cat ${PREFIX}.${SRC} | sed 's/▁//g' > ${SRCTOK} 12 | cat ${PREFIX}.${TGT} | sed 's/▁//g' > ${TGTTOK} 13 | 14 | paste ${SRCTOK} ${TGTTOK} | sed "s/\t/ ||| /" > ${CORPUS} 15 | 16 | rm -f $SRCTOK 17 | rm -f $TGTTOK 18 | 19 | cd /root/Mono4SiM/utility/fast_align/build 20 | ./fast_align -i ${CORPUS} -d -v -o -p fwd_params >${OUTDIR}/fwd_align 2>${OUTDIR}/fwd_err 21 | ./fast_align -i ${CORPUS} -r -d -v -o -p rev_params >${OUTDIR}/rev_align 2>${OUTDIR}/rev_err 22 | -------------------------------------------------------------------------------- /data/2-fast_align.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | OUTDIR=${DATA}/score 4 | PREFIX=${DATA}/ready/train.clean 5 | mkdir -p ${OUTDIR} 6 | 7 | SRCTOK=$(mktemp) 8 | TGTTOK=$(mktemp) 9 | CORPUS=${OUTDIR}/align_process.${SRC}-${TGT} 10 | 11 | cat ${PREFIX}.${SRC} | sed 's/▁//g' > ${SRCTOK} 12 | cat ${PREFIX}.${TGT} | sed 's/▁//g' > ${TGTTOK} 13 | 14 | paste ${SRCTOK} ${TGTTOK} | sed "s/\t/ ||| /" > ${CORPUS} 15 | 16 | rm -f $SRCTOK 17 | rm -f $TGTTOK 18 | 19 | cd /root/Mono4SiM/utility/fast_align/build 20 | ./fast_align -i ${CORPUS} -d -o -v > ${OUTDIR}/forward.align 21 | ./fast_align -i ${CORPUS} -d -o -v -r > ${OUTDIR}/reverse.align 22 | ./atools -i ${OUTDIR}/forward.align -j ${OUTDIR}/reverse.align -c grow-diag-final-and > ${OUTDIR}/diag.align 23 | -------------------------------------------------------------------------------- /data/2-k_anticipation.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | OUTPUT=${DATA}/score/k-anticipation.log 4 | rm ${OUTPUT} 5 | for k in {1..9}; do 6 | echo "calculating forward ${k} anticipation:" >> ${OUTPUT} 7 | python /root/Mono4SiM/utility/simultaneous_translation/eval/anticipation/count_anticipation.py \ 8 | -k $k < ${DATA}/score/forward.align >> ${OUTPUT} 9 | echo "calculating reverse ${k} anticipation:" >> ${OUTPUT} 10 | python /root/Mono4SiM/utility/simultaneous_translation/eval/anticipation/count_anticipation.py \ 11 | -k $k < ${DATA}/score/reverse.align >> ${OUTPUT} 12 | echo "calculating diag ${k} anticipation:" >> ${OUTPUT} 13 | python /root/Mono4SiM/utility/simultaneous_translation/eval/anticipation/count_anticipation.py \ 14 | -k $k < ${DATA}/score/diag.align >> ${OUTPUT} 15 | done 16 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/lm/train_kenlm.sh: -------------------------------------------------------------------------------- 1 | KENLMBIN=/home/XXXX-2/utility/kenlm/build/bin 2 | FAIRSEQ=/home/XXXX-2/utility/fairseq 3 | DATA=/media/XXXX-2/Data/cwmt/en-zh 4 | TGT=zh 5 | CORPUS=${DATA}/prep/train.dirty.${TGT} 6 | SPM_MODEL=${DATA}/prep/spm_unigram32000_zh.model 7 | LMDATA=./ 8 | NGRAM=3 9 | 10 | mkdir -p ${LMDATA} 11 | 12 | # split bpe 13 | if [ -f "${LMDATA}/corpus" ]; then 14 | echo "${LMDATA}/corpus exists. skipping spm_encode." 15 | else 16 | python ${FAIRSEQ}/scripts/spm_encode.py \ 17 | --model=$SPM_MODEL \ 18 | --output_format=piece \ 19 | < ${CORPUS} > ${LMDATA}/corpus 20 | fi 21 | 22 | # estimate ngram 23 | ${KENLMBIN}/lmplz -o ${NGRAM} -S 50% < ${LMDATA}/corpus > ${LMDATA}/lm.arpa 24 | 25 | # binarize 26 | ${KENLMBIN}/build_binary -s ${LMDATA}/lm.arpa ${LMDATA}/lm.bin -------------------------------------------------------------------------------- /data/3s-generate_subset.sh: -------------------------------------------------------------------------------- 1 | SCORE=$1 2 | folder=$2 3 | SRC=$3 4 | TGT=$4 5 | ROOT=$5 6 | BASE=$6 7 | ADD=$7 8 | RAW=$8 9 | workers=4 10 | bin=${ROOT}/data-bin/${SCORE}_${folder}${RAW} 11 | SPM_PREFIX=${BASE}/prep/spm_unigram32000 12 | ready=${BASE}/ready 13 | 14 | 15 | cd ${ROOT} 16 | cat ${ROOT}/${folder}/${SCORE}.${SRC} ${ADD}.${SRC} > ${ROOT}/${folder}/${SCORE}_tmp.${SRC} 17 | cat ${ROOT}/${folder}/${SCORE}.${TGT} ${ADD}.${TGT} > ${ROOT}/${folder}/${SCORE}_tmp.${TGT} 18 | 19 | python -m fairseq_cli.preprocess \ 20 | --source-lang ${SRC} \ 21 | --target-lang ${TGT} \ 22 | --trainpref ${ROOT}/${folder}/${SCORE}_tmp \ 23 | --validpref ${ready}/valid \ 24 | --testpref ${ready}/test \ 25 | --destdir ${bin} \ 26 | --workers ${workers} \ 27 | --srcdict ${SPM_PREFIX}_${SRC}.txt \ 28 | --tgtdict ${SPM_PREFIX}_${TGT}.txt 29 | 30 | rm ${ROOT}/${folder}/${SCORE}_tmp.${SRC} 31 | rm ${ROOT}/${folder}/${SCORE}_tmp.${TGT} 32 | -------------------------------------------------------------------------------- /data/3s-generate_raw.sh: -------------------------------------------------------------------------------- 1 | SCORE=$1 2 | folder=$2 3 | SRC=$3 4 | TGT=$4 5 | ROOT=$5 6 | BASE=$6 7 | ADD=$7 8 | ADDD=$8 9 | RAW=$9 10 | workers=4 11 | bin=${ROOT}/data-bin/${SCORE}_${folder}${RAW} 12 | SPM_PREFIX=${BASE}/prep/spm_unigram32000 13 | ready=${BASE}/ready 14 | 15 | 16 | cd ${ROOT} 17 | cat ${ROOT}/${folder}/${SCORE}.${SRC} ${ADD}.${SRC} ${ADDD}.${SRC} > ${ROOT}/${folder}/${SCORE}_tmp.${SRC} 18 | cat ${ROOT}/${folder}/${SCORE}.${TGT} ${ADD}.${TGT} ${ADDD}.${TGT} > ${ROOT}/${folder}/${SCORE}_tmp.${TGT} 19 | 20 | python -m fairseq_cli.preprocess \ 21 | --source-lang ${SRC} \ 22 | --target-lang ${TGT} \ 23 | --trainpref ${ROOT}/${folder}/${SCORE}_tmp \ 24 | --validpref ${ready}/valid \ 25 | --testpref ${ready}/test \ 26 | --destdir ${bin} \ 27 | --workers ${workers} \ 28 | --srcdict ${SPM_PREFIX}_${SRC}.txt \ 29 | --tgtdict ${SPM_PREFIX}_${TGT}.txt 30 | 31 | rm ${ROOT}/${folder}/${SCORE}_tmp.${SRC} 32 | rm ${ROOT}/${folder}/${SCORE}_tmp.${TGT} 33 | -------------------------------------------------------------------------------- /train/cwmt-enzh/0-teacher.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source ./data_path.sh 3 | TASK=teacher_cwmt_${SRC}${TGT} 4 | 5 | CUDA_VISIBLE_DEVICES=$1 python -m fairseq_cli.train \ 6 | ${DATA} --user-dir ${USERDIR} \ 7 | -s ${SRC} -t ${TGT} \ 8 | --max-tokens 25600 \ 9 | --task translation_infer \ 10 | --inference-config-yaml infer_mt.yaml \ 11 | --arch transformer \ 12 | --encoder-normalize-before --decoder-normalize-before \ 13 | --share-decoder-input-output-embed \ 14 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 15 | --clip-norm 10.0 \ 16 | --weight-decay 0.0001 \ 17 | --optimizer adam --lr 5e-4 --lr-scheduler inverse_sqrt \ 18 | --warmup-updates 4000 \ 19 | --max-update 50000 \ 20 | --save-dir checkpoints/${TASK} \ 21 | --no-epoch-checkpoints \ 22 | --save-interval-updates 500 \ 23 | --keep-best-checkpoints 5 \ 24 | --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \ 25 | --patience 50 \ 26 | --log-format simple --log-interval 50 \ 27 | --seed 1 \ 28 | --fp16 29 | -------------------------------------------------------------------------------- /train/cwmt-enzh/2s-proc_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def execCmd(cmd, *args): 6 | for arg in args: 7 | cmd = f"{cmd} {arg}" 8 | r = os.popen(cmd) 9 | text = r.read() 10 | r.close() 11 | print(text) 12 | return text 13 | 14 | 15 | def proc_generate(root, file): 16 | record = {} 17 | with open(f'{root}/{file}', 'r', encoding='utf-8') as f: 18 | for i in f.readlines(): 19 | if '\t' in i and i[0] == 'D': 20 | num, sen = i.rstrip().split('\t', 1) 21 | record[int(num[2:])] = sen 22 | with open(f'{root}/detok.txt', 'w', encoding='utf-8') as f: 23 | i = 0 24 | while i in record: 25 | f.write(record[i].split('\t', 1)[1] + '\n') 26 | i += 1 27 | print(i) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--root", "-r", type=str, default=None) 33 | parser.add_argument("--file", "-f", type=str, default='generate-test.txt') 34 | args = parser.parse_args() 35 | proc_generate(args.root, args.file) 36 | -------------------------------------------------------------------------------- /train/wmt21-enja/2s-proc_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def execCmd(cmd, *args): 6 | for arg in args: 7 | cmd = f"{cmd} {arg}" 8 | r = os.popen(cmd) 9 | text = r.read() 10 | r.close() 11 | print(text) 12 | return text 13 | 14 | 15 | def proc_generate(root, file): 16 | record = {} 17 | with open(f'{root}/{file}', 'r', encoding='utf-8') as f: 18 | for i in f.readlines(): 19 | if '\t' in i and i[0] == 'D': 20 | num, sen = i.rstrip().split('\t', 1) 21 | record[int(num[2:])] = sen 22 | with open(f'{root}/detok.txt', 'w', encoding='utf-8') as f: 23 | i = 0 24 | while i in record: 25 | f.write(record[i].split('\t', 1)[1] + '\n') 26 | i += 1 27 | print(i) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--root", "-r", type=str, default=None) 33 | parser.add_argument("--file", "-f", type=str, default='generate-test.txt') 34 | args = parser.parse_args() 35 | proc_generate(args.root, args.file) 36 | -------------------------------------------------------------------------------- /train/cwmt-enzh/2s-encode_test.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | OUTDIR=$1 4 | PREFIX=$1 5 | mkdir -p ${OUTDIR} 6 | 7 | SRCTOK=$(mktemp) 8 | TGTTOK=$(mktemp) 9 | CORPUS=${OUTDIR}/align_process 10 | 11 | vocab=32000 12 | vtype=unigram 13 | spm_encode=$FAIRSEQ/scripts/spm_encode.py 14 | 15 | SPM_PREFIX=/root/Mono4SiM/data/cwmt-enzh/prep/spm_${vtype}${vocab} 16 | for l in ${SRC} ${TGT}; do 17 | SPM_MODEL=${SPM_PREFIX}_${l}.model 18 | echo "Using SPM model $SPM_MODEL" 19 | if [ -f $ready/$split.$l ]; then 20 | echo "found $ready/$split.$l, skipping spm_encode" 21 | else 22 | echo "spm_encode to $split.$l..." 23 | python $spm_encode --model=$SPM_MODEL \ 24 | --output_format=piece \ 25 | < ${PREFIX}/$l.txt > ${PREFIX}/$l.tok 26 | fi 27 | done 28 | 29 | cat ${PREFIX}/${SRC}.tok | sed 's/▁//g' > ${SRCTOK} 30 | cat ${PREFIX}/${TGT}.tok | sed 's/▁//g' > ${TGTTOK} 31 | 32 | paste ${SRCTOK} ${TGTTOK} | sed "s/\t/ ||| /" > ${CORPUS} 33 | 34 | rm -f $SRCTOK 35 | rm -f $TGTTOK 36 | rm -f ${PREFIX}/${SRC}.txt 37 | rm -f ${PREFIX}/${TGT}.txt 38 | rm -f ${PREFIX}/${SRC}.tok 39 | rm -f ${PREFIX}/${TGT}.tok 40 | -------------------------------------------------------------------------------- /train/cwmt-enzh/1s-wait_k.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source ./data_path.sh 4 | 5 | WAITK=$1 6 | DATA=$2 7 | TASK=$3 8 | SUBSET=$5 9 | 10 | if [ "$DATA" == "raw" ]; then 11 | DATA=/root/Mono4SiM/data/${DATASET}/data-bin 12 | fi 13 | 14 | CUDA_VISIBLE_DEVICES=$4 python -m fairseq_cli.train \ 15 | ${DATA}/${SUBSET} --user-dir ${USERDIR} \ 16 | -s ${SRC} -t ${TGT} \ 17 | --max-tokens 25600 \ 18 | --task translation_infer \ 19 | --inference-config-yaml infer_mt.yaml \ 20 | --arch waitk_transformer \ 21 | --waitk ${WAITK} \ 22 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 23 | --clip-norm 10.0 \ 24 | --weight-decay 0.0001 \ 25 | --optimizer adam --lr 5e-4 --lr-scheduler inverse_sqrt \ 26 | --warmup-updates 4000 \ 27 | --max-update 50000 \ 28 | --save-dir checkpoints/${TASK}_${SUBSET} \ 29 | --no-epoch-checkpoints \ 30 | --save-interval-updates 500 \ 31 | --keep-interval-updates 1 \ 32 | --keep-best-checkpoints 5 \ 33 | --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \ 34 | --patience 50 \ 35 | --log-format simple --log-interval 50 \ 36 | --fp16 --local_rank $SLURM_LOCALID \ 37 | --seed 2 38 | -------------------------------------------------------------------------------- /train/wmt21-enja/0-teacher.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source ./data_path.sh 3 | TASK=teacher_wmt21_${SRC}${TGT} 4 | 5 | pip install sacrebleu[ja] 6 | CUDA_VISIBLE_DEVICES=$1 python -m fairseq_cli.train \ 7 | ${DATA} --user-dir ${USERDIR} \ 8 | -s ${SRC} -t ${TGT} \ 9 | --fp16 --ddp-backend=no_c10d \ 10 | --optimizer adam --adam-betas '(0.9,0.98)' \ 11 | --attention-dropout 0.0 --activation-dropout 0.0 --dropout 0.3 \ 12 | --max-tokens 25600 --update-freq 2 \ 13 | --task translation_infer \ 14 | --inference-config-yaml infer_mt.yaml \ 15 | --arch transformer \ 16 | --lr 1e-3 --lr-scheduler cosine --warmup-init-lr 1e-07 --weight-decay 0.0 \ 17 | --lr-shrink 1 --lr-period-updates 20000 --min-lr 1e-09 \ 18 | --warmup-updates 10000 --clip-norm 0.1 \ 19 | --max-update 50000 \ 20 | --share-decoder-input-output-embed \ 21 | --criterion label_smoothed_cross_entropy \ 22 | --max-source-positions 10000 --max-target-positions 10000 \ 23 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 24 | --save-dir checkpoints/${TASK} \ 25 | --no-epoch-checkpoints \ 26 | --save-interval-updates 500 \ 27 | --keep-best-checkpoints 5 \ 28 | --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \ 29 | --log-format simple --log-interval 50 \ 30 | --seed 1 31 | -------------------------------------------------------------------------------- /train/wmt21-enja/2s-encode_test.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | OUTDIR=$1 4 | PREFIX=$1 5 | NAME=$2 6 | TOK=$3 7 | mkdir -p ${OUTDIR} 8 | 9 | SRCTOK=$(mktemp) 10 | TGTTOK=$(mktemp) 11 | CORPUS=${OUTDIR}/align_process 12 | 13 | vocab=32000 14 | vtype=unigram 15 | spm_encode=$FAIRSEQ/scripts/spm_encode.py 16 | 17 | if [ "$TOK" == "true" ]; then 18 | cat ${PREFIX}/$NAME.$SRC | sed 's/▁//g' > ${SRCTOK} 19 | cat ${PREFIX}/$NAME.$TGT | sed 's/▁//g' > ${TGTTOK} 20 | else 21 | SPM_PREFIX=/root/Mono4SiM/data/wmt21-enja/prep/spm_${vtype}${vocab} 22 | for l in ${SRC} ${TGT}; do 23 | SPM_MODEL=${SPM_PREFIX}_${l}.model 24 | echo "Using SPM model $SPM_MODEL" 25 | if [ -f $ready/$split.$l ]; then 26 | echo "found $ready/$split.$l, skipping spm_encode" 27 | else 28 | echo "spm_encode to $split.$l..." 29 | python $spm_encode --model=$SPM_MODEL \ 30 | --output_format=piece \ 31 | < ${PREFIX}/$NAME.$l > ${PREFIX}/$l.tok 32 | fi 33 | done 34 | cat ${PREFIX}/${SRC}.tok | sed 's/▁//g' > ${SRCTOK} 35 | cat ${PREFIX}/${TGT}.tok | sed 's/▁//g' > ${TGTTOK} 36 | fi 37 | 38 | paste ${SRCTOK} ${TGTTOK} | sed "s/\t/ ||| /" > ${CORPUS} 39 | 40 | rm -f $SRCTOK 41 | rm -f $TGTTOK 42 | rm -f ${PREFIX}/${SRC}.tok 43 | rm -f ${PREFIX}/${TGT}.tok 44 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/run_all_simuleval.sh: -------------------------------------------------------------------------------- 1 | DATA=cwmt 2 | TGT=zh 3 | EXP=../expcwmt 4 | SRC_FILE=/media/XXXX-2/Data/cwmt/zh-en/prep/test.en-zh.en 5 | TGT_FILE=/media/XXXX-2/Data/cwmt/zh-en/prep/test.en-zh.zh.1 6 | 7 | for t in 2 3; do 8 | for k in 1 3 5 7 9; do 9 | MODEL=sinkhorn_delay${k}_ft 10 | bash simuleval.sh \ 11 | -a agents/simul_t2t_ctc.py \ 12 | -m ${MODEL} \ 13 | -k ${k} \ 14 | -e ${EXP} \ 15 | -s ${SRC_FILE} \ 16 | -t ${TGT_FILE} 17 | 18 | OUTPUT=${DATA}_${TGT}-results/${MODEL}.${DATA} 19 | mv ${OUTPUT}/scores ${OUTPUT}/scores.${t} 20 | done 21 | done 22 | bash run_cwmt_bleueval.sh 23 | 24 | 25 | DATA=wmt15 26 | TGT=en 27 | EXP=../expwmt15 28 | SRC_FILE=/media/XXXX-2/Data/wmt15/de-en/prep/test.de 29 | TGT_FILE=/media/XXXX-2/Data/wmt15/de-en/prep/test.en 30 | 31 | for t in 1 2 3; do 32 | for k in 1 3 5 7 9; do 33 | MODEL=sinkhorn_delay${k}_ft 34 | bash simuleval.sh \ 35 | -a agents/simul_t2t_ctc.py \ 36 | -m ${MODEL} \ 37 | -k ${k} \ 38 | -e ${EXP} \ 39 | -s ${SRC_FILE} \ 40 | -t ${TGT_FILE} 41 | 42 | OUTPUT=${DATA}_${TGT}-results/${MODEL}.${DATA} 43 | mv ${OUTPUT}/scores ${OUTPUT}/scores.${t} 44 | done 45 | done 46 | bash run_wmt15_bleueval.sh -------------------------------------------------------------------------------- /utility/scripts/avg_reorder_distance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import argparse 4 | import re 5 | import numpy as np 6 | 7 | 8 | def distance(align_lines, reverse): 9 | dists = [] 10 | pattern = re.compile(r"(?P[0-9]+)-(?P[0-9]+)") 11 | for line in align_lines: 12 | all_i = [] 13 | all_j = [] 14 | for si, sj in pattern.findall(line): 15 | i = int(sj if reverse else si) 16 | j = int(si if reverse else sj) 17 | all_i.append(i) 18 | all_j.append(j) 19 | 20 | min_i = min(all_i) 21 | min_j = min(all_j) 22 | max_i = max(all_i) 23 | max_j = max(all_j) 24 | for i, j in zip(all_i, all_j): 25 | tgt = (i - min_i) / (max_i - min_i + 1e-9) * (max_j - min_j) + min_j 26 | dists.append(abs(tgt - j)) 27 | 28 | return np.mean(dists), np.std(dists) 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--input", "-i", type=str, default=None) 34 | parser.add_argument("--reverse", "-r", action="store_true") 35 | args = parser.parse_args() 36 | 37 | if args.input is not None: 38 | with open(args.input, "r") as f: 39 | print(distance(f.readlines(), reverse=args.reverse)) 40 | else: 41 | print(distance(sys.stdin.readlines(), reverse=args.reverse)) 42 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/anticipation/run_aligner.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | MODEL=bert-base-multilingual-cased 3 | PREFIX=$1 4 | SRC=$2 5 | TGT=$3 6 | N=1000000 7 | 8 | OUTDIR=./alignments 9 | 10 | SRCTOK=$(mktemp) 11 | TGTTOK=$(mktemp) 12 | CORPUS=$(mktemp) 13 | 14 | head -n ${N} ${PREFIX}.${SRC} | sed 's/▁//g' > ${SRCTOK} 15 | head -n ${N} ${PREFIX}.${TGT} | sed 's/▁//g' > ${TGTTOK} 16 | 17 | echo "aligning ..." 18 | mkdir -p ${OUTDIR} 19 | ALIGNOUT=${OUTDIR}/$(basename ${PREFIX}).${SRC}-${TGT}_${N} 20 | ALIGNOUT_R=${OUTDIR}/$(basename ${PREFIX}).${TGT}-${SRC}_${N} 21 | if [ -f "${ALIGNOUT}" ]; then 22 | echo "${ALIGNOUT} exists, skipping alignment" 23 | elif [ -f "${ALIGNOUT_R}" ]; then 24 | echo "${ALIGNOUT_R} exists, skipping alignment" 25 | else 26 | paste ${SRCTOK} ${TGTTOK} | sed "s/\t/ ||| /" > ${CORPUS} 27 | python -m awesome_align.run_align \ 28 | --output_file=${ALIGNOUT} \ 29 | --model_name_or_path=${MODEL} \ 30 | --data_file=${CORPUS} \ 31 | --extraction 'softmax' \ 32 | --batch_size 128 33 | fi 34 | 35 | echo "calculating anticipation" 36 | for k in {1..9}; do 37 | if [ -f "${ALIGNOUT}" ]; then 38 | python count_anticipation.py -k $k < ${ALIGNOUT} 39 | elif [ -f "${ALIGNOUT_R}" ]; then 40 | python count_anticipation.py -k $k -r < ${ALIGNOUT_R} 41 | fi 42 | done 43 | 44 | 45 | rm -f $SRCTOK 46 | rm -f $TGTTOK 47 | rm -f $CORPUS -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/anticipation/avg_reorder_distance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import argparse 4 | import re 5 | import numpy as np 6 | 7 | 8 | def distance(align_lines, reverse): 9 | dists = [] 10 | pattern = re.compile(r"(?P[0-9]+)-(?P[0-9]+)") 11 | for line in align_lines: 12 | all_i = [] 13 | all_j = [] 14 | for si, sj in pattern.findall(line): 15 | i = int(sj if reverse else si) 16 | j = int(si if reverse else sj) 17 | all_i.append(i) 18 | all_j.append(j) 19 | 20 | min_i = min(all_i) 21 | min_j = min(all_j) 22 | max_i = max(all_i) 23 | max_j = max(all_j) 24 | for i, j in zip(all_i, all_j): 25 | tgt = (i - min_i) / (max_i - min_i + 1e-9) * (max_j - min_j) + min_j 26 | dists.append(abs(tgt - j)) 27 | 28 | return np.mean(dists), np.std(dists) 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--input", "-i", type=str, default=None) 34 | parser.add_argument("--reverse", "-r", action="store_true") 35 | args = parser.parse_args() 36 | 37 | if args.input is not None: 38 | with open(args.input, "r") as f: 39 | print(distance(f.readlines(), reverse=args.reverse)) 40 | else: 41 | print(distance(sys.stdin.readlines(), reverse=args.reverse)) 42 | -------------------------------------------------------------------------------- /utility/scripts/count_anticipation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import argparse 4 | import re 5 | 6 | 7 | def kAR(align, k, reverse=False, sent=False): 8 | if sent: 9 | corpus = align.strip().split("\n") 10 | else: 11 | corpus = [align] 12 | 13 | output = [] 14 | 15 | for line in corpus: 16 | inv, tot = 0, 1e-9 17 | itr = re.finditer(r"(?P[0-9]+)-(?P[0-9]+)", line) 18 | for m in itr: 19 | i = int(m.group("j" if reverse else "i")) 20 | j = int(m.group("i" if reverse else "j")) 21 | tot += 1 22 | if i - k + 1 > j: 23 | inv += 1 24 | output.append(str(inv / tot)) 25 | return "\n".join(output) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--input", "-i", type=str, default=None) 31 | parser.add_argument("--reverse", "-r", action="store_true") 32 | parser.add_argument("--delay", "-k", type=int, required=True) 33 | parser.add_argument("--sentence-level", "-s", action="store_true") 34 | args = parser.parse_args() 35 | 36 | if args.input is not None: 37 | with open(args.input, "r") as f: 38 | print(kAR(f.read(), k=args.delay, 39 | reverse=args.reverse, sent=args.sentence_level)) 40 | else: 41 | print(kAR(sys.stdin.read(), 42 | k=args.delay, reverse=args.reverse, sent=args.sentence_level)) 43 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/anticipation/count_anticipation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import argparse 4 | import re 5 | 6 | 7 | def kAR(align, k, reverse=False, sent=False): 8 | if sent: 9 | corpus = align.strip().split("\n") 10 | else: 11 | corpus = [align] 12 | 13 | output = [] 14 | 15 | for line in corpus: 16 | inv, tot = 0, 1e-9 17 | itr = re.finditer(r"(?P[0-9]+)-(?P[0-9]+)", line) 18 | for m in itr: 19 | i = int(m.group("j" if reverse else "i")) 20 | j = int(m.group("i" if reverse else "j")) 21 | tot += 1 22 | if i - k + 1 > j: 23 | inv += 1 24 | output.append(str(inv / tot)) 25 | return "\n".join(output) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--input", "-i", type=str, default=None) 31 | parser.add_argument("--reverse", "-r", action="store_true") 32 | parser.add_argument("--delay", "-k", type=int, required=True) 33 | parser.add_argument("--sentence-level", "-s", action="store_true") 34 | args = parser.parse_args() 35 | 36 | if args.input is not None: 37 | with open(args.input, "r") as f: 38 | print(kAR(f.read(), k=args.delay, 39 | reverse=args.reverse, sent=args.sentence_level)) 40 | else: 41 | print(kAR(sys.stdin.read(), 42 | k=args.delay, reverse=args.reverse, sent=args.sentence_level)) 43 | -------------------------------------------------------------------------------- /train/wmt21-enja/0-distill_enja.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source ./data_path.sh 3 | TASK=teacher_wmt21 4 | CHECKDIR=/root/Mono4SiM/train/checkpoints/${TASK}_${SRC}${TGT} 5 | AVG=true 6 | 7 | GENARGS="--beam 6 --lenpen 1.0 --max-len-a 1.2 --max-len-b 10 --remove-bpe sentencepiece" 8 | EXTRAARGS="--scoring sacrebleu --sacrebleu-tokenizer ja-mecab --sacrebleu-lowercase" 9 | 10 | if [[ $AVG == "true" ]]; then 11 | CHECKPOINT_FILENAME=avg_best_5_checkpoint.pt 12 | python /root/Mono4SiM/utility/scripts/average_checkpoints.py \ 13 | --inputs ${CHECKDIR} --num-best-checkpoints 5 \ 14 | --output "${CHECKDIR}/${CHECKPOINT_FILENAME}" 15 | else 16 | CHECKPOINT_FILENAME=checkpoint_best.pt 17 | fi 18 | 19 | ROOT=/root/Mono4SiM/generate/${TASK} 20 | FILES=/root/Mono4SiM/data/${DATASET} 21 | mkdir -p ${ROOT} 22 | mkdir -p ${ROOT}/interactive 23 | mkdir -p ${FILES}/split 24 | 25 | split -l $(($((`wc -l < ${FILES}/ready/train.clean.en`/8))+1)) -d -a 1 ${FILES}/ready/train.clean.en ${FILES}/split/train.en. 26 | 27 | for i in {0..7} 28 | do 29 | cat ${FILES}/split/train.en.${i} | CUDA_VISIBLE_DEVICES=${i} \ 30 | python -m fairseq_cli.interactive /root/Mono4SiM/data/${DATASET}/data-bin \ 31 | -s ${SRC} -t ${TGT} \ 32 | --user-dir ${USERDIR} \ 33 | --skip-invalid-size-inputs-valid-test \ 34 | --task translation \ 35 | --path ${CHECKDIR}/${CHECKPOINT_FILENAME} \ 36 | --batch-size 64 --buffer-size 128 --fp16 \ 37 | ${GENARGS} ${EXTRAARGS} > ${ROOT}/interactive/generate-train.${i}.txt 2>&1 & 38 | done 39 | -------------------------------------------------------------------------------- /train/wmt21-enja/0-distill_enja_mono.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source ./data_path.sh 3 | TASK=teacher_wmt21 4 | CHECKDIR=/root/Mono4SiM/train/checkpoints/${TASK}_${SRC}${TGT} 5 | AVG=true 6 | 7 | GENARGS="--beam 6 --lenpen 1.0 --max-len-a 1.2 --max-len-b 10 --remove-bpe sentencepiece" 8 | EXTRAARGS="--scoring sacrebleu --sacrebleu-tokenizer ja-mecab --sacrebleu-lowercase" 9 | 10 | if [[ $AVG == "true" ]]; then 11 | CHECKPOINT_FILENAME=avg_best_5_checkpoint.pt 12 | python /root/Mono4SiM/utility/scripts/average_checkpoints.py \ 13 | --inputs ${CHECKDIR} --num-best-checkpoints 5 \ 14 | --output "${CHECKDIR}/${CHECKPOINT_FILENAME}" 15 | else 16 | CHECKPOINT_FILENAME=checkpoint_best.pt 17 | fi 18 | 19 | ROOT=/root/Mono4SiM/generate/${TASK}_mono 20 | FILES=/root/Mono4SiM/data/mono-en 21 | mkdir -p ${ROOT} 22 | mkdir -p ${ROOT}/interactive 23 | mkdir -p ${FILES}/split 24 | 25 | split -l $(($((`wc -l < ${FILES}/ready/train.clean.en`/8))+1)) -d -a 1 ${FILES}/ready/train.clean.en ${FILES}/split/train.en. 26 | 27 | for i in {0..7} 28 | do 29 | cat ${FILES}/split/train.en.${i} | CUDA_VISIBLE_DEVICES=${i} \ 30 | python -m fairseq_cli.interactive /root/Mono4SiM/data/${DATASET}/data-bin \ 31 | -s ${SRC} -t ${TGT} \ 32 | --user-dir ${USERDIR} \ 33 | --skip-invalid-size-inputs-valid-test \ 34 | --task translation \ 35 | --path ${CHECKDIR}/${CHECKPOINT_FILENAME} \ 36 | --batch-size 64 --buffer-size 128 --fp16 \ 37 | ${GENARGS} ${EXTRAARGS} > ${ROOT}/interactive/generate-train.${i}.txt 2>&1 & 38 | done 39 | -------------------------------------------------------------------------------- /data/2-get_uncertainty.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | from collections import defaultdict 4 | from fairseq.data import Dictionary 5 | 6 | SRC = 'en' 7 | TGT = 'zh' 8 | DATASET = 'cwmt' 9 | 10 | root = f'/root/Mono4SiM/data/cwmt-{SRC}{TGT}' 11 | align = f'{root}/score/forward.align' 12 | src_file = f'{root}/ready/train.clean.{SRC}' 13 | tgt_file = f'{root}/ready/train.clean.{TGT}' 14 | dict_file = f'{root}/data-bin/dict.{SRC}.txt' 15 | 16 | d = defaultdict(lambda: defaultdict(int)) 17 | 18 | with open(align, 'r', encoding='utf-8') as fa, \ 19 | open(src_file, 'r', encoding='utf-8') as fs, \ 20 | open(tgt_file, 'r', encoding='utf-8') as ft: 21 | for a, s, t in zip(fa, fs, ft): 22 | itr = re.finditer(r"(?P[0-9]+)-(?P[0-9]+)", a.strip()) 23 | left = [] 24 | right = [] 25 | for m in itr: 26 | left.append(int(m.group("i"))) 27 | right.append(int(m.group("j"))) 28 | s = s.strip().split() 29 | t = t.strip().split() 30 | for i, j in zip(left, right): 31 | d[s[i]][t[j]] += 1 32 | 33 | dct = Dictionary.load(dict_file) 34 | for key, value in d.items(): 35 | if key in dct: 36 | tot = 0 37 | score = 0 38 | for _, j in value.items(): 39 | tot += j 40 | for _, j in value.items(): 41 | p = j / tot 42 | score += p * math.log(p) 43 | dct.count[dct.index(key)] = score 44 | dct.save(f'{root}/score/uncertainty.{SRC}.txt') 45 | -------------------------------------------------------------------------------- /train/cwmt-enzh/0-distill_enzh.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source ./data_path.sh 3 | TASK=teacher_cwmt 4 | CHECKDIR=/root/Mono4SiM/train/cwmt-${SRC}${TGT}/checkpoints/${TASK}_${SRC}${TGT} 5 | AVG=true 6 | 7 | GENARGS="--beam 5 --lenpen 1.5 --max-len-a 1.2 --max-len-b 10 --remove-bpe sentencepiece" 8 | EXTRAARGS="--scoring sacrebleu --sacrebleu-tokenizer zh --sacrebleu-lowercase" 9 | 10 | if [[ $AVG == "true" ]]; then 11 | CHECKPOINT_FILENAME=avg_best_5_checkpoint.pt 12 | python /root/Mono4SiM/utility/simultaneous_translation/scripts/average_checkpoints.py \ 13 | --inputs ${CHECKDIR} --num-best-checkpoints 5 \ 14 | --output "${CHECKDIR}/${CHECKPOINT_FILENAME}" 15 | else 16 | CHECKPOINT_FILENAME=checkpoint_best.pt 17 | fi 18 | 19 | ROOT=/root/Mono4SiM/generate/${TASK} 20 | FILES=/root/Mono4SiM/data/cwmt-enzh 21 | mkdir -p ${ROOT} 22 | mkdir -p ${ROOT}/interactive 23 | mkdir -p ${FILES}/split 24 | 25 | split -l $((`wc -l < ${FILES}/ready/train.clean.en`/8)) -d -a 1 ${FILES}/ready/train.clean.en ${FILES}/split/train.en. 26 | 27 | for i in {0..7} 28 | do 29 | cat ${FILES}/split/train.en.${i} | CUDA_VISIBLE_DEVICES=${i} \ 30 | python -m fairseq_cli.interactive /root/Mono4SiM/data/cwmt-enzh/data-bin \ 31 | -s ${SRC} -t ${TGT} \ 32 | --user-dir ${USERDIR} \ 33 | --skip-invalid-size-inputs-valid-test \ 34 | --task translation \ 35 | --path ${CHECKDIR}/${CHECKPOINT_FILENAME} \ 36 | --batch-size 64 --buffer-size 128 --fp16 \ 37 | ${GENARGS} ${EXTRAARGS} > ${ROOT}/interactive/generate-train.${i}.txt 2>&1 & 38 | done 39 | -------------------------------------------------------------------------------- /train/cwmt-enzh/0-distill_enzh_mono.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source ./data_path.sh 3 | TASK=teacher_cwmt 4 | CHECKDIR=/root/Mono4SiM/train/cwmt-${SRC}${TGT}/checkpoints/${TASK}_${SRC}${TGT} 5 | AVG=true 6 | 7 | GENARGS="--beam 5 --lenpen 1.5 --max-len-a 1.2 --max-len-b 10 --remove-bpe sentencepiece" 8 | EXTRAARGS="--scoring sacrebleu --sacrebleu-tokenizer zh --sacrebleu-lowercase" 9 | 10 | if [[ $AVG == "true" ]]; then 11 | CHECKPOINT_FILENAME=avg_best_5_checkpoint.pt 12 | python /root/Mono4SiM/utility/simultaneous_translation/scripts/average_checkpoints.py \ 13 | --inputs ${CHECKDIR} --num-best-checkpoints 5 \ 14 | --output "${CHECKDIR}/${CHECKPOINT_FILENAME}" 15 | else 16 | CHECKPOINT_FILENAME=checkpoint_best.pt 17 | fi 18 | 19 | ROOT=/root/Mono4SiM/generate/${TASK}_mono 20 | FILES=/root/Mono4SiM/data/mono-en 21 | mkdir -p ${ROOT} 22 | mkdir -p ${ROOT}/interactive 23 | mkdir -p ${FILES}/split 24 | 25 | split -l $((`wc -l < ${FILES}/ready/train.clean.en`/8)) -d -a 1 ${FILES}/ready/train.clean.en ${FILES}/split/train.en. 26 | 27 | for i in {0..7} 28 | do 29 | cat ${FILES}/split/train.en.${i} | CUDA_VISIBLE_DEVICES=${i} \ 30 | python -m fairseq_cli.interactive /root/Mono4SiM/data/${DATASET}/data-bin \ 31 | -s ${SRC} -t ${TGT} \ 32 | --user-dir ${USERDIR} \ 33 | --skip-invalid-size-inputs-valid-test \ 34 | --task translation \ 35 | --path ${CHECKDIR}/${CHECKPOINT_FILENAME} \ 36 | --batch-size 64 --buffer-size 128 --fp16 \ 37 | ${GENARGS} ${EXTRAARGS} > ${ROOT}/interactive/generate-train.${i}.txt 2>&1 & 38 | done 39 | -------------------------------------------------------------------------------- /train/wmt21-enja/1s-wait_k.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source ./data_path.sh 4 | 5 | WAITK=$1 6 | DATA=$2 7 | TASK=$3 8 | SUBSET=$5 9 | 10 | if [ "$DATA" == "raw" ]; then 11 | DATA=/root/Mono4SiM/data/${DATASET}/data-bin 12 | fi 13 | 14 | pip install sacrebleu[ja] 15 | CUDA_VISIBLE_DEVICES=$4 python -m fairseq_cli.train \ 16 | ${DATA}/${SUBSET} --user-dir ${USERDIR} \ 17 | -s ${SRC} -t ${TGT} \ 18 | --fp16 --ddp-backend=no_c10d \ 19 | --task translation_infer \ 20 | --inference-config-yaml infer_mt.yaml \ 21 | --arch waitk_transformer \ 22 | --waitk ${WAITK} \ 23 | --optimizer adam --adam-betas '(0.9,0.98)' \ 24 | --attention-dropout 0.0 --activation-dropout 0.0 --dropout 0.3 \ 25 | --max-tokens 25600 --update-freq 2 \ 26 | --lr 1e-3 --lr-scheduler cosine --warmup-init-lr 1e-07 --weight-decay 0.0 \ 27 | --lr-shrink 1 --lr-period-updates 20000 --min-lr 1e-09 \ 28 | --warmup-updates 10000 --clip-norm 0.1 \ 29 | --max-update 50000 \ 30 | --share-decoder-input-output-embed \ 31 | --criterion label_smoothed_cross_entropy \ 32 | --max-source-positions 10000 --max-target-positions 10000 \ 33 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 34 | --save-dir checkpoints/${TASK}_${SUBSET} \ 35 | --no-epoch-checkpoints \ 36 | --save-interval-updates 500 \ 37 | --keep-interval-updates 1 \ 38 | --keep-best-checkpoints 5 \ 39 | --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \ 40 | --log-format simple --log-interval 50 \ 41 | --seed 1 42 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/run_wmt15_bleueval.sh: -------------------------------------------------------------------------------- 1 | source ~/utility/sacrebleu/sacrebleu2/bin/activate 2 | SRC=de 3 | TGT=en 4 | DIR=wmt15_${TGT}-results 5 | WORKERS=2 6 | REF=( 7 | "/media/XXXX-2/Data/wmt15/de-en/prep/test.${TGT}" 8 | ) 9 | 10 | # Normal 11 | for DELAY in 1 3 5 7 9; do 12 | BASELINE="${DIR}/wait_${DELAY}_${SRC}${TGT}_distill.wmt15/prediction" 13 | SYSTEMS=( 14 | "${DIR}/wait_${DELAY}_${SRC}${TGT}_mon.wmt15/prediction" 15 | "${DIR}/wait_${DELAY}_${SRC}${TGT}_reorder.wmt15/prediction" 16 | "${DIR}/ctc_delay${DELAY}.wmt15/prediction" 17 | "${DIR}/ctc_delay${DELAY}_mon.wmt15/prediction" 18 | "${DIR}/ctc_delay${DELAY}_reorder.wmt15/prediction" 19 | "${DIR}/sinkhorn_delay${DELAY}.wmt15/prediction" 20 | "${DIR}/sinkhorn_delay${DELAY}_ft.wmt15/prediction" 21 | ) 22 | 23 | OUTPUT=${DIR}/quality-results.wmt15/delay${DELAY}-systems 24 | mkdir -p $(dirname ${OUTPUT}) 25 | python -m sacrebleu ${REF[@]} -i ${BASELINE} ${SYSTEMS[@]} \ 26 | --paired-jobs ${WORKERS} \ 27 | -m bleu chrf \ 28 | --width 2 \ 29 | --tok 13a -lc \ 30 | --chrf-lowercase \ 31 | --paired-bs | tee ${OUTPUT} 32 | done 33 | 34 | # Full-sentence 35 | TEACHER="${DIR}/teacher_wmt15_${SRC}${TGT}.wmt15/prediction" 36 | OUTPUT=${DIR}/quality-results.wmt15/full_sentence-system 37 | mkdir -p $(dirname ${OUTPUT}) 38 | python -m sacrebleu ${REF[@]} -i ${TEACHER} \ 39 | --paired-jobs ${WORKERS} \ 40 | -m bleu chrf \ 41 | --width 2 \ 42 | --tok 13a -lc \ 43 | --chrf-lowercase \ 44 | --confidence | tee ${OUTPUT} -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/simuleval_fullsentence.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | POSITIONAL=() 3 | while [[ $# -gt 0 ]]; do 4 | key="$1" 5 | 6 | case $key in 7 | -m|--model) 8 | MODEL="$2" 9 | shift # past argument 10 | shift # past value 11 | ;; 12 | -e|--expdir) 13 | EXP="$2" 14 | shift # past argument 15 | shift # past value 16 | ;; 17 | -s|--source) 18 | SRC_FILE="$2" 19 | shift # past argument 20 | shift # past value 21 | ;; 22 | -t|--target) 23 | TGT_FILE="$2" 24 | shift # past argument 25 | shift # past value 26 | ;; 27 | *) # unknown option 28 | POSITIONAL+=("$1") # save it in an array for later 29 | shift # past argument 30 | ;; 31 | esac 32 | done 33 | 34 | set -- "${POSITIONAL[@]}" # restore positional parameters 35 | 36 | AGENT=./agents/simul_t2t_waitk.py 37 | source ${EXP}/data_path.sh 38 | 39 | CHECKPOINT=${EXP}/checkpoints/${MODEL}/checkpoint_best.pt 40 | SPM_PREFIX=${DATA}/spm_unigram32000 41 | 42 | PORT=23451 43 | WORKERS=2 44 | BLEU_TOK=13a 45 | UNIT=word 46 | DATANAME=$(basename $(dirname $(dirname ${DATA}))) 47 | OUTPUT=${DATANAME}_${TGT}-results/${MODEL}.${DATANAME} 48 | mkdir -p ${OUTPUT} 49 | 50 | if [[ ${TGT} == "zh" ]]; then 51 | BLEU_TOK=zh 52 | UNIT=char 53 | NO_SPACE="--no-space" 54 | fi 55 | 56 | simuleval \ 57 | --agent ${AGENT} \ 58 | --user-dir ${USERDIR} \ 59 | --source ${SRC_FILE} \ 60 | --target ${TGT_FILE} \ 61 | --data-bin ${DATA} \ 62 | --model-path ${CHECKPOINT} \ 63 | --src-splitter-path ${SPM_PREFIX}_${SRC}.model \ 64 | --tgt-splitter-path ${SPM_PREFIX}_${TGT}.model \ 65 | --output ${OUTPUT} \ 66 | --sacrebleu-tokenizer ${BLEU_TOK} \ 67 | --eval-latency-unit ${UNIT} \ 68 | --segment-type ${UNIT} \ 69 | ${NO_SPACE} \ 70 | --scores \ 71 | --full-sentence \ 72 | --port ${PORT} \ 73 | --workers ${WORKERS} 74 | -------------------------------------------------------------------------------- /utility/scripts/reorder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import numpy as np 5 | from multiprocessing import Pool 6 | import tqdm 7 | 8 | def reorder(inputs): 9 | """ 10 | srcseq, tgtseq: ["tokens", ...] 11 | alnseq: ["0-0", "6-0", ...] 12 | """ 13 | srcstr, tgtstr, alnstr = inputs 14 | srcseq = inputs[0].strip().split() 15 | tgtseq = inputs[1].strip().split() 16 | alnseq = inputs[2].strip().split() 17 | tlen = len(tgtseq) 18 | null = -1 19 | new_order = np.full(tlen, null) 20 | for s_t in alnseq: 21 | s, t = tuple(map(int, s_t.split('-'))) # (0,0) 22 | new_order[t] = s 23 | 24 | for i in range(tlen): 25 | if new_order[i] == null: 26 | new_order[i] = new_order[i - 1] if i > 0 else 0 27 | 28 | reordered = [tgtseq[i] for i in new_order.argsort(kind='stable')] 29 | 30 | return reordered 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("-s", "--source", type=str, help="source file") 36 | parser.add_argument("-t", "--target", type=str, help="target file") 37 | parser.add_argument("-a", "--align", type=str, help="alignment file") 38 | parser.add_argument("-o", "--output", type=str, help="output file") 39 | parser.add_argument("-j", "--jobs", type=int, help="launch j parallel jobs.") 40 | args = parser.parse_args() 41 | print(args) 42 | 43 | def file_len(fname): 44 | import subprocess 45 | intstr = subprocess.getoutput(f'cat {fname} | wc -l') 46 | return int(intstr) 47 | 48 | srcs = open(args.source, "r") 49 | tgts = open(args.target, "r") 50 | alns = open(args.align, "r") 51 | 52 | srclen = file_len(args.source) 53 | tgtlen = file_len(args.target) 54 | alnlen = file_len(args.align) 55 | assert srclen == tgtlen and alnlen == tgtlen 56 | 57 | with Pool(args.jobs) as p: 58 | results = list(tqdm.tqdm(p.imap(reorder, zip(srcs, tgts, alns)), total=srclen)) 59 | 60 | srcs.close() 61 | tgts.close() 62 | alns.close() 63 | 64 | with open(args.output, "w") as f: 65 | for line in results: 66 | f.write(" ".join(line) + "\n") 67 | -------------------------------------------------------------------------------- /data/0-get_en_mono.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Adapted from https://github.com/pytorch/fairseq/blob/simulastsharedtask/examples/translation/prepare-iwslt14.sh 3 | source ./data_path.sh 4 | SCRIPTS=/root/Mono4SiM/utility/mosesdecoder/scripts 5 | # source ~/envs/apex/bin/activate 6 | 7 | vocab=32000 8 | vtype=unigram 9 | workers=4 10 | 11 | # TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 12 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 13 | # NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 14 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 15 | LC=$SCRIPTS/tokenizer/lowercase.perl 16 | 17 | spm_train=$FAIRSEQ/scripts/spm_train.py 18 | spm_encode=$FAIRSEQ/scripts/spm_encode.py 19 | 20 | CORPORA=( 21 | "news.2017.en.shuffled.deduped" 22 | "news.2016.en.shuffled" 23 | ) 24 | 25 | orig=${MONO}/orig 26 | prep=${MONO}/prep 27 | ready=${MONO}/ready 28 | mkdir -p $orig $prep $ready 29 | 30 | echo "downloading data" 31 | cd $orig 32 | 33 | wget http://data.statmt.org/wmt17/translation-task/news.2016.en.shuffled.gz 34 | wget http://data.statmt.org/wmt18/translation-task/news.2017.en.shuffled.deduped.gz 35 | gzip -dk news.2016.en.shuffled.gz 36 | gzip -dk news.2017.en.shuffled.deduped.gz 37 | cd .. 38 | 39 | echo "pre-processing train data..." 40 | for l in ${SRC}; do 41 | rm -f $prep/train.dirty.$l 42 | for f in "${CORPORA[@]}"; do 43 | echo "precprocess train $f" 44 | cat $orig/$f | \ 45 | perl $REM_NON_PRINT_CHAR | \ 46 | perl $LC >> $prep/train.dirty.$l 47 | done 48 | done 49 | 50 | # filter empty pairs 51 | perl $CLEAN -ratio 1000 $prep/train.dirty ${SRC} ${SRC} $prep/train 1 10000 52 | 53 | # SPM 54 | SPM_PREFIX=/root/Mono4SiM/data/cwmt-enzh/prep/spm_${vtype}${vocab} 55 | for l in ${SRC}; do 56 | SPM_MODEL=${SPM_PREFIX}_${l}.model 57 | echo "Using SPM model $SPM_MODEL" 58 | for split in train; do 59 | if [ -f $ready/$split.$l ]; then 60 | echo "found $ready/$split.$l, skipping spm_encode" 61 | else 62 | echo "spm_encode to $split.$l..." 63 | python $spm_encode --model=$SPM_MODEL \ 64 | --output_format=piece \ 65 | < $prep/$split.$l > $ready/$split.$l 66 | fi 67 | done 68 | done 69 | 70 | # filter ratio and maxlen < 1024 71 | perl $CLEAN -ratio 9 $ready/train ${SRC} ${SRC} $ready/train.clean 1 1024 72 | -------------------------------------------------------------------------------- /data/1s-preprocess_tokenizer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Adapted from https://github.com/pytorch/fairseq/blob/simulastsharedtask/examples/translation/prepare-iwslt14.sh 3 | source ./data_path.sh 4 | SCRIPTS=/root/Mono4SiM/utility/mosesdecoder/scripts 5 | LC_ALL=en_US.UTF-8 6 | LANG=en_US.UTF-8 7 | # source ~/envs/apex/bin/activate 8 | 9 | DATA=$1 10 | SRC=$3 11 | TGT=$4 12 | vocab=32000 13 | vtype=unigram 14 | workers=4 15 | SPM_PREFIX=$2 16 | 17 | # TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 18 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 19 | # NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 20 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 21 | LC=$SCRIPTS/tokenizer/lowercase.perl 22 | 23 | spm_train=$FAIRSEQ/scripts/spm_train.py 24 | spm_encode=$FAIRSEQ/scripts/spm_encode.py 25 | 26 | prep=${DATA}/prep 27 | ready=${DATA}/ready 28 | bin=${DATA}/data-bin 29 | mkdir -p $prep $ready $bin 30 | 31 | echo "pre-processing train data..." 32 | for l in ${SRC} ${TGT}; do 33 | rm -f $prep/train.dirty.$l 34 | cat ${DATA}/interactive/detok.$l | \ 35 | perl $REM_NON_PRINT_CHAR | \ 36 | perl $LC >> $prep/train.dirty.$l 37 | done 38 | 39 | # filter empty pairs 40 | perl $CLEAN -ratio 1000 $prep/train.dirty ${SRC} ${TGT} $prep/train 1 10000 41 | 42 | # SPM 43 | for l in ${SRC} ${TGT}; do 44 | SPM_MODEL=${SPM_PREFIX}_${l}.model 45 | echo "Using SPM model $SPM_MODEL" 46 | for split in train; do 47 | if [ -f $ready/$split.$l ]; then 48 | echo "found $ready/$split.$l, skipping spm_encode" 49 | else 50 | echo "spm_encode to $split.$l..." 51 | python $spm_encode --model=$SPM_MODEL \ 52 | --output_format=piece \ 53 | < $prep/$split.$l > $ready/$split.$l 54 | fi 55 | done 56 | done 57 | 58 | # filter ratio and maxlen < 256 59 | perl $CLEAN -ratio 9 $ready/train ${SRC} ${TGT} $ready/train.clean 1 256 60 | 61 | if [[ $5 == "true" ]]; then 62 | python -m fairseq_cli.preprocess \ 63 | --source-lang ${SRC} \ 64 | --target-lang ${TGT} \ 65 | --trainpref ${ready}/train.clean \ 66 | --validpref ${ready}/valid \ 67 | --testpref ${ready}/test \ 68 | --destdir ${bin} \ 69 | --workers ${workers} \ 70 | --srcdict ${SPM_PREFIX}_${SRC}.txt \ 71 | --tgtdict ${SPM_PREFIX}_${TGT}.txt 72 | fi 73 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/oracle_order/run_cwmt_oracle.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | SPLIT=test 3 | EXP=/home/XXXX-2/Projects/sinkhorn-simultrans/expcwmt 4 | SRC=en 5 | TGT=zh 6 | DATA=/media/XXXX-2/Data/cwmt/zh-en/data-bin 7 | FAIRSEQ=/home/XXXX-2/Projects/sinkhorn-simultrans/fairseq 8 | USERDIR=`realpath ../../simultaneous_translation` 9 | export PYTHONPATH="$USERDIR:$FAIRSEQ:$PYTHONPATH" 10 | 11 | function run_oracle () { 12 | CHECKDIR=${EXP}/checkpoints/${1} 13 | DATANAME=$(basename $(dirname $(dirname ${DATA}))) 14 | OUTPUT=${DATANAME}_${TGT}-results/${1} 15 | AVG=false 16 | BLEU_TOK=13a 17 | WORKERS=2 18 | 19 | if [[ ${TGT} == "zh" ]]; then 20 | BLEU_TOK=zh 21 | fi 22 | GENARGS="--beam 1 --remove-bpe sentencepiece" 23 | 24 | if [[ $AVG == "true" ]]; then 25 | CHECKPOINT_FILENAME=avg_best_5_checkpoint.pt 26 | python ../scripts/average_checkpoints.py \ 27 | --inputs ${CHECKDIR} --num-best-checkpoints 5 \ 28 | --output "${CHECKDIR}/${CHECKPOINT_FILENAME}" 29 | else 30 | CHECKPOINT_FILENAME=checkpoint_best.pt 31 | fi 32 | 33 | # python -m fairseq_cli.generate ${DATA} \ 34 | python generate.py ${DATA} \ 35 | --user-dir ${USERDIR} \ 36 | -s ${SRC} -t ${TGT} \ 37 | --gen-subset ${SPLIT} \ 38 | --task translation_infer \ 39 | --max-tokens 8000 --fp16 \ 40 | --inference-config-yaml ../exp/infer_mt.yaml \ 41 | --path ${CHECKDIR}/${CHECKPOINT_FILENAME} \ 42 | --model-overrides '{"load_pretrained_encoder_from": None}' \ 43 | --results-path ${OUTPUT} \ 44 | ${GENARGS} 45 | 46 | grep -E "D-[0-9]+" ${OUTPUT}/generate-${SPLIT}.txt | \ 47 | sed "s/^D-//" | \ 48 | sort -k1 -n | \ 49 | cut -f3 > ${OUTPUT}/oracle_prediction 50 | 51 | REF=( 52 | "/media/XXXX-2/Data/cwmt/zh-en/prep/test.${SRC}-${TGT}.${TGT}.1" 53 | "/media/XXXX-2/Data/cwmt/zh-en/prep/test.${SRC}-${TGT}.${TGT}.2" 54 | "/media/XXXX-2/Data/cwmt/zh-en/prep/test.${SRC}-${TGT}.${TGT}.3" 55 | ) 56 | SYSTEMS=( 57 | # "../${OUTPUT}/prediction" 58 | "${OUTPUT}/oracle_prediction" 59 | ) 60 | 61 | python -m sacrebleu ${REF[@]} -i ${SYSTEMS[@]} \ 62 | -m bleu \ 63 | --width 2 \ 64 | --tok zh -lc | tee ${OUTPUT}/score 65 | } 66 | 67 | for k in 1 3 5 7 9; do 68 | run_oracle sinkhorn_delay${k}_ft 69 | done -------------------------------------------------------------------------------- /train/wmt21-enja/2-test_model_full.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | SPLIT=test 4 | DATA=${BASE}/data-bin 5 | SRC_FILE=${BASE}/prep/test.${SRC} 6 | TGT_FILE=${BASE}/prep/test.${TGT} 7 | 8 | GENARGS="--beam 6 --lenpen 1.0 --max-len-a 1.2 --max-len-b 10 --remove-bpe sentencepiece" 9 | EXTRAARGS="--scoring sacrebleu --sacrebleu-tokenizer ja-mecab --sacrebleu-lowercase" 10 | REF=( 11 | "${BASE}/prep/test.${TGT}" 12 | ) 13 | 14 | PORT=12000 15 | TASK=teacher_wmt21_${SRC}${TGT} 16 | CHECKDIR=checkpoints/${TASK} 17 | CHECKPOINT=${CHECKDIR}/avg_best_5_checkpoint.pt 18 | OUTPUT=${CHECKDIR}/log 19 | mkdir -p ${OUTPUT} 20 | 21 | echo "Evaluating ${TASK}!" 22 | 23 | AGENT=/root/Mono4SiM/utility/simultaneous_translation/eval/agents/simul_t2t_waitk.py 24 | SPM_PREFIX=${DATA}/spm_unigram32000 25 | WORKERS=2 26 | 27 | BLEU_TOK=13a 28 | UNIT=word 29 | if [[ ${TGT} == "zh" ]]; then 30 | BLEU_TOK=zh 31 | UNIT=char 32 | NO_SPACE="--no-space" 33 | fi 34 | 35 | simuleval --gpu 0 \ 36 | --agent ${AGENT} \ 37 | --user-dir ${USERDIR} \ 38 | --source ${SRC_FILE} \ 39 | --target ${TGT_FILE} \ 40 | --data-bin ${DATA} \ 41 | --model-path ${CHECKPOINT} \ 42 | --src-splitter-path ${SPM_PREFIX}_${SRC}.model \ 43 | --tgt-splitter-path ${SPM_PREFIX}_${TGT}.model \ 44 | --output ${OUTPUT} \ 45 | --incremental-encoder \ 46 | --sacrebleu-tokenizer ${BLEU_TOK} \ 47 | --eval-latency-unit ${UNIT} \ 48 | --segment-type ${UNIT} \ 49 | ${NO_SPACE} \ 50 | --scores \ 51 | --full-sentence \ 52 | --port ${PORT} \ 53 | --workers ${WORKERS} > simul.tmp 54 | 55 | echo "Simuleval ${TASK} finished!" 56 | 57 | CUDA_VISIBLE_DEVICES=0 python -m \ 58 | fairseq_cli.generate ${BASE}/data-bin \ 59 | -s ${SRC} -t ${TGT} \ 60 | --user-dir ${USERDIR} \ 61 | --gen-subset ${SPLIT} \ 62 | --skip-invalid-size-inputs-valid-test \ 63 | --task translation_infer \ 64 | --inference-config-yaml pre_monotonic.yaml \ 65 | --path ${CHECKDIR}/checkpoint_best.pt \ 66 | --max-tokens 16000 --fp16 \ 67 | --results-path ${OUTPUT} \ 68 | ${GENARGS} ${EXTRAARGS} 69 | 70 | python 2b-proc_generate.py -r ${OUTPUT} -f generate-${SPLIT}.txt 71 | echo "Generation ${TASK} finished!" 72 | 73 | python -m sacrebleu ${REF[@]} -i ${OUTPUT}/detok.txt \ 74 | -m bleu chrf \ 75 | --chrf-lowercase \ 76 | --width 2 \ 77 | --tok ja-mecab -lc | tee ${OUTPUT}/sacrebleu.txt 78 | 79 | echo "SacreBLEU ${TASK} finished!" 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/simuleval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # credits: https://stackoverflow.com/questions/192249/how-do-i-parse-command-line-arguments-in-bash 3 | POSITIONAL=() 4 | while [[ $# -gt 0 ]]; do 5 | key="$1" 6 | 7 | case $key in 8 | -a|--agent) 9 | AGENT="$2" 10 | shift # past argument 11 | shift # past value 12 | ;; 13 | -m|--model) 14 | MODEL="$2" 15 | shift # past argument 16 | shift # past value 17 | ;; 18 | -k|--waitk) 19 | WAITK="$2" 20 | shift # past argument 21 | shift # past value 22 | ;; 23 | -e|--expdir) 24 | EXP="$2" 25 | shift # past argument 26 | shift # past value 27 | ;; 28 | -s|--source) 29 | SRC_FILE="$2" 30 | shift # past argument 31 | shift # past value 32 | ;; 33 | -t|--target) 34 | TGT_FILE="$2" 35 | shift # past argument 36 | shift # past value 37 | ;; 38 | -l|--lm-path) 39 | KENLM="$2" 40 | shift # past argument 41 | shift # past value 42 | ;; 43 | -w|--lm-weight) 44 | LM_WEIGHT="$2" 45 | shift # past argument 46 | shift # past value 47 | ;; 48 | *) # unknown option 49 | POSITIONAL+=("$1") # save it in an array for later 50 | shift # past argument 51 | ;; 52 | esac 53 | done 54 | 55 | set -- "${POSITIONAL[@]}" # restore positional parameters 56 | 57 | 58 | source ${EXP}/data_path.sh 59 | 60 | CHECKPOINT=${EXP}/checkpoints/${MODEL}/checkpoint_best.pt 61 | SPM_PREFIX=${DATA}/spm_unigram32000 62 | 63 | PORT=12345 64 | WORKERS=2 65 | BLEU_TOK=13a 66 | UNIT=word 67 | DATANAME=$(basename $(dirname $(dirname ${DATA}))) 68 | OUTPUT=${DATANAME}_${TGT}-results/${MODEL}.${DATANAME} 69 | mkdir -p ${OUTPUT} 70 | 71 | if [[ ${TGT} == "zh" ]]; then 72 | BLEU_TOK=zh 73 | UNIT=char 74 | NO_SPACE="--no-space" 75 | fi 76 | 77 | simuleval \ 78 | --agent ${AGENT} \ 79 | --user-dir ${USERDIR} \ 80 | --source ${SRC_FILE} \ 81 | --target ${TGT_FILE} \ 82 | --data-bin ${DATA} \ 83 | --model-path ${CHECKPOINT} \ 84 | --src-splitter-path ${SPM_PREFIX}_${SRC}.model \ 85 | --tgt-splitter-path ${SPM_PREFIX}_${TGT}.model \ 86 | --output ${OUTPUT} \ 87 | --incremental-encoder \ 88 | --sacrebleu-tokenizer ${BLEU_TOK} \ 89 | --eval-latency-unit ${UNIT} \ 90 | --segment-type ${UNIT} \ 91 | ${NO_SPACE} \ 92 | --scores \ 93 | --test-waitk ${WAITK} \ 94 | --port ${PORT} \ 95 | --lm-path ${KENLM} \ 96 | --lm-weight ${LM_WEIGHT} \ 97 | --workers ${WORKERS} 98 | -------------------------------------------------------------------------------- /data/0-get_en_mono_scaling.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Adapted from https://github.com/pytorch/fairseq/blob/simulastsharedtask/examples/translation/prepare-iwslt14.sh 3 | source ./data_path.sh 4 | SCRIPTS=/root/Mono4SiM/utility/mosesdecoder/scripts 5 | # source ~/envs/apex/bin/activate 6 | 7 | vocab=32000 8 | vtype=unigram 9 | workers=4 10 | 11 | # TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 12 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 13 | # NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 14 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 15 | LC=$SCRIPTS/tokenizer/lowercase.perl 16 | 17 | spm_train=$FAIRSEQ/scripts/spm_train.py 18 | spm_encode=$FAIRSEQ/scripts/spm_encode.py 19 | 20 | CORPORA=( 21 | "news.2007.en.shuffled" 22 | "news.2008.en.shuffled" 23 | "news.2009.en.shuffled" 24 | "news.2010.en.shuffled" 25 | "news.2011.en.shuffled" 26 | "news.2012.en.shuffled" 27 | "news.2013.en.shuffled" 28 | "news.2014.en.shuffled.v2" 29 | "news.2015.en.shuffled" 30 | "news.2016.en.shuffled" 31 | "news.2017.en.shuffled.deduped" 32 | "news-discuss-v1.en.txt" 33 | "news-discuss.2015-2016.en.shuffled" 34 | "news-discuss.2017.en.shuffled.deduped" 35 | ) 36 | 37 | orig=${MONO}/orig 38 | prep=${MONO}/prep 39 | ready=${MONO}/ready 40 | mkdir -p $orig $prep $ready 41 | 42 | echo "pre-processing train data..." 43 | for l in ${SRC}; do 44 | rm -f $prep/train.dirty.$l 45 | for f in "${CORPORA[@]}"; do 46 | echo "precprocess train $f" 47 | cd $orig 48 | gzip -dk $orig/$f.gz 49 | cd .. 50 | cat $orig/$f | \ 51 | perl $REM_NON_PRINT_CHAR | \ 52 | perl $LC >> $prep/train.dirty.$l 53 | done 54 | done 55 | 56 | # filter empty pairs 57 | perl $CLEAN -ratio 1000 $prep/train.dirty ${SRC} ${SRC} $prep/train 1 10000 58 | 59 | # SPM 60 | SPM_PREFIX=/root/Mono4SiM/data/cwmt-enzh/prep/spm_${vtype}${vocab} 61 | for l in ${SRC}; do 62 | SPM_MODEL=${SPM_PREFIX}_${l}.model 63 | echo "Using SPM model $SPM_MODEL" 64 | for split in train; do 65 | if [ -f $ready/$split.$l ]; then 66 | echo "found $ready/$split.$l, skipping spm_encode" 67 | else 68 | echo "spm_encode to $split.$l..." 69 | python $spm_encode --model=$SPM_MODEL \ 70 | --output_format=piece \ 71 | < $prep/$split.$l > $ready/$split.$l 72 | fi 73 | done 74 | done 75 | 76 | # filter ratio and maxlen < 1024 77 | perl $CLEAN -ratio 9 $ready/train ${SRC} ${SRC} $ready/train.clean 1 1024 78 | -------------------------------------------------------------------------------- /train/cwmt-enzh/2-test_model_full.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | SPLIT=test 4 | ROOT=/root/Mono4SiM/data/${DATASET} 5 | DATA=${ROOT}/data-bin 6 | SRC_FILE=${ROOT}/prep/test.${SRC}-${TGT}.${SRC} 7 | TGT_FILE=${ROOT}/prep/test.${SRC}-${TGT}.${TGT}.1 8 | 9 | GENARGS="--beam 5 --lenpen 1.5 --max-len-a 1.2 --max-len-b 10 --remove-bpe sentencepiece" 10 | EXTRAARGS="--scoring sacrebleu --sacrebleu-tokenizer zh --sacrebleu-lowercase" 11 | REF=( 12 | "${BASE}/prep/test.${SRC}-${TGT}.${TGT}.1" 13 | "${BASE}/prep/test.${SRC}-${TGT}.${TGT}.2" 14 | "${BASE}/prep/test.${SRC}-${TGT}.${TGT}.3" 15 | ) 16 | 17 | PORT=11000 18 | TASK=teacher_cwmt_${SRC}${TGT} 19 | CHECKDIR=checkpoints/${TASK} 20 | CHECKPOINT=${CHECKDIR}/checkpoint_best.pt 21 | OUTPUT=${CHECKDIR}/log 22 | mkdir -p ${OUTPUT} 23 | 24 | echo "Evaluating ${TASK}!" 25 | 26 | AGENT=/root/Mono4SiM/utility/simultaneous_translation/eval/agents/simul_t2t_waitk.py 27 | SPM_PREFIX=${DATA}/spm_unigram32000 28 | WORKERS=2 29 | 30 | BLEU_TOK=13a 31 | UNIT=word 32 | if [[ ${TGT} == "zh" ]]; then 33 | BLEU_TOK=zh 34 | UNIT=char 35 | NO_SPACE="--no-space" 36 | fi 37 | 38 | simuleval --gpu 0 \ 39 | --agent ${AGENT} \ 40 | --user-dir ${USERDIR} \ 41 | --source ${SRC_FILE} \ 42 | --target ${TGT_FILE} \ 43 | --data-bin ${DATA} \ 44 | --model-path ${CHECKPOINT} \ 45 | --src-splitter-path ${SPM_PREFIX}_${SRC}.model \ 46 | --tgt-splitter-path ${SPM_PREFIX}_${TGT}.model \ 47 | --output ${OUTPUT} \ 48 | --sacrebleu-tokenizer ${BLEU_TOK} \ 49 | --eval-latency-unit ${UNIT} \ 50 | --segment-type ${UNIT} \ 51 | ${NO_SPACE} \ 52 | --scores \ 53 | --full-sentence \ 54 | --port ${PORT} \ 55 | --workers ${WORKERS} 56 | 57 | echo "Simuleval ${TASK} finished!" 58 | 59 | CUDA_VISIBLE_DEVICES=0 python -m \ 60 | fairseq_cli.generate ${ROOT}/data-bin \ 61 | -s ${SRC} -t ${TGT} \ 62 | --user-dir ${USERDIR} \ 63 | --gen-subset ${SPLIT} \ 64 | --skip-invalid-size-inputs-valid-test \ 65 | --task translation_infer \ 66 | --inference-config-yaml pre_monotonic.yaml \ 67 | --path ${CHECKDIR}/checkpoint_best.pt \ 68 | --max-tokens 4096 --fp16 \ 69 | --results-path ${OUTPUT} \ 70 | ${GENARGS} ${EXTRAARGS} 71 | 72 | python 2s-proc_generate.py -r ${OUTPUT} -f generate-${SPLIT}.txt 73 | echo "Generation ${TASK} finished!" 74 | 75 | python -m sacrebleu ${REF[@]} -i ${OUTPUT}/detok.txt \ 76 | -m bleu chrf \ 77 | --chrf-lowercase \ 78 | --width 2 \ 79 | --tok zh -lc | tee ${OUTPUT}/sacrebleu.txt 80 | 81 | echo "SacreBLEU ${TASK} finished!" 82 | -------------------------------------------------------------------------------- /train/wmt21-enja/2-test_model.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | SUBSET=$1 4 | SPLIT=test 5 | DATA=${BASE}/data-bin 6 | SRC_FILE=${BASE}/prep/test.${SRC} 7 | TGT_FILE=${BASE}/prep/test.${TGT} 8 | 9 | GENARGS="--beam 6 --lenpen 1.0 --max-len-a 1.2 --max-len-b 10 --remove-bpe sentencepiece" 10 | EXTRAARGS="--scoring sacrebleu --sacrebleu-tokenizer ja-mecab --sacrebleu-lowercase" 11 | REF=( 12 | "${BASE}/prep/test.${TGT}" 13 | ) 14 | 15 | PORT=12000 16 | for k in 1 3 5 7 9; do 17 | TASK=wait_${k}_${SRC}${TGT}_distill_${SUBSET} 18 | if [ "$SUBSET" == "raw" ]; then 19 | TASK=wait_${k}_${SRC}${TGT}_ 20 | fi 21 | CHECKDIR=checkpoints/${TASK} 22 | CHECKPOINT=${CHECKDIR}/checkpoint_best.pt 23 | OUTPUT=${CHECKDIR}/log 24 | mkdir -p ${OUTPUT} 25 | 26 | echo "Evaluating ${TASK}!" 27 | 28 | AGENT=/root/Mono4SiM/utility/simultaneous_translation/eval/agents/simul_t2t_waitk.py 29 | SPM_PREFIX=${DATA}/spm_unigram32000 30 | WORKERS=2 31 | 32 | BLEU_TOK=13a 33 | UNIT=word 34 | if [[ ${TGT} == "zh" ]]; then 35 | BLEU_TOK=zh 36 | UNIT=char 37 | NO_SPACE="--no-space" 38 | fi 39 | 40 | simuleval --gpu 0 \ 41 | --agent ${AGENT} \ 42 | --user-dir ${USERDIR} \ 43 | --source ${SRC_FILE} \ 44 | --target ${TGT_FILE} \ 45 | --data-bin ${DATA} \ 46 | --model-path ${CHECKPOINT} \ 47 | --src-splitter-path ${SPM_PREFIX}_${SRC}.model \ 48 | --tgt-splitter-path ${SPM_PREFIX}_${TGT}.model \ 49 | --output ${OUTPUT} \ 50 | --incremental-encoder \ 51 | --sacrebleu-tokenizer ${BLEU_TOK} \ 52 | --eval-latency-unit ${UNIT} \ 53 | --segment-type ${UNIT} \ 54 | ${NO_SPACE} \ 55 | --scores \ 56 | --test-waitk ${k} \ 57 | --port ${PORT} \ 58 | --workers ${WORKERS} > simul.tmp 59 | 60 | echo "Simuleval ${TASK} finished!" 61 | 62 | CUDA_VISIBLE_DEVICES=0 python -m \ 63 | fairseq_cli.generate ${BASE}/data-bin \ 64 | -s ${SRC} -t ${TGT} \ 65 | --user-dir ${USERDIR} \ 66 | --gen-subset ${SPLIT} \ 67 | --skip-invalid-size-inputs-valid-test \ 68 | --task translation_infer \ 69 | --inference-config-yaml pre_monotonic.yaml \ 70 | --path ${CHECKDIR}/checkpoint_best.pt \ 71 | --max-tokens 4096 --fp16 \ 72 | --results-path ${OUTPUT} \ 73 | ${GENARGS} ${EXTRAARGS} 74 | 75 | python 2s-proc_generate.py -r ${OUTPUT} -f generate-${SPLIT}.txt 76 | 77 | python -m sacrebleu ${REF[@]} -i ${OUTPUT}/detok.txt \ 78 | -m bleu chrf \ 79 | --chrf-lowercase \ 80 | --width 2 \ 81 | --tok zh -lc | tee ${OUTPUT}/sacrebleu.txt 82 | echo "Generation ${TASK} finished!" 83 | done 84 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/tasks/inference_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as op 3 | from argparse import Namespace 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class InferenceConfig(object): 9 | """Wrapper class for bleu config YAML""" 10 | 11 | def __init__(self, yaml_path): 12 | try: 13 | import yaml 14 | except ImportError: 15 | print("Please install PyYAML to load YAML files for " "S2T data config") 16 | self.config = {} 17 | if op.isfile(yaml_path): 18 | try: 19 | with open(yaml_path) as f: 20 | self.config = yaml.load(f, Loader=yaml.FullLoader) 21 | except Exception as e: 22 | logger.info(f"Failed to load config from {yaml_path}: {e}") 23 | else: 24 | logger.info(f"Cannot find {yaml_path}") 25 | 26 | @property 27 | def eval_wer(self): 28 | """evaluation with WER score in validation step.""" 29 | return self.config.get("eval_wer", False) 30 | 31 | @property 32 | def eval_bleu(self): 33 | """evaluation with BLEU score in validation step.""" 34 | return self.config.get("eval_bleu", False) 35 | 36 | @property 37 | def eval_any(self): 38 | return self.eval_bleu or self.eval_wer 39 | 40 | @property 41 | def generation_args(self): 42 | """generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string""" 43 | args = self.config.get("generation_args", {}) 44 | return Namespace(**args) 45 | 46 | @property 47 | def post_process(self): 48 | """post-process text by removing pre-processing such as BPE, letter segmentation, etc 49 | (valid options are: sentencepiece, wordpiece, letter, _EOW, none, otherwise treated as BPE symbol) 50 | """ 51 | return self.config.get("post_process", None) 52 | 53 | @property 54 | def print_samples(self): 55 | """print sample generations during validation""" 56 | return self.config.get("print_samples", False) 57 | 58 | @property 59 | def eval_bleu_args(self): 60 | """args for bleu scoring""" 61 | args = self.config.get("eval_bleu_args", { 62 | "sacrebleu_tokenizer": "13a", 63 | "sacrebleu_lowercase": False, 64 | "sacrebleu_char_level": False 65 | }) 66 | return Namespace(**args) 67 | 68 | @property 69 | def eval_wer_args(self): 70 | """args for wer scoring""" 71 | args = self.config.get("eval_wer_args", { 72 | "wer_tokenizer": "13a", 73 | "wer_remove_punct": True, 74 | "wer_lowercase": True, 75 | "wer_char_level": False 76 | }) 77 | return Namespace(**args) 78 | -------------------------------------------------------------------------------- /train/cwmt-enzh/2-test_model.sh: -------------------------------------------------------------------------------- 1 | source ./data_path.sh 2 | 3 | SUBSET=$1 4 | SPLIT=test 5 | DATA=${BASE}/data-bin 6 | SRC_FILE=${BASE}/prep/test.${SRC}-${TGT}.${SRC} 7 | TGT_FILE=${BASE}/prep/test.${SRC}-${TGT}.${TGT}.1 8 | 9 | GENARGS="--beam 5 --lenpen 1.5 --max-len-a 1.2 --max-len-b 10 --remove-bpe sentencepiece" 10 | EXTRAARGS="--scoring sacrebleu --sacrebleu-tokenizer zh --sacrebleu-lowercase" 11 | REF=( 12 | "${BASE}/prep/test.${SRC}-${TGT}.${TGT}.1" 13 | "${BASE}/prep/test.${SRC}-${TGT}.${TGT}.2" 14 | "${BASE}/prep/test.${SRC}-${TGT}.${TGT}.3" 15 | ) 16 | 17 | PORT=11001 18 | for k in 1 3 5 7 9; do 19 | TASK=wait_${k}_${SRC}${TGT}_distill_${SUBSET} 20 | if [ "$SUBSET" == "raw" ]; then 21 | TASK=wait_${k}_${SRC}${TGT}_ 22 | fi 23 | CHECKDIR=checkpoints/${TASK} 24 | CHECKPOINT=${CHECKDIR}/checkpoint_best.pt 25 | OUTPUT=${CHECKDIR}/log 26 | mkdir -p ${OUTPUT} 27 | 28 | echo "Evaluating ${TASK}!" 29 | 30 | AGENT=/root/Mono4SiM/utility/simultaneous_translation/eval/agents/simul_t2t_waitk.py 31 | SPM_PREFIX=${DATA}/spm_unigram32000 32 | WORKERS=2 33 | 34 | BLEU_TOK=13a 35 | UNIT=word 36 | if [[ ${TGT} == "zh" ]]; then 37 | BLEU_TOK=zh 38 | UNIT=char 39 | NO_SPACE="--no-space" 40 | fi 41 | 42 | simuleval --gpu 0 \ 43 | --agent ${AGENT} \ 44 | --user-dir ${USERDIR} \ 45 | --source ${SRC_FILE} \ 46 | --target ${TGT_FILE} \ 47 | --data-bin ${DATA} \ 48 | --model-path ${CHECKPOINT} \ 49 | --src-splitter-path ${SPM_PREFIX}_${SRC}.model \ 50 | --tgt-splitter-path ${SPM_PREFIX}_${TGT}.model \ 51 | --output ${OUTPUT} \ 52 | --incremental-encoder \ 53 | --sacrebleu-tokenizer ${BLEU_TOK} \ 54 | --eval-latency-unit ${UNIT} \ 55 | --segment-type ${UNIT} \ 56 | ${NO_SPACE} \ 57 | --scores \ 58 | --test-waitk ${k} \ 59 | --port ${PORT} \ 60 | --workers ${WORKERS} 61 | 62 | echo "Simuleval ${TASK} finished!" 63 | 64 | CUDA_VISIBLE_DEVICES=0 python -m \ 65 | fairseq_cli.generate ${BASE}/data-bin \ 66 | -s ${SRC} -t ${TGT} \ 67 | --user-dir ${USERDIR} \ 68 | --gen-subset ${SPLIT} \ 69 | --skip-invalid-size-inputs-valid-test \ 70 | --task translation_infer \ 71 | --inference-config-yaml pre_monotonic.yaml \ 72 | --path ${CHECKDIR}/checkpoint_best.pt \ 73 | --max-tokens 4096 --fp16 \ 74 | --results-path ${OUTPUT} \ 75 | ${GENARGS} ${EXTRAARGS} 76 | 77 | python 2s-proc_generate.py -r ${OUTPUT} -f generate-${SPLIT}.txt 78 | 79 | python -m sacrebleu ${REF[@]} -i ${OUTPUT}/detok.txt \ 80 | -m bleu chrf \ 81 | --chrf-lowercase \ 82 | --width 2 \ 83 | --tok zh -lc | tee ${OUTPUT}/sacrebleu.txt 84 | echo "Generation ${TASK} finished!" 85 | done 86 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/run_cwmt_bleueval.sh: -------------------------------------------------------------------------------- 1 | source ~/utility/sacrebleu/sacrebleu2/bin/activate 2 | SRC=en 3 | TGT=zh 4 | DIR=cwmt_${TGT}-results 5 | WORKERS=2 6 | REF=( 7 | "/root/Datasets/cwmt-enzh/prep/test.${SRC}-${TGT}.${TGT}.1" 8 | "/root/Datasets/cwmt-enzh/prep/test.${SRC}-${TGT}.${TGT}.2" 9 | "/root/Datasets/cwmt-enzh/prep/test.${SRC}-${TGT}.${TGT}.3" 10 | ) 11 | 12 | # Normal 13 | for DELAY in 1 3 5 7 9; do 14 | BASELINE="${DIR}/wait_${DELAY}_${SRC}${TGT}_distill.cwmt/prediction" 15 | SYSTEMS=( 16 | "${DIR}/wait_${DELAY}_${SRC}${TGT}_mon.cwmt/prediction" 17 | "${DIR}/wait_${DELAY}_${SRC}${TGT}_reorder.cwmt/prediction" 18 | "${DIR}/ctc_delay${DELAY}.cwmt/prediction" 19 | "${DIR}/ctc_delay${DELAY}_mon.cwmt/prediction" 20 | "${DIR}/ctc_delay${DELAY}_reorder.cwmt/prediction" 21 | "${DIR}/sinkhorn_delay${DELAY}.cwmt/prediction" 22 | "${DIR}/sinkhorn_delay${DELAY}_ft.cwmt/prediction" 23 | ) 24 | OUTPUT=${DIR}/quality-results.cwmt/delay${DELAY}-systems 25 | mkdir -p $(dirname ${OUTPUT}) 26 | python -m sacrebleu ${REF[@]} -i ${BASELINE} ${SYSTEMS[@]} \ 27 | --paired-jobs ${WORKERS} \ 28 | -m bleu chrf \ 29 | --width 2 \ 30 | --tok zh -lc \ 31 | --chrf-lowercase \ 32 | --paired-bs | tee ${OUTPUT} 33 | done 34 | 35 | # Full-sentence 36 | TEACHER="${DIR}/teacher_cwmt_${SRC}${TGT}.cwmt/prediction" 37 | OUTPUT=${DIR}/quality-results.cwmt/full_sentence-systems 38 | mkdir -p $(dirname ${OUTPUT}) 39 | python -m sacrebleu ${REF[@]} -i ${TEACHER} \ 40 | --paired-jobs ${WORKERS} \ 41 | -m bleu chrf \ 42 | --width 2 \ 43 | --tok zh -lc \ 44 | --chrf-lowercase \ 45 | --confidence | tee ${OUTPUT} 46 | 47 | # # Ablation 48 | # BASELINE="${DIR}/sinkhorn_delay3.cwmt/prediction" 49 | # SYSTEMS=( 50 | # "${DIR}/sinkhorn_delay3_unittemp.cwmt/prediction" 51 | # "${DIR}/sinkhorn_delay3_nonoise.cwmt/prediction" 52 | # "${DIR}/sinkhorn_delay3_softmax.cwmt/prediction" 53 | # ) 54 | # OUTPUT=${DIR}/quality-results.cwmt/ablation-systems 55 | # mkdir -p $(dirname ${OUTPUT}) 56 | # python -m sacrebleu ${REF[@]} -i ${BASELINE} ${SYSTEMS[@]} \ 57 | # --paired-jobs ${WORKERS} \ 58 | # -m bleu chrf \ 59 | # --width 2 \ 60 | # --tok zh -lc \ 61 | # --chrf-lowercase \ 62 | # --paired-bs | tee ${OUTPUT} 63 | 64 | # # Verbose scores 65 | # OUTDIR=${DIR}/quality-results.cwmt/verbose 66 | # mkdir -p ${OUTDIR} 67 | # for DELAY in 1 3 5 7 9; do 68 | # SYSTEMS=( 69 | # "sinkhorn_delay${DELAY}" 70 | # "sinkhorn_delay${DELAY}_ft" 71 | # ) 72 | # for s in "${SYSTEMS[@]}"; do 73 | # python -m sacrebleu ${REF[@]} \ 74 | # -i ${DIR}/${s}.cwmt/prediction \ 75 | # -m bleu \ 76 | # --width 2 \ 77 | # --tok zh -lc | tee ${OUTDIR}/${s} 78 | # done 79 | # done -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/run_cwmt_bleueval copy.sh: -------------------------------------------------------------------------------- 1 | source ~/utility/sacrebleu/sacrebleu2/bin/activate 2 | SRC=en 3 | TGT=zh 4 | DIR=cwmt_${TGT}-results 5 | WORKERS=2 6 | REF=( 7 | "/root/Datasets/cwmt-enzh/prep/test.${SRC}-${TGT}.${TGT}.1" 8 | "/root/Datasets/cwmt-enzh/prep/test.${SRC}-${TGT}.${TGT}.2" 9 | "/root/Datasets/cwmt-enzh/prep/test.${SRC}-${TGT}.${TGT}.3" 10 | ) 11 | 12 | # Normal 13 | for DELAY in 1 3 5 7 9; do 14 | BASELINE="${DIR}/wait_${DELAY}_${SRC}${TGT}_distill.cwmt/prediction" 15 | SYSTEMS=( 16 | "${DIR}/wait_${DELAY}_${SRC}${TGT}_mon.cwmt/prediction" 17 | "${DIR}/wait_${DELAY}_${SRC}${TGT}_reorder.cwmt/prediction" 18 | "${DIR}/ctc_delay${DELAY}.cwmt/prediction" 19 | "${DIR}/ctc_delay${DELAY}_mon.cwmt/prediction" 20 | "${DIR}/ctc_delay${DELAY}_reorder.cwmt/prediction" 21 | "${DIR}/sinkhorn_delay${DELAY}.cwmt/prediction" 22 | "${DIR}/sinkhorn_delay${DELAY}_ft.cwmt/prediction" 23 | ) 24 | OUTPUT=${DIR}/quality-results.cwmt/delay${DELAY}-systems 25 | mkdir -p $(dirname ${OUTPUT}) 26 | python -m sacrebleu ${REF[@]} -i ${BASELINE} ${SYSTEMS[@]} \ 27 | --paired-jobs ${WORKERS} \ 28 | -m bleu chrf \ 29 | --width 2 \ 30 | --tok zh -lc \ 31 | --chrf-lowercase \ 32 | --paired-bs | tee ${OUTPUT} 33 | done 34 | 35 | # Full-sentence 36 | TEACHER="${DIR}/teacher_cwmt_${SRC}${TGT}.cwmt/prediction" 37 | OUTPUT=${DIR}/quality-results.cwmt/full_sentence-systems 38 | mkdir -p $(dirname ${OUTPUT}) 39 | python -m sacrebleu ${REF[@]} -i ${TEACHER} \ 40 | --paired-jobs ${WORKERS} \ 41 | -m bleu chrf \ 42 | --width 2 \ 43 | --tok zh -lc \ 44 | --chrf-lowercase \ 45 | --confidence | tee ${OUTPUT} 46 | 47 | # # Ablation 48 | # BASELINE="${DIR}/sinkhorn_delay3.cwmt/prediction" 49 | # SYSTEMS=( 50 | # "${DIR}/sinkhorn_delay3_unittemp.cwmt/prediction" 51 | # "${DIR}/sinkhorn_delay3_nonoise.cwmt/prediction" 52 | # "${DIR}/sinkhorn_delay3_softmax.cwmt/prediction" 53 | # ) 54 | # OUTPUT=${DIR}/quality-results.cwmt/ablation-systems 55 | # mkdir -p $(dirname ${OUTPUT}) 56 | # python -m sacrebleu ${REF[@]} -i ${BASELINE} ${SYSTEMS[@]} \ 57 | # --paired-jobs ${WORKERS} \ 58 | # -m bleu chrf \ 59 | # --width 2 \ 60 | # --tok zh -lc \ 61 | # --chrf-lowercase \ 62 | # --paired-bs | tee ${OUTPUT} 63 | 64 | # # Verbose scores 65 | # OUTDIR=${DIR}/quality-results.cwmt/verbose 66 | # mkdir -p ${OUTDIR} 67 | # for DELAY in 1 3 5 7 9; do 68 | # SYSTEMS=( 69 | # "sinkhorn_delay${DELAY}" 70 | # "sinkhorn_delay${DELAY}_ft" 71 | # ) 72 | # for s in "${SYSTEMS[@]}"; do 73 | # python -m sacrebleu ${REF[@]} \ 74 | # -i ${DIR}/${s}.cwmt/prediction \ 75 | # -m bleu \ 76 | # --width 2 \ 77 | # --tok zh -lc | tee ${OUTDIR}/${s} 78 | # done 79 | # done -------------------------------------------------------------------------------- /train/get_score.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from test_test import Aligner 3 | from subprocess import Popen, PIPE, STDOUT 4 | 5 | dataset = 'cwmt-enzh' 6 | root = '/root/Mono4SiM/train/cwmt-enzh/checkpoints' 7 | 8 | lst = ['raw', '', 'train_random_subset', 9 | 'sentence_frequency_low_subset', 10 | 'sentence_uncertainty_high_subset', 11 | '3_anticipation_rate_low_subset', 12 | 'chunking_align_high_subset', 13 | 'chunking_LM_high_subset', 14 | '3_anticipation_rate_low_chunking_align_filter', 15 | '3_anticipation_rate_low_chunking_LM_filter'] 16 | 17 | csvfile = open(f'/root/Mono4SiM/train/{dataset}.csv', 'a', encoding='utf-8') 18 | writer = csv.writer(csvfile) 19 | align = Aligner(f'/root/Mono4SiM/data/{dataset}/score/fwd_align', 20 | f'/root/Mono4SiM/data/{dataset}/score/fwd_err', 21 | f'/root/Mono4SiM/data/{dataset}/score/rev_align', 22 | f'/root/Mono4SiM/data/{dataset}/score/rev_err') 23 | 24 | 25 | def execCmd(cmd, *args): 26 | for arg in args: 27 | cmd = f'{cmd} {arg}' 28 | p = Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) 29 | while p.poll() is None: 30 | print(p.stdout.readline().decode("utf-8").rstrip()) 31 | 32 | 33 | def proc(str): 34 | return str[str.rfind(':') + 1:].strip() 35 | 36 | 37 | for i in lst: 38 | execCmd(f'bash /root/Mono4SiM/train/{dataset}/2-test_model.sh', i) 39 | for k in range(1, 10, 2): 40 | try: 41 | base = f'{root}/wait_{k}_enzh_distill_{i}/log' 42 | if i == 'raw': 43 | base = f'{root}/wait_{k}_enzh_/log' 44 | align.align(base) 45 | 46 | with open(f'{base}/sacrebleu.txt', 'r', encoding='utf-8') as f: 47 | tmp = f.read().split(',') 48 | q = proc(tmp[1]) 49 | asw = proc(tmp[3]) 50 | e = proc(tmp[11]) 51 | asw = asw[1:asw.rfind('(')].strip().split('/') 52 | asw.insert(0, q) 53 | asw.insert(0, e) 54 | 55 | with open(f'{base}/scores', 'r', encoding='utf-8') as f: 56 | tmp = f.read().split(',') 57 | asw.append(proc(tmp[1])) 58 | asw.append(proc(tmp[3])) 59 | asw.append(proc(tmp[5])) 60 | 61 | with open(f'{base}/align_scores.txt', 'r', encoding='utf-8') as f: 62 | tmp = f.read().strip().split() 63 | anti = tmp[3:36:4] 64 | chunk = tmp[-1] 65 | sumn = 0 66 | for m in anti: 67 | sumn += float(m) 68 | sumn /= len(anti) 69 | asw.append(anti[k - 1]) 70 | asw.append(chunk) 71 | 72 | writer.writerow([f'wait_{k}_{i}'] + asw) 73 | 74 | except Exception as e: 75 | print(f'wait_{k}_enzh_distill_{i} Error!') 76 | print(e) 77 | 78 | writer.writerow('') 79 | 80 | csvfile.close() 81 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/models/causal_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from fairseq.models import ( 10 | register_model, 11 | register_model_architecture, 12 | FairseqEncoderModel, 13 | ) 14 | 15 | # user 16 | from simultaneous_translation.models.sinkhorn_encoder import ( 17 | SinkhornEncoderModel, 18 | sinkhorn_encoder 19 | ) 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | @register_model("causal_encoder") 25 | class CausalEncoderModel(SinkhornEncoderModel): 26 | @staticmethod 27 | def add_args(parser): 28 | """Add model-specific arguments to the parser.""" 29 | FairseqEncoderModel.add_args(parser) 30 | parser.add_argument( 31 | "--load-pretrained-encoder-from", 32 | type=str, 33 | metavar="STR", 34 | help="model to take encoder weights from (for initialization)", 35 | ) 36 | parser.add_argument( 37 | "--upsample-ratio", 38 | type=int, 39 | help=( 40 | 'number of upsampling factor before ctc loss. used for mt.' 41 | ), 42 | ) 43 | parser.add_argument( 44 | '--delay', type=int, help='delay for incremental reading') 45 | 46 | def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False, **unused): 47 | 48 | encoder_out = self.encoder.forward( 49 | src_tokens=src_tokens, 50 | src_lengths=src_lengths, 51 | return_all_hiddens=return_all_hiddens 52 | ) 53 | x = self.output_projection(encoder_out["encoder_out"][0]) 54 | x = x.transpose(1, 0) # force batch first 55 | 56 | padding_mask = encoder_out["encoder_padding_mask"][0] \ 57 | if len(encoder_out["encoder_padding_mask"]) > 0 else None 58 | extra = { 59 | "padding_mask": padding_mask, 60 | "encoder_out": encoder_out, 61 | "attn": encoder_out["attn"], 62 | "log_alpha": encoder_out["log_alpha"], 63 | } 64 | return x, extra 65 | 66 | 67 | @register_model_architecture( 68 | "causal_encoder", "causal_encoder" 69 | ) 70 | def causal_encoder(args): 71 | args.non_causal_layers = 0 72 | args.sinkhorn_tau = 1 73 | args.sinkhorn_iters = 1 74 | args.sinkhorn_noise_factor = 0 75 | args.sinkhorn_bucket_size = 1 76 | args.sinkhorn_energy = "dot" 77 | args.mask_ratio = 1 78 | args.mask_uniform = False 79 | 80 | sinkhorn_encoder(args) 81 | 82 | 83 | @register_model_architecture( 84 | "causal_encoder", "causal_encoder_small" 85 | ) 86 | def causal_encoder_small(args): 87 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) 88 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) 89 | args.encoder_layers = getattr(args, "encoder_layers", 5) 90 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) 91 | 92 | causal_encoder(args) 93 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/oracle_order/nat_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, List 3 | from torch import Tensor 4 | from fairseq.utils import new_arange 5 | 6 | 7 | def inject_noise(prev_output_tokens, dictionary, ratio=-1, uniform=False): 8 | """ mask out tokens. uniform: uniform length masking """ 9 | pad = dictionary.pad() 10 | bos = dictionary.bos() 11 | eos = dictionary.eos() 12 | unk = dictionary.unk() 13 | 14 | # move eos to the back 15 | N, T = prev_output_tokens.shape[:2] 16 | target_tokens = torch.cat( 17 | ( 18 | prev_output_tokens[:, 1:], 19 | prev_output_tokens.new_full((N, 1), pad) 20 | ), dim=1 21 | ) 22 | target_length = target_tokens.ne(pad).sum(1, keepdim=True) 23 | target_tokens.scatter_(1, target_length, eos) 24 | 25 | if not uniform: 26 | assert 0 <= ratio <= 1, "mask ratio invalid." 27 | if ratio == 0: 28 | return target_tokens, target_tokens.eq(pad) 29 | 30 | target_masks = ( 31 | target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) 32 | ) 33 | target_score = target_tokens.clone().float().uniform_() 34 | target_score.masked_fill_(~target_masks, 2.0) 35 | target_length = target_masks.sum(1).float() 36 | if uniform: 37 | target_length = target_length * target_length.clone().uniform_() 38 | else: 39 | target_length = target_length * target_length.clone().fill_(ratio) 40 | target_length = target_length + 1 # make sure to mask at least one token. 41 | 42 | _, target_rank = target_score.sort(1) 43 | target_cutoff = new_arange(target_rank) < target_length[:, None].long() 44 | target_tokens.masked_fill_( 45 | target_cutoff.scatter(1, target_rank, target_cutoff), unk 46 | ) 47 | 48 | return target_tokens, target_tokens.eq(pad) 49 | 50 | 51 | def generate(model, src_tokens, src_lengths, net_output=None, blank_idx=0, blank_penalty=0, collapse=True, **unused): 52 | """ 53 | lprobs is expected to be batch first. (from model forward output, or net_output) 54 | """ 55 | 56 | if net_output is None: 57 | net_output = model.forward(src_tokens, src_lengths, None) 58 | lprobs = model.get_normalized_probs( 59 | net_output, log_probs=True 60 | ) 61 | 62 | if blank_penalty > 0.0: 63 | lprobs[:, :, blank_idx] -= blank_penalty 64 | 65 | # get subsampling padding mask & lengths 66 | if net_output[1]["padding_mask"] is not None: 67 | non_padding_mask = ~net_output[1]["padding_mask"] 68 | input_lengths = non_padding_mask.long().sum(-1) 69 | else: 70 | sum_dim = 1 71 | input_lengths = lprobs.new_ones( 72 | lprobs.shape[:2], dtype=torch.long).sum(sum_dim) 73 | 74 | bsz = lprobs.size(0) 75 | 76 | # list of completed sentences 77 | finalized = torch.jit.annotate( 78 | List[List[Dict[str, Tensor]]], 79 | [torch.jit.annotate(List[Dict[str, Tensor]], []) 80 | for i in range(bsz)], 81 | ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step 82 | 83 | # TODO faster argmax before for loop? 84 | for sent, lp, inp_l in zip( 85 | range(bsz), 86 | lprobs, 87 | input_lengths, 88 | ): 89 | lp = lp[:inp_l] 90 | 91 | toks = lp.argmax(dim=-1) 92 | score = torch.index_select( 93 | lp.view(inp_l, -1), -1, toks.view(-1)).sum() 94 | if collapse: 95 | toks = toks.unique_consecutive() 96 | if toks.eq(blank_idx).all(): 97 | toks = toks[:1] 98 | else: 99 | toks = toks[toks != blank_idx] 100 | 101 | p_score = torch.zeros_like(toks).float() 102 | 103 | finalized[sent].append( 104 | { 105 | "tokens": toks, 106 | "score": score, 107 | "attention": None, # src_len x tgt_len 108 | "alignment": torch.empty(0), 109 | "positional_scores": p_score, 110 | } 111 | ) 112 | return finalized 113 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/models/nat_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, List 3 | from torch import Tensor 4 | from fairseq.utils import new_arange 5 | 6 | 7 | def inject_noise(prev_output_tokens, dictionary, ratio=-1, uniform=False): 8 | """ mask out tokens. uniform: uniform length masking """ 9 | pad = dictionary.pad() 10 | bos = dictionary.bos() 11 | eos = dictionary.eos() 12 | unk = dictionary.unk() 13 | 14 | # move eos to the back 15 | N, T = prev_output_tokens.shape[:2] 16 | target_tokens = torch.cat( 17 | ( 18 | prev_output_tokens[:, 1:], 19 | prev_output_tokens.new_full((N, 1), pad) 20 | ), dim=1 21 | ) 22 | target_length = target_tokens.ne(pad).sum(1, keepdim=True) 23 | target_tokens.scatter_(1, target_length, eos) 24 | 25 | if not uniform: 26 | assert 0 <= ratio <= 1, "mask ratio invalid." 27 | if ratio == 0: 28 | return target_tokens, target_tokens.eq(pad) 29 | 30 | target_masks = ( 31 | target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) 32 | ) 33 | target_score = target_tokens.clone().float().uniform_() 34 | target_score.masked_fill_(~target_masks, 2.0) 35 | target_length = target_masks.sum(1).float() 36 | if uniform: 37 | target_length = target_length * target_length.clone().uniform_() 38 | else: 39 | target_length = target_length * target_length.clone().fill_(ratio) 40 | target_length = target_length + 1 # make sure to mask at least one token. 41 | 42 | _, target_rank = target_score.sort(1) 43 | target_cutoff = new_arange(target_rank) < target_length[:, None].long() 44 | target_tokens.masked_fill_( 45 | target_cutoff.scatter(1, target_rank, target_cutoff), unk 46 | ) 47 | 48 | return target_tokens, target_tokens.eq(pad) 49 | 50 | 51 | def generate(model, src_tokens, src_lengths, net_output=None, blank_idx=0, collapse=True, **unused): 52 | """ 53 | lprobs is expected to be batch first. (from model forward output, or net_output) 54 | """ 55 | 56 | if net_output is None: 57 | net_output = model.forward(src_tokens, src_lengths, None) 58 | lprobs = model.get_normalized_probs( 59 | net_output, log_probs=True 60 | ) 61 | 62 | # eos_penalty = 1 63 | # if eos_penalty > 0.0: 64 | # lprobs[:, :, blank_idx] -= eos_penalty 65 | 66 | # get subsampling padding mask & lengths 67 | if net_output[1]["padding_mask"] is not None: 68 | non_padding_mask = ~net_output[1]["padding_mask"] 69 | input_lengths = non_padding_mask.long().sum(-1) 70 | else: 71 | sum_dim = 1 72 | input_lengths = lprobs.new_ones( 73 | lprobs.shape[:2], dtype=torch.long).sum(sum_dim) 74 | 75 | bsz = lprobs.size(0) 76 | 77 | # list of completed sentences 78 | finalized = torch.jit.annotate( 79 | List[List[Dict[str, Tensor]]], 80 | [torch.jit.annotate(List[Dict[str, Tensor]], []) 81 | for i in range(bsz)], 82 | ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step 83 | 84 | # TODO faster argmax before for loop? 85 | for sent, lp, inp_l in zip( 86 | range(bsz), 87 | lprobs, 88 | input_lengths, 89 | ): 90 | lp = lp[:inp_l] 91 | 92 | toks = lp.argmax(dim=-1) 93 | score = torch.index_select( 94 | lp.view(inp_l, -1), -1, toks.view(-1)).sum() 95 | if collapse: 96 | toks = toks.unique_consecutive() 97 | if toks.eq(blank_idx).all(): 98 | toks = toks[:1] 99 | else: 100 | toks = toks[toks != blank_idx] 101 | 102 | p_score = torch.zeros_like(toks).float() 103 | 104 | finalized[sent].append( 105 | { 106 | "tokens": toks, 107 | "score": score, 108 | "attention": None, # src_len x tgt_len 109 | "alignment": torch.empty(0), 110 | "positional_scores": p_score, 111 | } 112 | ) 113 | return finalized 114 | -------------------------------------------------------------------------------- /data/1-preprocess_distill.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from subprocess import Popen, PIPE, STDOUT 4 | 5 | 6 | src = 'en' 7 | tgt = 'zh' 8 | dataset = 'cwmt' 9 | base = f'/root/Mono4SiM/data/{dataset}-{src}{tgt}' 10 | root = f'/root/Mono4SiM/generate/teacher_{dataset}_mono' 11 | spm_prefix = f'{base}/prep/spm_unigram32000' 12 | 13 | 14 | def execCmd(cmd, *args): 15 | for arg in args: 16 | cmd = f"{cmd} {arg}" 17 | r = os.popen(cmd) 18 | text = r.read().rstrip() 19 | r.close() 20 | if len(text.strip()): 21 | print(text) 22 | return text 23 | 24 | 25 | def execInteractive(cmd, *args): 26 | for arg in args: 27 | cmd = f"{cmd} {arg}" 28 | p = Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) 29 | while p.poll() is None: 30 | print(p.stdout.readline().decode("utf-8"), end = "") 31 | 32 | 33 | def cddir(dir): 34 | execCmd("mkdir -p", dir) 35 | os.chdir(dir) 36 | dir = execCmd("pwd").rstrip() 37 | if not dir.endswith('/'): 38 | dir += '/' 39 | return dir 40 | 41 | 42 | def proc_interactive(root, n): 43 | print(f"Processing interactive {n}") 44 | lst = ['S', 'D', 'P'] 45 | record = defaultdict(dict) 46 | 47 | with open(f'{root}/interactive/generate-train.{n}.txt', 'r', encoding='utf-8') as f: 48 | for i in f.readlines(): 49 | if '\t' in i and i[0] in lst: 50 | num, sen = i.rstrip().split('\t', 1) 51 | record[i[0]][int(num[2:])] = sen 52 | with open(f'{root}/score/generate_ppl.txt', 'a', encoding='utf-8') as p,\ 53 | open(f'{root}/interactive/detok.{tgt}', 'a', encoding='utf-8') as d,\ 54 | open(f'{root}/interactive/detok.{src}', 'a', encoding='utf-8') as s: 55 | i = 0 56 | while i in record['S']: 57 | try: 58 | lst = [float(num) for num in record['P'][i].split()] 59 | p.write(f'{sum(lst)}\t{len(lst)}\n') 60 | d.write(record['D'][i].split('\t', 1)[1] + '\n') 61 | s.write(record['S'][i] + '\n') 62 | except: 63 | print(record['S'][i]) 64 | i += 1 65 | i += 1 66 | while i in record['S']: 67 | try: 68 | lst = [float(num) for num in record['P'][i].split()] 69 | p.write(f'{sum(lst)}\t{len(lst)}\n') 70 | d.write(record['D'][i].split('\t', 1)[1] + '\n') 71 | s.write(record['S'][i] + '\n') 72 | except: 73 | print(record['S'][i]) 74 | i += 1 75 | i += 1 76 | while i in record['S']: 77 | try: 78 | lst = [float(num) for num in record['P'][i].split()] 79 | p.write(f'{sum(lst)}\t{len(lst)}\n') 80 | d.write(record['D'][i].split('\t', 1)[1] + '\n') 81 | s.write(record['S'][i] + '\n') 82 | except: 83 | print(record['S'][i]) 84 | i += 1 85 | i += 1 86 | while i in record['S']: 87 | try: 88 | lst = [float(num) for num in record['P'][i].split()] 89 | p.write(f'{sum(lst)}\t{len(lst)}\n') 90 | d.write(record['D'][i].split('\t', 1)[1] + '\n') 91 | s.write(record['S'][i] + '\n') 92 | except: 93 | print(record['S'][i]) 94 | i += 1 95 | print(i) 96 | return i 97 | 98 | 99 | def join_file(root, subset, lang, rm=False): 100 | cmd = 'cat' 101 | for i in range(8): 102 | cmd += f' {root}/{subset}.{i}.{lang}' 103 | execCmd(cmd, f'> {root}/{subset}.{lang}') 104 | if rm: 105 | for i in range(8): 106 | execCmd(f'rm {root}/{subset}.{i}.{lang}') 107 | 108 | 109 | if __name__ == '__main__': 110 | os.chdir('/root/Mono4SiM/generate') 111 | execInteractive(f'mkdir -p {root}/score') 112 | execInteractive(f'mkdir -p {root}/ready') 113 | execInteractive(f'rm -f {root}/score/generate_ppl.txt') 114 | execInteractive(f'rm -f {root}/interactive/detok.{src}') 115 | execInteractive(f'rm -f {root}/interactive/detok.{tgt}') 116 | 117 | print("Running Tokenizer...") 118 | for i in range(8): 119 | proc_interactive(root, i) 120 | execCmd(f'wc -l < /root/Mono4SiM/data/mono-en/split/train.en.{i}') 121 | 122 | preprocess = 'false' if 'mono' in root else 'true' 123 | execCmd(f'cp {base}/ready/valid.{src} {root}/ready/valid.{src}') 124 | execCmd(f'cp {base}/ready/valid.{tgt} {root}/ready/valid.{tgt}') 125 | execCmd(f'cp {base}/ready/test.{src} {root}/ready/test.{src}') 126 | execCmd(f'cp {base}/ready/test.{tgt} {root}/ready/test.{tgt}') 127 | execInteractive( 128 | f'bash 1-preprocess_tokenizer.sh {root} {spm_prefix} {src} {tgt} {preprocess}') 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Simultaneous Machine Translation with Monolingual Data 2 | 3 | ## Setup 4 | 5 | 1. Install fairseq 6 | Stick to the specified checkout version to avoid compatibility issues. 7 | 8 | ```bash 9 | git clone https://github.com/pytorch/fairseq.git 10 | cd fairseq 11 | git checkout 8b861be 12 | python setup.py build_ext --inplace 13 | pip install . 14 | ``` 15 | 16 | 2. (Optional) [Install apex](docs/apex_installation.md) for faster mixed precision (fp16) training. 17 | 18 | 3. Install dependencies (clone in folder [utility](utility/README.md) if possible). 19 | 20 | ```bash 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | For the installation guide, see [extra_installation](extra_installation.md). 25 | 26 | ## Data Preparation 27 | 28 | All corresponding bashes are in folder data. 29 | 30 | 1. To download corresponding datasets, go to [Google Drive](https://drive.google.com/drive/folders/1HbzxBD0klgX-EugVGB36CFVdObJJ5Uk7?usp=sharing) for cleaned dataset, or run bashes begin with 0. 31 | 32 | ```bash 33 | cd data 34 | bash 0-get_data_cwmt.sh 35 | bash 0-get_en_mono.sh 36 | ``` 37 | 38 | 2. After distilling, run [1-preprocess-distill.py](data/1-preprocess-distill.py) to preprocess those data, and then run bashes beginning with 2 to calculate corresponding scores. 39 | 40 | ```bash 41 | cd data 42 | python 1-preprocess-distill.py 43 | bash 2-train_align.sh 44 | bash 2-train_kenlm.sh 45 | bash 2-fast-align.sh 46 | bash 2-k-anticipation.sh 47 | python 2-get_uncertainty.py 48 | ``` 49 | 50 | 3. Finally, run [3-scoring_preprocessing.py](data/3-scoring_preprocessing.py) to calculate the score of the distilled data and extract the data according to the metrics we propose. 51 | 52 | ```bash 53 | cd data 54 | python 3-scoring_preprocessing.py 55 | ``` 56 | 57 | **Note** that you need to change the [data path](data\data_path.sh) mannually. 58 | 59 | ## Training 60 | 61 | We need a full-sentence model as teacher for sequence-KD. 62 | 63 | The following command will train the teacher model. 64 | 65 | ```bash 66 | cd train/cwmt-enzh 67 | bash 0-teacher.sh 68 | ``` 69 | 70 | To distill the training set, run 71 | 72 | ```bash 73 | cd train/cwmt-enzh 74 | bash 0-distill_enzh_mono.sh 75 | ``` 76 | 77 | We provide our dataset including distill set and pseudo reference set for easier reproducibility. 78 | 79 | We can now train vanilla wait-k model. To do this, run 80 | 81 | ```bash 82 | bash 1b-distill_all_wait_k.sh generate/teacher_cwmt_mono/data-bin 3_anticipation_rate_low_chunking_LM_filter 83 | ``` 84 | 85 | *3_anticipation_rate_low_chunking_LM_filter* is the default name of our best strategy, change this field to run wait-k under any dataset (raw for original bilingual datasets). 86 | 87 | Our models are released at [Google Drive](https://drive.google.com/drive/folders/19aPnAPvT75KmlLA2Y0VipNJVF3cf3CaP?usp=sharing). 88 | 89 | ## Evaluation (SimulEval) 90 | 91 | Install [SimulEval](docs/extra_installation.md). 92 | 93 | ### full-sentence model 94 | 95 | ```bash 96 | cd train/cwmt-enzh 97 | bash 2-test_model_full.sh 98 | ``` 99 | 100 | ### wait-k models 101 | 102 | ```bash 103 | cd train/cwmt-enzh 104 | bash 2-test_model.sh 3_anticipation_rate_low_chunking_LM_filter 105 | ``` 106 | 107 | Change *3_anticipation_rate_low_chunking_LM_filter* to run evaluation under any dataset (raw for original bilingual datasets). 108 | 109 | or simply run: 110 | 111 | ```bash 112 | cd train 113 | python get_score.py 114 | ``` 115 | 116 | for all subsets. 117 | 118 | 119 | ## Citation 120 | If you find this work helpful, please consider citing as follows: 121 | ```bibtex 122 | @article{Deng_Ding_Liu_Zhang_Tao_Zhang_2023, 123 | title={Improving Simultaneous Machine Translation with Monolingual Data}, 124 | volume={37}, 125 | url={https://ojs.aaai.org/index.php/AAAI/article/view/26497}, 126 | DOI={10.1609/aaai.v37i11.26497}, 127 | abstractNote={Simultaneous machine translation (SiMT) is usually done via sequence-level knowledge distillation (Seq-KD) from a full-sentence neural machine translation (NMT) model. However, there is still a significant performance gap between NMT and SiMT. In this work, we propose to leverage monolingual data to improve SiMT, which trains a SiMT student on the combination of bilingual data and external monolingual data distilled by Seq-KD. Preliminary experiments on En-Zh and En-Ja news domain corpora demonstrate that monolingual data can significantly improve translation quality (e.g., +3.15 BLEU on En-Zh). Inspired by the behavior of human simultaneous interpreters, we propose a novel monolingual sampling strategy for SiMT, considering both chunk length and monotonicity. Experimental results show that our sampling strategy consistently outperforms the random sampling strategy (and other conventional typical NMT monolingual sampling strategies) by avoiding the key problem of SiMT -- hallucination, and has better scalability. We achieve +0.72 BLEU improvements on average against random sampling on En-Zh and En-Ja. Data and codes can be found at https://github.com/hexuandeng/Mono4SiMT.}, 128 | number={11}, 129 | journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 130 | author={Deng, Hexuan and Ding, Liang and Liu, Xuebo and Zhang, Meishan and Tao, Dacheng and Zhang, Min}, 131 | year={2023}, 132 | month={Jun.}, 133 | pages={12728-12736} 134 | } 135 | ``` 136 | -------------------------------------------------------------------------------- /train/test_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from collections import defaultdict 4 | from subprocess import Popen, PIPE, STDOUT 5 | 6 | def exec(cmd, *args): 7 | for arg in args: 8 | cmd = f'{cmd} {arg}' 9 | p = Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) 10 | while p.poll() is None: 11 | tmp = p.stdout.readline().decode("utf-8").rstrip() 12 | if(len(tmp.strip())): 13 | print(tmp) 14 | 15 | 16 | def get_dict(file): 17 | lst = ['S', 'T', 'D'] 18 | record = defaultdict(dict) 19 | with open(file, 'r', encoding='utf-8') as f: 20 | for i in f.readlines(): 21 | if '\t' in i and i[0] in lst: 22 | num, sen = i.rstrip().split('\t', 1) 23 | record[int(num[2:])][i[0]] = sen 24 | return record 25 | 26 | 27 | def chunking_align(align, src, reverse=False): 28 | with open(align, 'r', encoding='utf-8') as corpus,\ 29 | open(src, 'r', encoding='utf-8') as en: 30 | num_chunk, tot = 0, 1e-9 31 | num_align, tot_align = 0, 1e-9 32 | for line, sen in zip(corpus, en): 33 | line = line.strip() 34 | tot_align += float(len(sen.strip().split())) 35 | itr = re.finditer(r"(?P[0-9]+)-(?P[0-9]+)", line) 36 | left = [] 37 | right = [] 38 | for m in itr: 39 | left.append(int(m.group("j" if reverse else "i"))) 40 | right.append(int(m.group("i" if reverse else "j"))) 41 | num_align += float(len(set(left))) 42 | 43 | while len(left) != 0: 44 | pair = del_pair_mono(left, right, 0) 45 | num_chunk += 1 46 | tot += len(set(pair[0])) 47 | return tot / num_chunk, num_align / tot_align 48 | 49 | 50 | def del_pair_mono(left, right, i): 51 | l = [left[i]] 52 | r = [right[i]] 53 | del left[i], right[i] 54 | Flag = True 55 | while Flag: 56 | Flag = False 57 | for i in range(min(l), max(l) + 1): 58 | if i in left: 59 | tmp = del_pair_mono(left, right, left.index(i)) 60 | l += tmp[0] 61 | r += tmp[1] 62 | Flag = True 63 | for i in range(min(r), max(r) + 1): 64 | if i in right: 65 | tmp = del_pair_mono(left, right, right.index(i)) 66 | l += tmp[0] 67 | r += tmp[1] 68 | Flag = True 69 | 70 | return [l, r] 71 | 72 | 73 | class Aligner: 74 | def __init__(self, fwd_params, fwd_err, rev_params, rev_err, heuristic='grow-diag-final-and'): 75 | self.fast_align = '/root/Mono4SiM/utility/fast_align/build/fast_align' 76 | self.atools = '/root/Mono4SiM/utility/fast_align/build/atools' 77 | self.fwd_params = fwd_params 78 | self.rev_params = rev_params 79 | self.heuristic = heuristic 80 | (self.fwd_T, self.fwd_m) = self.read_err(fwd_err) 81 | (self.rev_T, self.rev_m) = self.read_err(rev_err) 82 | 83 | def align(self, prefix): 84 | dict = get_dict(f'{prefix}/generate-test.txt') 85 | with open(f'{prefix}/tmp.en', 'w', encoding='utf-8') as en,\ 86 | open(f'{prefix}/tmp.zh', 'w', encoding='utf-8') as zh: 87 | for j in range(3003): 88 | en.write(dict[j]['S'] + '\n') 89 | zh.write(dict[j]['D'].split('\t')[-1] + '\n') 90 | exec('bash /root/Mono4SiM/train/cwmt-enzh/2s-encode_test.sh', prefix, 'tmp') 91 | exec(self.fast_align, '-i', f'{prefix}/align_process', '-d', '-v', '-o', 92 | '-T', self.fwd_T, '-m', self.fwd_m, '-f', self.fwd_params, f'> {prefix}/tmp.fwd') 93 | with open(f'{prefix}/tmp.fwd', 'r', encoding='utf-8') as r,\ 94 | open(f'{prefix}/fwd', 'w', encoding='utf-8') as w: 95 | for i in r: 96 | w.write(i.split('|||')[2].strip()+'\n') 97 | exec(self.fast_align, '-i', f'{prefix}/align_process', '-d', '-v', '-o', 98 | '-T', self.rev_T, '-m', self.rev_m, '-f', self.rev_params, f'-r > {prefix}/tmp.rev') 99 | with open(f'{prefix}/tmp.rev', 'r', encoding='utf-8') as r,\ 100 | open(f'{prefix}/rev', 'w', encoding='utf-8') as w: 101 | for i in r: 102 | w.write(i.split('|||')[2].strip()+'\n') 103 | exec(self.atools, '-i', f'{prefix}/fwd', '-j', 104 | f'{prefix}/rev', '-c', self.heuristic, f'> {prefix}/align_process') 105 | exec('/root/Mono4SiM/utility/scripts/run_aligner.sh', f'{prefix}/align_process > {prefix}/align_scores.txt') 106 | with open(f'{prefix}/align_scores.txt', 'a', encoding='utf-8') as f: 107 | score1, score2 = chunking_align(f'{prefix}/align_process', f'{prefix}/tmp.en') 108 | f.write(f'\nChunk Avg Lengths:\n{score2}\t{score1}\n') 109 | 110 | exec(f'rm {prefix}/fwd') 111 | exec(f'rm {prefix}/rev') 112 | exec(f'rm {prefix}/tmp.fwd') 113 | exec(f'rm {prefix}/tmp.rev') 114 | exec(f'rm {prefix}/align_process') 115 | 116 | def read_err(self, err): 117 | (T, m) = ('', '') 118 | for line in open(err): 119 | # expected target length = source length * N 120 | if 'expected target length' in line: 121 | m = line.split()[-1] 122 | # final tension: N 123 | elif 'final tension' in line: 124 | T = line.split()[-1] 125 | return (T, m) 126 | -------------------------------------------------------------------------------- /extra_installation.md: -------------------------------------------------------------------------------- 1 | # Extra Installation 2 | 3 | ## Install PyTorch versions 4 | 5 | First install pytorch binaries with specific version of cuda. We will use `pytorch 1.8.1 + CUDA 11.1` as our example. 6 | 7 | ``` 8 | torch 1.8.X+cu102/cu111 ==> CUDA 10.2 / 11.1 9 | torch 1.7.X+cu92/cu101/cu102/cu110 ==> CUDA 9.2 / 10.1 / 10.2 / 11.0 10 | torch 1.6.X+cu92/cu101/cu102 ==> CUDA 9.2 / 10.1 / 10.2 11 | ``` 12 | 13 | ## Install matching nvcc compiler 14 | 15 | Apex will ask for same nvcc compiler version as that used to compile pytorch binary. If your system's CUDA Toolkit is of a different version than that of your pytorch binary, you need to install a matching one. 16 | 17 | ### Check NVCC Compatibility 18 | 19 | Check your nvcc version (if any) by running 20 | 21 | ```bash 22 | nvcc -V 23 | ``` 24 | 25 | If this version is the same as your pytorch cuda version (in our case 11.1), then you can skip to [Install apex](#install-apex). 26 | 27 | ### Download Legacy CUDA Toolkit 28 | 29 | Go to [CUDA Archive](https://developer.nvidia.com/cuda-toolkit-archive) to download the specific version of CUDA used to compile your installed pytorch binary. For example, download the runfile installation for Ubuntu 20.04 by 30 | 31 | ```bash 32 | wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run 33 | ``` 34 | 35 | ### User install 36 | 37 | We only need the `nvcc` compiler so we will use user installation which does not require root. 38 | 39 | 1. Run the installation as non-root user 40 | 41 | ```bash 42 | bash cuda_11.1.1_455.32.00_linux.run 43 | ``` 44 | 45 | 2. De-select **Driver installation, Samples, Demo** and **Documentation**, as we don't need them. 46 | 47 | ``` 48 | ┌──────────────────────────────────────────────────────────────────────────────┐ 49 | │ CUDA Installer │ 50 | │ - [ ] Driver < │ 51 | │ [ ] 455.32.00 │ 52 | │ + [X] CUDA Toolkit 11.1 │ 53 | │ [ ] CUDA Samples 11.1 < │ 54 | │ [ ] CUDA Demo Suite 11.1 < │ 55 | │ [ ] CUDA Documentation 11.1 < │ 56 | │ Options │ 57 | │ Install │ 58 | ``` 59 | 60 | 2. Set **Options --> Library install** path to a non-root path 61 | 62 | ``` 63 | ┌──────────────────────────────────────────────────────────────────────────────┐ 64 | │ Options │ 65 | │ Driver Options │ 66 | │ Toolkit Options │ 67 | │ Samples Options │ 68 | │ Library install path (Blank for system default) < │ 69 | │ Done │ 70 | ``` 71 | 72 | 3. Set **Options --> Toolkit** options 1) set the **install path** to non-root path 2) de-select the **symbolic link, shortcuts** and **manpage documents**. 73 | 74 | ``` 75 | ┌──────────────────────────────────────────────────────────────────────────────┐ 76 | │ CUDA Toolkit │ 77 | │ Change Toolkit Install Path < │ 78 | │ [ ] Create symbolic link from /usr/local/cuda < │ 79 | │ - [ ] Create desktop menu shortcuts < │ 80 | │ [ ] Yes │ 81 | │ [ ] No │ 82 | │ [ ] Install manpage documents to /usr/share/man < │ 83 | │ Done │ 84 | ``` 85 | 86 | 4. Install. You should get the following 87 | 88 | ``` 89 | =========== 90 | = Summary = 91 | =========== 92 | 93 | Driver: Not Selected 94 | Toolkit: Installed in 95 | Samples: Not Selected 96 | ``` 97 | 98 | ## Install apex 99 | 100 | Run the following command to install apex 101 | 102 | ## SimulEval 103 | 104 | We updated SimulEval with two functionalities: 105 | 106 | 1. To evaluate computational aware (CA) latency metrics for text. 107 | 2. Save actual system predictions to a file, so that multi-reference BLEU can be calculated. (we found that `instances.log` forcefully add whitespaces, which is undesired for Chinese.) 108 | 109 | ```bash 110 | git clone XXXX-1 111 | ``` 112 | 113 | Alternatively, you can use the official repository if you're skeptical of our modifications. Though you need to extract predictions manually and the result for Chinese might be inaccurate. 114 | 115 | ```bash 116 | git clone https://github.com/facebookresearch/SimulEval.git 117 | ``` 118 | 119 | You need to add the following lines to the class `TextInstance` in `SimulEval/simuleval/scorer/instance.py` in order to obtain computational aware (CA) latency metrics: 120 | 121 | ```python 122 | # class TextInstance(Instance): 123 | # add following function to TextInstance 124 | def sentence_level_eval(self): 125 | super().sentence_level_eval() 126 | # For the computation-aware latency 127 | self.metrics["latency_ca"] = eval_all_latency( 128 | self.elapsed, self.source_length(), self.reference_length() + 1) 129 | ``` 130 | 131 | Regardless of which approach you use, proceed to install the package via pip: 132 | 133 | ```bash 134 | cd SimulEval 135 | pip install -e . 136 | ``` 137 | 138 | ## SacreBLEU 139 | 140 | To evaluate Translation Edit Rate (TER) or enable bootstrap resampling, we need to use SacreBLEU v2.0.0. However, version 2 currently **breaks compatibility** with the version of fairseq that we use. The solution is to use python venv to create an environment only for evaluation: 141 | 142 | ```bash 143 | python -m venv ~/envs/sacrebleu2 144 | ``` 145 | 146 | Activate it by: 147 | 148 | ```bash 149 | source ~/envs/sacrebleu2/bin/activate 150 | ``` 151 | 152 | Install sacrebleu version 2 153 | 154 | ```bash 155 | git clone https://github.com/mjpost/sacrebleu.git 156 | cd sacrebleu 157 | pip install . 158 | ``` 159 | 160 | Then you can use sacrebleu v2, without breaking fairseq. 161 | -------------------------------------------------------------------------------- /data/0-get_data_cwmt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Adapted from https://github.com/pytorch/fairseq/blob/simulastsharedtask/examples/translation/prepare-iwslt14.sh 3 | source ./data_path.sh 4 | SCRIPTS=/root/Mono4SiM/utility/mosesdecoder/scripts 5 | # source ~/envs/apex/bin/activate 6 | 7 | vocab=32000 8 | vtype=unigram 9 | workers=4 10 | 11 | # TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 12 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 13 | # NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 14 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 15 | LC=$SCRIPTS/tokenizer/lowercase.perl 16 | 17 | spm_train=$FAIRSEQ/scripts/spm_train.py 18 | spm_encode=$FAIRSEQ/scripts/spm_encode.py 19 | 20 | CORPORA=( 21 | "cwmt/parallel/casia2015/casia2015/casia2015" 22 | "cwmt/parallel/casict2011/casict2011/casict-A" 23 | "cwmt/parallel/casict2011/casict2011/casict-B" 24 | "cwmt/parallel/casict2015/casict2015/casict2015" 25 | "cwmt/parallel/neu2017/neu2017/NEU" 26 | ) 27 | 28 | orig=${BASE}/orig 29 | prep=${BASE}/prep 30 | ready=${BASE}/ready 31 | bin=${BASE}/data-bin 32 | mkdir -p $orig $prep $ready $bin 33 | 34 | echo "downloading data" 35 | cd $orig 36 | 37 | file=cwmt-data.zip 38 | if [ -f $file ]; then 39 | echo "$file already exists, skipping download" 40 | else 41 | kaggle datasets download -d warmth/cwmt-data 42 | if [ -f $file ]; then 43 | echo "$file successfully downloaded." 44 | else 45 | echo "$file not successfully downloaded." 46 | exit -1 47 | fi 48 | unzip $file 49 | fi 50 | cd .. 51 | 52 | echo "pre-processing train data..." 53 | for l in ${SRC} ${TGT}; do 54 | rm -f $prep/train.dirty.$l 55 | for f in "${CORPORA[@]}"; do 56 | if [ "$l" == "zh" ]; then 57 | if [[ "$f" == *"NEU"* ]]; then 58 | t="_cn.txt" 59 | else 60 | t="_ch.txt" 61 | fi 62 | else 63 | t="_en.txt" 64 | fi 65 | 66 | echo "precprocess train $f$t" 67 | cat $orig/$f$t | \ 68 | perl $REM_NON_PRINT_CHAR | \ 69 | perl $LC >> $prep/train.dirty.$l 70 | done 71 | done 72 | 73 | echo "pre-processing valid data..." 74 | for l in ${SRC} ${TGT}; do 75 | if [ "$l" == "zh" ]; then 76 | DEV=( 77 | "$orig/cwmt/dev/NJU-newsdev2018-zhen/NJU-newsdev2018-zhen/CWMT2017-ce-news-test-src.xml" 78 | "$orig/cwmt/dev/NJU-newsdev2018-enzh/NJU-newsdev2018-enzh/CWMT2017-ec-news-test-ref.xml" 79 | ) 80 | else 81 | DEV=( 82 | "$orig/cwmt/dev/NJU-newsdev2018-zhen/NJU-newsdev2018-zhen/CWMT2017-ce-news-test-ref.xml" 83 | "$orig/cwmt/dev/NJU-newsdev2018-enzh/NJU-newsdev2018-enzh/CWMT2017-ec-news-test-src.xml" 84 | ) 85 | fi 86 | 87 | rm -f $prep/valid.dirty.$l 88 | for f in "${DEV[@]}"; do 89 | echo "precprocess valid $f" 90 | grep '\s*//g' | \ 92 | sed -e 's/\s*<\/seg>\s*//g' | \ 93 | sed -e "s/\’/\'/g" | \ 94 | perl $REM_NON_PRINT_CHAR | \ 95 | perl $LC >> $prep/valid.dirty.$l 96 | done 97 | done 98 | 99 | # testset en -> zh 100 | rm -f $prep/test.* 101 | for y in 2008 2009 2011; do 102 | tail +2 $orig/cwmt${y}_ec_news.tsv | cut -f6 | \ 103 | perl $REM_NON_PRINT_CHAR | \ 104 | perl $LC >> $prep/test.en-zh.en 105 | for c in 1 2 3; do 106 | tail +2 $orig/cwmt${y}_ec_news.tsv | cut -f$(($c+6)) | \ 107 | perl $REM_NON_PRINT_CHAR | \ 108 | perl $LC >> $prep/test.en-zh.zh.$c 109 | done 110 | done 111 | 112 | # zh -> en 113 | for y in 2008 2009; do 114 | tail +2 $orig/cwmt${y}_ce_news.tsv | cut -f6 | \ 115 | perl $REM_NON_PRINT_CHAR | \ 116 | perl $LC > $prep/test.zh-en.zh 117 | for c in 1 2 3; do 118 | tail +2 $orig/cwmt${y}_ce_news.tsv | cut -f$(($c+6)) | \ 119 | perl $REM_NON_PRINT_CHAR | \ 120 | perl $LC > $prep/test.zh-en.en.$c 121 | done 122 | done 123 | 124 | cp $prep/test.${SRC}-${TGT}.${SRC} $prep/test.${SRC} 125 | cp $prep/test.${SRC}-${TGT}.${TGT}.1 $prep/test.${TGT} 126 | 127 | # filter empty pairs 128 | perl $CLEAN -ratio 1000 $prep/train.dirty ${SRC} ${TGT} $prep/train 1 10000 129 | perl $CLEAN -ratio 1000 $prep/valid.dirty ${SRC} ${TGT} $prep/valid 1 10000 130 | 131 | # SPM 132 | SPM_PREFIX=$prep/spm_${vtype}${vocab} 133 | for l in ${SRC} ${TGT}; do 134 | SPM_MODEL=${SPM_PREFIX}_${l}.model 135 | DICT=${SPM_PREFIX}_${l}.txt 136 | BPE_TRAIN=$prep/bpe-train.$l 137 | 138 | if [[ ! -f $SPM_MODEL ]]; then 139 | if [ -f $BPE_TRAIN ]; then 140 | echo "$BPE_TRAIN found, skipping concat." 141 | else 142 | train=$prep/train.$l 143 | default=1000000 144 | total=$(cat $train | wc -l) 145 | echo "lang $l total: $total." 146 | if [ "$total" -gt "$default" ]; then 147 | cat $train | \ 148 | shuf -r -n $default >> $BPE_TRAIN 149 | else 150 | cat $train >> $BPE_TRAIN 151 | fi 152 | fi 153 | 154 | echo "spm_train on $BPE_TRAIN..." 155 | ccvg=1.0 156 | if [[ ${l} == "zh" ]]; then 157 | ccvg=0.9995 158 | fi 159 | python $spm_train --input=$BPE_TRAIN \ 160 | --model_prefix=${SPM_PREFIX}_${l} \ 161 | --vocab_size=$vocab \ 162 | --character_coverage=$ccvg \ 163 | --model_type=$vtype \ 164 | --normalization_rule_name=nmt_nfkc_cf 165 | 166 | cut -f1 ${SPM_PREFIX}_${l}.vocab | tail -n +4 | sed "s/$/ 100/g" > $DICT 167 | cp $SPM_MODEL $bin/$(basename $SPM_MODEL) 168 | cp $DICT $bin/$(basename $DICT) 169 | fi 170 | 171 | echo "Using SPM model $SPM_MODEL" 172 | for split in train valid test; do 173 | if [ -f $ready/$split.$l ]; then 174 | echo "found $ready/$split.$l, skipping spm_encode" 175 | else 176 | echo "spm_encode to $split.$l..." 177 | python $spm_encode --model=$SPM_MODEL \ 178 | --output_format=piece \ 179 | < $prep/$split.$l > $ready/$split.$l 180 | fi 181 | done 182 | done 183 | 184 | # filter ratio and maxlen < 1024 185 | perl $CLEAN -ratio 9 $ready/train ${SRC} ${TGT} $ready/train.clean 1 1024 186 | 187 | python -m fairseq_cli.preprocess \ 188 | --source-lang ${SRC} \ 189 | --target-lang ${TGT} \ 190 | --trainpref ${ready}/train.clean \ 191 | --validpref ${ready}/valid \ 192 | --testpref ${ready}/test \ 193 | --destdir ${bin} \ 194 | --workers ${workers} \ 195 | --srcdict ${SPM_PREFIX}_${SRC}.txt \ 196 | --tgtdict ${SPM_PREFIX}_${TGT}.txt 197 | -------------------------------------------------------------------------------- /utility/scripts/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import collections 9 | import os 10 | import re 11 | 12 | import torch 13 | from fairseq.file_io import PathManager 14 | 15 | 16 | def average_checkpoints(inputs): 17 | """Loads checkpoints from inputs and returns a model with averaged weights. 18 | 19 | Args: 20 | inputs: An iterable of string paths of checkpoints to load from. 21 | 22 | Returns: 23 | A dict of string keys mapping to various values. The 'model' key 24 | from the returned dict should correspond to an OrderedDict mapping 25 | string parameter names to torch Tensors. 26 | """ 27 | params_dict = collections.OrderedDict() 28 | params_keys = None 29 | new_state = None 30 | num_models = len(inputs) 31 | 32 | for fpath in inputs: 33 | with PathManager.open(fpath, "rb") as f: 34 | state = torch.load( 35 | f, 36 | map_location=( 37 | lambda s, _: torch.serialization.default_restore_location(s, "cpu") 38 | ), 39 | ) 40 | # Copies over the settings from the first checkpoint 41 | if new_state is None: 42 | new_state = state 43 | 44 | model_params = state["model"] 45 | 46 | model_params_keys = list(model_params.keys()) 47 | if params_keys is None: 48 | params_keys = model_params_keys 49 | elif params_keys != model_params_keys: 50 | raise KeyError( 51 | "For checkpoint {}, expected list of params: {}, " 52 | "but found: {}".format(f, params_keys, model_params_keys) 53 | ) 54 | 55 | for k in params_keys: 56 | p = model_params[k] 57 | if isinstance(p, torch.HalfTensor): 58 | p = p.float() 59 | if k not in params_dict: 60 | params_dict[k] = p.clone() 61 | # NOTE: clone() is needed in case of p is a shared parameter 62 | else: 63 | params_dict[k] += p 64 | 65 | averaged_params = collections.OrderedDict() 66 | for k, v in params_dict.items(): 67 | averaged_params[k] = v 68 | if averaged_params[k].is_floating_point(): 69 | averaged_params[k].div_(num_models) 70 | else: 71 | averaged_params[k] //= num_models 72 | new_state["model"] = averaged_params 73 | return new_state 74 | 75 | 76 | def last_n_checkpoints(paths, n, 77 | best_based, update_based, upper_bound=None): 78 | assert len(paths) == 1 79 | path = paths[0] 80 | if best_based: 81 | pt_regexp = re.compile(r"checkpoint.best_[a-z]+_([+-]?(\d*[.])?\d+)\.pt") 82 | elif update_based: 83 | pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt") 84 | else: 85 | pt_regexp = re.compile(r"checkpoint(\d+)\.pt") 86 | files = PathManager.ls(path) 87 | 88 | entries = [] 89 | for f in files: 90 | m = pt_regexp.fullmatch(f) 91 | if m is not None: 92 | sort_key = float(m.group(1)) 93 | if upper_bound is None or sort_key <= upper_bound: 94 | entries.append((sort_key, m.group(0))) 95 | if len(entries) < n: 96 | raise Exception( 97 | "Found {} checkpoint files but need at least {}", len(entries), n 98 | ) 99 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 100 | 101 | 102 | def main(): 103 | parser = argparse.ArgumentParser( 104 | description="Tool to average the params of input checkpoints to " 105 | "produce a new checkpoint", 106 | ) 107 | # fmt: off 108 | parser.add_argument('--inputs', required=True, nargs='+', 109 | help='Input checkpoint file paths.') 110 | parser.add_argument('--output', required=True, metavar='FILE', 111 | help='Write the new checkpoint containing the averaged weights to this path.') 112 | num_group = parser.add_mutually_exclusive_group() 113 | num_group.add_argument('--num-epoch-checkpoints', type=int, 114 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 115 | 'and average last this many of them.') 116 | num_group.add_argument('--num-update-checkpoints', type=int, 117 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 118 | 'and average last this many of them.') 119 | num_group.add_argument('--num-best-checkpoints', type=int, 120 | help='if set, will try to find checkpoints with names checkpoint_best_xx.pt in the path specified by input, ' 121 | 'and average last this many of them.') 122 | parser.add_argument('--checkpoint-upper-bound', type=int, 123 | help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' 124 | 'when using --num-update-checkpoints, this will set an upper bound on which update to use' 125 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.' 126 | 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500' 127 | ) 128 | # fmt: on 129 | args = parser.parse_args() 130 | print(args) 131 | 132 | num = None 133 | is_update_based = False 134 | is_best_based = False 135 | if args.num_best_checkpoints is not None: 136 | num = args.num_best_checkpoints 137 | is_best_based = True 138 | elif args.num_update_checkpoints is not None: 139 | num = args.num_update_checkpoints 140 | is_update_based = True 141 | elif args.num_epoch_checkpoints is not None: 142 | num = args.num_epoch_checkpoints 143 | 144 | assert args.checkpoint_upper_bound is None or ( 145 | args.num_epoch_checkpoints is not None 146 | or args.num_update_checkpoints is not None 147 | ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints" 148 | assert ( 149 | args.num_epoch_checkpoints is None or args.num_update_checkpoints is None 150 | ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints" 151 | 152 | if num is not None: 153 | args.inputs = last_n_checkpoints( 154 | args.inputs, 155 | num, 156 | is_best_based, 157 | is_update_based, 158 | upper_bound=args.checkpoint_upper_bound, 159 | ) 160 | print("averaging checkpoints: ", args.inputs) 161 | 162 | new_state = average_checkpoints(args.inputs) 163 | with PathManager.open(args.output, "wb") as f: 164 | torch.save(new_state, f) 165 | print("Finished writing averaged checkpoint to {}".format(args.output)) 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/models/sinkhorn_waitk.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from fairseq import checkpoint_utils 9 | 10 | from fairseq.models import ( 11 | register_model, 12 | register_model_architecture 13 | ) 14 | from fairseq.models.transformer import ( 15 | base_architecture, 16 | DEFAULT_MAX_SOURCE_POSITIONS, 17 | DEFAULT_MAX_TARGET_POSITIONS 18 | ) 19 | from fairseq.modules.transformer_sentence_encoder import init_bert_params 20 | # user 21 | from .waitk_transformer import ( 22 | CausalTransformerEncoder, 23 | WaitkTransformerModel, 24 | ) 25 | from .sinkhorn_encoder import ASNAugmentedEncoder 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | @register_model("sinkhorn_waitk") 31 | class SinkhornWaitkTransformerModel(WaitkTransformerModel): 32 | """ 33 | Waitk transformer with sinkhorn encoder 34 | """ 35 | 36 | @staticmethod 37 | def add_args(parser): 38 | """Add model-specific arguments to the parser.""" 39 | super(SinkhornWaitkTransformerModel, 40 | SinkhornWaitkTransformerModel).add_args(parser) 41 | parser.add_argument( 42 | "--non-causal-layers", 43 | type=int, 44 | help=( 45 | 'number of layers for non-causal encoder.' 46 | ), 47 | ) 48 | parser.add_argument( 49 | '--sinkhorn-tau', 50 | type=float, 51 | required=True, 52 | help='temperature for gumbel sinkhorn.' 53 | ) 54 | parser.add_argument( 55 | "--sinkhorn-iters", 56 | type=int, 57 | required=True, 58 | help=( 59 | 'iters of sinkhorn normalization to perform.' 60 | ), 61 | ) 62 | parser.add_argument( 63 | "--sinkhorn-noise-factor", 64 | type=float, 65 | required=True, 66 | help=( 67 | 'represents how many gumbel randomness in training.' 68 | ), 69 | ) 70 | parser.add_argument( 71 | "--sinkhorn-bucket-size", 72 | type=int, 73 | required=True, 74 | help=( 75 | 'number of elements to group before performing sinkhorn sorting.' 76 | ), 77 | ) 78 | parser.add_argument( 79 | "--sinkhorn-energy", 80 | type=str, 81 | required=True, 82 | choices=["dot", "cos", "l2"], 83 | help=( 84 | 'type of energy function to use to calculate attention. available: dot, cos, L2' 85 | ), 86 | ) 87 | parser.add_argument( 88 | "--mask-ratio", 89 | required=True, 90 | type=float, 91 | help=( 92 | 'ratio of target tokens to mask when feeding to sorting network.' 93 | ), 94 | ) 95 | parser.add_argument( 96 | "--mask-uniform", 97 | action="store_true", 98 | default=False, 99 | help=( 100 | 'ratio of target tokens to mask when feeding to aligner.' 101 | ), 102 | ) 103 | 104 | @classmethod 105 | def build_encoder(cls, args, src_dict, tgt_dict, encoder_embed_tokens, decoder_embed_tokens): 106 | encoder = CausalTransformerEncoder( 107 | args, src_dict, encoder_embed_tokens) 108 | encoder.apply(init_bert_params) 109 | if getattr(args, "load_pretrained_encoder_from", None): 110 | encoder = checkpoint_utils.load_pretrained_component_from_model( 111 | component=encoder, checkpoint=args.load_pretrained_encoder_from 112 | ) 113 | logger.info( 114 | f"loaded pretrained encoder from: " 115 | f"{args.load_pretrained_encoder_from}" 116 | ) 117 | cascade = ASNAugmentedEncoderSlice( 118 | args, encoder, tgt_dict, decoder_embed_tokens) 119 | return cascade 120 | 121 | @classmethod 122 | def build_model(cls, args, task): 123 | """Build a new model instance.""" 124 | 125 | # make sure all arguments are present in older models 126 | base_architecture(args) 127 | 128 | if args.encoder_layers_to_keep: 129 | args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) 130 | if args.decoder_layers_to_keep: 131 | args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) 132 | 133 | if getattr(args, "max_source_positions", None) is None: 134 | args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS 135 | if getattr(args, "max_target_positions", None) is None: 136 | args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS 137 | 138 | src_dict, tgt_dict = task.source_dictionary, task.target_dictionary 139 | 140 | if args.share_all_embeddings: 141 | if src_dict != tgt_dict: 142 | raise ValueError( 143 | "--share-all-embeddings requires a joined dictionary") 144 | if args.encoder_embed_dim != args.decoder_embed_dim: 145 | raise ValueError( 146 | "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" 147 | ) 148 | if args.decoder_embed_path and ( 149 | args.decoder_embed_path != args.encoder_embed_path 150 | ): 151 | raise ValueError( 152 | "--share-all-embeddings not compatible with --decoder-embed-path" 153 | ) 154 | encoder_embed_tokens = cls.build_embedding( 155 | args, src_dict, args.encoder_embed_dim, args.encoder_embed_path 156 | ) 157 | decoder_embed_tokens = encoder_embed_tokens 158 | args.share_decoder_input_output_embed = True 159 | else: 160 | encoder_embed_tokens = cls.build_embedding( 161 | args, src_dict, args.encoder_embed_dim, args.encoder_embed_path 162 | ) 163 | decoder_embed_tokens = cls.build_embedding( 164 | args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path 165 | ) 166 | if getattr(args, "offload_activations", False): 167 | args.checkpoint_activations = True # offloading implies checkpointing 168 | encoder = cls.build_encoder( 169 | args, src_dict, tgt_dict, encoder_embed_tokens, decoder_embed_tokens) 170 | decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) 171 | 172 | return cls(args, encoder, decoder) 173 | 174 | def forward(self, src_tokens, src_lengths, prev_output_tokens): 175 | """ changed encoder forward to forward_train """ 176 | encoder_out = self.encoder.forward_train( 177 | src_tokens=src_tokens, 178 | src_lengths=src_lengths, 179 | prev_output_tokens=prev_output_tokens, 180 | ) 181 | x, extra = self.decoder( 182 | prev_output_tokens=prev_output_tokens, 183 | encoder_out=encoder_out, 184 | features_only=True, 185 | ) 186 | extra["decoder_states"] = x 187 | extra["attn"] = encoder_out["attn"] 188 | extra["log_alpha"] = encoder_out["log_alpha"] 189 | logits = self.decoder.output_projection(x) 190 | return logits, extra 191 | 192 | 193 | class ASNAugmentedEncoderSlice(ASNAugmentedEncoder): 194 | def slice_encoder_out(self, encoder_out, context_size): 195 | return self.causal_encoder.slice_encoder_out(encoder_out, context_size) 196 | 197 | 198 | @register_model_architecture( 199 | "sinkhorn_waitk", "sinkhorn_waitk" 200 | ) 201 | def sinkhorn_waitk(args): 202 | args.waitk = getattr(args, 'waitk', 60000) # default is wait-until-end 203 | 204 | args.max_source_positions = getattr(args, "max_source_positions", 1024) 205 | args.max_target_positions = getattr(args, "max_target_positions", 1024) 206 | 207 | args.decoder_layers = getattr(args, "decoder_layers", 1) 208 | args.non_causal_layers = getattr(args, "non_causal_layers", 3) 209 | args.dropout = getattr(args, "dropout", 0.1) 210 | 211 | args.share_decoder_input_output_embed = True 212 | args.upsample_ratio = 1 213 | args.delay = 1 214 | base_architecture(args) 215 | 216 | 217 | @register_model_architecture( 218 | "sinkhorn_waitk", "sinkhorn_waitk_small" 219 | ) 220 | def sinkhorn_waitk_small(args): 221 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) 222 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) 223 | args.encoder_layers = getattr(args, "encoder_layers", 5) 224 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) 225 | args.encoder_normalize_before = getattr( 226 | args, "encoder_normalize_before", True) 227 | 228 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) 229 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) 230 | args.decoder_layers = getattr(args, "decoder_layers", 1) 231 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) 232 | args.decoder_normalize_before = getattr( 233 | args, "decoder_normalize_before", True) 234 | 235 | args.non_causal_layers = getattr(args, "non_causal_layers", 3) 236 | 237 | sinkhorn_waitk(args) 238 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/criterion/label_smoothed_ctc_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributions.categorical import Categorical 10 | from typing import Optional 11 | from fairseq import utils, metrics 12 | from fairseq.criterions import ( 13 | register_criterion, 14 | ) 15 | from fairseq.criterions.label_smoothed_cross_entropy import ( 16 | LabelSmoothedCrossEntropyCriterionConfig, 17 | LabelSmoothedCrossEntropyCriterion 18 | ) 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def calc_recall_precision(predict, target, pad_idx=1, eps=1e-8): 25 | N, S = predict.size() 26 | N, T = target.size() 27 | 28 | uniq, inverse = torch.unique( 29 | torch.cat((predict, target), dim=1), 30 | return_inverse=True, 31 | ) 32 | src = target.new_ones(1) 33 | 34 | def collect(tokens): 35 | return tokens.new_zeros( 36 | (N, uniq.size(-1)) 37 | ).scatter_add_(1, tokens, src.expand_as(tokens)) 38 | 39 | pred_words = collect(inverse[:, :S]) 40 | target_words = collect(inverse[:, S:]) 41 | 42 | match = torch.min(target_words, pred_words).sum(-1) 43 | recall = match / (target.ne(pad_idx).sum(-1) + eps) 44 | precision = match / (predict.ne(pad_idx).sum(-1) + eps) 45 | return recall.sum(), precision.sum() 46 | 47 | 48 | @dataclass 49 | class LabelSmoothedCTCCriterionConfig(LabelSmoothedCrossEntropyCriterionConfig): 50 | decoder_use_ctc: bool = field( 51 | default=False, 52 | metadata={"help": "use ctcloss for decoder loss."}, 53 | ) 54 | zero_infinity: Optional[bool] = field( 55 | default=True, 56 | metadata={"help": "zero inf loss when source length <= target length"}, 57 | ) 58 | report_sinkhorn_dist: bool = field( 59 | default=False, 60 | metadata={"help": "print sinkhorn distance value."}, 61 | ) 62 | eos_loss: bool = field( 63 | default=False, 64 | metadata={"help": "calculate loss for eos token. default is to treat eos as pad."}, 65 | ) 66 | 67 | 68 | @register_criterion( 69 | "label_smoothed_ctc", dataclass=LabelSmoothedCTCCriterionConfig 70 | ) 71 | class LabelSmoothedCTCCriterion(LabelSmoothedCrossEntropyCriterion): 72 | def __init__(self, cfg, task): 73 | super().__init__( 74 | task, 75 | cfg.sentence_avg, 76 | cfg.label_smoothing, 77 | ignore_prefix_size=cfg.ignore_prefix_size, 78 | report_accuracy=cfg.report_accuracy 79 | ) 80 | self.decoder_use_ctc = cfg.decoder_use_ctc 81 | if self.decoder_use_ctc: 82 | logger.info("Using ctc loss for decoder!") 83 | 84 | self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0 85 | self.pad_idx = task.target_dictionary.pad() 86 | self.eos_idx = task.target_dictionary.eos() 87 | self.zero_infinity = cfg.zero_infinity 88 | self.report_sinkhorn_dist = cfg.report_sinkhorn_dist 89 | self.eos_loss = cfg.eos_loss 90 | 91 | def forward(self, model, sample, reduce=True): 92 | net_output = model(**sample["net_input"]) 93 | if self.decoder_use_ctc: 94 | loss, nll_loss = self.compute_ctc_loss(model, net_output, sample["target"], reduce=reduce) 95 | else: 96 | # original label smoothed xentropy loss by fairseq 97 | loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 98 | 99 | if self.report_sinkhorn_dist: 100 | 101 | with torch.no_grad(): 102 | attn = net_output[1]["attn"][0].float() 103 | cost = -net_output[1]["log_alpha"][0].float() 104 | 105 | B, S, denom = attn.size() 106 | dist = (cost * attn).mean() * B * S 107 | 108 | # compute inversion rate 109 | # expected value of position in source that aligns to each target 110 | alignment = (utils.new_arange(attn) * attn).sum(-1) # (N, L1) 111 | inv_rate = alignment[:, :-1] - alignment[:, 1:] 112 | inv_rate = (inv_rate / denom).clamp(min=0).float().sum() 113 | 114 | try: 115 | entropy = Categorical(probs=attn.float()).entropy().sum() / denom 116 | except ValueError: 117 | logger.warning("entropy calculation failed because of invalid input!") 118 | entropy = 0 119 | 120 | else: 121 | dist = inv_rate = entropy = 0 122 | 123 | if self.report_accuracy: 124 | encoder_out = net_output[1]["encoder_out"] 125 | encoder_states = encoder_out["causal_out"][0] \ 126 | if "causal_out" in encoder_out else encoder_out["encoder_out"][0] 127 | with torch.no_grad(): 128 | x = encoder_states 129 | logits = model.output_layer(x.permute(1, 0, 2)) 130 | y_pred = logits.argmax(-1) 131 | recall, precision = calc_recall_precision( 132 | y_pred, sample["target"], pad_idx=self.pad_idx) 133 | blank_rate = y_pred.eq(self.blank_idx).float().mean(-1).sum() 134 | else: 135 | recall = 0 136 | precision = 0 137 | blank_rate = 0 138 | 139 | sample_size = ( 140 | sample["target"].size(0) if self.sentence_avg else sample["ntokens"] 141 | ) 142 | logging_output = { 143 | "loss": loss.data, 144 | "nll_loss": nll_loss.data, 145 | "ntokens": sample["ntokens"], 146 | "nsentences": sample["target"].size(0), 147 | "sample_size": sample_size, 148 | "sinkhorn_dist": dist, 149 | "inv_rate": inv_rate, 150 | "matching_entropy": entropy, 151 | 152 | "recall": recall, 153 | "precision": precision, 154 | "blank_rate": blank_rate, 155 | } 156 | return loss, sample_size, logging_output 157 | 158 | def compute_ctc_loss(self, model, net_output, target, reduce=True): 159 | """ 160 | lprobs is expected to be batch first. (from model forward output, or net_output) 161 | """ 162 | lprobs = model.get_normalized_probs( 163 | net_output, log_probs=True 164 | ) 165 | bsz = target.size(0) 166 | # reshape lprobs to (L,B,X) for torch.ctc 167 | if lprobs.size(0) != bsz: 168 | raise RuntimeError( 169 | f'batch size error: lprobs shape={lprobs.size()}, bsz={bsz}') 170 | max_src = lprobs.size(1) 171 | lprobs = lprobs.transpose(1, 0).contiguous() 172 | 173 | # get subsampling padding mask & lengths 174 | if net_output[1]["padding_mask"] is not None: 175 | non_padding_mask = ~net_output[1]["padding_mask"] 176 | input_lengths = non_padding_mask.long().sum(-1) 177 | else: 178 | input_lengths = lprobs.new_ones( 179 | (bsz, max_src), dtype=torch.long).sum(-1) 180 | 181 | pad_mask = target.ne(self.pad_idx) 182 | if not self.eos_loss: 183 | pad_mask &= target.ne(self.eos_idx) 184 | targets_flat = target.masked_select(pad_mask) 185 | target_lengths = pad_mask.long().sum(-1) 186 | 187 | with torch.backends.cudnn.flags(enabled=False): 188 | nll_loss = F.ctc_loss( 189 | lprobs, 190 | targets_flat, 191 | input_lengths, 192 | target_lengths, 193 | blank=self.blank_idx, 194 | reduction="sum", 195 | zero_infinity=self.zero_infinity, 196 | ) 197 | 198 | # label smoothing 199 | smooth_loss = -lprobs.sum(dim=-1).transpose(1, 0) # (L,B) -> (B,L) 200 | if net_output[1]["padding_mask"] is not None: 201 | smooth_loss.masked_fill_( 202 | net_output[1]["padding_mask"], 203 | 0.0 204 | ) 205 | eps_i = self.eps / lprobs.size(-1) 206 | loss = (1.0 - self.eps) * nll_loss + eps_i * smooth_loss.sum() 207 | 208 | return loss, nll_loss 209 | 210 | @classmethod 211 | def reduce_metrics(cls, logging_outputs) -> None: 212 | """Aggregate logging outputs from data parallel training.""" 213 | super().reduce_metrics(logging_outputs) 214 | 215 | def sum_logs(key): 216 | import torch 217 | result = sum(log.get(key, 0) for log in logging_outputs) 218 | if torch.is_tensor(result): 219 | result = result.cpu() 220 | return result 221 | 222 | inv_rate = sum_logs("inv_rate") 223 | sinkhorn_dist_sum = sum_logs("sinkhorn_dist") 224 | nsentences = sum_logs("nsentences") 225 | matching_entropy = sum_logs("matching_entropy") 226 | recall = sum_logs("recall") 227 | precision = sum_logs("precision") 228 | blank_rate = sum_logs("blank_rate") 229 | 230 | metrics.log_scalar( 231 | "inversion_rate", inv_rate / nsentences, nsentences, round=3 232 | ) 233 | metrics.log_scalar( 234 | "sinkhorn_dist", sinkhorn_dist_sum / nsentences, nsentences, round=3 235 | ) 236 | metrics.log_scalar( 237 | "matching_entropy", matching_entropy / nsentences, nsentences, round=3 238 | ) 239 | 240 | metrics.log_scalar( 241 | "recall", recall / nsentences, nsentences, round=3 242 | ) 243 | metrics.log_scalar( 244 | "precision", precision / nsentences, nsentences, round=3 245 | ) 246 | metrics.log_scalar( 247 | "blank_rate", blank_rate / nsentences, nsentences, round=3 248 | ) 249 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/tasks/translation_infer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .waitk_sequence_generator import WaitkSequenceGenerator 7 | from .inference_config import InferenceConfig 8 | import logging 9 | import torch 10 | import numpy as np 11 | from fairseq import metrics, utils 12 | from fairseq.tasks import register_task, LegacyFairseqTask 13 | from fairseq.tasks.translation import TranslationTask 14 | from fairseq.logging.meters import safe_round 15 | 16 | from fairseq.scoring.bleu import SacrebleuScorer 17 | from fairseq.scoring.wer import WerScorer 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | EVAL_BLEU_ORDER = 4 22 | 23 | 24 | @register_task("translation_infer") 25 | class TranslationWInferenceTask(TranslationTask): 26 | @staticmethod 27 | def add_args(parser): 28 | TranslationTask.add_args(parser) 29 | 30 | parser.add_argument( 31 | "--inference-config-yaml", 32 | type=str, 33 | default="inference.yaml", 34 | help="Configuration YAML filename for bleu or wer eval (under exp/)", 35 | ) 36 | 37 | parser.add_argument( 38 | "--from-encoder", action="store_true", 39 | help=( 40 | 'decode prediction from model.encoder.' 41 | ), 42 | ) 43 | 44 | def __init__(self, args, src_dict, tgt_dict): 45 | super().__init__(args, src_dict, tgt_dict) 46 | # for waitk models, left padding is bad for inference. 47 | self.cfg.left_pad_source = False 48 | 49 | self.inference_cfg = InferenceConfig(args.inference_config_yaml) 50 | self.pre_tokenizer = self.build_tokenizer(args) 51 | 52 | self.from_encoder = getattr(args, "from_encoder", False) 53 | 54 | def build_model(self, args): 55 | model = super().build_model(args) 56 | if self.inference_cfg.eval_any: 57 | self.sequence_generator = self.build_generator( 58 | [model], 59 | self.inference_cfg.generation_args, 60 | ) 61 | return model 62 | 63 | def build_generator( 64 | self, 65 | models, 66 | args, 67 | seq_gen_cls=None, 68 | extra_gen_cls_kwargs=None, 69 | ): 70 | """ speech_to_text ignores seq_gen_cls and overrides 71 | extra_gen_cls_kwargs. So we will call LegacyFairseqTask's 72 | method. """ 73 | waitk = getattr(models[0], "waitk", None) 74 | test_waitk = getattr(self.inference_cfg.generation_args, "waitk", None) 75 | if test_waitk is not None and test_waitk != waitk: 76 | # test override. 77 | logger.warning( 78 | f"Train test mismatch: training wait-{waitk}, while testing wait-{test_waitk}.") 79 | waitk = test_waitk 80 | pre_ratio = 1 81 | if waitk is not None: 82 | pre_ratio = models[0].pre_decision_ratio 83 | seq_gen_cls = WaitkSequenceGenerator 84 | extra = {"waitk": waitk, "pre_decision_ratio": pre_ratio} 85 | if extra_gen_cls_kwargs: 86 | extra_gen_cls_kwargs.update(extra) 87 | else: 88 | extra_gen_cls_kwargs = extra 89 | return LegacyFairseqTask.build_generator( 90 | self, models, args, seq_gen_cls=seq_gen_cls, extra_gen_cls_kwargs=extra_gen_cls_kwargs 91 | ) 92 | 93 | def valid_step(self, sample, model, criterion): 94 | loss, sample_size, logging_output = super().valid_step(sample, model, criterion) 95 | if self.inference_cfg.eval_any: 96 | _metrics = self._inference_with_metrics(self.sequence_generator, sample, model) 97 | 98 | if self.inference_cfg.eval_bleu: 99 | bleu = _metrics["bleu"] 100 | logging_output["_bleu_sys_len"] = bleu.sys_len 101 | logging_output["_bleu_ref_len"] = bleu.ref_len 102 | # we split counts into separate entries so that they can be 103 | # summed efficiently across workers using fast-stat-sync 104 | assert len(bleu.counts) == EVAL_BLEU_ORDER 105 | for i in range(EVAL_BLEU_ORDER): 106 | logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] 107 | logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] 108 | if self.inference_cfg.eval_wer: 109 | logging_output.update(_metrics["wer"]) 110 | return loss, sample_size, logging_output 111 | 112 | @torch.no_grad() 113 | def inference_step( 114 | self, generator, models, sample, prefix_tokens=None, constraints=None 115 | ): 116 | if getattr(models[0], "one_pass_decoding", False): 117 | # one-pass decoding 118 | if hasattr(self, 'blank_symbol'): 119 | sample["net_input"]["blank_idx"] = self.tgt_dict.index(self.blank_symbol) 120 | sample["net_input"]["from_encoder"] = self.from_encoder 121 | return models[0].generate(**sample["net_input"]) 122 | else: 123 | # incremental decoding 124 | return generator.generate( 125 | models, sample, prefix_tokens=prefix_tokens, constraints=constraints 126 | ) 127 | 128 | def _inference_with_metrics(self, generator, sample, model): 129 | 130 | def decode(toks, escape_unk=False): 131 | s = self.tgt_dict.string( 132 | toks.int().cpu(), 133 | self.inference_cfg.post_process, # this will handle bpe for us. 134 | # The default unknown string in fairseq is ``, but 135 | # this is tokenized by sacrebleu as `< unk >`, inflating 136 | # BLEU scores. Instead, we use a somewhat more verbose 137 | # alternative that is unlikely to appear in the real 138 | # reference, but doesn't get split into multiple tokens. 139 | unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), 140 | ) 141 | if self.pre_tokenizer is not None: 142 | s = self.pre_tokenizer.decode(s) 143 | return s if s else "UNKNOWNTOKENINHYP" 144 | 145 | gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) 146 | hyps, refs = [], [] 147 | for i in range(len(gen_out)): 148 | hyps.append( 149 | decode(gen_out[i][0]["tokens"]) 150 | ) 151 | refs.append( 152 | decode( 153 | utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), 154 | escape_unk=True, # don't count as matches to the hypo 155 | ) 156 | ) 157 | if self.inference_cfg.print_samples: 158 | logger.info("example hypothesis: " + hyps[0]) 159 | logger.info("example reference: " + refs[0]) 160 | 161 | ret = {} 162 | if self.inference_cfg.eval_bleu: 163 | bleu_scorer = SacrebleuScorer(self.inference_cfg.eval_bleu_args) 164 | for h, r in zip(hyps, refs): 165 | bleu_scorer.add_string(ref=r, pred=h) 166 | 167 | ret["bleu"] = bleu_scorer.sacrebleu.corpus_bleu( 168 | bleu_scorer.pred, [bleu_scorer.ref], 169 | tokenize="none" # use none because it's handled by SacrebleuScorer 170 | ) 171 | 172 | if self.inference_cfg.eval_wer: 173 | wer_scorer = WerScorer(self.inference_cfg.eval_wer_args) 174 | for h, r in zip(hyps, refs): 175 | wer_scorer.add_string(ref=r, pred=h) 176 | 177 | ret["wer"] = { 178 | "wv_errors": wer_scorer.distance, 179 | "w_errors": wer_scorer.distance, 180 | "w_total": wer_scorer.ref_length 181 | } 182 | 183 | return ret 184 | 185 | def reduce_metrics(self, logging_outputs, criterion): 186 | super().reduce_metrics(logging_outputs, criterion) 187 | 188 | def sum_logs(key): 189 | import torch 190 | result = sum(log.get(key, 0) for log in logging_outputs) 191 | if torch.is_tensor(result): 192 | result = result.cpu() 193 | return result 194 | 195 | if self.inference_cfg.eval_bleu: 196 | 197 | counts, totals = [], [] 198 | for i in range(EVAL_BLEU_ORDER): 199 | counts.append(sum_logs("_bleu_counts_" + str(i))) 200 | totals.append(sum_logs("_bleu_totals_" + str(i))) 201 | 202 | if max(totals) > 0: 203 | # log counts as numpy arrays -- log_scalar will sum them correctly 204 | metrics.log_scalar("_bleu_counts", np.array(counts)) 205 | metrics.log_scalar("_bleu_totals", np.array(totals)) 206 | metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) 207 | metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) 208 | 209 | def compute_bleu(meters): 210 | import inspect 211 | import sacrebleu 212 | 213 | fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] 214 | if "smooth_method" in fn_sig: 215 | smooth = {"smooth_method": "exp"} 216 | else: 217 | smooth = {"smooth": "exp"} 218 | bleu = sacrebleu.compute_bleu( 219 | correct=meters["_bleu_counts"].sum, 220 | total=meters["_bleu_totals"].sum, 221 | sys_len=meters["_bleu_sys_len"].sum, 222 | ref_len=meters["_bleu_ref_len"].sum, 223 | **smooth 224 | ) 225 | return round(bleu.score, 2) 226 | 227 | metrics.log_derived("bleu", compute_bleu) 228 | 229 | if self.inference_cfg.eval_wer: 230 | 231 | w_errors = sum_logs("w_errors") 232 | wv_errors = sum_logs("wv_errors") 233 | w_total = sum_logs("w_total") 234 | 235 | metrics.log_scalar("_w_errors", w_errors) 236 | metrics.log_scalar("_wv_errors", wv_errors) 237 | metrics.log_scalar("_w_total", w_total) 238 | 239 | if w_total > 0: 240 | metrics.log_derived( 241 | "wer", 242 | lambda meters: safe_round( 243 | meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 244 | ) 245 | if meters["_w_total"].sum > 0 246 | else float("nan"), 247 | ) 248 | metrics.log_derived( 249 | "raw_wer", 250 | lambda meters: safe_round( 251 | meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 252 | ) 253 | if meters["_w_total"].sum > 0 254 | else float("nan"), 255 | ) 256 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/modules/waitk_transformer_layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Adapted from https://github.com/elbayadm/attn2d/blob/master/examples/waitk 7 | # Implementation of the papers: 8 | # *Efficient Wait-k Models for Simultaneous Machine Translation 9 | # http://www.interspeech2020.org/uploadfile/pdf/Tue-1-1-2.pdf 10 | 11 | import logging 12 | from typing import Dict, List, Optional 13 | 14 | import torch 15 | 16 | from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer 17 | from torch import Tensor 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class NonCausalTransformerEncoderLayer(TransformerEncoderLayer): 23 | """ Enhance encoder layer by 24 | 1. adding log-distance penalty for speech 25 | 2. handle encoder padding mask 26 | """ 27 | def __init__(self, args): 28 | super().__init__(args) 29 | self._future_mask = torch.empty(0) 30 | self.log_penalty = getattr(args, "encoder_log_penalty", False) 31 | 32 | def buffered_future_mask(self, tensor): 33 | dim = tensor.size(0) 34 | # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. 35 | if ( 36 | self._future_mask.size(0) == 0 37 | or (not self._future_mask.device == tensor.device) 38 | or self._future_mask.size(0) < dim 39 | ): 40 | self._future_mask = torch.zeros([dim, dim]) 41 | if self.log_penalty: 42 | penalty = torch.arange(dim).type_as(self._future_mask) 43 | penalty = torch.abs( 44 | penalty.unsqueeze(1) - penalty 45 | ).clamp(min=1) 46 | self._future_mask -= penalty.log() 47 | self._future_mask = self._future_mask.to(tensor) 48 | if self._future_mask.any(): 49 | return self._future_mask[:dim, :dim] 50 | else: 51 | return None 52 | 53 | def forward( 54 | self, x, encoder_padding_mask, 55 | ): 56 | attn_mask = self.buffered_future_mask(x) 57 | 58 | ###################################### 59 | # below is same as original # 60 | ###################################### 61 | 62 | residual = x 63 | if self.normalize_before: 64 | x = self.self_attn_layer_norm(x) 65 | x, _ = self.self_attn( 66 | query=x, 67 | key=x, 68 | value=x, 69 | key_padding_mask=encoder_padding_mask, 70 | attn_mask=attn_mask, 71 | ) 72 | x = self.dropout_module(x) 73 | x = self.residual_connection(x, residual) 74 | if not self.normalize_before: 75 | x = self.self_attn_layer_norm(x) 76 | 77 | residual = x 78 | if self.normalize_before: 79 | x = self.final_layer_norm(x) 80 | x = self.activation_fn(self.fc1(x)) 81 | x = self.activation_dropout_module(x) 82 | x = self.fc2(x) 83 | x = self.dropout_module(x) 84 | x = self.residual_connection(x, residual) 85 | if not self.normalize_before: 86 | x = self.final_layer_norm(x) 87 | return x 88 | 89 | 90 | class CausalTransformerEncoderLayer(TransformerEncoderLayer): 91 | """ Similar to NonCausal above, but adds 92 | 1. future masking for causal encoding 93 | 2. incremental states for incremental encoding in inference 94 | """ 95 | def __init__(self, args, delay=1): 96 | super().__init__(args) 97 | self._future_mask = torch.empty(0) 98 | self.log_penalty = getattr(args, "encoder_log_penalty", False) 99 | self.delay = delay 100 | assert self.delay > 0, "Cannot be faster than delay=1." 101 | 102 | def buffered_future_mask(self, tensor): 103 | dim = tensor.size(0) 104 | # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. 105 | if ( 106 | self._future_mask.size(0) == 0 107 | or (not self._future_mask.device == tensor.device) 108 | or self._future_mask.size(0) < dim 109 | ): 110 | neg_inf = -torch.finfo(tensor.dtype).max 111 | self._future_mask = torch.triu( 112 | torch.full([dim, dim], neg_inf), self.delay 113 | # utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 114 | ) 115 | if self.log_penalty: 116 | penalty = torch.arange(dim).type_as(self._future_mask) 117 | penalty = torch.abs( 118 | penalty.unsqueeze(1) - penalty 119 | ).clamp(min=1) 120 | self._future_mask -= penalty.log() 121 | self._future_mask = self._future_mask.to(tensor) 122 | return self._future_mask[:dim, :dim] 123 | 124 | def forward( 125 | self, x, encoder_padding_mask, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 126 | ): 127 | """ 128 | In inference, prev states are cached so we need to 129 | compute mask from cached states and input x together. 130 | """ 131 | if incremental_state is not None: 132 | proto = x 133 | saved_state = self.self_attn._get_input_buffer(incremental_state) 134 | if "prev_key" in saved_state: 135 | prev_key = saved_state["prev_key"] 136 | assert prev_key is not None 137 | prev_len = prev_key.size(2) 138 | new_len = x.size(0) 139 | proto = x.new_zeros( 140 | (prev_len + new_len, 1)) # only dim 0 is used. 141 | attn_mask = self.buffered_future_mask( 142 | proto)[-x.size(0):] # keep mask for x only 143 | else: 144 | attn_mask = self.buffered_future_mask(x) 145 | 146 | ################################################# 147 | # below is same as original + incremental_State # 148 | ################################################# 149 | 150 | residual = x 151 | if self.normalize_before: 152 | x = self.self_attn_layer_norm(x) 153 | x, _ = self.self_attn( 154 | query=x, 155 | key=x, 156 | value=x, 157 | key_padding_mask=encoder_padding_mask, 158 | incremental_state=incremental_state, 159 | attn_mask=attn_mask, 160 | ) 161 | x = self.dropout_module(x) 162 | x = self.residual_connection(x, residual) 163 | if not self.normalize_before: 164 | x = self.self_attn_layer_norm(x) 165 | 166 | residual = x 167 | if self.normalize_before: 168 | x = self.final_layer_norm(x) 169 | x = self.activation_fn(self.fc1(x)) 170 | x = self.activation_dropout_module(x) 171 | x = self.fc2(x) 172 | x = self.dropout_module(x) 173 | x = self.residual_connection(x, residual) 174 | if not self.normalize_before: 175 | x = self.final_layer_norm(x) 176 | return x 177 | 178 | def make_generation_fast_(self, need_attn: bool = False, **kwargs): 179 | self.need_attn = need_attn 180 | 181 | def prune_incremental_state( 182 | self, 183 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], 184 | keep: Optional[int] = None, 185 | ): 186 | if keep is None: 187 | return 188 | 189 | input_buffer = self.self_attn._get_input_buffer(incremental_state) 190 | for key in ["prev_key", "prev_value"]: 191 | input_buffer_key = input_buffer[key] 192 | assert input_buffer_key is not None 193 | # if input_buffer_key.size(2) > prune: 194 | if keep > 0: 195 | input_buffer[key] = input_buffer_key[:, :, :keep, :] 196 | else: 197 | typed_empty_dict: Dict[str, Optional[Tensor]] = {} 198 | input_buffer = typed_empty_dict 199 | break 200 | assert incremental_state is not None 201 | self.self_attn._set_input_buffer(incremental_state, input_buffer) 202 | 203 | 204 | class WaitkTransformerDecoderLayer(TransformerDecoderLayer): 205 | """Wait-k Decoder layer block. 206 | 1. added encoder_attn_mask for wait-k masking 207 | 2. for simul trans, we CANNOT cache encoder states! in inference, 208 | the encoder states dicts should be constantly updated. 209 | """ 210 | def forward( 211 | self, 212 | x, 213 | encoder_out: Optional[torch.Tensor] = None, 214 | encoder_padding_mask: Optional[torch.Tensor] = None, 215 | encoder_attn_mask: Optional[torch.Tensor] = None, 216 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 217 | prev_self_attn_state: Optional[List[torch.Tensor]] = None, 218 | prev_attn_state: Optional[List[torch.Tensor]] = None, 219 | self_attn_mask: Optional[torch.Tensor] = None, 220 | self_attn_padding_mask: Optional[torch.Tensor] = None, 221 | need_attn: bool = False, 222 | need_head_weights: bool = False, 223 | cache_encoder: bool = True, 224 | cache_decoder: bool = True, 225 | ): 226 | """ 227 | Args: 228 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` 229 | encoder_padding_mask (ByteTensor, optional): binary 230 | ByteTensor of shape `(batch, src_len)` where padding 231 | elements are indicated by ``1``. 232 | need_attn (bool, optional): return attention weights 233 | need_head_weights (bool, optional): return attention weights 234 | for each head (default: return average over heads). 235 | Returns: 236 | encoded output of shape `(seq_len, batch, embed_dim)` 237 | """ 238 | if need_head_weights: 239 | need_attn = True 240 | 241 | residual = x 242 | if self.normalize_before: 243 | x = self.self_attn_layer_norm(x) 244 | if prev_self_attn_state is not None: 245 | # if incremental_state is None: 246 | # incremental_state = {} 247 | prev_key, prev_value = prev_self_attn_state[:2] 248 | saved_state: Dict[str, Optional[Tensor]] = { 249 | "prev_key": prev_key, 250 | "prev_value": prev_value, 251 | } 252 | if len(prev_self_attn_state) >= 3: 253 | saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] 254 | assert incremental_state is not None 255 | self.self_attn._set_input_buffer(incremental_state, saved_state) 256 | 257 | x, attn = self.self_attn( 258 | query=x, 259 | key=x, 260 | value=x, 261 | key_padding_mask=self_attn_padding_mask, 262 | incremental_state=incremental_state if cache_decoder else None, 263 | need_weights=False, 264 | attn_mask=self_attn_mask, 265 | ) 266 | x = self.dropout_module(x) 267 | x = self.residual_connection(x, residual) 268 | if not self.normalize_before: 269 | x = self.self_attn_layer_norm(x) 270 | 271 | if self.encoder_attn is not None and encoder_out is not None: 272 | residual = x 273 | if self.normalize_before: 274 | x = self.encoder_attn_layer_norm(x) 275 | if prev_attn_state is not None: 276 | # if incremental_state is None: 277 | # incremental_state = {} 278 | prev_key, prev_value = prev_attn_state[:2] 279 | saved_state: Dict[str, Optional[Tensor]] = { 280 | "prev_key": prev_key, 281 | "prev_value": prev_value, 282 | } 283 | if len(prev_attn_state) >= 3: 284 | saved_state["prev_key_padding_mask"] = prev_attn_state[2] 285 | assert incremental_state is not None 286 | self.encoder_attn._set_input_buffer(incremental_state, saved_state) 287 | # for simul trans, you CANNOT cache encoder states! in inference, 288 | # the encoder should be constantly updated. 289 | x, attn = self.encoder_attn( 290 | query=x, 291 | key=encoder_out, 292 | value=encoder_out, 293 | key_padding_mask=encoder_padding_mask, 294 | attn_mask=encoder_attn_mask, 295 | incremental_state=None, 296 | static_kv=False, 297 | need_weights=need_attn or (not self.training and self.need_attn), 298 | need_head_weights=need_head_weights, 299 | ) 300 | x = self.dropout_module(x) 301 | x = self.residual_connection(x, residual) 302 | if not self.normalize_before: 303 | x = self.encoder_attn_layer_norm(x) 304 | 305 | residual = x 306 | if self.normalize_before: 307 | x = self.final_layer_norm(x) 308 | 309 | x = self.activation_fn(self.fc1(x)) 310 | x = self.activation_dropout_module(x) 311 | x = self.fc2(x) 312 | x = self.dropout_module(x) 313 | x = self.residual_connection(x, residual) 314 | if not self.normalize_before: 315 | x = self.final_layer_norm(x) 316 | return x, attn 317 | 318 | def make_generation_fast_(self, need_attn: bool = False, **kwargs): 319 | self.need_attn = need_attn 320 | 321 | def prune_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): 322 | input_buffer = self.self_attn._get_input_buffer(incremental_state) 323 | for key in ["prev_key", "prev_value"]: 324 | input_buffer_key = input_buffer[key] 325 | assert input_buffer_key is not None 326 | if input_buffer_key.size(2) > 1: 327 | input_buffer[key] = input_buffer_key[:, :, :-1, :] 328 | else: 329 | typed_empty_dict: Dict[str, Optional[Tensor]] = {} 330 | input_buffer = typed_empty_dict 331 | break 332 | assert incremental_state is not None 333 | self.self_attn._set_input_buffer(incremental_state, input_buffer) 334 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/modules/sinkhorn_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from fairseq.modules.fairseq_dropout import FairseqDropout 12 | from torch import Tensor, nn 13 | from torch.nn import Parameter 14 | 15 | 16 | def sample_gumbel(proto, eps=1e-8): 17 | u = torch.rand_like(proto, dtype=torch.float32) 18 | return -torch.log(-torch.log(u + eps) + eps) 19 | 20 | 21 | def log_sinkhorn_norm(log_alpha: torch.Tensor, n_iter: int = 20) -> (torch.Tensor,): 22 | for _ in range(n_iter): 23 | log_alpha = log_alpha - torch.logsumexp(log_alpha, -1, keepdim=True) 24 | log_alpha = log_alpha - torch.logsumexp(log_alpha, -2, keepdim=True) 25 | return log_alpha.exp() 26 | 27 | 28 | def gumbel_sinkhorn( 29 | log_alpha: torch.Tensor, 30 | tau: float = 0.7, 31 | n_iter: int = 20, 32 | noise_factor: float = 1.0 33 | ) -> (torch.Tensor,): 34 | if noise_factor > 0: 35 | noise = noise_factor * sample_gumbel(log_alpha) 36 | log_alpha = log_alpha + noise.type_as(log_alpha) 37 | log_alpha = log_alpha / tau 38 | sampled_perm_mat = log_sinkhorn_norm(log_alpha, n_iter) if n_iter > 0 else log_alpha.softmax(-1) 39 | return sampled_perm_mat 40 | 41 | 42 | class GaussianBlur(nn.Conv2d): 43 | """ Blur the attention map before sinkhorn normalization """ 44 | def __init__(self, kernel_size=3): 45 | super().__init__( 46 | 1, 1, kernel_size, 47 | padding=kernel_size // 2, 48 | bias=False, 49 | padding_mode='replicate', 50 | ) 51 | mu = (kernel_size - 1) / 2. 52 | var = (kernel_size / 2.)**2 53 | grid = torch.arange(kernel_size) - mu 54 | grid_x, grid_y = torch.meshgrid(grid, grid) 55 | grid_xy = grid_x**2 + grid_y**2 56 | 57 | gaussian = torch.exp( 58 | -grid_xy / (2 * var) 59 | ).view(1, 1, kernel_size, kernel_size) 60 | gaussian = gaussian / gaussian.sum() 61 | 62 | self.weight.data = gaussian 63 | self.weight.data.requires_grad = False 64 | 65 | def forward(self, x): 66 | return super().forward( 67 | x.unsqueeze(1) 68 | ).squeeze(1) 69 | 70 | 71 | class SinkhornAttention(nn.Module): 72 | """Single head attention with sinkhorn normalization. 73 | """ 74 | ENERGY_FNS = ["dot", "cos", "l2"] 75 | 76 | def __init__( 77 | self, 78 | embed_dim, 79 | bucket_size, 80 | kdim=None, 81 | vdim=None, 82 | dropout=0.0, 83 | bias=True, 84 | add_bias_kv=False, 85 | no_query_proj=False, 86 | no_key_proj=False, 87 | no_value_proj=False, 88 | no_out_proj=False, 89 | blurr_kernel=1, 90 | sinkhorn_tau=0.75, 91 | sinkhorn_iters=8, 92 | sinkhorn_noise_factor=1.0, 93 | energy_fn='dot', 94 | ): 95 | super().__init__() 96 | self.embed_dim = embed_dim 97 | self.bucket_size = bucket_size 98 | self.kdim = kdim if kdim is not None else embed_dim 99 | self.vdim = vdim if vdim is not None else embed_dim 100 | 101 | self.dropout_module = FairseqDropout( 102 | dropout, module_name=self.__class__.__name__ 103 | ) 104 | self.scaling = self.embed_dim ** -0.5 105 | 106 | if no_query_proj: 107 | self.q_proj = None 108 | else: 109 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 110 | 111 | if no_key_proj: 112 | self.k_proj = None 113 | else: 114 | self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) 115 | 116 | if no_value_proj: 117 | self.v_proj = None 118 | else: 119 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 120 | 121 | if no_out_proj: 122 | self.out_proj = None 123 | else: 124 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 125 | 126 | if add_bias_kv: 127 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 128 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 129 | else: 130 | self.bias_k = self.bias_v = None 131 | 132 | self.tau = sinkhorn_tau 133 | self.iters = sinkhorn_iters 134 | self.noise_factor = sinkhorn_noise_factor 135 | self.energy_fn = energy_fn 136 | assert self.energy_fn in self.ENERGY_FNS, f"{energy_fn} not in {self.ENERGY_FNS}" 137 | 138 | if blurr_kernel > 1: 139 | self.blurr = GaussianBlur(blurr_kernel) 140 | else: 141 | self.blurr = None 142 | 143 | self.reset_parameters() 144 | 145 | def extra_repr(self): 146 | s = "dim={}, bucket_size={}, tau={}, iters={}, noise_factor={}, energy_fn={}".format( 147 | self.embed_dim, 148 | self.bucket_size, 149 | self.tau, 150 | self.iters, 151 | self.noise_factor, 152 | self.energy_fn, 153 | ) 154 | return s 155 | 156 | def reset_parameters(self): 157 | # Empirically observed the convergence to be much better with 158 | # the scaled initialization 159 | if self.q_proj is not None: 160 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) 161 | if self.k_proj is not None: 162 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) 163 | if self.v_proj is not None: 164 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) 165 | 166 | if self.out_proj is not None: 167 | nn.init.xavier_uniform_(self.out_proj.weight) 168 | if self.out_proj.bias is not None: 169 | nn.init.constant_(self.out_proj.bias, 0.0) 170 | if self.bias_k is not None: 171 | nn.init.xavier_normal_(self.bias_k) 172 | if self.bias_v is not None: 173 | nn.init.xavier_normal_(self.bias_v) 174 | 175 | def pad_to_multiple(self, q, k, v, key_padding_mask): 176 | """Input shape 177 | q: (B, T, E), 178 | k, v: (B, S, E), 179 | key_padding_mask: (B, S), 180 | """ 181 | B, T, E = q.size() 182 | B, S, E = k.size() 183 | 184 | new_q = q 185 | new_k = k 186 | new_v = v 187 | new_key_padding_mask = key_padding_mask 188 | 189 | buckets = math.ceil(T / self.bucket_size) 190 | kv_buckets = math.ceil(S / self.bucket_size) 191 | 192 | # pad query 193 | new_T = buckets * self.bucket_size 194 | if new_T != T: 195 | new_q = torch.cat([ 196 | q, 197 | q.new_zeros((B, new_T - T, E)), 198 | ], dim=1) 199 | # if attn_mask is not None: 200 | # new_attn_mask = attn_mask.new_zeros((new_T, new_T)) 201 | # new_attn_mask[:T, :T] = attn_mask 202 | # new_attn_mask[:, T:].fill_(float("-inf")) 203 | 204 | # pad key value 205 | new_S = kv_buckets * self.bucket_size 206 | if new_S != S: 207 | new_k = torch.cat([ 208 | k, 209 | k.new_zeros((B, new_S - S, E)), 210 | ], dim=1) 211 | new_v = torch.cat([ 212 | v, 213 | v.new_zeros((B, new_S - S, E)), 214 | ], dim=1) 215 | if key_padding_mask is None: 216 | key_padding_mask = k.new_zeros((B, S), dtype=torch.bool) 217 | 218 | new_key_padding_mask = torch.cat([ 219 | key_padding_mask, 220 | key_padding_mask.new_ones((B, new_S - S)), 221 | ], dim=1) 222 | 223 | return ( 224 | new_q, 225 | new_k, 226 | new_v, 227 | new_key_padding_mask, 228 | new_T - T, 229 | new_S - S 230 | ) 231 | 232 | def aggregate_buckets(self, q, k, v, key_padding_mask): 233 | """Input shape 234 | q: (B, T, E), 235 | k, v: (B, S, E), 236 | key_padding_mask: (B, S), 237 | """ 238 | B, T, E = q.size() 239 | B, S, E = k.size() 240 | buckets = T // self.bucket_size 241 | kv_buckets = S // self.bucket_size 242 | 243 | # aggregate query & key by meaning (summing in paper?) each buckets 244 | new_q = q.view(B, buckets, self.bucket_size, E).mean(dim=2) 245 | new_k = k.view(B, kv_buckets, self.bucket_size, E).mean(dim=2) 246 | 247 | # aggregate value by concatenating into single vector 248 | new_v = v.contiguous().view(B, kv_buckets, self.bucket_size * E) 249 | 250 | # aggregate padding mask by: if a bucket is all pad then it is masked. 251 | new_key_padding_mask = key_padding_mask 252 | if key_padding_mask is not None: 253 | new_key_padding_mask = key_padding_mask.view( 254 | B, kv_buckets, self.bucket_size).prod(dim=2).type_as(key_padding_mask) 255 | 256 | return ( 257 | new_q, 258 | new_k, 259 | new_v, 260 | new_key_padding_mask 261 | ) 262 | 263 | def undo_aggregate_buckets(self, v, tail_v): 264 | """Input shape 265 | v: (B, new_S, E), 266 | """ 267 | B, kv_buckets, bucket_size_E = v.size() 268 | E = bucket_size_E // self.bucket_size 269 | new_v = v.view(B, kv_buckets * self.bucket_size, E) 270 | return new_v[:, :-tail_v, :] if tail_v > 0 else new_v 271 | 272 | def forward( 273 | self, 274 | query, 275 | key: Optional[Tensor], 276 | value: Optional[Tensor], 277 | key_padding_mask: Optional[Tensor] = None, 278 | # attn_mask: Optional[Tensor] = None, 279 | **unused, 280 | ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 281 | """Input shape: Time x Batch x Channel 282 | Args: 283 | key_padding_mask (ByteTensor, optional): mask to exclude 284 | keys that are pads, of shape `(batch, src_len)`, where 285 | padding elements are indicated by 1s. 286 | """ 287 | tgt_len, bsz, embed_dim = query.size() 288 | src_len, key_bsz, _ = key.size() 289 | 290 | assert embed_dim == self.embed_dim 291 | assert key is not None and value is not None 292 | assert key_bsz == bsz 293 | assert src_len, bsz == value.shape[:2] 294 | 295 | q = query 296 | k = key 297 | v = value 298 | 299 | if self.q_proj is not None: 300 | q = self.q_proj(query) 301 | 302 | if self.k_proj is not None: 303 | k = self.k_proj(key) 304 | 305 | if self.v_proj is not None: 306 | v = self.v_proj(value) 307 | 308 | if self.bias_k is not None: 309 | assert self.bias_v is not None 310 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 311 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 312 | 313 | if key_padding_mask is not None: 314 | key_padding_mask = torch.cat( 315 | [ 316 | key_padding_mask, 317 | key_padding_mask.new_zeros(key_padding_mask.size(0), 1), 318 | ], 319 | dim=1, 320 | ) 321 | 322 | q = q.transpose(0, 1) 323 | k = k.transpose(0, 1) 324 | v = v.transpose(0, 1) 325 | 326 | q, k, v, key_padding_mask, q_tail, v_tail = self.pad_to_multiple( 327 | q, k, v, key_padding_mask) 328 | 329 | q, k, v, new_key_padding_mask = self.aggregate_buckets( 330 | q, k, v, key_padding_mask) 331 | 332 | tgt_len = q.size(1) 333 | src_len = k.size(1) 334 | 335 | # This is part of a workaround to get around fork/join parallelism 336 | # not supporting Optional types. 337 | if new_key_padding_mask is not None and new_key_padding_mask.dim() == 0: 338 | new_key_padding_mask = None 339 | 340 | if self.energy_fn == "dot": 341 | attn_weights = torch.bmm(q, k.transpose(1, 2)) * self.scaling 342 | elif self.energy_fn == "cos": 343 | # serious underflow for half. 344 | q = q.float() 345 | k = k.float() 346 | attn_weights = F.cosine_similarity( 347 | q.unsqueeze(2), # (bsz, tgt_len, 1, embed_dim) 348 | k.unsqueeze(1), # (bsz, 1, src_len, embed_dim) 349 | dim=-1, 350 | ).type_as(v) 351 | elif self.energy_fn == "l2": 352 | # cdist not inplemented for half. 353 | q = q.float() 354 | k = k.float() 355 | attn_weights = -torch.cdist(q, k, p=2).type_as(v) 356 | else: 357 | raise NotImplementedError() 358 | 359 | # add blurring 360 | if self.blurr is not None: 361 | attn_weights = self.blurr(attn_weights) 362 | 363 | # save a copy before masking 364 | log_alpha = attn_weights.type_as(v) 365 | 366 | assert list(attn_weights.size()) == [bsz, tgt_len, src_len] 367 | 368 | if new_key_padding_mask is not None: 369 | assert list(new_key_padding_mask.size()) == [bsz, src_len] 370 | new_key_padding_mask = new_key_padding_mask.bool() 371 | 372 | final_mask = new_key_padding_mask.unsqueeze(1) & (~new_key_padding_mask).unsqueeze(2) 373 | neg_inf = -torch.finfo(attn_weights.dtype).max 374 | # mask out normal -> pad attentions 375 | attn_weights = attn_weights.masked_fill( 376 | final_mask, 377 | neg_inf, 378 | ) 379 | # mask out pad -> normal attentions 380 | attn_weights = attn_weights.masked_fill( 381 | final_mask.transpose(2, 1), 382 | neg_inf, 383 | ) 384 | 385 | attn_weights_float = gumbel_sinkhorn( 386 | attn_weights, 387 | tau=self.tau, 388 | n_iter=self.iters, 389 | noise_factor=self.noise_factor if self.training else 0, 390 | ) 391 | 392 | # convert back to half/float 393 | attn_weights = attn_weights_float.type_as(v) 394 | attn_probs = self.dropout_module(attn_weights) 395 | 396 | attn = torch.bmm(attn_probs, v) 397 | 398 | attn = self.undo_aggregate_buckets(attn, v_tail) 399 | 400 | attn = attn.transpose(0, 1) 401 | 402 | if self.out_proj is not None: 403 | attn = self.out_proj(attn) 404 | 405 | return attn, attn_weights, log_alpha 406 | -------------------------------------------------------------------------------- /utility/simultaneous_translation/eval/agents/simul_t2t_waitk.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # from torchinfo import summary 7 | import os 8 | import logging 9 | from fairseq import checkpoint_utils, tasks, utils 10 | import sentencepiece as spm 11 | import torch 12 | 13 | try: 14 | from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS 15 | from simuleval.agents import TextAgent 16 | from simuleval.states import ListEntry, TextStates 17 | except ImportError: 18 | print("Please install simuleval 'pip install simuleval'") 19 | 20 | logger = logging.getLogger(__name__) 21 | # logger.setLevel(logging.DEBUG) 22 | BOW_PREFIX = "\u2581" 23 | 24 | 25 | class SimulTransTextAgentWaitk(TextAgent): 26 | """ 27 | Simultaneous Translation 28 | Text agent for wait-k models 29 | """ 30 | @staticmethod 31 | def add_args(parser): 32 | # fmt: off 33 | parser.add_argument('--model-path', type=str, required=True, 34 | help='path to your pretrained model.') 35 | parser.add_argument("--data-bin", type=str, required=True, 36 | help="Path of data binary") 37 | # parser.add_argument("--max-len", type=int, default=100, 38 | # help="Max length of translation") 39 | parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", 40 | help="Subword splitter type for target text.") 41 | parser.add_argument("--tgt-splitter-path", type=str, default=None, 42 | help="Subword splitter model path for target text.") 43 | parser.add_argument("--src-splitter-type", type=str, default="SentencePiece", 44 | help="Subword splitter type for source text.") 45 | parser.add_argument("--src-splitter-path", type=str, default=None, 46 | help="Subword splitter model path for source text.") 47 | parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", 48 | help="User directory for simultaneous translation") 49 | parser.add_argument("--max-len-a", type=int, default=1.2, 50 | help="Max length of translation ax+b") 51 | parser.add_argument("--max-len-b", type=int, default=10, 52 | help="Max length of translation ax+b") 53 | parser.add_argument("--force-finish", default=False, action="store_true", 54 | help="Force the model to finish the hypothsis if the source is not finished") 55 | parser.add_argument("--test-waitk", type=int, default=1) 56 | parser.add_argument("--incremental-encoder", default=False, action="store_true", 57 | help="Update the model incrementally without recomputation of history.") 58 | parser.add_argument("--full-sentence", default=False, action="store_true", 59 | help="use full sentence strategy, " 60 | "by updating the encoder only once after read is finished.") 61 | parser.add_argument("--segment-type", type=str, default="word", choices=["word", "char"], 62 | help="Agent can send a word or a char to server at a time.") 63 | parser.add_argument("--workers", type=int, default=1) 64 | # unused for wait-k 65 | parser.add_argument('--lm-path', type=str, default="", 66 | help='path to your kenlm model.') 67 | parser.add_argument("--lm-weight", type=float, default=0, 68 | help='the log prob is interpolated as: model_logp + lm_weight * lm_logp.') 69 | return parser 70 | 71 | def __init__(self, args): 72 | 73 | self.test_waitk = args.test_waitk 74 | self.force_finish = args.force_finish 75 | self.incremental_encoder = args.incremental_encoder 76 | self.full_sentence = args.full_sentence 77 | self.segment_type = args.segment_type 78 | self.workers = args.workers 79 | 80 | if self.full_sentence: 81 | logger.info("Full sentence override waitk to 1024.") 82 | self.test_waitk = 1024 83 | 84 | # Whether use gpu 85 | self.gpu = getattr(args, "gpu", False) 86 | 87 | # Load Model 88 | self.load_model_vocab(args) 89 | 90 | # build word splitter 91 | self.build_word_splitter(args) 92 | 93 | self.eos = DEFAULT_EOS 94 | 95 | # Max len 96 | self.max_len = lambda x: args.max_len_a * x + args.max_len_b 97 | 98 | torch.set_grad_enabled(False) 99 | torch.set_num_threads(self.workers) 100 | 101 | def load_model_vocab(self, args): 102 | filename = args.model_path 103 | if not os.path.exists(filename): 104 | raise IOError("Model file not found: {}".format(filename)) 105 | 106 | state = checkpoint_utils.load_checkpoint_to_cpu( 107 | path=filename, arg_overrides=None, load_on_all_ranks=False) 108 | 109 | cfg = state["cfg"] 110 | 111 | # update overwrites: 112 | cfg.common.user_dir = args.user_dir 113 | cfg.task.data = args.data_bin 114 | cfg.model.load_pretrained_encoder_from = None 115 | cfg.model.load_pretrained_decoder_from = None 116 | 117 | utils.import_user_module(cfg.common) 118 | # Setup task, e.g., translation, language modeling, etc. 119 | task = tasks.setup_task(cfg.task) 120 | # Build model and criterion 121 | model = task.build_model(cfg.model) 122 | 123 | model.load_state_dict( 124 | state["model"], strict=True, model_cfg=cfg.model 125 | ) 126 | 127 | # Optimize ensemble for generation 128 | if self.gpu: 129 | model.cuda() 130 | model.prepare_for_inference_(cfg) 131 | 132 | self.model = model 133 | 134 | # Set dictionary 135 | self.dict = {} 136 | self.dict["tgt"] = task.target_dictionary 137 | self.dict["src"] = task.source_dictionary 138 | 139 | self.pre_tokenizer = task.pre_tokenizer 140 | 141 | # logger.info(summary(self.model)) 142 | logger.info("task: {}".format(task.__class__.__name__)) 143 | logger.info("model: {}".format(self.model.__class__.__name__)) 144 | logger.info("pre_tokenizer: {}".format(self.pre_tokenizer)) 145 | 146 | def build_word_splitter(self, args): 147 | self.spm = {} 148 | for lang in ['src', 'tgt']: 149 | if getattr(args, f'{lang}_splitter_type', None): 150 | path = getattr(args, f'{lang}_splitter_path', None) 151 | if path: 152 | self.spm[lang] = spm.SentencePieceProcessor() 153 | self.spm[lang].Load(path) 154 | 155 | def initialize_states(self, states): 156 | states.units.source = ListEntry() 157 | states.units.target = ListEntry() 158 | states.enc_incremental_states = dict() 159 | states.dec_incremental_states = dict() 160 | 161 | def build_states(self, args, client, sentence_id): 162 | # Initialize states here, for example add customized entry to states 163 | # This function will be called at beginning of every new sentence 164 | states = TextStates(args, client, sentence_id, self) 165 | self.initialize_states(states) 166 | return states 167 | 168 | def to_device(self, tensor): 169 | if self.gpu: 170 | return tensor.cuda() 171 | else: 172 | return tensor.cpu() 173 | 174 | def segment_to_units(self, segment, states): 175 | # Split a full word (segment) into subwords (units) 176 | return self.spm['src'].EncodeAsPieces(segment) 177 | 178 | def update_model_encoder(self, states): 179 | src_len = len(states.units.source) 180 | enc_len = 0 181 | 182 | if getattr(states, "encoder_states", None) is not None: 183 | enc_len = states.encoder_states["encoder_out"][0].size(0) 184 | 185 | src_indices = [ 186 | self.dict['src'].index(x) 187 | for x in states.units.source.value 188 | ] 189 | 190 | if states.finish_read() and src_indices[-1] != self.dict["tgt"].eos(): 191 | # Append the eos index when the prediction is over 192 | src_indices += [self.dict["tgt"].eos()] 193 | src_len += 1 194 | logger.debug("ADD EOS") 195 | 196 | if src_len <= enc_len: 197 | logger.debug("Redundant read") 198 | return 199 | 200 | src_indices = self.to_device( 201 | torch.LongTensor(src_indices).unsqueeze(0) 202 | ) 203 | src_lengths = self.to_device( 204 | torch.LongTensor([src_indices.size(1)]) 205 | ) 206 | 207 | if self.incremental_encoder: 208 | encoder_out = self.model.encoder( 209 | src_indices, 210 | src_lengths, 211 | incremental_state=states.enc_incremental_states, 212 | incremental_step=src_len - enc_len, 213 | ) 214 | if getattr(states, "encoder_states", None) is None: 215 | states.encoder_states = { 216 | # List[T x B x C] 217 | "encoder_out": encoder_out["encoder_out"], 218 | "encoder_padding_mask": [], # B x T 219 | "encoder_embedding": [], # B x T x C 220 | "encoder_states": [], # List[T x B x C] 221 | "src_tokens": [], 222 | "src_lengths": [], 223 | } 224 | else: 225 | states.encoder_states["encoder_out"][0] = torch.cat( 226 | ( 227 | states.encoder_states["encoder_out"][0], 228 | encoder_out["encoder_out"][0] 229 | ), dim=0 230 | ) 231 | else: 232 | states.encoder_states = self.model.encoder( 233 | src_indices, src_lengths) 234 | 235 | torch.cuda.empty_cache() 236 | 237 | def update_states_read(self, states): 238 | # Happens after a read action. 239 | if not self.full_sentence or states.finish_read(): 240 | self.update_model_encoder(states) 241 | 242 | def units_to_segment(self, unit_queue, states): 243 | """Merge sub word to full word. 244 | queue: stores bpe tokens. 245 | server: accept words. 246 | 247 | Therefore, we need merge subwords into word. we find the first 248 | subword that starts with BOW_PREFIX, then merge with subwords 249 | prior to this subword, remove them from queue, send to server. 250 | """ 251 | if self.segment_type == "char": 252 | return self.units_to_segment_char(unit_queue, states) 253 | tgt_dict = self.dict["tgt"] 254 | 255 | # if segment starts with eos, send EOS 256 | if tgt_dict.eos() == unit_queue[0]: 257 | return DEFAULT_EOS 258 | 259 | string_to_return = None 260 | 261 | def decode(tok_idx): 262 | hyp = tgt_dict.string( 263 | tok_idx, 264 | "sentencepiece", 265 | ) 266 | if self.pre_tokenizer is not None: 267 | hyp = self.pre_tokenizer.decode(hyp) 268 | return hyp 269 | 270 | # if force finish, there will be None's 271 | segment = [] 272 | if None in unit_queue.value: 273 | unit_queue.value.remove(None) 274 | 275 | src_len = len(states.units.source) 276 | if ( 277 | (len(unit_queue) > 0 and tgt_dict.eos() == unit_queue[-1]) 278 | or 279 | (states.finish_read() and len(states.units.target) > self.max_len(src_len)) 280 | ): 281 | hyp = decode(unit_queue) 282 | string_to_return = ([hyp] if hyp else []) + [DEFAULT_EOS] 283 | else: 284 | space_p = None 285 | for p, unit_id in enumerate(unit_queue): 286 | if p == 0: 287 | continue 288 | token = tgt_dict.string([unit_id]) 289 | if token.startswith(BOW_PREFIX): 290 | """ 291 | find the first tokens with escape symbol 292 | """ 293 | space_p = p 294 | break 295 | if space_p is not None: 296 | for j in range(space_p): 297 | segment += [unit_queue.pop()] 298 | 299 | hyp = decode(segment) 300 | string_to_return = [hyp] if hyp else [] 301 | 302 | if tgt_dict.eos() == unit_queue[0]: 303 | string_to_return += [DEFAULT_EOS] 304 | 305 | return string_to_return 306 | 307 | def units_to_segment_char(self, unit_queue, states): 308 | """ For chinese, direclty send tokens. """ 309 | 310 | tgt_dict = self.dict["tgt"] 311 | 312 | if None in unit_queue.value: 313 | unit_queue.value.remove(None) 314 | 315 | src_len = len(states.units.source) 316 | if ( 317 | (len(unit_queue) > 0 and tgt_dict.eos() == unit_queue[-1]) 318 | or 319 | (states.finish_read() and len(states.units.target) > self.max_len(src_len)) 320 | ): 321 | return DEFAULT_EOS 322 | 323 | unit_id = unit_queue.value.pop() 324 | token = tgt_dict.string([unit_id]) 325 | 326 | # even if replace with space, it will be stripped by the server :( 327 | return token.replace(BOW_PREFIX, "") 328 | 329 | def policy(self, states): 330 | if not getattr(states, "encoder_states", None) and not states.finish_read(): 331 | return READ_ACTION 332 | 333 | waitk = self.test_waitk 334 | src_len = len(states.units.source) 335 | enc_len = 0 336 | tgt_len = len(states.units.target) 337 | 338 | if getattr(states, "encoder_states", None) is not None: 339 | enc_len = states.encoder_states["encoder_out"][0].size(0) 340 | 341 | if src_len - tgt_len < waitk and not states.finish_read(): 342 | return READ_ACTION 343 | else: 344 | if states.finish_read() and enc_len < src_len + 1: 345 | # encode the last few sources (+1 eos) 346 | self.update_model_encoder(states) 347 | enc_len = states.encoder_states["encoder_out"][0].size(0) 348 | 349 | tgt_indices = self.to_device( 350 | torch.LongTensor( 351 | [self.dict["tgt"].eos()] 352 | + [x for x in states.units.target.value if x is not None] 353 | ).unsqueeze(0) 354 | ) 355 | 356 | logits, extra = self.model.forward_decoder( 357 | prev_output_tokens=tgt_indices, 358 | encoder_out=states.encoder_states, 359 | incremental_state=states.dec_incremental_states, 360 | ) 361 | 362 | states.decoder_out = logits 363 | 364 | torch.cuda.empty_cache() 365 | 366 | return WRITE_ACTION 367 | 368 | def predict(self, states): 369 | 370 | lprobs = self.model.get_normalized_probs( 371 | [states.decoder_out[:, -1:]], log_probs=True 372 | ) 373 | 374 | index = lprobs.argmax(dim=-1) 375 | 376 | index = index[0, 0].item() 377 | 378 | if ( 379 | self.force_finish 380 | and index == self.dict["tgt"].eos() 381 | and not states.finish_read() 382 | ): 383 | # If we want to force finish the translation 384 | # (don't stop before finish reading), return a None 385 | self.model.decoder.clear_cache(states.dec_incremental_states) 386 | index = None 387 | 388 | return index 389 | --------------------------------------------------------------------------------