├── src ├── __init__.py ├── modules │ ├── __init__.py │ ├── layer_norm.py │ ├── label_smoothed_cross_entropy.py │ ├── sinusoidal_positional_embedding.py │ └── multihead_attention.py ├── __pycache__ │ ├── logger.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── data │ ├── __pycache__ │ │ ├── loader.cpython-36.pyc │ │ ├── loader.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── dataset.cpython-36.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── dictionary.cpython-36.pyc │ │ └── dictionary.cpython-37.pyc │ ├── loader.py │ ├── dictionary.py │ └── dataset.py ├── model │ ├── __init__.py │ └── attention.py ├── fairseq_utils.py ├── utils.py ├── logger.py ├── adam_inverse_sqrt_with_warmup.py ├── distributed_utils.py └── trainer.py ├── train.sh ├── scripts ├── mosesdecoder │ ├── README.md │ └── scripts │ │ ├── share │ │ └── nonbreaking_prefixes │ │ │ ├── README.txt │ │ │ ├── nonbreaking_prefix.ro │ │ │ ├── nonbreaking_prefix.ga │ │ │ ├── nonbreaking_prefix.sv │ │ │ ├── nonbreaking_prefix.ca │ │ │ ├── nonbreaking_prefix.sl │ │ │ ├── nonbreaking_prefix.yue │ │ │ ├── nonbreaking_prefix.zh │ │ │ ├── nonbreaking_prefix.es │ │ │ ├── nonbreaking_prefix.lv │ │ │ ├── nonbreaking_prefix.fr │ │ │ ├── nonbreaking_prefix.en │ │ │ ├── nonbreaking_prefix.fi │ │ │ ├── nonbreaking_prefix.hu │ │ │ ├── nonbreaking_prefix.nl │ │ │ ├── nonbreaking_prefix.is │ │ │ ├── nonbreaking_prefix.it │ │ │ ├── nonbreaking_prefix.ru │ │ │ ├── nonbreaking_prefix.pl │ │ │ ├── nonbreaking_prefix.pt │ │ │ ├── nonbreaking_prefix.ta │ │ │ ├── nonbreaking_prefix.de │ │ │ ├── nonbreaking_prefix.cs │ │ │ ├── nonbreaking_prefix.sk │ │ │ └── nonbreaking_prefix.lt │ │ ├── tokenizer │ │ ├── lowercase.perl │ │ ├── remove-non-printing-char.perl │ │ ├── deescape-special-chars.perl │ │ ├── escape-special-chars.perl │ │ ├── replace-unicode-punctuation.perl │ │ └── normalize-punctuation.perl │ │ ├── generic │ │ ├── input-from-sgm.perl │ │ ├── wrap-xml.perl │ │ ├── multi-bleu.perl │ │ ├── bsbleu.py │ │ └── multi-bleu-detok.perl │ │ ├── recaser │ │ ├── detruecase.perl │ │ ├── truecase.perl │ │ └── train-truecaser.perl │ │ └── training │ │ └── clean-corpus-n.perl ├── multi-bleu.perl ├── learn_joint_bpe_and_vocab.py ├── multi-bleu-detok.perl └── learn_bpe.py ├── .gitignore ├── README.md ├── single_train.py ├── eval.sh ├── preprocess.py ├── multiprocessing_train.py ├── main_lstm.py ├── main.py └── translate.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd $(dirname $0) 3 | exec python3 main.py $@ -------------------------------------------------------------------------------- /src/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/data/__pycache__/loader.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/data/__pycache__/loader.cpython-37.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/dictionary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/data/__pycache__/dictionary.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/dictionary.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgwt123/simple-fairseq/HEAD/src/data/__pycache__/dictionary.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/mosesdecoder/README.md: -------------------------------------------------------------------------------- 1 | # moses-scripts 2 | 3 | A number of preprocessing scripts extracted from Moses (https://github.com/moses-smt/mosesdecoder) 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.swp 3 | *.pyc 4 | *.egg-info/ 5 | .eggs/ 6 | .idea/ 7 | .tox/ 8 | .pytest_cache/* 9 | __pycache__/ 10 | /build 11 | /dist 12 | /all_models 13 | /all_datas 14 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/README.txt: -------------------------------------------------------------------------------- 1 | The language suffix can be found here: 2 | 3 | http://www.loc.gov/standards/iso639-2/php/code_list.php 4 | 5 | This code includes data from Daniel Naber's Language Tools (czech abbreviations). 6 | This code includes data from czech wiktionary (also czech abbreviations). 7 | 8 | 9 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.ro: -------------------------------------------------------------------------------- 1 | A 2 | B 3 | C 4 | D 5 | E 6 | F 7 | G 8 | H 9 | I 10 | J 11 | K 12 | L 13 | M 14 | N 15 | O 16 | P 17 | Q 18 | R 19 | S 20 | T 21 | U 22 | V 23 | W 24 | X 25 | Y 26 | Z 27 | dpdv 28 | etc 29 | șamd 30 | M.Ap.N 31 | dl 32 | Dl 33 | d-na 34 | D-na 35 | dvs 36 | Dvs 37 | pt 38 | Pt 39 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | 4 | LatentState = namedtuple('LatentState', 'dec_input, input_len') 5 | 6 | 7 | def build_mt_model(params, cuda=True): 8 | if params.attention: 9 | from .attention import build_attention_model 10 | return build_attention_model(params, cuda=cuda) 11 | else: 12 | return None 13 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.ga: -------------------------------------------------------------------------------- 1 | 2 | A 3 | B 4 | C 5 | D 6 | E 7 | F 8 | G 9 | H 10 | I 11 | J 12 | K 13 | L 14 | M 15 | N 16 | O 17 | P 18 | Q 19 | R 20 | S 21 | T 22 | U 23 | V 24 | W 25 | X 26 | Y 27 | Z 28 | Á 29 | É 30 | Í 31 | Ó 32 | Ú 33 | 34 | Uacht 35 | Dr 36 | B.Arch 37 | 38 | m.sh 39 | .i 40 | Co 41 | Cf 42 | cf 43 | i.e 44 | r 45 | Chr 46 | lch #NUMERIC_ONLY# 47 | lgh #NUMERIC_ONLY# 48 | uimh #NUMERIC_ONLY# 49 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.sv: -------------------------------------------------------------------------------- 1 | #single upper case letter are usually initials 2 | A 3 | B 4 | C 5 | D 6 | E 7 | F 8 | G 9 | H 10 | I 11 | J 12 | K 13 | L 14 | M 15 | N 16 | O 17 | P 18 | Q 19 | R 20 | S 21 | T 22 | U 23 | V 24 | W 25 | X 26 | Y 27 | Z 28 | #misc abbreviations 29 | AB 30 | G 31 | VG 32 | dvs 33 | etc 34 | from 35 | iaf 36 | jfr 37 | kl 38 | kr 39 | mao 40 | mfl 41 | mm 42 | osv 43 | pga 44 | tex 45 | tom 46 | vs 47 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/tokenizer/lowercase.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use strict; 8 | 9 | while (@ARGV) { 10 | $_ = shift; 11 | /^-b$/ && ($| = 1, next); # not buffered (flush each line) 12 | } 13 | 14 | binmode(STDIN, ":utf8"); 15 | binmode(STDOUT, ":utf8"); 16 | 17 | while() { 18 | print lc($_); 19 | } 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simple-fairseq 2 | seq2seq翻译模型,代码简化自fairseq和opennmt-py等框架。 3 | 4 | ## 支持特性 5 | 支持Bi-LSTM及Transformer模型 6 | 7 | 支持多gpu训练 8 | 9 | 支持延迟更新 10 | 11 | ## 使用方式 12 | ### 数据处理 13 | preprocess.py文件用于将文本文件转化为二进制文件,方便保存,并加快读取速度 14 | 15 | 需要修改文件中的源语言及目标语言的文件及对应字典,并指定生成的二进制文件,然后运行 16 | 17 | ### 训练模型 18 | 19 | 训练的入口文件是main.py 需要根据需求指定部分参数, 详细内容见文件内部的参数说明 20 | 21 | 可以通过参数指定多少次update保存模型,以及是否保存optimizer参数。 22 | 23 | 可以通过参数指定是否在训练时测试bleu,如需测试要提供字典文件路径,测试文件以及参考文件。 24 | 25 | ### 测试翻译结果 26 | 27 | 调用翻译的入口文件是translate.py 需要根据具体需求修改部分参数,详见内部说明 28 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/tokenizer/remove-non-printing-char.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use utf8; 8 | 9 | binmode(STDIN, ":utf8"); 10 | binmode(STDOUT, ":utf8"); 11 | binmode(STDERR, ":utf8"); 12 | 13 | while (my $line = ) { 14 | chomp($line); 15 | #$line =~ tr/\040-\176/ /c; 16 | #$line =~ s/[^[:print:]]/ /g; 17 | #$line =~ s/\s+/ /g; 18 | $line =~ s/\p{C}/ /g; 19 | 20 | print "$line\n"; 21 | } 22 | 23 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.ca: -------------------------------------------------------------------------------- 1 | Dr 2 | Dra 3 | pàg 4 | p 5 | c 6 | av 7 | Sr 8 | Sra 9 | adm 10 | esq 11 | Prof 12 | S.A 13 | S.L 14 | p.e 15 | ptes 16 | Sta 17 | St 18 | pl 19 | màx 20 | cast 21 | dir 22 | nre 23 | fra 24 | admdora 25 | Emm 26 | Excma 27 | espf 28 | dc 29 | admdor 30 | tel 31 | angl 32 | aprox 33 | ca 34 | dept 35 | dj 36 | dl 37 | dt 38 | ds 39 | dg 40 | dv 41 | ed 42 | entl 43 | al 44 | i.e 45 | maj 46 | smin 47 | n 48 | núm 49 | pta 50 | A 51 | B 52 | C 53 | D 54 | E 55 | F 56 | G 57 | H 58 | I 59 | J 60 | K 61 | L 62 | M 63 | N 64 | O 65 | P 66 | Q 67 | R 68 | S 69 | T 70 | U 71 | V 72 | W 73 | X 74 | Y 75 | Z 76 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.sl: -------------------------------------------------------------------------------- 1 | dr 2 | Dr 3 | itd 4 | itn 5 | št #NUMERIC_ONLY# 6 | Št #NUMERIC_ONLY# 7 | d 8 | jan 9 | Jan 10 | feb 11 | Feb 12 | mar 13 | Mar 14 | apr 15 | Apr 16 | jun 17 | Jun 18 | jul 19 | Jul 20 | avg 21 | Avg 22 | sept 23 | Sept 24 | sep 25 | Sep 26 | okt 27 | Okt 28 | nov 29 | Nov 30 | dec 31 | Dec 32 | tj 33 | Tj 34 | npr 35 | Npr 36 | sl 37 | Sl 38 | op 39 | Op 40 | gl 41 | Gl 42 | oz 43 | Oz 44 | prev 45 | dipl 46 | ing 47 | prim 48 | Prim 49 | cf 50 | Cf 51 | gl 52 | Gl 53 | A 54 | B 55 | C 56 | D 57 | E 58 | F 59 | G 60 | H 61 | I 62 | J 63 | K 64 | L 65 | M 66 | N 67 | O 68 | P 69 | Q 70 | R 71 | S 72 | T 73 | U 74 | V 75 | W 76 | X 77 | Y 78 | Z 79 | -------------------------------------------------------------------------------- /single_train.py: -------------------------------------------------------------------------------- 1 | from src.model import build_mt_model 2 | from src.data.loader import load_data 3 | from src.trainer import TrainerMT 4 | from tqdm import tqdm 5 | from logging import getLogger 6 | logger = getLogger() 7 | 8 | def main(params): 9 | data = load_data(params, name='train') 10 | test_data = load_data(params, name='test') 11 | encoder, decoder, num_updates = build_mt_model(params) 12 | trainer = TrainerMT(encoder, decoder, data, test_data, params, num_updates) 13 | 14 | for i in range(trainer.epoch, params.max_epoch): 15 | logger.info("==== Starting epoch %i ...====" % trainer.epoch) 16 | trainer.train_epoch() 17 | tqdm.write('Finish epcoh %i.' % i) 18 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/tokenizer/deescape-special-chars.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use strict; 8 | 9 | while() { 10 | s/\&bar;/\|/g; # factor separator (legacy) 11 | s/\|/\|/g; # factor separator 12 | s/\</\/g; # xml 14 | s/\&bra;/\[/g; # syntax non-terminal (legacy) 15 | s/\&ket;/\]/g; # syntax non-terminal (legacy) 16 | s/\"/\"/g; # xml 17 | s/\'/\'/g; # xml 18 | s/\[/\[/g; # syntax non-terminal 19 | s/\]/\]/g; # syntax non-terminal 20 | s/\&/\&/g; # escape escape 21 | print $_; 22 | } 23 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_TRAIN=$ARNOLD_TRAIN 4 | DATA_VAL=/mnt/cephfs_new_wj/$ARNOLD_CEPH_TEST 5 | OUTPUT=$ARNOLD_OUTPUT 6 | MODEL_DIR=/mnt/cephfs_new_wj/$ARNOLD_CEPH_PREMODEL 7 | cd $(dirname $0) 8 | 9 | for name in `ls $MODEL_DIR` 10 | do 11 | if [ "${name##*.}"x = "pt"x ]; then 12 | echo $name 13 | if test -f $MODEL_DIR/predict_$name 14 | then 15 | echo '文件已存在!' 16 | else 17 | python3 translate.py --src_dico_file $DATA_VAL/src_dico_file --tgt_dico_file $DATA_VAL/tgt_dico_file --translate_file $DATA_VAL/translate_file --reference_file $DATA_VAL/reference_file --checkpoint_dir $MODEL_DIR --model_name $name $@ 18 | echo $name >> $MODEL_DIR/bleu.log 19 | perl scripts/multi-bleu.perl $DATA_VAL/reference_file < $MODEL_DIR/predict_${name:0:-3} | grep BLEU >> $MODEL_DIR/bleu.log 20 | fi 21 | fi 22 | done 23 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.yue: -------------------------------------------------------------------------------- 1 | # 2 | # Cantonese (Chinese) 3 | # 4 | # Anything in this file, followed by a period, 5 | # does NOT indicate an end-of-sentence marker. 6 | # 7 | # English/Euro-language given-name initials (appearing in 8 | # news, periodicals, etc.) 9 | A 10 | Ā 11 | B 12 | C 13 | Č 14 | D 15 | E 16 | Ē 17 | F 18 | G 19 | Ģ 20 | H 21 | I 22 | Ī 23 | J 24 | K 25 | Ķ 26 | L 27 | Ļ 28 | M 29 | N 30 | Ņ 31 | O 32 | P 33 | Q 34 | R 35 | S 36 | Š 37 | T 38 | U 39 | Ū 40 | V 41 | W 42 | X 43 | Y 44 | Z 45 | Ž 46 | 47 | # Numbers only. These should only induce breaks when followed by 48 | # a numeric sequence. 49 | # Add NUMERIC_ONLY after the word for this function. This case is 50 | # mostly for the english "No." which can either be a sentence of its 51 | # own, or if followed by a number, a non-breaking prefix. 52 | No #NUMERIC_ONLY# 53 | Nr #NUMERIC_ONLY# 54 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.zh: -------------------------------------------------------------------------------- 1 | # 2 | # Mandarin (Chinese) 3 | # 4 | # Anything in this file, followed by a period, 5 | # does NOT indicate an end-of-sentence marker. 6 | # 7 | # English/Euro-language given-name initials (appearing in 8 | # news, periodicals, etc.) 9 | A 10 | Ā 11 | B 12 | C 13 | Č 14 | D 15 | E 16 | Ē 17 | F 18 | G 19 | Ģ 20 | H 21 | I 22 | Ī 23 | J 24 | K 25 | Ķ 26 | L 27 | Ļ 28 | M 29 | N 30 | Ņ 31 | O 32 | P 33 | Q 34 | R 35 | S 36 | Š 37 | T 38 | U 39 | Ū 40 | V 41 | W 42 | X 43 | Y 44 | Z 45 | Ž 46 | 47 | # Numbers only. These should only induce breaks when followed by 48 | # a numeric sequence. 49 | # Add NUMERIC_ONLY after the word for this function. This case is 50 | # mostly for the english "No." which can either be a sentence of its 51 | # own, or if followed by a number, a non-breaking prefix. 52 | No #NUMERIC_ONLY# 53 | Nr #NUMERIC_ONLY# 54 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/tokenizer/escape-special-chars.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use strict; 8 | 9 | while() { 10 | chop; 11 | 12 | # avoid general madness 13 | s/[\000-\037]//g; 14 | s/\s+/ /g; 15 | s/^ //g; 16 | s/ $//g; 17 | 18 | # special characters in moses 19 | s/\&/\&/g; # escape escape 20 | s/\|/\|/g; # factor separator 21 | s/\/\>/g; # xml 23 | s/\'/\'/g; # xml 24 | s/\"/\"/g; # xml 25 | s/\[/\[/g; # syntax non-terminal 26 | s/\]/\]/g; # syntax non-terminal 27 | 28 | # restore xml instructions 29 | s/\<(\S+) translation="(.+?)"> (.+?) <\/(\S+)>/\<$1 translation=\"$2\"> $3 <\/$4>/g; 30 | print $_."\n"; 31 | } 32 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/generic/input-from-sgm.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use strict; 8 | 9 | die("ERROR syntax: input-from-sgm.perl < in.sgm > in.txt") 10 | unless scalar @ARGV == 0; 11 | 12 | while(my $line = ) { 13 | chop($line); 14 | while ($line =~ /]+>\s*$/i) { 15 | my $next_line = ; 16 | $line .= $next_line; 17 | chop($line); 18 | } 19 | while ($line =~ /]+>\s*(.*)\s*$/i && 20 | $line !~ /]+>\s*(.*)\s*<\/seg>/i) { 21 | my $next_line = ; 22 | $line .= $next_line; 23 | chop($line); 24 | } 25 | if ($line =~ /]+>\s*(.*)\s*<\/seg>/i) { 26 | my $input = $1; 27 | $input =~ s/\s+/ /g; 28 | $input =~ s/^ //g; 29 | $input =~ s/ $//g; 30 | print $input."\n"; 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/tokenizer/replace-unicode-punctuation.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use strict; 8 | 9 | #binmode(STDIN, ":utf8"); 10 | #binmode(STDOUT, ":utf8"); 11 | 12 | while() { 13 | s/,/,/g; 14 | s/。 */. /g; 15 | s/、/,/g; 16 | s/”/"/g; 17 | s/“/"/g; 18 | s/∶/:/g; 19 | s/:/:/g; 20 | s/?/\?/g; 21 | s/《/"/g; 22 | s/》/"/g; 23 | s/)/\)/g; 24 | s/!/\!/g; 25 | s/(/\(/g; 26 | s/;/;/g; 27 | s/1/"/g; 28 | s/」/"/g; 29 | s/「/"/g; 30 | s/0/0/g; 31 | s/3/3/g; 32 | s/2/2/g; 33 | s/5/5/g; 34 | s/6/6/g; 35 | s/9/9/g; 36 | s/7/7/g; 37 | s/8/8/g; 38 | s/4/4/g; 39 | s/. */. /g; 40 | s/~/\~/g; 41 | s/’/\'/g; 42 | s/…/\.\.\./g; 43 | s/━/\-/g; 44 | s/〈/\/g; 46 | s/【/\[/g; 47 | s/】/\]/g; 48 | s/%/\%/g; 49 | print $_; 50 | } 51 | -------------------------------------------------------------------------------- /src/fairseq_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | 4 | INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) 5 | 6 | 7 | def _get_full_incremental_state_key(module_instance, key): 8 | module_name = module_instance.__class__.__name__ 9 | 10 | # assign a unique ID to each module instance, so that incremental state is 11 | # not shared across module instances 12 | if not hasattr(module_instance, '_fairseq_instance_id'): 13 | INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 14 | module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] 15 | 16 | return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key) 17 | 18 | 19 | def get_incremental_state(module, incremental_state, key): 20 | """Helper for getting incremental state for an nn.Module.""" 21 | full_key = _get_full_incremental_state_key(module, key) 22 | if incremental_state is None or full_key not in incremental_state: 23 | return None 24 | return incremental_state[full_key] 25 | 26 | 27 | def set_incremental_state(module, incremental_state, key, value): 28 | """Helper for setting incremental state for an nn.Module.""" 29 | if incremental_state is not None: 30 | full_key = _get_full_incremental_state_key(module, key) 31 | incremental_state[full_key] = value 32 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/generic/wrap-xml.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use strict; 8 | 9 | my ($language,$src,$system) = @ARGV; 10 | die("wrapping frame not found ($src)") unless -e $src; 11 | $system = "Edinburgh" unless $system; 12 | 13 | open(SRC,$src) or die "Cannot open: $!"; 14 | my @OUT = ; 15 | chomp(@OUT); 16 | #my @OUT = `cat $decoder_output`; 17 | my $missing_end_seg = 0; 18 | while() { 19 | chomp; 20 | if (/^/) { 34 | s/(]+> *).*(<\/seg>)/$1$line$2/i; 35 | $missing_end_seg = 0; 36 | } 37 | else { 38 | s/(]+> *)[^<]*/$1$line<\/seg>/i; 39 | $missing_end_seg = 1; 40 | } 41 | } 42 | elsif ($missing_end_seg) { 43 | if (/<\/doc>/) { 44 | $missing_end_seg = 0; 45 | } 46 | else { 47 | next; 48 | } 49 | } 50 | print $_."\n"; 51 | } 52 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.es: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender 5 | #usually upper case letters are initials in a name 6 | A 7 | B 8 | C 9 | D 10 | E 11 | F 12 | G 13 | H 14 | I 15 | J 16 | K 17 | L 18 | M 19 | N 20 | O 21 | P 22 | Q 23 | R 24 | S 25 | T 26 | U 27 | V 28 | W 29 | X 30 | Y 31 | Z 32 | 33 | # Period-final abbreviation list from http://www.ctspanish.com/words/abbreviations.htm 34 | 35 | A.C 36 | Apdo 37 | Av 38 | Bco 39 | CC.AA 40 | Da 41 | Dep 42 | Dn 43 | Dr 44 | Dra 45 | EE.UU 46 | Excmo 47 | FF.CC 48 | Fil 49 | Gral 50 | J.C 51 | Let 52 | Lic 53 | N.B 54 | P.D 55 | P.V.P 56 | Prof 57 | Pts 58 | Rte 59 | S.A 60 | S.A.R 61 | S.E 62 | S.L 63 | S.R.C 64 | Sr 65 | Sra 66 | Srta 67 | Sta 68 | Sto 69 | T.V.E 70 | Tel 71 | Ud 72 | Uds 73 | V.B 74 | V.E 75 | Vd 76 | Vds 77 | a/c 78 | adj 79 | admón 80 | afmo 81 | apdo 82 | av 83 | c 84 | c.f 85 | c.g 86 | cap 87 | cm 88 | cta 89 | dcha 90 | doc 91 | ej 92 | entlo 93 | esq 94 | etc 95 | f.c 96 | gr 97 | grs 98 | izq 99 | kg 100 | km 101 | mg 102 | mm 103 | núm 104 | núm 105 | p 106 | p.a 107 | p.ej 108 | ptas 109 | pág 110 | págs 111 | pág 112 | págs 113 | q.e.g.e 114 | q.e.s.m 115 | s 116 | s.s.s 117 | vid 118 | vol 119 | -------------------------------------------------------------------------------- /src/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LayerNorm(nn.Module): 7 | """Applies Layer Normalization over the last dimension.""" 8 | 9 | def __init__(self, features, eps=1e-5): 10 | super().__init__() 11 | self.features = features 12 | self.eps = eps 13 | self.gain = nn.Parameter(torch.ones(features)) 14 | self.bias = nn.Parameter(torch.zeros(features)) 15 | self.dummy = None 16 | self.w = None 17 | self.b = None 18 | 19 | def forward(self, input): 20 | shape = input.size() 21 | 22 | # In order to force the cudnn path, everything needs to be 23 | # contiguous. Hence the check here and reallocation below. 24 | if not input.is_contiguous(): 25 | input = input.contiguous() 26 | input = input.view(1, -1, shape[-1]) 27 | 28 | # Expand w and b buffers if necessary. 29 | n = input.size(1) 30 | cur = self.dummy.numel() if self.dummy is not None else 0 31 | if cur == 0: 32 | self.dummy = input.data.new(n) 33 | self.w = input.data.new(n).fill_(1) 34 | self.b = input.data.new(n).zero_() 35 | elif n > cur: 36 | self.dummy.resize_(n) 37 | self.w.resize_(n) 38 | self.w[cur:n].fill_(1) 39 | self.b.resize_(n) 40 | self.b[cur:n].zero_() 41 | dummy = self.dummy[:n] 42 | w = self.w[:n] 43 | b = self.b[:n] 44 | output = F.batch_norm(input, dummy, dummy, w, b, True, 0., self.eps) 45 | return torch.addcmul(self.bias, 1, output.view(*shape), self.gain) -------------------------------------------------------------------------------- /src/modules/label_smoothed_cross_entropy.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class LabelSmoothedCrossEntropyLoss(nn.Module): 6 | 7 | def __init__(self, eps, padding_idx=None, size_average=False, weight=None): 8 | super().__init__() 9 | self.eps = eps 10 | self.padding_idx = padding_idx 11 | self.size_average = size_average 12 | self.register_buffer('weight', weight) 13 | 14 | def forward(self, input, target): 15 | # lprobs,input -> [batch_size_tokens,target_vocab_size] 16 | lprobs = F.log_softmax(input, dim=-1) 17 | target = target.view(-1, 1) 18 | 19 | # nll_loss get [batch_sentence*seqlength(~=batch_size_tokens), 1] 20 | # nll means no label smooth loss 21 | nll_loss = -lprobs.gather(dim=-1, index=target) 22 | # smooth loss calculates the sum of non-target loss 23 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 24 | if self.padding_idx is not None: 25 | # non_pad_mask -> [batch_size_tokens,1] 26 | non_pad_mask = target.ne(self.padding_idx) 27 | # ignore pad word loss 28 | nll_loss = nll_loss[non_pad_mask] 29 | sample_size = nll_loss.size(0) 30 | smooth_loss = smooth_loss[non_pad_mask] 31 | 32 | if self.size_average: 33 | nll_loss = nll_loss.mean() 34 | smooth_loss = smooth_loss.mean() 35 | else: 36 | nll_loss = nll_loss.sum() 37 | smooth_loss = smooth_loss.sum() 38 | 39 | eps_i = self.eps / lprobs.size(-1) 40 | loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss 41 | 42 | return loss,sample_size 43 | 44 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import inspect 3 | from logging import getLogger 4 | from src.adam_inverse_sqrt_with_warmup import AdamInverseSqrtWithWarmup 5 | 6 | logger = getLogger() 7 | 8 | 9 | def get_optimizer(parameters, s): 10 | if "," in s: 11 | method = s[:s.find(',')] 12 | optim_params = {} 13 | for x in s[s.find(',') + 1:].split(','): 14 | split = x.split('=') 15 | assert len(split) == 2 16 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 17 | optim_params[split[0]] = float(split[1]) 18 | else: 19 | method = s 20 | optim_params = {} 21 | 22 | if method == 'adam_inverse_sqrt': 23 | optim_fn = AdamInverseSqrtWithWarmup 24 | optim_params['lr'] = optim_params.get('lr', 0.0005) 25 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.98)) 26 | optim_params['warmup_updates'] = optim_params.get('warmup_updates', 4000) 27 | optim_params['weight_decay'] = optim_params.get('weight_decay', 0.0001) 28 | optim_params.pop('beta1', None) 29 | optim_params.pop('beta2', None) 30 | else: 31 | raise Exception('write yourself method: "%s"' % method) 32 | 33 | # check that we give good parameters to the optimizer 34 | expected_args = inspect.getfullargspec(optim_fn.__init__)[0] 35 | assert expected_args[:2] == ['self', 'params'] 36 | if not all(k in expected_args[2:] for k in optim_params.keys()): 37 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 38 | str(expected_args[2:]), str(optim_params.keys()))) 39 | logger.info(optim_params) 40 | return optim_fn(parameters, **optim_params) 41 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.lv: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | A 7 | Ā 8 | B 9 | C 10 | Č 11 | D 12 | E 13 | Ē 14 | F 15 | G 16 | Ģ 17 | H 18 | I 19 | Ī 20 | J 21 | K 22 | Ķ 23 | L 24 | Ļ 25 | M 26 | N 27 | Ņ 28 | O 29 | P 30 | Q 31 | R 32 | S 33 | Š 34 | T 35 | U 36 | Ū 37 | V 38 | W 39 | X 40 | Y 41 | Z 42 | Ž 43 | 44 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 45 | dr 46 | Dr 47 | med 48 | prof 49 | Prof 50 | inž 51 | Inž 52 | ist.loc 53 | Ist.loc 54 | kor.loc 55 | Kor.loc 56 | v.i 57 | vietn 58 | Vietn 59 | 60 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 61 | a.l 62 | t.p 63 | pārb 64 | Pārb 65 | vec 66 | Vec 67 | inv 68 | Inv 69 | sk 70 | Sk 71 | spec 72 | Spec 73 | vienk 74 | Vienk 75 | virz 76 | Virz 77 | māksl 78 | Māksl 79 | mūz 80 | Mūz 81 | akad 82 | Akad 83 | soc 84 | Soc 85 | galv 86 | Galv 87 | vad 88 | Vad 89 | sertif 90 | Sertif 91 | folkl 92 | Folkl 93 | hum 94 | Hum 95 | 96 | #Numbers only. These should only induce breaks when followed by a numeric sequence 97 | # add NUMERIC_ONLY after the word for this function 98 | #This case is mostly for the english "No." which can either be a sentence of its own, or 99 | #if followed by a number, a non-breaking prefix 100 | Nr #NUMERIC_ONLY# 101 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.fr: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | # 4 | #any single upper case letter followed by a period is not a sentence ender 5 | #usually upper case letters are initials in a name 6 | #no French words end in single lower-case letters, so we throw those in too? 7 | A 8 | B 9 | C 10 | D 11 | E 12 | F 13 | G 14 | H 15 | I 16 | J 17 | K 18 | L 19 | M 20 | N 21 | O 22 | P 23 | Q 24 | R 25 | S 26 | T 27 | U 28 | V 29 | W 30 | X 31 | Y 32 | Z 33 | #a 34 | b 35 | c 36 | d 37 | e 38 | f 39 | g 40 | h 41 | i 42 | j 43 | k 44 | l 45 | m 46 | n 47 | o 48 | p 49 | q 50 | r 51 | s 52 | t 53 | u 54 | v 55 | w 56 | x 57 | y 58 | z 59 | 60 | # Period-final abbreviation list for French 61 | A.C.N 62 | A.M 63 | art 64 | ann 65 | apr 66 | av 67 | auj 68 | lib 69 | B.P 70 | boul 71 | ca 72 | c.-à-d 73 | cf 74 | ch.-l 75 | chap 76 | contr 77 | C.P.I 78 | C.Q.F.D 79 | C.N 80 | C.N.S 81 | C.S 82 | dir 83 | éd 84 | e.g 85 | env 86 | al 87 | etc 88 | E.V 89 | ex 90 | fasc 91 | fém 92 | fig 93 | fr 94 | hab 95 | ibid 96 | id 97 | i.e 98 | inf 99 | LL.AA 100 | LL.AA.II 101 | LL.AA.RR 102 | LL.AA.SS 103 | L.D 104 | LL.EE 105 | LL.MM 106 | LL.MM.II.RR 107 | loc.cit 108 | masc 109 | MM 110 | ms 111 | N.B 112 | N.D.A 113 | N.D.L.R 114 | N.D.T 115 | n/réf 116 | NN.SS 117 | N.S 118 | N.D 119 | N.P.A.I 120 | p.c.c 121 | pl 122 | pp 123 | p.ex 124 | p.j 125 | P.S 126 | R.A.S 127 | R.-V 128 | R.P 129 | R.I.P 130 | SS 131 | S.S 132 | S.A 133 | S.A.I 134 | S.A.R 135 | S.A.S 136 | S.E 137 | sec 138 | sect 139 | sing 140 | S.M 141 | S.M.I.R 142 | sq 143 | sqq 144 | suiv 145 | sup 146 | suppl 147 | tél 148 | T.S.V.P 149 | vb 150 | vol 151 | vs 152 | X.O 153 | Z.I 154 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from datetime import timedelta 4 | 5 | 6 | class LogFormatter(object): # format logging record 7 | 8 | def __init__(self): 9 | self.start_time = time.time() 10 | 11 | def format(self, record): 12 | elapsed_seconds = round(record.created - self.start_time) 13 | 14 | prefix = "%s - %s - %s" % ( 15 | record.levelname, 16 | time.strftime('%x %X'), 17 | timedelta(seconds=elapsed_seconds) 18 | ) 19 | message = record.getMessage() 20 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) 21 | return "%s - %s" % (prefix, message) if message else '' 22 | 23 | 24 | def create_logger(filepath): # create a logger based on logging module, using special format 25 | """ 26 | Create a logger. 27 | """ 28 | # create log formatter 29 | log_formatter = LogFormatter() 30 | 31 | # create file handler and set level to debug 32 | if filepath is not None: 33 | file_handler = logging.FileHandler(filepath, "a") 34 | file_handler.setLevel(logging.DEBUG) 35 | file_handler.setFormatter(log_formatter) 36 | 37 | # create console handler and set level to info 38 | console_handler = logging.StreamHandler() 39 | console_handler.setLevel(logging.INFO) 40 | console_handler.setFormatter(log_formatter) 41 | 42 | # create logger and set level to debug 43 | logger = logging.getLogger() 44 | logger.handlers = [] 45 | logger.setLevel(logging.DEBUG) 46 | logger.propagate = False 47 | if filepath is not None: 48 | logger.addHandler(file_handler) 49 | logger.addHandler(console_handler) 50 | 51 | # reset logger elapsed time 52 | def reset_time(): 53 | log_formatter.start_time = time.time() 54 | logger.reset_time = reset_time 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | A 7 | B 8 | C 9 | D 10 | E 11 | F 12 | G 13 | H 14 | I 15 | J 16 | K 17 | L 18 | M 19 | N 20 | O 21 | P 22 | Q 23 | R 24 | S 25 | T 26 | U 27 | V 28 | W 29 | X 30 | Y 31 | Z 32 | 33 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 34 | Adj 35 | Adm 36 | Adv 37 | Asst 38 | Bart 39 | Bldg 40 | Brig 41 | Bros 42 | Capt 43 | Cmdr 44 | Col 45 | Comdr 46 | Con 47 | Corp 48 | Cpl 49 | DR 50 | Dr 51 | Drs 52 | Ens 53 | Gen 54 | Gov 55 | Hon 56 | Hr 57 | Hosp 58 | Insp 59 | Lt 60 | MM 61 | MR 62 | MRS 63 | MS 64 | Maj 65 | Messrs 66 | Mlle 67 | Mme 68 | Mr 69 | Mrs 70 | Ms 71 | Msgr 72 | Op 73 | Ord 74 | Pfc 75 | Ph 76 | Prof 77 | Pvt 78 | Rep 79 | Reps 80 | Res 81 | Rev 82 | Rt 83 | Sen 84 | Sens 85 | Sfc 86 | Sgt 87 | Sr 88 | St 89 | Supt 90 | Surg 91 | 92 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 93 | v 94 | vs 95 | i.e 96 | rev 97 | e.g 98 | 99 | #Numbers only. These should only induce breaks when followed by a numeric sequence 100 | # add NUMERIC_ONLY after the word for this function 101 | #This case is mostly for the english "No." which can either be a sentence of its own, or 102 | #if followed by a number, a non-breaking prefix 103 | No #NUMERIC_ONLY# 104 | Nos 105 | Art #NUMERIC_ONLY# 106 | Nr 107 | pp #NUMERIC_ONLY# 108 | 109 | #month abbreviations 110 | Jan 111 | Feb 112 | Mar 113 | Apr 114 | #May is a full word 115 | Jun 116 | Jul 117 | Aug 118 | Sep 119 | Oct 120 | Nov 121 | Dec 122 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.fi: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT 2 | #indicate an end-of-sentence marker. Special cases are included for prefixes 3 | #that ONLY appear before 0-9 numbers. 4 | 5 | #This list is compiled from omorfi database 6 | #by Tommi A Pirinen. 7 | 8 | 9 | #any single upper case letter followed by a period is not a sentence ender 10 | A 11 | B 12 | C 13 | D 14 | E 15 | F 16 | G 17 | H 18 | I 19 | J 20 | K 21 | L 22 | M 23 | N 24 | O 25 | P 26 | Q 27 | R 28 | S 29 | T 30 | U 31 | V 32 | W 33 | X 34 | Y 35 | Z 36 | Å 37 | Ä 38 | Ö 39 | 40 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 41 | alik 42 | alil 43 | amir 44 | apul 45 | apul.prof 46 | arkkit 47 | ass 48 | assist 49 | dipl 50 | dipl.arkkit 51 | dipl.ekon 52 | dipl.ins 53 | dipl.kielenk 54 | dipl.kirjeenv 55 | dipl.kosm 56 | dipl.urk 57 | dos 58 | erikoiseläinl 59 | erikoishammasl 60 | erikoisl 61 | erikoist 62 | ev.luutn 63 | evp 64 | fil 65 | ft 66 | hallinton 67 | hallintot 68 | hammaslääket 69 | jatk 70 | jääk 71 | kansaned 72 | kapt 73 | kapt.luutn 74 | kenr 75 | kenr.luutn 76 | kenr.maj 77 | kers 78 | kirjeenv 79 | kom 80 | kom.kapt 81 | komm 82 | konst 83 | korpr 84 | luutn 85 | maist 86 | maj 87 | Mr 88 | Mrs 89 | Ms 90 | M.Sc 91 | neuv 92 | nimim 93 | Ph.D 94 | prof 95 | puh.joht 96 | pääll 97 | res 98 | san 99 | siht 100 | suom 101 | sähköp 102 | säv 103 | toht 104 | toim 105 | toim.apul 106 | toim.joht 107 | toim.siht 108 | tuom 109 | ups 110 | vänr 111 | vääp 112 | ye.ups 113 | ylik 114 | ylil 115 | ylim 116 | ylimatr 117 | yliop 118 | yliopp 119 | ylip 120 | yliv 121 | 122 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall 123 | #into this category - it sometimes ends a sentence) 124 | e.g 125 | ent 126 | esim 127 | huom 128 | i.e 129 | ilm 130 | l 131 | mm 132 | myöh 133 | nk 134 | nyk 135 | par 136 | po 137 | t 138 | v 139 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.hu: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | A 7 | B 8 | C 9 | D 10 | E 11 | F 12 | G 13 | H 14 | I 15 | J 16 | K 17 | L 18 | M 19 | N 20 | O 21 | P 22 | Q 23 | R 24 | S 25 | T 26 | U 27 | V 28 | W 29 | X 30 | Y 31 | Z 32 | Á 33 | É 34 | Í 35 | Ó 36 | Ö 37 | Ő 38 | Ú 39 | Ü 40 | Ű 41 | 42 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 43 | Dr 44 | dr 45 | kb 46 | Kb 47 | vö 48 | Vö 49 | pl 50 | Pl 51 | ca 52 | Ca 53 | min 54 | Min 55 | max 56 | Max 57 | ún 58 | Ún 59 | prof 60 | Prof 61 | de 62 | De 63 | du 64 | Du 65 | Szt 66 | St 67 | 68 | #Numbers only. These should only induce breaks when followed by a numeric sequence 69 | # add NUMERIC_ONLY after the word for this function 70 | #This case is mostly for the english "No." which can either be a sentence of its own, or 71 | #if followed by a number, a non-breaking prefix 72 | 73 | # Month name abbreviations 74 | jan #NUMERIC_ONLY# 75 | Jan #NUMERIC_ONLY# 76 | Feb #NUMERIC_ONLY# 77 | feb #NUMERIC_ONLY# 78 | márc #NUMERIC_ONLY# 79 | Márc #NUMERIC_ONLY# 80 | ápr #NUMERIC_ONLY# 81 | Ápr #NUMERIC_ONLY# 82 | máj #NUMERIC_ONLY# 83 | Máj #NUMERIC_ONLY# 84 | jún #NUMERIC_ONLY# 85 | Jún #NUMERIC_ONLY# 86 | Júl #NUMERIC_ONLY# 87 | júl #NUMERIC_ONLY# 88 | aug #NUMERIC_ONLY# 89 | Aug #NUMERIC_ONLY# 90 | Szept #NUMERIC_ONLY# 91 | szept #NUMERIC_ONLY# 92 | okt #NUMERIC_ONLY# 93 | Okt #NUMERIC_ONLY# 94 | nov #NUMERIC_ONLY# 95 | Nov #NUMERIC_ONLY# 96 | dec #NUMERIC_ONLY# 97 | Dec #NUMERIC_ONLY# 98 | 99 | # Other abbreviations 100 | tel #NUMERIC_ONLY# 101 | Tel #NUMERIC_ONLY# 102 | Fax #NUMERIC_ONLY# 103 | fax #NUMERIC_ONLY# 104 | -------------------------------------------------------------------------------- /src/adam_inverse_sqrt_with_warmup.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | 3 | 4 | class AdamInverseSqrtWithWarmup(optim.Adam): 5 | """Decay the LR based on the inverse square root of the update number. 6 | 7 | We also support a warmup phase where we linearly increase the learning rate 8 | from some initial learning rate (`warmup-init-lr`) until the configured 9 | learning rate (`lr`). Thereafter we decay proportional to the number of 10 | updates, with a decay factor set to align with the configured learning rate. 11 | 12 | During warmup: 13 | 14 | lrs = torch.linspace(warmup_init_lr, lr, warmup_updates) 15 | lr = lrs[update_num] 16 | 17 | After warmup: 18 | 19 | lr = decay_factor / sqrt(update_num) 20 | 21 | where 22 | 23 | decay_factor = lr * sqrt(warmup_updates) 24 | """ 25 | 26 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 27 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7): 28 | super().__init__( 29 | params, 30 | lr=warmup_init_lr, 31 | betas=betas, 32 | eps=eps, 33 | weight_decay=weight_decay, 34 | ) 35 | self.warmup_updates = warmup_updates 36 | self.warmup_init_lr = warmup_init_lr 37 | 38 | # linearly warmup for the first warmup_updates 39 | warmup_end_lr = lr 40 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates 41 | 42 | # then, decay prop. to the inverse square root of the update number 43 | self.decay_factor = warmup_end_lr * warmup_updates**0.5 44 | 45 | self._num_updates = 0 46 | 47 | def get_lr_for_step(self, num_updates): 48 | if num_updates < self.warmup_updates: 49 | return self.warmup_init_lr + num_updates*self.lr_step 50 | else: 51 | return self.decay_factor * num_updates**-0.5 52 | 53 | def step(self, closure=None): 54 | super().step(closure) 55 | self._num_updates += 1 56 | 57 | # update learning rate 58 | new_lr = self.get_lr_for_step(self._num_updates) 59 | for param_group in self.param_groups: 60 | param_group['lr'] = new_lr 61 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.nl: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | #Sources: http://nl.wikipedia.org/wiki/Lijst_van_afkortingen 4 | # http://nl.wikipedia.org/wiki/Aanspreekvorm 5 | # http://nl.wikipedia.org/wiki/Titulatuur_in_het_Nederlands_hoger_onderwijs 6 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 7 | #usually upper case letters are initials in a name 8 | A 9 | B 10 | C 11 | D 12 | E 13 | F 14 | G 15 | H 16 | I 17 | J 18 | K 19 | L 20 | M 21 | N 22 | O 23 | P 24 | Q 25 | R 26 | S 27 | T 28 | U 29 | V 30 | W 31 | X 32 | Y 33 | Z 34 | 35 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 36 | bacc 37 | bc 38 | bgen 39 | c.i 40 | dhr 41 | dr 42 | dr.h.c 43 | drs 44 | drs 45 | ds 46 | eint 47 | fa 48 | Fa 49 | fam 50 | gen 51 | genm 52 | ing 53 | ir 54 | jhr 55 | jkvr 56 | jr 57 | kand 58 | kol 59 | lgen 60 | lkol 61 | Lt 62 | maj 63 | Mej 64 | mevr 65 | Mme 66 | mr 67 | mr 68 | Mw 69 | o.b.s 70 | plv 71 | prof 72 | ritm 73 | tint 74 | Vz 75 | Z.D 76 | Z.D.H 77 | Z.E 78 | Z.Em 79 | Z.H 80 | Z.K.H 81 | Z.K.M 82 | Z.M 83 | z.v 84 | 85 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 86 | #we seem to have a lot of these in dutch i.e.: i.p.v - in plaats van (in stead of) never ends a sentence 87 | a.g.v 88 | bijv 89 | bijz 90 | bv 91 | d.w.z 92 | e.c 93 | e.g 94 | e.k 95 | ev 96 | i.p.v 97 | i.s.m 98 | i.t.t 99 | i.v.m 100 | m.a.w 101 | m.b.t 102 | m.b.v 103 | m.h.o 104 | m.i 105 | m.i.v 106 | v.w.t 107 | 108 | #Numbers only. These should only induce breaks when followed by a numeric sequence 109 | # add NUMERIC_ONLY after the word for this function 110 | #This case is mostly for the english "No." which can either be a sentence of its own, or 111 | #if followed by a number, a non-breaking prefix 112 | Nr #NUMERIC_ONLY# 113 | Nrs 114 | nrs 115 | nr #NUMERIC_ONLY# 116 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use strict; 8 | 9 | my $language = "en"; 10 | my $PENN = 0; 11 | 12 | while (@ARGV) { 13 | $_ = shift; 14 | /^-b$/ && ($| = 1, next); # not buffered (flush each line) 15 | /^-l$/ && ($language = shift, next); 16 | /^[^\-]/ && ($language = $_, next); 17 | /^-penn$/ && ($PENN = 1, next); 18 | } 19 | 20 | while() { 21 | s/\r//g; 22 | # remove extra spaces 23 | s/\(/ \(/g; 24 | s/\)/\) /g; s/ +/ /g; 25 | s/\) ([\.\!\:\?\;\,])/\)$1/g; 26 | s/\( /\(/g; 27 | s/ \)/\)/g; 28 | s/(\d) \%/$1\%/g; 29 | s/ :/:/g; 30 | s/ ;/;/g; 31 | # normalize unicode punctuation 32 | if ($PENN == 0) { 33 | s/\`/\'/g; 34 | s/\'\'/ \" /g; 35 | } 36 | 37 | s/„/\"/g; 38 | s/“/\"/g; 39 | s/”/\"/g; 40 | s/–/-/g; 41 | s/—/ - /g; s/ +/ /g; 42 | s/´/\'/g; 43 | s/([a-z])‘([a-z])/$1\'$2/gi; 44 | s/([a-z])’([a-z])/$1\'$2/gi; 45 | s/‘/\"/g; 46 | s/‚/\"/g; 47 | s/’/\"/g; 48 | s/''/\"/g; 49 | s/´´/\"/g; 50 | s/…/.../g; 51 | # French quotes 52 | s/ « / \"/g; 53 | s/« /\"/g; 54 | s/«/\"/g; 55 | s/ » /\" /g; 56 | s/ »/\"/g; 57 | s/»/\"/g; 58 | # handle pseudo-spaces 59 | s/ \%/\%/g; 60 | s/nº /nº /g; 61 | s/ :/:/g; 62 | s/ ºC/ ºC/g; 63 | s/ cm/ cm/g; 64 | s/ \?/\?/g; 65 | s/ \!/\!/g; 66 | s/ ;/;/g; 67 | s/, /, /g; s/ +/ /g; 68 | 69 | # English "quotation," followed by comma, style 70 | if ($language eq "en") { 71 | s/\"([,\.]+)/$1\"/g; 72 | } 73 | # Czech is confused 74 | elsif ($language eq "cs" || $language eq "cz") { 75 | } 76 | # German/Spanish/French "quotation", followed by comma, style 77 | else { 78 | s/,\"/\",/g; 79 | s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence 80 | } 81 | 82 | 83 | if ($language eq "de" || $language eq "es" || $language eq "cz" || $language eq "cs" || $language eq "fr") { 84 | s/(\d) (\d)/$1,$2/g; 85 | } 86 | else { 87 | s/(\d) (\d)/$1.$2/g; 88 | } 89 | print $_; 90 | } 91 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.is: -------------------------------------------------------------------------------- 1 | no #NUMERIC_ONLY# 2 | No #NUMERIC_ONLY# 3 | nr #NUMERIC_ONLY# 4 | Nr #NUMERIC_ONLY# 5 | nR #NUMERIC_ONLY# 6 | NR #NUMERIC_ONLY# 7 | a 8 | b 9 | c 10 | d 11 | e 12 | f 13 | g 14 | h 15 | i 16 | j 17 | k 18 | l 19 | m 20 | n 21 | o 22 | p 23 | q 24 | r 25 | s 26 | t 27 | u 28 | v 29 | w 30 | x 31 | y 32 | z 33 | ^ 34 | í 35 | á 36 | ó 37 | æ 38 | A 39 | B 40 | C 41 | D 42 | E 43 | F 44 | G 45 | H 46 | I 47 | J 48 | K 49 | L 50 | M 51 | N 52 | O 53 | P 54 | Q 55 | R 56 | S 57 | T 58 | U 59 | V 60 | W 61 | X 62 | Y 63 | Z 64 | ab.fn 65 | a.fn 66 | afs 67 | al 68 | alm 69 | alg 70 | andh 71 | ath 72 | aths 73 | atr 74 | ao 75 | au 76 | aukaf 77 | áfn 78 | áhrl.s 79 | áhrs 80 | ákv.gr 81 | ákv 82 | bh 83 | bls 84 | dr 85 | e.Kr 86 | et 87 | ef 88 | efn 89 | ennfr 90 | eink 91 | end 92 | e.st 93 | erl 94 | fél 95 | fskj 96 | fh 97 | f.hl 98 | físl 99 | fl 100 | fn 101 | fo 102 | forl 103 | frb 104 | frl 105 | frh 106 | frt 107 | fsl 108 | fsh 109 | fs 110 | fsk 111 | fst 112 | f.Kr 113 | ft 114 | fv 115 | fyrrn 116 | fyrrv 117 | germ 118 | gm 119 | gr 120 | hdl 121 | hdr 122 | hf 123 | hl 124 | hlsk 125 | hljsk 126 | hljv 127 | hljóðv 128 | hr 129 | hv 130 | hvk 131 | holl 132 | Hos 133 | höf 134 | hk 135 | hrl 136 | ísl 137 | kaf 138 | kap 139 | Khöfn 140 | kk 141 | kg 142 | kk 143 | km 144 | kl 145 | klst 146 | kr 147 | kt 148 | kgúrsk 149 | kvk 150 | leturbr 151 | lh 152 | lh.nt 153 | lh.þt 154 | lo 155 | ltr 156 | mlja 157 | mljó 158 | millj 159 | mm 160 | mms 161 | m.fl 162 | miðm 163 | mgr 164 | mst 165 | mín 166 | nf 167 | nh 168 | nhm 169 | nl 170 | nk 171 | nmgr 172 | no 173 | núv 174 | nt 175 | o.áfr 176 | o.m.fl 177 | ohf 178 | o.fl 179 | o.s.frv 180 | ófn 181 | ób 182 | óákv.gr 183 | óákv 184 | pfn 185 | PR 186 | pr 187 | Ritstj 188 | Rvík 189 | Rvk 190 | samb 191 | samhlj 192 | samn 193 | samn 194 | sbr 195 | sek 196 | sérn 197 | sf 198 | sfn 199 | sh 200 | sfn 201 | sh 202 | s.hl 203 | sk 204 | skv 205 | sl 206 | sn 207 | so 208 | ss.us 209 | s.st 210 | samþ 211 | sbr 212 | shlj 213 | sign 214 | skál 215 | st 216 | st.s 217 | stk 218 | sþ 219 | teg 220 | tbl 221 | tfn 222 | tl 223 | tvíhlj 224 | tvt 225 | till 226 | to 227 | umr 228 | uh 229 | us 230 | uppl 231 | útg 232 | vb 233 | Vf 234 | vh 235 | vkf 236 | Vl 237 | vl 238 | vlf 239 | vmf 240 | 8vo 241 | vsk 242 | vth 243 | þt 244 | þf 245 | þjs 246 | þgf 247 | þlt 248 | þolm 249 | þm 250 | þml 251 | þýð 252 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.it: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | A 7 | B 8 | C 9 | D 10 | E 11 | F 12 | G 13 | H 14 | I 15 | J 16 | K 17 | L 18 | M 19 | N 20 | O 21 | P 22 | Q 23 | R 24 | S 25 | T 26 | U 27 | V 28 | W 29 | X 30 | Y 31 | Z 32 | 33 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 34 | Adj 35 | Adm 36 | Adv 37 | Amn 38 | Arch 39 | Asst 40 | Avv 41 | Bart 42 | Bcc 43 | Bldg 44 | Brig 45 | Bros 46 | C.A.P 47 | C.P 48 | Capt 49 | Cc 50 | Cmdr 51 | Co 52 | Col 53 | Comdr 54 | Con 55 | Corp 56 | Cpl 57 | DR 58 | Dott 59 | Dr 60 | Drs 61 | Egr 62 | Ens 63 | Gen 64 | Geom 65 | Gov 66 | Hon 67 | Hosp 68 | Hr 69 | Id 70 | Ing 71 | Insp 72 | Lt 73 | MM 74 | MR 75 | MRS 76 | MS 77 | Maj 78 | Messrs 79 | Mlle 80 | Mme 81 | Mo 82 | Mons 83 | Mr 84 | Mrs 85 | Ms 86 | Msgr 87 | N.B 88 | Op 89 | Ord 90 | P.S 91 | P.T 92 | Pfc 93 | Ph 94 | Prof 95 | Pvt 96 | RP 97 | RSVP 98 | Rag 99 | Rep 100 | Reps 101 | Res 102 | Rev 103 | Rif 104 | Rt 105 | S.A 106 | S.B.F 107 | S.P.M 108 | S.p.A 109 | S.r.l 110 | Sen 111 | Sens 112 | Sfc 113 | Sgt 114 | Sig 115 | Sigg 116 | Soc 117 | Spett 118 | Sr 119 | St 120 | Supt 121 | Surg 122 | V.P 123 | 124 | # other 125 | a.c 126 | acc 127 | all 128 | banc 129 | c.a 130 | c.c.p 131 | c.m 132 | c.p 133 | c.s 134 | c.v 135 | corr 136 | dott 137 | e.p.c 138 | ecc 139 | es 140 | fatt 141 | gg 142 | int 143 | lett 144 | ogg 145 | on 146 | p.c 147 | p.c.c 148 | p.es 149 | p.f 150 | p.r 151 | p.v 152 | post 153 | pp 154 | racc 155 | ric 156 | s.n.c 157 | seg 158 | sgg 159 | ss 160 | tel 161 | u.s 162 | v.r 163 | v.s 164 | 165 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 166 | v 167 | vs 168 | i.e 169 | rev 170 | e.g 171 | 172 | #Numbers only. These should only induce breaks when followed by a numeric sequence 173 | # add NUMERIC_ONLY after the word for this function 174 | #This case is mostly for the english "No." which can either be a sentence of its own, or 175 | #if followed by a number, a non-breaking prefix 176 | No #NUMERIC_ONLY# 177 | Nos 178 | Art #NUMERIC_ONLY# 179 | Nr 180 | pp #NUMERIC_ONLY# 181 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from src.logger import create_logger 5 | from src.data.dictionary import Dictionary 6 | 7 | if __name__ == '__main__': 8 | logger = create_logger(None) 9 | 10 | # src_voc_path = 'data/vocab.zh' 11 | # src_txt_path = 'data/all.zh.bpe' 12 | # tgt_voc_path = 'data/vocab.en' 13 | # tgt_txt_path = 'data/all.en.bpe' 14 | 15 | # bin_path = 'data/cwmt.bin' 16 | src_voc_path = sys.argv[3] 17 | src_txt_path = sys.argv[1] 18 | tgt_voc_path = sys.argv[4] 19 | tgt_txt_path = sys.argv[2] 20 | bin_path = sys.argv[5] 21 | assert os.path.isfile(src_voc_path) 22 | assert os.path.isfile(src_txt_path) 23 | assert os.path.isfile(tgt_voc_path) 24 | assert os.path.isfile(tgt_txt_path) 25 | 26 | src_dico = Dictionary.read_vocab(src_voc_path) 27 | tgt_dico = Dictionary.read_vocab(tgt_voc_path) 28 | 29 | data = Dictionary.index_data(src_txt_path, tgt_txt_path, src_dico, tgt_dico, bin_path) 30 | if data is None: 31 | exit(0) 32 | logger.info("%i words (%i unique) in %i sentences." % ( 33 | len(data['src_sentences']) - len(data['src_positions']), 34 | len(data['src_dico']), 35 | len(data['src_positions']) 36 | )) 37 | logger.info("%i words (%i unique) in %i sentences." % ( 38 | len(data['tgt_sentences']) - len(data['tgt_positions']), 39 | len(data['tgt_dico']), 40 | len(data['tgt_positions']) 41 | )) 42 | if len(data['src_unk_words']) > 0: 43 | logger.info("%i unknown words (%i unique), covering %.2f%% of the data." % ( 44 | sum(data['src_unk_words'].values()), 45 | len(data['src_unk_words']), 46 | sum(data['src_unk_words'].values()) * 100. / (len(data['src_sentences']) - len(data['src_positions'])) 47 | )) 48 | if len(data['src_unk_words']) < 30: 49 | for w, c in sorted(data['src_unk_words'].items(), key=lambda x: x[1])[::-1]: 50 | logger.info("%s: %i" % (w, c)) 51 | else: 52 | logger.info("0 unknown word.") 53 | 54 | if len(data['tgt_unk_words']) > 0: 55 | logger.info("%i unknown words (%i unique), covering %.2f%% of the data." % ( 56 | sum(data['tgt_unk_words'].values()), 57 | len(data['tgt_unk_words']), 58 | sum(data['tgt_unk_words'].values()) * 100. / (len(data['tgt_sentences']) - len(data['tgt_positions'])) 59 | )) 60 | if len(data['tgt_unk_words']) < 30: 61 | for w, c in sorted(data['tgt_unk_words'].items(), key=lambda x: x[1])[::-1]: 62 | logger.info("%s: %i" % (w, c)) 63 | else: 64 | logger.info("0 unknown word.") 65 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/recaser/detruecase.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | use strict; 8 | use Getopt::Long "GetOptions"; 9 | 10 | binmode(STDIN, ":utf8"); 11 | binmode(STDOUT, ":utf8"); 12 | 13 | my ($SRC,$INFILE,$UNBUFFERED); 14 | die("detruecase.perl < in > out") 15 | unless &GetOptions('headline=s' => \$SRC, 16 | 'in=s' => \$INFILE, 17 | 'b|unbuffered' => \$UNBUFFERED); 18 | if (defined($UNBUFFERED) && $UNBUFFERED) { $|=1; } 19 | 20 | my %SENTENCE_END = ("."=>1,":"=>1,"?"=>1,"!"=>1); 21 | my %DELAYED_SENTENCE_START = ("("=>1,"["=>1,"\""=>1,"'"=>1,"""=>1,"'"=>1,"["=>1,"]"=>1); 22 | 23 | # lowercase even in headline 24 | my %ALWAYS_LOWER; 25 | foreach ("a","after","against","al-.+","and","any","as","at","be","because","between","by","during","el-.+","for","from","his","in","is","its","last","not","of","off","on","than","the","their","this","to","was","were","which","will","with") { $ALWAYS_LOWER{$_} = 1; } 26 | 27 | # find out about the headlines 28 | my @HEADLINE; 29 | if (defined($SRC)) { 30 | open(SRC,$SRC); 31 | my $headline_flag = 0; 32 | while() { 33 | $headline_flag = 1 if //; 34 | $headline_flag = 0 if /<.hl>/; 35 | next unless /^) { 46 | &process($_,$sentence++); 47 | } 48 | close(IN); 49 | } 50 | else { 51 | while() { 52 | &process($_,$sentence++); 53 | } 54 | } 55 | 56 | sub process { 57 | my $line = $_[0]; 58 | chomp($line); 59 | $line =~ s/^\s+//; 60 | $line =~ s/\s+$//; 61 | my @WORD = split(/\s+/,$line); 62 | 63 | # uppercase at sentence start 64 | my $sentence_start = 1; 65 | for(my $i=0;$i self.weights.size(0): 70 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 71 | max_pos, 72 | self.embedding_dim, 73 | self.padding_idx, 74 | ).type_as(self.weights) 75 | self.weights = self.weights.type_as(self._float_tensor) 76 | weights = self.weights 77 | 78 | if incremental_state is not None: 79 | # positions is the same for every token when decoding a single step 80 | return weights[self.padding_idx + seq_len, :].expand(1, bsz, -1) 81 | 82 | positions = make_positions(input.data, self.padding_idx, self.left_pad) 83 | return weights.index_select(0, positions.view(-1)).view(seq_len, bsz, -1) 84 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/recaser/truecase.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id: train-recaser.perl 1326 2007-03-26 05:44:27Z bojar $ 7 | 8 | use warnings; 9 | use strict; 10 | use Getopt::Long "GetOptions"; 11 | 12 | binmode(STDIN, ":utf8"); 13 | binmode(STDOUT, ":utf8"); 14 | 15 | # apply switches 16 | # ASR input has no case, make sure it is lowercase, and make sure known are cased eg. 'i' to be uppercased even if i is known 17 | my ($MODEL, $UNBUFFERED, $ASR); 18 | die("truecase.perl --model MODEL [-b] [-a] < in > out") 19 | unless &GetOptions('model=s' => \$MODEL,'b|unbuffered' => \$UNBUFFERED, 'a|asr' => \$ASR) 20 | && defined($MODEL); 21 | if (defined($UNBUFFERED) && $UNBUFFERED) { $|=1; } 22 | my $asr = 0; 23 | if (defined($ASR) && $ASR) { $asr = 1; } 24 | 25 | my (%BEST,%KNOWN); 26 | open(MODEL,$MODEL) || die("ERROR: could not open '$MODEL'"); 27 | binmode(MODEL, ":utf8"); 28 | while() { 29 | my ($word,@OPTIONS) = split; 30 | $BEST{ lc($word) } = $word; 31 | if ($asr == 0) { 32 | $KNOWN{ $word } = 1; 33 | for(my $i=1;$i<$#OPTIONS;$i+=2) { 34 | $KNOWN{ $OPTIONS[$i] } = 1; 35 | } 36 | } 37 | } 38 | close(MODEL); 39 | 40 | my %SENTENCE_END = ("."=>1,":"=>1,"?"=>1,"!"=>1); 41 | my %DELAYED_SENTENCE_START = ("("=>1,"["=>1,"\""=>1,"'"=>1,"'"=>1,"""=>1,"["=>1,"]"=>1); 42 | 43 | while() { 44 | chop; 45 | my ($WORD,$MARKUP) = split_xml($_); 46 | my $sentence_start = 1; 47 | for(my $i=0;$i<=$#$WORD;$i++) { 48 | print " " if $i && $$MARKUP[$i] eq ''; 49 | print $$MARKUP[$i]; 50 | 51 | my ($word,$otherfactors); 52 | if ($$WORD[$i] =~ /^([^\|]+)(.*)/) 53 | { 54 | $word = $1; 55 | $otherfactors = $2; 56 | } 57 | else 58 | { 59 | $word = $$WORD[$i]; 60 | $otherfactors = ""; 61 | } 62 | if ($asr){ 63 | $word = lc($word); #make sure ASR output is not uc 64 | } 65 | 66 | if ($sentence_start && defined($BEST{lc($word)})) { 67 | print $BEST{lc($word)}; # truecase sentence start 68 | } 69 | elsif (defined($KNOWN{$word})) { 70 | print $word; # don't change known words 71 | } 72 | elsif (defined($BEST{lc($word)})) { 73 | print $BEST{lc($word)}; # truecase otherwise unknown words 74 | } 75 | else { 76 | print $word; # unknown, nothing to do 77 | } 78 | print $otherfactors; 79 | 80 | if ( defined($SENTENCE_END{ $word })) { $sentence_start = 1; } 81 | elsif (!defined($DELAYED_SENTENCE_START{ $word })) { $sentence_start = 0; } 82 | } 83 | print $$MARKUP[$#$MARKUP]; 84 | print "\n"; 85 | } 86 | 87 | # store away xml markup 88 | sub split_xml { 89 | my ($line) = @_; 90 | my (@WORD,@MARKUP); 91 | my $i = 0; 92 | $MARKUP[0] = ""; 93 | while($line =~ /\S/) { 94 | # XML tag 95 | if ($line =~ /^\s*(<\S[^>]*>)(.*)$/) { 96 | my $potential_xml = $1; 97 | my $line_next = $2; 98 | # exception for factor that is an XML tag 99 | if ($line =~ /^\S/ && scalar(@WORD)>0 && $WORD[$i-1] =~ /\|$/) { 100 | $WORD[$i-1] .= $potential_xml; 101 | if ($line_next =~ /^(\|+)(.*)$/) { 102 | $WORD[$i-1] .= $1; 103 | $line_next = $2; 104 | } 105 | } 106 | else { 107 | $MARKUP[$i] .= $potential_xml." "; 108 | } 109 | $line = $line_next; 110 | } 111 | # non-XML text 112 | elsif ($line =~ /^\s*([^\s<>]+)(.*)$/) { 113 | $WORD[$i++] = $1; 114 | $MARKUP[$i] = ""; 115 | $line = $2; 116 | } 117 | # '<' or '>' occurs in word, but it's not an XML tag 118 | elsif ($line =~ /^\s*(\S+)(.*)$/) { 119 | $WORD[$i++] = $1; 120 | $MARKUP[$i] = ""; 121 | $line = $2; 122 | } 123 | else { 124 | die("ERROR: huh? $line\n"); 125 | } 126 | } 127 | chop($MARKUP[$#MARKUP]); 128 | return (\@WORD,\@MARKUP); 129 | } 130 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.cs: -------------------------------------------------------------------------------- 1 | Bc 2 | BcA 3 | Ing 4 | Ing.arch 5 | MUDr 6 | MVDr 7 | MgA 8 | Mgr 9 | JUDr 10 | PhDr 11 | RNDr 12 | PharmDr 13 | ThLic 14 | ThDr 15 | Ph.D 16 | Th.D 17 | prof 18 | doc 19 | CSc 20 | DrSc 21 | dr. h. c 22 | PaedDr 23 | Dr 24 | PhMr 25 | DiS 26 | abt 27 | ad 28 | a.i 29 | aj 30 | angl 31 | anon 32 | apod 33 | atd 34 | atp 35 | aut 36 | bd 37 | biogr 38 | b.m 39 | b.p 40 | b.r 41 | cca 42 | cit 43 | cizojaz 44 | c.k 45 | col 46 | čes 47 | čín 48 | čj 49 | ed 50 | facs 51 | fasc 52 | fol 53 | fot 54 | franc 55 | h.c 56 | hist 57 | hl 58 | hrsg 59 | ibid 60 | il 61 | ind 62 | inv.č 63 | jap 64 | jhdt 65 | jv 66 | koed 67 | kol 68 | korej 69 | kl 70 | krit 71 | lat 72 | lit 73 | m.a 74 | maď 75 | mj 76 | mp 77 | násl 78 | např 79 | nepubl 80 | něm 81 | no 82 | nr 83 | n.s 84 | okr 85 | odd 86 | odp 87 | obr 88 | opr 89 | orig 90 | phil 91 | pl 92 | pokrač 93 | pol 94 | port 95 | pozn 96 | př.kr 97 | př.n.l 98 | přel 99 | přeprac 100 | příl 101 | pseud 102 | pt 103 | red 104 | repr 105 | resp 106 | revid 107 | rkp 108 | roč 109 | roz 110 | rozš 111 | samost 112 | sect 113 | sest 114 | seš 115 | sign 116 | sl 117 | srv 118 | stol 119 | sv 120 | šk 121 | šk.ro 122 | špan 123 | tab 124 | t.č 125 | tis 126 | tj 127 | tř 128 | tzv 129 | univ 130 | uspoř 131 | vol 132 | vl.jm 133 | vs 134 | vyd 135 | vyobr 136 | zal 137 | zejm 138 | zkr 139 | zprac 140 | zvl 141 | n.p 142 | např 143 | než 144 | MUDr 145 | abl 146 | absol 147 | adj 148 | adv 149 | ak 150 | ak. sl 151 | akt 152 | alch 153 | amer 154 | anat 155 | angl 156 | anglosas 157 | arab 158 | arch 159 | archit 160 | arg 161 | astr 162 | astrol 163 | att 164 | bás 165 | belg 166 | bibl 167 | biol 168 | boh 169 | bot 170 | bulh 171 | círk 172 | csl 173 | č 174 | čas 175 | čes 176 | dat 177 | děj 178 | dep 179 | dět 180 | dial 181 | dór 182 | dopr 183 | dosl 184 | ekon 185 | epic 186 | etnonym 187 | eufem 188 | f 189 | fam 190 | fem 191 | fil 192 | film 193 | form 194 | fot 195 | fr 196 | fut 197 | fyz 198 | gen 199 | geogr 200 | geol 201 | geom 202 | germ 203 | gram 204 | hebr 205 | herald 206 | hist 207 | hl 208 | hovor 209 | hud 210 | hut 211 | chcsl 212 | chem 213 | ie 214 | imp 215 | impf 216 | ind 217 | indoevr 218 | inf 219 | instr 220 | interj 221 | ión 222 | iron 223 | it 224 | kanad 225 | katalán 226 | klas 227 | kniž 228 | komp 229 | konj 230 | 231 | konkr 232 | kř 233 | kuch 234 | lat 235 | lék 236 | les 237 | lid 238 | lit 239 | liturg 240 | lok 241 | log 242 | m 243 | mat 244 | meteor 245 | metr 246 | mod 247 | ms 248 | mysl 249 | n 250 | náb 251 | námoř 252 | neklas 253 | něm 254 | nesklon 255 | nom 256 | ob 257 | obch 258 | obyč 259 | ojed 260 | opt 261 | part 262 | pas 263 | pejor 264 | pers 265 | pf 266 | pl 267 | plpf 268 | 269 | práv 270 | prep 271 | předl 272 | přivl 273 | r 274 | rcsl 275 | refl 276 | reg 277 | rkp 278 | ř 279 | řec 280 | s 281 | samohl 282 | sg 283 | sl 284 | souhl 285 | spec 286 | srov 287 | stfr 288 | střv 289 | stsl 290 | subj 291 | subst 292 | superl 293 | sv 294 | sz 295 | táz 296 | tech 297 | telev 298 | teol 299 | trans 300 | typogr 301 | var 302 | vedl 303 | verb 304 | vl. jm 305 | voj 306 | vok 307 | vůb 308 | vulg 309 | výtv 310 | vztaž 311 | zahr 312 | zájm 313 | zast 314 | zejm 315 | 316 | zeměd 317 | zkr 318 | zř 319 | mj 320 | dl 321 | atp 322 | sport 323 | Mgr 324 | horn 325 | MVDr 326 | JUDr 327 | RSDr 328 | Bc 329 | PhDr 330 | ThDr 331 | Ing 332 | aj 333 | apod 334 | PharmDr 335 | pomn 336 | ev 337 | slang 338 | nprap 339 | odp 340 | dop 341 | pol 342 | st 343 | stol 344 | p. n. l 345 | před n. l 346 | n. l 347 | př. Kr 348 | po Kr 349 | př. n. l 350 | odd 351 | RNDr 352 | tzv 353 | atd 354 | tzn 355 | resp 356 | tj 357 | p 358 | br 359 | č. j 360 | čj 361 | č. p 362 | čp 363 | a. s 364 | s. r. o 365 | spol. s r. o 366 | p. o 367 | s. p 368 | v. o. s 369 | k. s 370 | o. p. s 371 | o. s 372 | v. r 373 | v z 374 | ml 375 | vč 376 | kr 377 | mld 378 | hod 379 | popř 380 | ap 381 | event 382 | rus 383 | slov 384 | rum 385 | švýc 386 | P. T 387 | zvl 388 | hor 389 | dol 390 | S.O.S -------------------------------------------------------------------------------- /main_lstm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from src.logger import create_logger 3 | import os 4 | import torch 5 | logger = create_logger('train.log') 6 | 7 | parser = argparse.ArgumentParser(description='Settings') 8 | parser.add_argument("--train_data", type=str, default='data/120w.bin', 9 | help="train data dir") 10 | parser.add_argument("--max_len", type=int, default=100, 11 | help="max length of sentences") 12 | parser.add_argument("--reload_model", type=str, default='', 13 | help="reload model") 14 | parser.add_argument("--batch_size", type=int, default=80, 15 | help="batch size sentences") 16 | parser.add_argument("--batch_size_tokens", type=int, default=-1, 17 | help="batch size tokens") 18 | parser.add_argument("--src_n_words", type=int, default=0, 19 | help="data") 20 | parser.add_argument("--tgt_n_words", type=int, default=0, 21 | help="data") 22 | parser.add_argument("--dropout", type=float, default=0.1, 23 | help="Dropout") 24 | parser.add_argument("--label-smoothing", type=float, default=0.1, 25 | help="Label smoothing") 26 | parser.add_argument("--attention", type=bool, default=True, 27 | help="Use an attention mechanism") 28 | parser.add_argument("--transformer", type=bool, default=False, 29 | help="Use Transformer") 30 | parser.add_argument("--lstm", type=bool, default=True, 31 | help="Use Bid-LSTM") 32 | parser.add_argument("--emb_dim", type=int, default=512, 33 | help="Embedding layer size") 34 | parser.add_argument("--n_enc_layers", type=int, default=4, 35 | help="Number of layers in the encoders") 36 | parser.add_argument("--n_dec_layers", type=int, default=4, 37 | help="Number of layers in the decoders") 38 | parser.add_argument("--hidden_dim", type=int, default=512, 39 | help="Hidden layer size") 40 | 41 | parser.add_argument("--transformer_ffn_emb_dim", type=int, default=2048, 42 | help="Transformer fully-connected hidden dim size") 43 | parser.add_argument("--attention_dropout", type=float, default=0, 44 | help="attention_dropout") 45 | parser.add_argument("--relu_dropout", type=float, default=0, 46 | help="relu_dropout") 47 | parser.add_argument("--encoder_attention_heads", type=int, default=8, 48 | help="encoder_attention_heads") 49 | parser.add_argument("--decoder_attention_heads", type=int, default=8, 50 | help="decoder_attention_heads") 51 | parser.add_argument("--encoder_normalize_before", type=bool, default=False, 52 | help="encoder_normalize_before") 53 | parser.add_argument("--decoder_normalize_before", type=bool, default=False, 54 | help="decoder_normalize_before") 55 | parser.add_argument("--share_encdec_emb", type=bool, default=False, 56 | help="share encoder and decoder embedding") 57 | parser.add_argument("--share_decpro_emb", type=bool, default=True, 58 | help="share decoder input and project embedding") 59 | parser.add_argument("--beam_size", type=int, default=6, 60 | help="beam search size") 61 | parser.add_argument("--length_penalty", type=float, default=1.0, 62 | help="length penalty") 63 | parser.add_argument("--clip_grad_norm", type=float, default=5.0, 64 | help="clip grad norm") 65 | parser.add_argument("--update_freq", type=int, default=1) 66 | parser.add_argument("--optim", type=str, default="adam_inverse_sqrt,lr=0.001") 67 | parser.add_argument("--gpu_num", type=int, default=1) 68 | 69 | 70 | if __name__ == '__main__': 71 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 72 | params = parser.parse_args() 73 | params.batch_size_tokens = 8000 74 | params.checkpoint_dir = 'all_models/lstm_6k_8' 75 | params.update_freq = 8 76 | params.seed = 1234 77 | params.gpu_num = 1 78 | 79 | if params.gpu_num == 1: 80 | from single_train import main 81 | main(params) 82 | else: 83 | from multiprocessing_train import main 84 | logger.info('GPU numbers: %s',params.gpu_num) 85 | main(params) 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/recaser/train-truecaser.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id: train-recaser.perl 1326 2007-03-26 05:44:27Z bojar $ 7 | 8 | # 9 | # Options: 10 | # 11 | # --possiblyUseFirstToken : boolean option; the default behaviour (when this 12 | # option is not provided) is that the first token of a sentence is ignored, on 13 | # the basis that the first word of a sentence is always capitalized; if this 14 | # option is provided then: a) if a sentence-initial token is *not* capitalized, 15 | # then it is counted, and b) if a capitalized sentence-initial token is the 16 | # only token of the segment, then it is counted, but with only 10% of the 17 | # weight of a normal token. 18 | 19 | use warnings; 20 | use strict; 21 | use Getopt::Long "GetOptions"; 22 | 23 | # apply switches 24 | my ($MODEL,$CORPUS); 25 | die("train-truecaser.perl --model truecaser --corpus cased [--possiblyUseFirstToken]") 26 | unless &GetOptions('corpus=s' => \$CORPUS, 27 | 'model=s' => \$MODEL, 28 | 'possiblyUseFirstToken' => \(my $possiblyUseFirstToken = 0)) 29 | && defined($CORPUS) && defined($MODEL); 30 | my %CASING; 31 | my %SENTENCE_END = ("."=>1,":"=>1,"?"=>1,"!"=>1); 32 | my %DELAYED_SENTENCE_START = ("("=>1,"["=>1,"\""=>1,"'"=>1,"'"=>1,"""=>1,"["=>1,"]"=>1); 33 | open(CORPUS,$CORPUS) || die("ERROR: could not open '$CORPUS'"); 34 | binmode(CORPUS, ":utf8"); 35 | while() { 36 | chop; 37 | my ($WORD,$MARKUP) = split_xml($_); 38 | my $start = 0; 39 | while($start<=$#$WORD && defined($DELAYED_SENTENCE_START{$$WORD[$start]})) { $start++; } 40 | my $firstWordOfSentence = 1; 41 | for(my $i=$start;$i<=$#$WORD;$i++) { 42 | my $currentWord = $$WORD[$i]; 43 | if (! $firstWordOfSentence && defined($SENTENCE_END{$$WORD[$i-1]})) { 44 | $firstWordOfSentence = 1; 45 | } 46 | 47 | my $currentWordWeight = 0; 48 | if (! $firstWordOfSentence) { 49 | $currentWordWeight = 1; 50 | } elsif ($possiblyUseFirstToken) { 51 | # gated special handling of first word of sentence 52 | my $firstChar = substr($currentWord, 0, 1); 53 | if (lc($firstChar) eq $firstChar) { 54 | # if the first character is not upper case, count the token as full evidence (because if it's not capitalized, then there's no reason to be wary that the given casing is only due to being sentence-initial) 55 | $currentWordWeight = 1; 56 | } elsif (scalar(@$WORD) == 1) { 57 | # if the first character is upper case, but the current token is the only token of the segment, then count the token as partial evidence (because the segment is presumably not a sentence and the token is therefore not the first word of a sentence and is possibly in its natural case) 58 | $currentWordWeight = 0.1; 59 | } 60 | } 61 | if ($currentWordWeight > 0) { 62 | $CASING{ lc($currentWord) }{ $currentWord } += $currentWordWeight; 63 | } 64 | 65 | $firstWordOfSentence = 0; 66 | } 67 | } 68 | close(CORPUS); 69 | 70 | open(MODEL,">$MODEL") || die("ERROR: could not create '$MODEL'"); 71 | binmode(MODEL, ":utf8"); 72 | foreach my $type (keys %CASING) { 73 | my ($score,$total,$best) = (-1,0,""); 74 | foreach my $word (keys %{$CASING{$type}}) { 75 | my $count = $CASING{$type}{$word}; 76 | $total += $count; 77 | if ($count > $score) { 78 | $best = $word; 79 | $score = $count; 80 | } 81 | } 82 | print MODEL "$best ($score/$total)"; 83 | foreach my $word (keys %{$CASING{$type}}) { 84 | print MODEL " $word ($CASING{$type}{$word})" unless $word eq $best; 85 | } 86 | print MODEL "\n"; 87 | } 88 | close(MODEL); 89 | 90 | 91 | # store away xml markup 92 | sub split_xml { 93 | my ($line) = @_; 94 | my (@WORD,@MARKUP); 95 | my $i = 0; 96 | $MARKUP[0] = ""; 97 | while($line =~ /\S/) { 98 | # XML tag 99 | if ($line =~ /^\s*(<\S[^>]*>)(.*)$/) { 100 | $MARKUP[$i] .= $1." "; 101 | $line = $2; 102 | } 103 | # non-XML text 104 | elsif ($line =~ /^\s*([^\s<>]+)(.*)$/) { 105 | $WORD[$i++] = $1; 106 | $MARKUP[$i] = ""; 107 | $line = $2; 108 | } 109 | # '<' or '>' occurs in word, but it's not an XML tag 110 | elsif ($line =~ /^\s*(\S+)(.*)$/) { 111 | $WORD[$i++] = $1; 112 | $MARKUP[$i] = ""; 113 | $line = $2; 114 | } 115 | else { 116 | die("ERROR: huh? $line\n"); 117 | } 118 | } 119 | chop($MARKUP[$#MARKUP]); 120 | return (\@WORD,\@MARKUP); 121 | } 122 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from src.logger import create_logger 3 | import os 4 | import torch 5 | logger = create_logger('train.log') 6 | 7 | parser = argparse.ArgumentParser(description='Settings') 8 | parser.add_argument("--train_data", type=str, default='data/cwmt.bin', 9 | help="train data dir") 10 | parser.add_argument("--max_len", type=int, default=100, 11 | help="max length of sentences") 12 | parser.add_argument("--reload_model", type=str, default='', 13 | help="reload model") 14 | parser.add_argument("--batch_size", type=int, default=80, 15 | help="batch size sentences") 16 | parser.add_argument("--batch_size_tokens", type=int, default=4000, 17 | help="batch size tokens") 18 | parser.add_argument("--src_n_words", type=int, default=0, 19 | help="data") 20 | parser.add_argument("--tgt_n_words", type=int, default=0, 21 | help="data") 22 | parser.add_argument("--dropout", type=float, default=0.1, 23 | help="Dropout") 24 | parser.add_argument("--label-smoothing", type=float, default=0.1, 25 | help="Label smoothing") 26 | parser.add_argument("--attention", type=bool, default=True, 27 | help="Use an attention mechanism") 28 | parser.add_argument("--transformer", type=bool, default=True, 29 | help="Use Transformer") 30 | parser.add_argument("--emb_dim", type=int, default=512, 31 | help="Embedding layer size") 32 | parser.add_argument("--n_enc_layers", type=int, default=6, 33 | help="Number of layers in the encoders") 34 | parser.add_argument("--n_dec_layers", type=int, default=6, 35 | help="Number of layers in the decoders") 36 | parser.add_argument("--hidden_dim", type=int, default=512, 37 | help="Hidden layer size") 38 | 39 | parser.add_argument("--transformer_ffn_emb_dim", type=int, default=2048, 40 | help="Transformer fully-connected hidden dim size") 41 | parser.add_argument("--attention_dropout", type=float, default=0, 42 | help="attention_dropout") 43 | parser.add_argument("--relu_dropout", type=float, default=0, 44 | help="relu_dropout") 45 | parser.add_argument("--encoder_attention_heads", type=int, default=8, 46 | help="encoder_attention_heads") 47 | parser.add_argument("--decoder_attention_heads", type=int, default=8, 48 | help="decoder_attention_heads") 49 | parser.add_argument("--encoder_normalize_before", type=bool, default=False, 50 | help="encoder_normalize_before") 51 | parser.add_argument("--decoder_normalize_before", type=bool, default=False, 52 | help="decoder_normalize_before") 53 | parser.add_argument("--share_encdec_emb", type=bool, default=False, 54 | help="share encoder and decoder embedding") 55 | parser.add_argument("--share_decpro_emb", type=bool, default=True, 56 | help="share decoder input and project embedding") 57 | parser.add_argument("--beam_size", type=int, default=5, 58 | help="beam search size") 59 | parser.add_argument("--length_penalty", type=float, default=1.0, 60 | help="length penalty") 61 | parser.add_argument("--clip_grad_norm", type=float, default=5.0, 62 | help="clip grad norm") 63 | parser.add_argument("--update_freq", type=int, default=1) 64 | parser.add_argument("--optim", type=str, default="adam_inverse_sqrt,lr=0.0005") 65 | parser.add_argument("--gpu_num", type=int, default=1) 66 | parser.add_argument("--checkpoint_dir", type=str, default="all_models/base") 67 | parser.add_argument("--seed", type=int, default=1244) 68 | parser.add_argument("--max_epoch", type=int, default=5) 69 | parser.add_argument("--save_freq_update", type=int, default=5000) 70 | parser.add_argument("--save_optimizer", type=bool, default=False, 71 | help="save optimizer parameters") 72 | parser.add_argument("--do_eval", type=bool, default=True, 73 | help="do evalution during training") 74 | parser.add_argument("--model_name",type=str, default="") 75 | parser.add_argument("--src_dico_file", type=str, default='') 76 | parser.add_argument("--tgt_dico_file", type=str, default='') 77 | parser.add_argument("--translate_file", type=str, default='') 78 | parser.add_argument("--reference_file", type=str, default='') 79 | 80 | 81 | if __name__ == '__main__': 82 | params = parser.parse_args() 83 | 84 | if params.gpu_num == 1: 85 | from single_train import main 86 | main(params) 87 | else: 88 | from multiprocessing_train import main 89 | logger.info('GPU numbers: %s',params.gpu_num) 90 | main(params) 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /src/distributed_utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import os 3 | import pickle 4 | import socket 5 | import subprocess 6 | import warnings 7 | import math 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch import nn 12 | 13 | def is_master(args): 14 | return args.rank == 0 15 | 16 | def suppress_output(is_master): 17 | """Suppress printing on the current device. Force printing with `force=True`.""" 18 | import builtins as __builtin__ 19 | builtin_print = __builtin__.print 20 | 21 | def print(*args, **kwargs): 22 | force = kwargs.pop('force', False) 23 | if is_master or force: 24 | builtin_print(*args, **kwargs) 25 | 26 | __builtin__.print = print 27 | 28 | 29 | def get_rank(): 30 | return dist.get_rank() 31 | 32 | def get_world_size(): 33 | return dist.get_world_size() 34 | 35 | def get_default_group(): 36 | return dist.group.WORLD 37 | 38 | def all_reduce(tensor, group=None): 39 | if group is None: 40 | group = get_default_group() 41 | return dist.all_reduce(tensor, group=group) 42 | 43 | def item(tensor): 44 | if hasattr(tensor, 'item'): 45 | return tensor.item() 46 | if hasattr(tensor, '__getitem__'): 47 | return tensor[0] 48 | return tensor 49 | 50 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 51 | buffer_size=10485760): 52 | """All-reduce and rescale tensors in chunks of the specified size. 53 | Args: 54 | tensors: list of Tensors to all-reduce 55 | rescale_denom: denominator for rescaling summed Tensors 56 | buffer_size: all-reduce chunk size in bytes 57 | """ 58 | # buffer size in bytes, determine equiv. # of elements based on data type 59 | buffer_t = tensors[0].new( 60 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 61 | buffer = [] 62 | 63 | def all_reduce_buffer(): 64 | # copy tensors into buffer_t 65 | offset = 0 66 | for t in buffer: 67 | numel = t.numel() 68 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 69 | offset += numel 70 | 71 | # all-reduce and rescale 72 | torch.distributed.all_reduce(buffer_t[:offset]) 73 | buffer_t.div_(rescale_denom) 74 | 75 | # copy all-reduced buffer back into tensors 76 | offset = 0 77 | for t in buffer: 78 | numel = t.numel() 79 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 80 | offset += numel 81 | 82 | filled = 0 83 | for t in tensors: 84 | sz = t.numel() * t.element_size() 85 | if sz > buffer_size: 86 | # tensor is bigger than buffer, all-reduce and rescale directly 87 | torch.distributed.all_reduce(t) 88 | t.div_(rescale_denom) 89 | elif filled + sz > buffer_size: 90 | # buffer is full, all-reduce and replace buffer with grad 91 | all_reduce_buffer() 92 | buffer = [t] 93 | filled = sz 94 | else: 95 | # add tensor to buffer 96 | buffer.append(t) 97 | filled += sz 98 | 99 | if len(buffer) > 0: 100 | all_reduce_buffer() 101 | 102 | 103 | def all_gather_list(data, max_size=4096): 104 | """Gathers arbitrary data from all nodes into a list.""" 105 | world_size = torch.distributed.get_world_size() 106 | if not hasattr(all_gather_list, '_in_buffer') or \ 107 | max_size != all_gather_list._in_buffer.size(): 108 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 109 | all_gather_list._out_buffers = [ 110 | torch.cuda.ByteTensor(max_size) 111 | for i in range(world_size) 112 | ] 113 | in_buffer = all_gather_list._in_buffer 114 | out_buffers = all_gather_list._out_buffers 115 | 116 | enc = pickle.dumps(data) 117 | enc_size = len(enc) 118 | if enc_size + 2 > max_size: 119 | raise ValueError( 120 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 121 | assert max_size < 255*256 122 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 123 | in_buffer[1] = enc_size % 255 124 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 125 | 126 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 127 | 128 | results = [] 129 | for i in range(world_size): 130 | out_buffer = out_buffers[i] 131 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 132 | 133 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 134 | result = pickle.loads(bytes_list) 135 | results.append(result) 136 | return results -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.sk: -------------------------------------------------------------------------------- 1 | Bc 2 | Mgr 3 | RNDr 4 | PharmDr 5 | PhDr 6 | JUDr 7 | PaedDr 8 | ThDr 9 | Ing 10 | MUDr 11 | MDDr 12 | MVDr 13 | Dr 14 | ThLic 15 | PhD 16 | ArtD 17 | ThDr 18 | Dr 19 | DrSc 20 | CSs 21 | prof 22 | obr 23 | Obr 24 | Č 25 | č 26 | absol 27 | adj 28 | admin 29 | adr 30 | Adr 31 | adv 32 | advok 33 | afr 34 | ak 35 | akad 36 | akc 37 | akuz 38 | et 39 | al 40 | alch 41 | amer 42 | anat 43 | angl 44 | Angl 45 | anglosas 46 | anorg 47 | ap 48 | apod 49 | arch 50 | archeol 51 | archit 52 | arg 53 | art 54 | astr 55 | astrol 56 | astron 57 | atp 58 | atď 59 | austr 60 | Austr 61 | aut 62 | belg 63 | Belg 64 | bibl 65 | Bibl 66 | biol 67 | bot 68 | bud 69 | bás 70 | býv 71 | cest 72 | chem 73 | cirk 74 | csl 75 | čs 76 | Čs 77 | dat 78 | dep 79 | det 80 | dial 81 | diaľ 82 | dipl 83 | distrib 84 | dokl 85 | dosl 86 | dopr 87 | dram 88 | duš 89 | dv 90 | dvojčl 91 | dór 92 | ekol 93 | ekon 94 | el 95 | elektr 96 | elektrotech 97 | energet 98 | epic 99 | est 100 | etc 101 | etonym 102 | eufem 103 | európ 104 | Európ 105 | ev 106 | evid 107 | expr 108 | fa 109 | fam 110 | farm 111 | fem 112 | feud 113 | fil 114 | filat 115 | filoz 116 | fi 117 | fon 118 | form 119 | fot 120 | fr 121 | Fr 122 | franc 123 | Franc 124 | fraz 125 | fut 126 | fyz 127 | fyziol 128 | garb 129 | gen 130 | genet 131 | genpor 132 | geod 133 | geogr 134 | geol 135 | geom 136 | germ 137 | gr 138 | Gr 139 | gréc 140 | Gréc 141 | gréckokat 142 | hebr 143 | herald 144 | hist 145 | hlav 146 | hosp 147 | hromad 148 | hud 149 | hypok 150 | ident 151 | i.e 152 | ident 153 | imp 154 | impf 155 | indoeur 156 | inf 157 | inform 158 | instr 159 | int 160 | interj 161 | inšt 162 | inštr 163 | iron 164 | jap 165 | Jap 166 | jaz 167 | jedn 168 | juhoamer 169 | juhových 170 | juhozáp 171 | juž 172 | kanad 173 | Kanad 174 | kanc 175 | kapit 176 | kpt 177 | kart 178 | katastr 179 | knih 180 | kniž 181 | komp 182 | konj 183 | konkr 184 | kozmet 185 | krajč 186 | kresť 187 | kt 188 | kuch 189 | lat 190 | latinskoamer 191 | lek 192 | lex 193 | lingv 194 | lit 195 | litur 196 | log 197 | lok 198 | max 199 | Max 200 | maď 201 | Maď 202 | medzinár 203 | mest 204 | metr 205 | mil 206 | Mil 207 | min 208 | Min 209 | miner 210 | ml 211 | mld 212 | mn 213 | mod 214 | mytol 215 | napr 216 | nar 217 | Nar 218 | nasl 219 | nedok 220 | neg 221 | negat 222 | neklas 223 | nem 224 | Nem 225 | neodb 226 | neos 227 | neskl 228 | nesklon 229 | nespis 230 | nespráv 231 | neved 232 | než 233 | niekt 234 | niž 235 | nom 236 | náb 237 | nákl 238 | námor 239 | nár 240 | obch 241 | obj 242 | obv 243 | obyč 244 | obč 245 | občian 246 | odb 247 | odd 248 | ods 249 | ojed 250 | okr 251 | Okr 252 | opt 253 | opyt 254 | org 255 | os 256 | osob 257 | ot 258 | ovoc 259 | par 260 | part 261 | pejor 262 | pers 263 | pf 264 | Pf 265 | P.f 266 | p.f 267 | pl 268 | Plk 269 | pod 270 | podst 271 | pokl 272 | polit 273 | politol 274 | polygr 275 | pomn 276 | popl 277 | por 278 | porad 279 | porov 280 | posch 281 | potrav 282 | použ 283 | poz 284 | pozit 285 | poľ 286 | poľno 287 | poľnohosp 288 | poľov 289 | pošt 290 | pož 291 | prac 292 | predl 293 | pren 294 | prep 295 | preuk 296 | priezv 297 | Priezv 298 | privl 299 | prof 300 | práv 301 | príd 302 | príj 303 | prík 304 | príp 305 | prír 306 | prísl 307 | príslov 308 | príč 309 | psych 310 | publ 311 | pís 312 | písm 313 | pôv 314 | refl 315 | reg 316 | rep 317 | resp 318 | rozk 319 | rozlič 320 | rozpráv 321 | roč 322 | Roč 323 | ryb 324 | rádiotech 325 | rím 326 | samohl 327 | semest 328 | sev 329 | severoamer 330 | severových 331 | severozáp 332 | sg 333 | skr 334 | skup 335 | sl 336 | Sloven 337 | soc 338 | soch 339 | sociol 340 | sp 341 | spol 342 | Spol 343 | spoloč 344 | spoluhl 345 | správ 346 | spôs 347 | st 348 | star 349 | starogréc 350 | starorím 351 | s.r.o 352 | stol 353 | stor 354 | str 355 | stredoamer 356 | stredoškol 357 | subj 358 | subst 359 | superl 360 | sv 361 | sz 362 | súkr 363 | súp 364 | súvzť 365 | tal 366 | Tal 367 | tech 368 | tel 369 | Tel 370 | telef 371 | teles 372 | telev 373 | teol 374 | trans 375 | turist 376 | tuzem 377 | typogr 378 | tzn 379 | tzv 380 | ukaz 381 | ul 382 | Ul 383 | umel 384 | univ 385 | ust 386 | ved 387 | vedľ 388 | verb 389 | veter 390 | vin 391 | viď 392 | vl 393 | vod 394 | vodohosp 395 | pnl 396 | vulg 397 | vyj 398 | vys 399 | vysokoškol 400 | vzťaž 401 | vôb 402 | vých 403 | výd 404 | výrob 405 | výsk 406 | výsl 407 | výtv 408 | výtvar 409 | význ 410 | včel 411 | vš 412 | všeob 413 | zahr 414 | zar 415 | zariad 416 | zast 417 | zastar 418 | zastaráv 419 | zb 420 | zdravot 421 | združ 422 | zjemn 423 | zlat 424 | zn 425 | Zn 426 | zool 427 | zr 428 | zried 429 | zv 430 | záhr 431 | zák 432 | zákl 433 | zám 434 | záp 435 | západoeur 436 | zázn 437 | územ 438 | účt 439 | čast 440 | čes 441 | Čes 442 | čl 443 | čísl 444 | živ 445 | pr 446 | fak 447 | Kr 448 | p.n.l 449 | A 450 | B 451 | C 452 | D 453 | E 454 | F 455 | G 456 | H 457 | I 458 | J 459 | K 460 | L 461 | M 462 | N 463 | O 464 | P 465 | Q 466 | R 467 | S 468 | T 469 | U 470 | V 471 | W 472 | X 473 | Y 474 | Z 475 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/training/clean-corpus-n.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id: clean-corpus-n.perl 3633 2010-10-21 09:49:27Z phkoehn $ 7 | use warnings; 8 | use strict; 9 | use Getopt::Long; 10 | my $help; 11 | my $lc = 0; # lowercase the corpus? 12 | my $ignore_ratio = 0; 13 | my $ignore_xml = 0; 14 | my $enc = "utf8"; # encoding of the input and output files 15 | # set to anything else you wish, but I have not tested it yet 16 | my $max_word_length = 1000; # any segment with a word (or factor) exceeding this length in chars 17 | # is discarded; motivated by symal.cpp, which has its own such parameter (hardcoded to 1000) 18 | # and crashes if it encounters a word that exceeds it 19 | my $ratio = 9; 20 | 21 | GetOptions( 22 | "help" => \$help, 23 | "lowercase|lc" => \$lc, 24 | "encoding=s" => \$enc, 25 | "ratio=f" => \$ratio, 26 | "ignore-ratio" => \$ignore_ratio, 27 | "ignore-xml" => \$ignore_xml, 28 | "max-word-length|mwl=s" => \$max_word_length 29 | ) or exit(1); 30 | 31 | if (scalar(@ARGV) < 6 || $help) { 32 | print "syntax: clean-corpus-n.perl [-ratio n] corpus l1 l2 clean-corpus min max [lines retained file]\n"; 33 | exit; 34 | } 35 | 36 | my $corpus = $ARGV[0]; 37 | my $l1 = $ARGV[1]; 38 | my $l2 = $ARGV[2]; 39 | my $out = $ARGV[3]; 40 | my $min = $ARGV[4]; 41 | my $max = $ARGV[5]; 42 | 43 | my $linesRetainedFile = ""; 44 | if (scalar(@ARGV) > 6) { 45 | $linesRetainedFile = $ARGV[6]; 46 | open(LINES_RETAINED,">$linesRetainedFile") or die "Can't write $linesRetainedFile"; 47 | } 48 | 49 | print STDERR "clean-corpus.perl: processing $corpus.$l1 & .$l2 to $out, cutoff $min-$max, ratio $ratio\n"; 50 | 51 | my $opn = undef; 52 | my $l1input = "$corpus.$l1"; 53 | if (-e $l1input) { 54 | $opn = $l1input; 55 | } elsif (-e $l1input.".gz") { 56 | $opn = "gunzip -c $l1input.gz |"; 57 | } else { 58 | die "Error: $l1input does not exist"; 59 | } 60 | open(F,$opn) or die "Can't open '$opn'"; 61 | $opn = undef; 62 | my $l2input = "$corpus.$l2"; 63 | if (-e $l2input) { 64 | $opn = $l2input; 65 | } elsif (-e $l2input.".gz") { 66 | $opn = "gunzip -c $l2input.gz |"; 67 | } else { 68 | die "Error: $l2input does not exist"; 69 | } 70 | 71 | open(E,$opn) or die "Can't open '$opn'"; 72 | 73 | open(FO,">$out.$l1") or die "Can't write $out.$l1"; 74 | open(EO,">$out.$l2") or die "Can't write $out.$l2"; 75 | 76 | # necessary for proper lowercasing 77 | my $binmode; 78 | if ($enc eq "utf8") { 79 | $binmode = ":utf8"; 80 | } else { 81 | $binmode = ":encoding($enc)"; 82 | } 83 | binmode(F, $binmode); 84 | binmode(E, $binmode); 85 | binmode(FO, $binmode); 86 | binmode(EO, $binmode); 87 | 88 | my $innr = 0; 89 | my $outnr = 0; 90 | my $factored_flag; 91 | while(my $f = ) { 92 | $innr++; 93 | print STDERR "." if $innr % 10000 == 0; 94 | print STDERR "($innr)" if $innr % 100000 == 0; 95 | my $e = ; 96 | die "$corpus.$l2 is too short!" if !defined $e; 97 | chomp($e); 98 | chomp($f); 99 | if ($innr == 1) { 100 | $factored_flag = ($e =~ /\|/ || $f =~ /\|/); 101 | } 102 | 103 | #if lowercasing, lowercase 104 | if ($lc) { 105 | $e = lc($e); 106 | $f = lc($f); 107 | } 108 | 109 | $e =~ s/\|//g unless $factored_flag; 110 | $e =~ s/\s+/ /g; 111 | $e =~ s/^ //; 112 | $e =~ s/ $//; 113 | $f =~ s/\|//g unless $factored_flag; 114 | $f =~ s/\s+/ /g; 115 | $f =~ s/^ //; 116 | $f =~ s/ $//; 117 | next if $f eq ''; 118 | next if $e eq ''; 119 | 120 | my $ec = &word_count($e); 121 | my $fc = &word_count($f); 122 | next if $ec > $max; 123 | next if $fc > $max; 124 | next if $ec < $min; 125 | next if $fc < $min; 126 | next if !$ignore_ratio && $ec/$fc > $ratio; 127 | next if !$ignore_ratio && $fc/$ec > $ratio; 128 | # Skip this segment if any factor is longer than $max_word_length 129 | my $max_word_length_plus_one = $max_word_length + 1; 130 | next if $e =~ /[^\s\|]{$max_word_length_plus_one}/; 131 | next if $f =~ /[^\s\|]{$max_word_length_plus_one}/; 132 | 133 | # An extra check: none of the factors can be blank! 134 | die "There is a blank factor in $corpus.$l1 on line $innr: $f" 135 | if $f =~ /[ \|]\|/; 136 | die "There is a blank factor in $corpus.$l2 on line $innr: $e" 137 | if $e =~ /[ \|]\|/; 138 | 139 | $outnr++; 140 | print FO $f."\n"; 141 | print EO $e."\n"; 142 | 143 | if ($linesRetainedFile ne "") { 144 | print LINES_RETAINED $innr."\n"; 145 | } 146 | } 147 | 148 | if ($linesRetainedFile ne "") { 149 | close LINES_RETAINED; 150 | } 151 | 152 | print STDERR "\n"; 153 | my $e = ; 154 | die "$corpus.$l2 is too long!" if defined $e; 155 | 156 | print STDERR "Input sentences: $innr Output sentences: $outnr\n"; 157 | 158 | sub word_count { 159 | my ($line) = @_; 160 | if ($ignore_xml) { 161 | $line =~ s/<\S[^>]*\S>/ /g; 162 | $line =~ s/\s+/ /g; 163 | $line =~ s/^ //g; 164 | $line =~ s/ $//g; 165 | } 166 | my @w = split(/ /,$line); 167 | return scalar @w; 168 | } 169 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from src.logger import create_logger 4 | import os 5 | from src.model import build_mt_model 6 | from src.data.loader import load_data 7 | import subprocess 8 | import re 9 | 10 | logger = create_logger('translate.log') 11 | 12 | parser = argparse.ArgumentParser(description='Settings') 13 | parser.add_argument("--train_data", type=str, default='', 14 | help="train data dir") 15 | parser.add_argument("--max_len", type=int, default=100, 16 | help="max length of sentences") 17 | parser.add_argument("--reload_model", type=str, default='', 18 | help="reload model") 19 | parser.add_argument("--batch_size", type=int, default=128, 20 | help="batch size") 21 | parser.add_argument("--batch_size_tokens", type=int, default=-1, 22 | help="batch size tokens") 23 | parser.add_argument("--src_n_words", type=int, default=0, 24 | help="data") 25 | parser.add_argument("--tgt_n_words", type=int, default=0, 26 | help="data") 27 | parser.add_argument("--dropout", type=float, default=0, 28 | help="Dropout") 29 | parser.add_argument("--label-smoothing", type=float, default=0, 30 | help="Label smoothing") 31 | parser.add_argument("--attention", type=bool, default=True, 32 | help="Use an attention mechanism") 33 | parser.add_argument("--transformer", type=bool, default=True, 34 | help="Use Transformer") 35 | parser.add_argument("--lstm", type=bool, default=False, 36 | help="Use LSTM") 37 | parser.add_argument("--emb_dim", type=int, default=512, 38 | help="Embedding layer size") 39 | parser.add_argument("--n_enc_layers", type=int, default=6, 40 | help="Number of layers in the encoders") 41 | parser.add_argument("--n_dec_layers", type=int, default=6, 42 | help="Number of layers in the decoders") 43 | parser.add_argument("--hidden_dim", type=int, default=512, 44 | help="Hidden layer size") 45 | 46 | parser.add_argument("--transformer_ffn_emb_dim", type=int, default=2048, 47 | help="Transformer fully-connected hidden dim size") 48 | parser.add_argument("--attention_dropout", type=float, default=0, 49 | help="attention_dropout") 50 | parser.add_argument("--relu_dropout", type=float, default=0, 51 | help="relu_dropout") 52 | parser.add_argument("--encoder_attention_heads", type=int, default=8, 53 | help="encoder_attention_heads") 54 | parser.add_argument("--decoder_attention_heads", type=int, default=8, 55 | help="decoder_attention_heads") 56 | parser.add_argument("--encoder_normalize_before", type=bool, default=False, 57 | help="encoder_normalize_before") 58 | parser.add_argument("--decoder_normalize_before", type=bool, default=False, 59 | help="decoder_normalize_before") 60 | parser.add_argument("--share_encdec_emb", type=bool, default=False, 61 | help="share encoder and decoder embedding") 62 | parser.add_argument("--share_decpro_emb", type=bool, default=True, 63 | help="share decoder input and project embedding") 64 | parser.add_argument("--beam_size", type=int, default=5, 65 | help="beam search size") 66 | parser.add_argument("--length_penalty", type=float, default=1.0, 67 | help="length penalty") 68 | parser.add_argument("--clip_grad_norm", type=float, default=5.0, 69 | help="clip grad norm") 70 | parser.add_argument("--model_name",type=str, default="") 71 | parser.add_argument("--checkpoint_dir", type=str, default='') 72 | parser.add_argument("--src_dico_file", type=str, default='') 73 | parser.add_argument("--tgt_dico_file", type=str, default='') 74 | parser.add_argument("--translate_file", type=str, default='') 75 | parser.add_argument("--gpu_num", type=int, default=1) 76 | parser.add_argument("--seed", type=int, default=1234) 77 | parser.add_argument("--reference_file", type=str, default='') 78 | params = parser.parse_args() 79 | params.reload_model = '{}/{}'.format(params.checkpoint_dir, params.model_name) 80 | params.out_file = '{}/predict_{}'.format(params.checkpoint_dir, params.model_name[:-3]) 81 | if __name__ == '__main__': 82 | data = load_data(params, name='test') 83 | encoder, decoder, _ = build_mt_model(params) 84 | encoder.eval() 85 | decoder.eval() 86 | iterator = data.get_iterator(shuffle=False, group_by_size=False)() 87 | file = open(params.out_file, 'w',encoding='utf-8') 88 | total = 0 89 | with torch.no_grad(): 90 | for (sen1, len1) in iterator: 91 | len1, bak_order = len1.sort(descending=True) 92 | sen1 = sen1[:,bak_order] 93 | sen1 = sen1.cuda() 94 | encoded = encoder(sen1, len1) 95 | sent2, len2, _ = decoder.generate(encoded) 96 | total += len2.size(0) 97 | logger.info('Translating %i sentences.' % total) 98 | for j in bak_order.argsort().tolist(): 99 | file.write(params.tgt_dico.idx2string(sent2[:, j]).replace('@@ ', '')+'\n') 100 | file.close() 101 | # calculate bleu value 102 | ''' 103 | command = f'perl scripts/multi-bleu.perl {params.reference_file} < {params.out_file}' 104 | print(command) 105 | p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) 106 | result = p.communicate()[0].decode("utf-8") 107 | bleu = re.findall(r"BLEU = (.+?),", result)[0] 108 | logger.info(result) 109 | logger.info(bleu) 110 | file.close() 111 | with open('{}/bleu.log'.format(params.checkpoint_dir),'a+') as f: 112 | f.write(str(params.model_name)+' '+result) 113 | ''' 114 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/generic/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | 172 | print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 173 | 174 | sub my_log { 175 | return -9999999999 unless $_[0]; 176 | return log($_[0]); 177 | } 178 | -------------------------------------------------------------------------------- /scripts/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | 172 | print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 173 | 174 | sub my_log { 175 | return -9999999999 unless $_[0]; 176 | return log($_[0]); 177 | } 178 | -------------------------------------------------------------------------------- /scripts/learn_joint_bpe_and_vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author: Rico Sennrich 4 | 5 | """Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text. 6 | This script learns BPE jointly on a concatenation of a list of texts (typically the source and target side of a parallel corpus, 7 | applies the learned operation to each and (optionally) returns the resulting vocabulary of each text. 8 | The vocabulary can be used in apply_bpe.py to avoid producing symbols that are rare or OOV in a training text. 9 | 10 | Reference: 11 | Rico Sennrich, Barry Haddow and Alexandra Birch (2016). Neural Machine Translation of Rare Words with Subword Units. 12 | Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany. 13 | """ 14 | 15 | from __future__ import unicode_literals 16 | 17 | import sys 18 | import os 19 | import inspect 20 | import codecs 21 | import argparse 22 | import tempfile 23 | import warnings 24 | from collections import Counter 25 | 26 | #hack to get imports working if running this as a script, or within a package 27 | if __name__ == '__main__': 28 | import learn_bpe 29 | import apply_bpe 30 | else: 31 | from . import learn_bpe 32 | from . import apply_bpe 33 | 34 | # hack for python2/3 compatibility 35 | from io import open 36 | argparse.open = open 37 | 38 | def create_parser(subparsers=None): 39 | 40 | if subparsers: 41 | parser = subparsers.add_parser('learn-joint-bpe-and-vocab', 42 | formatter_class=argparse.RawDescriptionHelpFormatter, 43 | description="learn BPE-based word segmentation") 44 | else: 45 | parser = argparse.ArgumentParser( 46 | formatter_class=argparse.RawDescriptionHelpFormatter, 47 | description="learn BPE-based word segmentation") 48 | 49 | parser.add_argument( 50 | '--input', '-i', type=argparse.FileType('r'), required=True, nargs = '+', 51 | metavar='PATH', 52 | help="Input texts (multiple allowed).") 53 | parser.add_argument( 54 | '--output', '-o', type=argparse.FileType('w'), required=True, 55 | metavar='PATH', 56 | help="Output file for BPE codes.") 57 | parser.add_argument( 58 | '--symbols', '-s', type=int, default=10000, 59 | help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))") 60 | parser.add_argument( 61 | '--separator', type=str, default='@@', metavar='STR', 62 | help="Separator between non-final subword units (default: '%(default)s'))") 63 | parser.add_argument( 64 | '--write-vocabulary', type=argparse.FileType('w'), required=True, nargs = '+', default=None, 65 | metavar='PATH', dest='vocab', 66 | help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py') 67 | parser.add_argument( 68 | '--min-frequency', type=int, default=2, metavar='FREQ', 69 | help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))') 70 | parser.add_argument( 71 | '--total-symbols', '-t', action="store_true", 72 | help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).") 73 | parser.add_argument( 74 | '--verbose', '-v', action="store_true", 75 | help="verbose mode.") 76 | 77 | return parser 78 | 79 | def learn_joint_bpe_and_vocab(args): 80 | 81 | if args.vocab and len(args.input) != len(args.vocab): 82 | sys.stderr.write('Error: number of input files and vocabulary files must match\n') 83 | sys.exit(1) 84 | 85 | # read/write files as UTF-8 86 | args.input = [codecs.open(f.name, encoding='UTF-8') for f in args.input] 87 | args.vocab = [codecs.open(f.name, 'w', encoding='UTF-8') for f in args.vocab] 88 | 89 | # get combined vocabulary of all input texts 90 | full_vocab = Counter() 91 | for f in args.input: 92 | full_vocab += learn_bpe.get_vocabulary(f) 93 | f.seek(0) 94 | 95 | vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()] 96 | 97 | # learn BPE on combined vocabulary 98 | with codecs.open(args.output.name, 'w', encoding='UTF-8') as output: 99 | learn_bpe.learn_bpe(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True, total_symbols=args.total_symbols) 100 | 101 | with codecs.open(args.output.name, encoding='UTF-8') as codes: 102 | bpe = apply_bpe.BPE(codes, separator=args.separator) 103 | 104 | # apply BPE to each training corpus and get vocabulary 105 | for train_file, vocab_file in zip(args.input, args.vocab): 106 | 107 | tmp = tempfile.NamedTemporaryFile(delete=False) 108 | tmp.close() 109 | 110 | tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8') 111 | 112 | train_file.seek(0) 113 | for line in train_file: 114 | tmpout.write(bpe.segment(line).strip()) 115 | tmpout.write('\n') 116 | 117 | tmpout.close() 118 | tmpin = codecs.open(tmp.name, encoding='UTF-8') 119 | 120 | vocab = learn_bpe.get_vocabulary(tmpin) 121 | tmpin.close() 122 | os.remove(tmp.name) 123 | 124 | for key, freq in sorted(vocab.items(), key=lambda x: x[1], reverse=True): 125 | vocab_file.write("{0} {1}\n".format(key, freq)) 126 | vocab_file.close() 127 | 128 | 129 | if __name__ == '__main__': 130 | 131 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 132 | newdir = os.path.join(currentdir, 'subword_nmt') 133 | if os.path.isdir(newdir): 134 | warnings.simplefilter('default') 135 | warnings.warn( 136 | "this script's location has moved to {0}. This symbolic link will be removed in a future version. Please point to the new location, or install the package and use the command 'subword-nmt'".format(newdir), 137 | DeprecationWarning 138 | ) 139 | 140 | # python 2/3 compatibility 141 | if sys.version_info < (3, 0): 142 | sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) 143 | sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) 144 | sys.stdin = codecs.getreader('UTF-8')(sys.stdin) 145 | else: 146 | sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer) 147 | sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer) 148 | sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer) 149 | 150 | parser = create_parser() 151 | args = parser.parse_args() 152 | 153 | if sys.version_info < (3, 0): 154 | args.separator = args.separator.decode('UTF-8') 155 | 156 | assert(len(args.input) == len(args.vocab)) 157 | 158 | learn_joint_bpe_and_vocab(args) 159 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/generic/bsbleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # compute Bleu scores with confidence intervals via boostrap resampling 3 | # written by Ulrich Germann 4 | # 5 | # This file is part of moses. Its use is licensed under the GNU Lesser General 6 | # Public License version 2.1 or, at your option, any later version. 7 | 8 | from argparse import ArgumentParser 9 | import math 10 | import os 11 | from random import randint 12 | import sys, gzip 13 | 14 | 15 | def count_ngrams(snt, max_n): 16 | """ 17 | Return a dictionary of ngram counts (up to length /max_n/) 18 | for sentence (list of words) /snt/. 19 | """ 20 | ret = {} 21 | for i in xrange(len(snt)): 22 | for k in xrange(i + 1, min(i + max_n + 1, len(snt) + 1)): 23 | key = tuple(snt[i:k]) 24 | ret[key] = ret.get(key, 0) + 1 25 | return ret 26 | 27 | 28 | def max_counts(ng1, ng2): 29 | """ 30 | Return a dicitonary of ngram counts such that 31 | each count is the greater of the two individual counts 32 | for each ngram in the input ngram count dictionaries 33 | /ng1/ and /ng2/. 34 | """ 35 | ret = ng1.copy() 36 | for k, v in ng2.items(): 37 | ret[k] = max(ret.get(k, 0), v) 38 | return ret 39 | 40 | 41 | def ng_hits(hyp, ref, max_n): 42 | """ 43 | Return a list of ngram counts such that each ngram count 44 | is the minimum of the counts in hyp and ref, up to ngram 45 | length /max_n/. 46 | """ 47 | ret = [0 for i in xrange(max_n)] 48 | for ng, cnt in hyp.items(): 49 | k = ng 50 | if len(k) <= max_n: 51 | ret[len(k) - 1] += min(cnt, ref.get(ng, 0)) 52 | return ret 53 | 54 | 55 | class BleuScore: 56 | def __init__(self, hyp, ref, max_n=4, bootstrap=1000): 57 | # print len(hyp.ngrams), len(ref.ngrams), "X" 58 | self.hits = [ 59 | ng_hits(hyp.ngrams[i], ref.ngrams[i], max_n) 60 | for i in xrange(len(hyp.ngrams))] 61 | self.max_n = max_n 62 | self.hyp = hyp 63 | self.ref = ref 64 | self.lower = None 65 | self.upper = None 66 | self.median = None 67 | self.actual = self.score([i for i in xrange(len(hyp.snt))]) 68 | if bootstrap: 69 | self.bootstrap = [self.score([randint(0, len(hyp.snt) - 1) 70 | for s in hyp.snt]) 71 | for i in xrange(bootstrap)] 72 | self.bootstrap.sort() 73 | else: 74 | self.bootstrap = [self.actual] 75 | pass 76 | 77 | def score(self, sample): 78 | hits = [0 for i in xrange(self.max_n)] 79 | self.hyplen = 0 80 | self.reflen = 0 81 | self.total = [0 for i in hits] 82 | for i in sample: 83 | self.hyplen += len(self.hyp.snt[i]) 84 | self.reflen += len(self.ref.snt[i]) 85 | for n in xrange(self.max_n): 86 | hits[n] += self.hits[i][n] 87 | self.total[n] += max(len(self.hyp.snt[i]) - n, 0) 88 | pass 89 | self.prec = [float(hits[n]) / self.total[n] 90 | for n in xrange(self.max_n)] 91 | ret = sum([math.log(x) for x in self.prec]) / self.max_n 92 | self.BP = min( 93 | 1, math.exp(1. - float(self.reflen) / float(self.hyplen))) 94 | ret += math.log(self.BP) 95 | return math.exp(ret) 96 | 97 | 98 | class Document: 99 | def __init__(self, fname=None): 100 | self.fname = fname 101 | if fname: 102 | if fname[-3:] == ".gz": 103 | self.snt = [line.strip().split() for line in gzip.open(fname).readlines()] 104 | else: 105 | self.snt = [line.strip().split() for line in open(fname)] 106 | pass 107 | self.ngrams = [count_ngrams(snt, 4) for snt in self.snt] 108 | # print self.snt 109 | else: 110 | self.snt = None 111 | self.ngrams = None 112 | 113 | def merge(self, R): 114 | self.fname = "multi-ref" 115 | self.ngrams = [x for x in R[0].ngrams] 116 | self.snt = [x for x in R[0].snt] 117 | for i in xrange(len(R[0].ngrams)): 118 | for k in xrange(1, len(R)): 119 | self.ngrams[i] = max_counts(self.ngrams[i], R[k].ngrams[i]) 120 | 121 | def update(self, hyp, R): 122 | for i, hyp_snt in enumerate(hyp.snt): 123 | clen = len(hyp_snt) 124 | K = 0 125 | for k in xrange(1, len(R)): 126 | k_snt = R[k].snt[i] 127 | assert len(R[k].snt) == len(hyp.snt), ( 128 | "Mismatch in number of sentences " + 129 | "between reference and candidate") 130 | if abs(len(k_snt) - clen) == abs(len(R[K].snt[i]) - clen): 131 | if len(k_snt) < len(R[K].snt[i]): 132 | K = k 133 | elif abs(len(k_snt) - clen) < abs(len(R[K].snt[i]) - clen): 134 | K = k 135 | self.snt[i] = R[K].snt[i] 136 | 137 | 138 | if __name__ == "__main__": 139 | argparser = ArgumentParser() 140 | argparser.add_argument( 141 | "-r", "--ref", nargs='+', help="Reference translation(s).") 142 | argparser.add_argument( 143 | "-c", "--cand", nargs='+', help="Candidate translations.") 144 | argparser.add_argument( 145 | "-i", "--individual", action='store_true', 146 | help="Compute BLEU scores for individual references.") 147 | argparser.add_argument( 148 | "-b", "--bootstrap", type=int, default=1000, 149 | help="Sample size for bootstrap resampling.") 150 | argparser.add_argument( 151 | "-a", "--alpha", type=float, default=.05, 152 | help="1-alpha = confidence interval.") 153 | args = argparser.parse_args(sys.argv[1:]) 154 | R = [Document(fname) for fname in args.ref] 155 | C = [Document(fname) for fname in args.cand] 156 | Rx = Document() # for multi-reference BLEU 157 | Rx.merge(R) 158 | for c in C: 159 | # compute multi-reference BLEU 160 | Rx.update(c, R) 161 | bleu = BleuScore(c, Rx, bootstrap=args.bootstrap) 162 | print "%5.2f %s [%5.2f-%5.2f; %5.2f] %s" % ( 163 | 100 * bleu.actual, 164 | os.path.basename(Rx.fname), 165 | 100 * bleu.bootstrap[int((args.alpha / 2) * args.bootstrap)], 166 | 100 * bleu.bootstrap[int((1 - (args.alpha / 2)) * args.bootstrap)], 167 | 100 * bleu.bootstrap[int(.5 * args.bootstrap)], 168 | c.fname) # os.path.basename(c.fname)) 169 | 170 | if args.individual: 171 | for r in R: 172 | bleu = BleuScore(c, r, bootstrap=args.bootstrap) 173 | print " %5.2f %s" % ( 174 | 100 * bleu.actual, os.path.basename(r.fname)) 175 | # print bleu.prec, bleu.hyplen, bleu.reflen, bleu.BP 176 | 177 | # print [ 178 | # sum([bleu.hits[i][n] for i in xrange(len(bleu.hits))]) 179 | # for n in xrange(4)] 180 | -------------------------------------------------------------------------------- /scripts/multi-bleu-detok.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # This file uses the internal tokenization of mteval-v13a.pl, 7 | # giving the exact same (case-sensitive) results on untokenized text. 8 | # Using this script with detokenized output and untokenized references is 9 | # preferrable over multi-bleu.perl, since scores aren't affected by tokenization differences. 10 | # 11 | # like multi-bleu.perl , it supports plain text input and multiple references. 12 | 13 | # $Id$ 14 | use warnings; 15 | use strict; 16 | 17 | my $lowercase = 0; 18 | if ($ARGV[0] eq "-lc") { 19 | $lowercase = 1; 20 | shift; 21 | } 22 | 23 | my $stem = $ARGV[0]; 24 | if (!defined $stem) { 25 | print STDERR "usage: multi-bleu-detok.pl [-lc] reference < hypothesis\n"; 26 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 27 | exit(1); 28 | } 29 | 30 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 31 | 32 | my @REF; 33 | my $ref=0; 34 | while(-e "$stem$ref") { 35 | &add_to_ref("$stem$ref",\@REF); 36 | $ref++; 37 | } 38 | &add_to_ref($stem,\@REF) if -e $stem; 39 | die("ERROR: could not find reference file $stem") unless scalar @REF; 40 | 41 | # add additional references explicitly specified on the command line 42 | shift; 43 | foreach my $stem (@ARGV) { 44 | &add_to_ref($stem,\@REF) if -e $stem; 45 | } 46 | 47 | 48 | 49 | sub add_to_ref { 50 | my ($file,$REF) = @_; 51 | my $s=0; 52 | if ($file =~ /.gz$/) { 53 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 54 | } else { 55 | open(REF,$file) or die "Can't read $file"; 56 | } 57 | while() { 58 | chop; 59 | $_ = tokenization($_); 60 | push @{$$REF[$s++]}, $_; 61 | } 62 | close(REF); 63 | } 64 | 65 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 66 | my $s=0; 67 | while() { 68 | chop; 69 | $_ = lc if $lowercase; 70 | $_ = tokenization($_); 71 | my @WORD = split; 72 | my %REF_NGRAM = (); 73 | my $length_translation_this_sentence = scalar(@WORD); 74 | my ($closest_diff,$closest_length) = (9999,9999); 75 | foreach my $reference (@{$REF[$s]}) { 76 | # print "$s $_ <=> $reference\n"; 77 | $reference = lc($reference) if $lowercase; 78 | my @WORD = split(' ',$reference); 79 | my $length = scalar(@WORD); 80 | my $diff = abs($length_translation_this_sentence-$length); 81 | if ($diff < $closest_diff) { 82 | $closest_diff = $diff; 83 | $closest_length = $length; 84 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 85 | } elsif ($diff == $closest_diff) { 86 | $closest_length = $length if $length < $closest_length; 87 | # from two references with the same closeness to me 88 | # take the *shorter* into account, not the "first" one. 89 | } 90 | for(my $n=1;$n<=4;$n++) { 91 | my %REF_NGRAM_N = (); 92 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 93 | my $ngram = "$n"; 94 | for(my $w=0;$w<$n;$w++) { 95 | $ngram .= " ".$WORD[$start+$w]; 96 | } 97 | $REF_NGRAM_N{$ngram}++; 98 | } 99 | foreach my $ngram (keys %REF_NGRAM_N) { 100 | if (!defined($REF_NGRAM{$ngram}) || 101 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 102 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 103 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 104 | } 105 | } 106 | } 107 | } 108 | $length_translation += $length_translation_this_sentence; 109 | $length_reference += $closest_length; 110 | for(my $n=1;$n<=4;$n++) { 111 | my %T_NGRAM = (); 112 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 113 | my $ngram = "$n"; 114 | for(my $w=0;$w<$n;$w++) { 115 | $ngram .= " ".$WORD[$start+$w]; 116 | } 117 | $T_NGRAM{$ngram}++; 118 | } 119 | foreach my $ngram (keys %T_NGRAM) { 120 | $ngram =~ /^(\d+) /; 121 | my $n = $1; 122 | # my $corr = 0; 123 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 124 | $TOTAL[$n] += $T_NGRAM{$ngram}; 125 | if (defined($REF_NGRAM{$ngram})) { 126 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 127 | $CORRECT[$n] += $T_NGRAM{$ngram}; 128 | # $corr = $T_NGRAM{$ngram}; 129 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 130 | } 131 | else { 132 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 133 | # $corr = $REF_NGRAM{$ngram}; 134 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 135 | } 136 | } 137 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 138 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 139 | } 140 | } 141 | $s++; 142 | } 143 | my $brevity_penalty = 1; 144 | my $bleu = 0; 145 | 146 | my @bleu=(); 147 | 148 | for(my $n=1;$n<=4;$n++) { 149 | if (defined ($TOTAL[$n])){ 150 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 151 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 152 | }else{ 153 | $bleu[$n]=0; 154 | } 155 | } 156 | 157 | if ($length_reference==0){ 158 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 159 | exit(1); 160 | } 161 | 162 | if ($length_translation<$length_reference) { 163 | $brevity_penalty = exp(1-$length_reference/$length_translation); 164 | } 165 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 166 | my_log( $bleu[2] ) + 167 | my_log( $bleu[3] ) + 168 | my_log( $bleu[4] ) ) / 4) ; 169 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 170 | 100*$bleu, 171 | 100*$bleu[1], 172 | 100*$bleu[2], 173 | 100*$bleu[3], 174 | 100*$bleu[4], 175 | $brevity_penalty, 176 | $length_translation / $length_reference, 177 | $length_translation, 178 | $length_reference; 179 | 180 | sub my_log { 181 | return -9999999999 unless $_[0]; 182 | return log($_[0]); 183 | } 184 | 185 | 186 | 187 | sub tokenization 188 | { 189 | my ($norm_text) = @_; 190 | 191 | # language-independent part: 192 | $norm_text =~ s///g; # strip "skipped" tags 193 | $norm_text =~ s/-\n//g; # strip end-of-line hyphenation and join lines 194 | $norm_text =~ s/\n/ /g; # join lines 195 | $norm_text =~ s/"/"/g; # convert SGML tag for quote to " 196 | $norm_text =~ s/&/&/g; # convert SGML tag for ampersand to & 197 | $norm_text =~ s/</ 198 | $norm_text =~ s/>/>/g; # convert SGML tag for greater-than to < 199 | 200 | # language-dependent part (assuming Western languages): 201 | $norm_text = " $norm_text "; 202 | $norm_text =~ s/([\{-\~\[-\` -\&\(-\+\:-\@\/])/ $1 /g; # tokenize punctuation 203 | $norm_text =~ s/([^0-9])([\.,])/$1 $2 /g; # tokenize period and comma unless preceded by a digit 204 | $norm_text =~ s/([\.,])([^0-9])/ $1 $2/g; # tokenize period and comma unless followed by a digit 205 | $norm_text =~ s/([0-9])(-)/$1 $2 /g; # tokenize dash when preceded by a digit 206 | $norm_text =~ s/\s+/ /g; # one space only between words 207 | $norm_text =~ s/^\s+//; # no leading space 208 | $norm_text =~ s/\s+$//; # no trailing space 209 | 210 | return $norm_text; 211 | } 212 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/generic/multi-bleu-detok.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # This file uses the internal tokenization of mteval-v13a.pl, 7 | # giving the exact same (case-sensitive) results on untokenized text. 8 | # Using this script with detokenized output and untokenized references is 9 | # preferrable over multi-bleu.perl, since scores aren't affected by tokenization differences. 10 | # 11 | # like multi-bleu.perl , it supports plain text input and multiple references. 12 | 13 | # $Id$ 14 | use warnings; 15 | use strict; 16 | 17 | my $lowercase = 0; 18 | if ($ARGV[0] eq "-lc") { 19 | $lowercase = 1; 20 | shift; 21 | } 22 | 23 | my $stem = $ARGV[0]; 24 | if (!defined $stem) { 25 | print STDERR "usage: multi-bleu-detok.pl [-lc] reference < hypothesis\n"; 26 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 27 | exit(1); 28 | } 29 | 30 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 31 | 32 | my @REF; 33 | my $ref=0; 34 | while(-e "$stem$ref") { 35 | &add_to_ref("$stem$ref",\@REF); 36 | $ref++; 37 | } 38 | &add_to_ref($stem,\@REF) if -e $stem; 39 | die("ERROR: could not find reference file $stem") unless scalar @REF; 40 | 41 | # add additional references explicitly specified on the command line 42 | shift; 43 | foreach my $stem (@ARGV) { 44 | &add_to_ref($stem,\@REF) if -e $stem; 45 | } 46 | 47 | 48 | 49 | sub add_to_ref { 50 | my ($file,$REF) = @_; 51 | my $s=0; 52 | if ($file =~ /.gz$/) { 53 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 54 | } else { 55 | open(REF,$file) or die "Can't read $file"; 56 | } 57 | while() { 58 | chop; 59 | $_ = tokenization($_); 60 | push @{$$REF[$s++]}, $_; 61 | } 62 | close(REF); 63 | } 64 | 65 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 66 | my $s=0; 67 | while() { 68 | chop; 69 | $_ = lc if $lowercase; 70 | $_ = tokenization($_); 71 | my @WORD = split; 72 | my %REF_NGRAM = (); 73 | my $length_translation_this_sentence = scalar(@WORD); 74 | my ($closest_diff,$closest_length) = (9999,9999); 75 | foreach my $reference (@{$REF[$s]}) { 76 | # print "$s $_ <=> $reference\n"; 77 | $reference = lc($reference) if $lowercase; 78 | my @WORD = split(' ',$reference); 79 | my $length = scalar(@WORD); 80 | my $diff = abs($length_translation_this_sentence-$length); 81 | if ($diff < $closest_diff) { 82 | $closest_diff = $diff; 83 | $closest_length = $length; 84 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 85 | } elsif ($diff == $closest_diff) { 86 | $closest_length = $length if $length < $closest_length; 87 | # from two references with the same closeness to me 88 | # take the *shorter* into account, not the "first" one. 89 | } 90 | for(my $n=1;$n<=4;$n++) { 91 | my %REF_NGRAM_N = (); 92 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 93 | my $ngram = "$n"; 94 | for(my $w=0;$w<$n;$w++) { 95 | $ngram .= " ".$WORD[$start+$w]; 96 | } 97 | $REF_NGRAM_N{$ngram}++; 98 | } 99 | foreach my $ngram (keys %REF_NGRAM_N) { 100 | if (!defined($REF_NGRAM{$ngram}) || 101 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 102 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 103 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 104 | } 105 | } 106 | } 107 | } 108 | $length_translation += $length_translation_this_sentence; 109 | $length_reference += $closest_length; 110 | for(my $n=1;$n<=4;$n++) { 111 | my %T_NGRAM = (); 112 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 113 | my $ngram = "$n"; 114 | for(my $w=0;$w<$n;$w++) { 115 | $ngram .= " ".$WORD[$start+$w]; 116 | } 117 | $T_NGRAM{$ngram}++; 118 | } 119 | foreach my $ngram (keys %T_NGRAM) { 120 | $ngram =~ /^(\d+) /; 121 | my $n = $1; 122 | # my $corr = 0; 123 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 124 | $TOTAL[$n] += $T_NGRAM{$ngram}; 125 | if (defined($REF_NGRAM{$ngram})) { 126 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 127 | $CORRECT[$n] += $T_NGRAM{$ngram}; 128 | # $corr = $T_NGRAM{$ngram}; 129 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 130 | } 131 | else { 132 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 133 | # $corr = $REF_NGRAM{$ngram}; 134 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 135 | } 136 | } 137 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 138 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 139 | } 140 | } 141 | $s++; 142 | } 143 | my $brevity_penalty = 1; 144 | my $bleu = 0; 145 | 146 | my @bleu=(); 147 | 148 | for(my $n=1;$n<=4;$n++) { 149 | if (defined ($TOTAL[$n])){ 150 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 151 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 152 | }else{ 153 | $bleu[$n]=0; 154 | } 155 | } 156 | 157 | if ($length_reference==0){ 158 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 159 | exit(1); 160 | } 161 | 162 | if ($length_translation<$length_reference) { 163 | $brevity_penalty = exp(1-$length_reference/$length_translation); 164 | } 165 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 166 | my_log( $bleu[2] ) + 167 | my_log( $bleu[3] ) + 168 | my_log( $bleu[4] ) ) / 4) ; 169 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 170 | 100*$bleu, 171 | 100*$bleu[1], 172 | 100*$bleu[2], 173 | 100*$bleu[3], 174 | 100*$bleu[4], 175 | $brevity_penalty, 176 | $length_translation / $length_reference, 177 | $length_translation, 178 | $length_reference; 179 | 180 | sub my_log { 181 | return -9999999999 unless $_[0]; 182 | return log($_[0]); 183 | } 184 | 185 | 186 | 187 | sub tokenization 188 | { 189 | my ($norm_text) = @_; 190 | 191 | # language-independent part: 192 | $norm_text =~ s///g; # strip "skipped" tags 193 | $norm_text =~ s/-\n//g; # strip end-of-line hyphenation and join lines 194 | $norm_text =~ s/\n/ /g; # join lines 195 | $norm_text =~ s/"/"/g; # convert SGML tag for quote to " 196 | $norm_text =~ s/&/&/g; # convert SGML tag for ampersand to & 197 | $norm_text =~ s/</ 198 | $norm_text =~ s/>/>/g; # convert SGML tag for greater-than to < 199 | 200 | # language-dependent part (assuming Western languages): 201 | $norm_text = " $norm_text "; 202 | $norm_text =~ s/([\{-\~\[-\` -\&\(-\+\:-\@\/])/ $1 /g; # tokenize punctuation 203 | $norm_text =~ s/([^0-9])([\.,])/$1 $2 /g; # tokenize period and comma unless preceded by a digit 204 | $norm_text =~ s/([\.,])([^0-9])/ $1 $2/g; # tokenize period and comma unless followed by a digit 205 | $norm_text =~ s/([0-9])(-)/$1 $2 /g; # tokenize dash when preceded by a digit 206 | $norm_text =~ s/\s+/ /g; # one space only between words 207 | $norm_text =~ s/^\s+//; # no leading space 208 | $norm_text =~ s/\s+$//; # no trailing space 209 | 210 | return $norm_text; 211 | } 212 | -------------------------------------------------------------------------------- /src/data/dictionary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from logging import getLogger 4 | 5 | logger = getLogger() 6 | 7 | BOS_WORD = '' 8 | EOS_WORD = '' 9 | PAD_WORD = '' 10 | UNK_WORD = '' 11 | 12 | 13 | class Dictionary(object): 14 | def __init__(self, id2word, word2id): 15 | assert len(id2word) == len(word2id) 16 | self.id2word = id2word 17 | self.word2id = word2id 18 | self.bos_index = word2id[BOS_WORD] 19 | self.eos_index = word2id[EOS_WORD] 20 | self.pad_index = word2id[PAD_WORD] 21 | self.unk_index = word2id[UNK_WORD] 22 | self.check_valid() 23 | 24 | def __len__(self): 25 | """Returns the number of words in the dictionary""" 26 | return len(self.id2word) 27 | 28 | def __getitem__(self, i): 29 | """ 30 | Returns the word of the specified index. 31 | """ 32 | return self.id2word[i] 33 | 34 | def __contains__(self, w): 35 | """ 36 | Returns whether a word is in the dictionary. 37 | """ 38 | return w in self.word2id 39 | 40 | def __eq__(self, y): 41 | """ 42 | Compare this dictionary with another one. 43 | """ 44 | self.check_valid() 45 | y.check_valid() 46 | if len(self.id2word) != len(y): 47 | return False 48 | return all(self.id2word[i] == y[i] for i in range(len(y))) 49 | 50 | def check_valid(self): 51 | """ 52 | Check that the dictionary is valid. 53 | """ 54 | assert self.bos_index == 0 55 | assert self.eos_index == 1 56 | assert self.pad_index == 2 57 | assert self.unk_index == 3 58 | assert len(self.id2word) == len(self.word2id) 59 | for i in range(len(self.id2word)): 60 | assert self.word2id[self.id2word[i]] == i 61 | 62 | def idx2string(self, indexes): 63 | str = [] 64 | for i in indexes: 65 | i = i.item() 66 | if i == self.bos_index: 67 | continue 68 | if i == self.eos_index: 69 | break 70 | str.append(self.id2word[i]) 71 | return ' '.join(str) 72 | 73 | def str2idx(self, str): 74 | return torch.Tensor([self.index[w] for w in str.split()]) 75 | 76 | def index(self, word, no_unk=False): 77 | """ 78 | Returns the index of the specified word. 79 | """ 80 | if no_unk: 81 | return self.word2id[word] 82 | else: 83 | return self.word2id.get(word, self.unk_index) 84 | 85 | def prune(self, max_vocab): 86 | """ 87 | Limit the vocabulary size. 88 | """ 89 | assert max_vocab >= 1 90 | self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab} 91 | self.word2id = {v: k for k, v in self.id2word.items()} 92 | self.check_valid() 93 | 94 | @staticmethod 95 | def read_vocab(vocab_path): 96 | """ 97 | Create a dictionary from a vocabulary file. 98 | """ 99 | skipped = 0 100 | assert os.path.isfile(vocab_path), vocab_path 101 | word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3} 102 | f = open(vocab_path, 'r', encoding='utf-8') 103 | for i, line in enumerate(f): 104 | if '\u2028' in line: 105 | skipped += 1 106 | continue 107 | line = line.rstrip().split() 108 | assert len(line) == 2, (i, line) 109 | assert line[0] not in word2id and line[1].isdigit(), (i, line) 110 | word2id[line[0]] = 4 + i - skipped # shift because of extra words 111 | f.close() 112 | id2word = {v: k for k, v in word2id.items()} 113 | dico = Dictionary(id2word, word2id) 114 | logger.info("Read %i words from the vocabulary file." % len(dico)) 115 | if skipped > 0: 116 | logger.warning("Skipped %i empty lines!" % skipped) 117 | return dico 118 | 119 | @staticmethod 120 | def index_data(src_txt_path, tgt_txt_path, src_dico, tgt_dico, bin_path): 121 | """ 122 | Index sentences with a dictionary. 123 | """ 124 | if os.path.isfile(bin_path): 125 | print("Exsited file %s ..." % bin_path) 126 | return None 127 | 128 | positions_s = [] 129 | sentences_s = [] 130 | unk_words_s = {} 131 | 132 | positions_t = [] 133 | sentences_t = [] 134 | unk_words_t = {} 135 | 136 | # index sentences 137 | fs = open(src_txt_path, 'r', encoding='utf-8') 138 | ft = open(tgt_txt_path, 'r', encoding='utf-8') 139 | for i, line in enumerate(fs): 140 | line_t = ft.readline() 141 | if i % 1000000 == 0 and i > 0: 142 | print(i) 143 | s = line.rstrip().split() 144 | s_t = line_t.rstrip().split() 145 | # skip empty sentences 146 | if (len(s) == 0) or (len(s_t) == 0): 147 | print("Empty sentence in line %i." % i) 148 | # continue 149 | # index sentence words 150 | count_unk_s = 0 151 | count_unk_t = 0 152 | count_unk_sdoc = 0 153 | indexed_s = [] 154 | indexed_t = [] 155 | indexed_sdoc = [] 156 | for w in s: 157 | word_id = src_dico.index(w, no_unk=False) 158 | if word_id < 4 and word_id != src_dico.unk_index: 159 | logger.warning('Found unexpected special word "%s" (%i)!!' % (w, word_id)) 160 | continue 161 | indexed_s.append(word_id) 162 | if word_id == src_dico.unk_index: 163 | unk_words_s[w] = unk_words_s.get(w, 0) + 1 164 | count_unk_s += 1 165 | for w in s_t: 166 | word_id = tgt_dico.index(w, no_unk=False) 167 | if word_id < 4 and word_id != tgt_dico.unk_index: 168 | logger.warning('Found unexpected special word "%s" (%i)!!' % (w, word_id)) 169 | continue 170 | indexed_t.append(word_id) 171 | if word_id == tgt_dico.unk_index: 172 | unk_words_t[w] = unk_words_t.get(w, 0) + 1 173 | count_unk_t += 1 174 | # add sentence 175 | positions_s.append([len(sentences_s), len(sentences_s) + len(indexed_s)]) 176 | sentences_s.extend(indexed_s) 177 | sentences_s.append(-1) 178 | 179 | positions_t.append([len(sentences_t), len(sentences_t) + len(indexed_t)]) 180 | sentences_t.extend(indexed_t) 181 | sentences_t.append(-1) 182 | 183 | fs.close() 184 | ft.close() 185 | 186 | # tensorize data 187 | positions_s = torch.LongTensor(positions_s) 188 | sentences_s = torch.LongTensor(sentences_s) 189 | positions_t = torch.LongTensor(positions_t) 190 | sentences_t = torch.LongTensor(sentences_t) 191 | 192 | data = { 193 | 'src_dico': src_dico, 194 | 'tgt_dico':tgt_dico, 195 | 'src_positions': positions_s, 196 | 'tgt_positions':positions_t, 197 | 'src_sentences': sentences_s, 198 | 'tgt_sentences':sentences_t, 199 | 'src_unk_words': unk_words_s, 200 | 'tgt_unk_words':unk_words_t 201 | } 202 | print("Saving the data to %s ..." % bin_path) 203 | torch.save(data, bin_path) 204 | 205 | return data 206 | -------------------------------------------------------------------------------- /src/modules/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | 6 | from .. import fairseq_utils as utils 7 | 8 | 9 | class MultiheadAttention(nn.Module): 10 | """Multi-headed attention. 11 | 12 | See "Attention Is All You Need" for more details. 13 | """ 14 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True): 15 | super().__init__() 16 | self.embed_dim = embed_dim 17 | self.num_heads = num_heads 18 | self.dropout = dropout 19 | self.head_dim = embed_dim // num_heads 20 | assert self.head_dim * num_heads == self.embed_dim 21 | self.scaling = self.head_dim**-0.5 22 | self._mask = None 23 | 24 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 25 | if bias: 26 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) 27 | else: 28 | self.register_parameter('in_proj_bias', None) 29 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | nn.init.xavier_uniform_(self.in_proj_weight) 35 | nn.init.xavier_uniform_(self.out_proj.weight) 36 | if self.in_proj_bias is not None: 37 | nn.init.constant_(self.in_proj_bias, 0.) 38 | nn.init.constant_(self.out_proj.bias, 0.) 39 | 40 | def forward(self, query, key, value, mask_future_timesteps=False, 41 | key_padding_mask=None, incremental_state=None, 42 | need_weights=True, static_kv=False): 43 | """Input shape: Time x Batch x Channel 44 | 45 | Self-attention can be implemented by passing in the same arguments for 46 | query, key and value. Future timesteps can be masked with the 47 | `mask_future_timesteps` argument. Padding elements can be excluded from 48 | the key by passing a binary ByteTensor (`key_padding_mask`) with shape: 49 | batch x src_len, where padding elements are indicated by 1s. 50 | """ 51 | 52 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 53 | kv_same = key.data_ptr() == value.data_ptr() 54 | 55 | if incremental_state is not None: 56 | saved_state = utils.get_incremental_state( 57 | self, 58 | incremental_state, 59 | 'attn_state', 60 | ) or {} 61 | 62 | if 'prev_key' in saved_state: 63 | # previous time steps are cached - no need to recompute 64 | # key and value if they are static 65 | if static_kv: 66 | assert kv_same 67 | key = key.data.new(0) 68 | value = value.data.new(0) 69 | else: 70 | saved_state = None 71 | 72 | tgt_len, bsz, embed_dim = query.size() 73 | assert embed_dim == self.embed_dim 74 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 75 | assert key.size() == value.size() 76 | 77 | if qkv_same: 78 | # self-attention 79 | q, k, v = self.in_proj_qkv(query) 80 | elif kv_same: 81 | # encoder-decoder attention 82 | q = self.in_proj_q(query) 83 | k, v = self.in_proj_kv(key) 84 | else: 85 | q = self.in_proj_q(query) 86 | k = self.in_proj_k(key) 87 | v = self.in_proj_v(value) 88 | q *= self.scaling 89 | 90 | if saved_state is not None: 91 | if 'prev_key' in saved_state: 92 | k = torch.cat((saved_state['prev_key'], k), dim=0) 93 | if 'prev_value' in saved_state: 94 | v = torch.cat((saved_state['prev_value'], v), dim=0) 95 | saved_state['prev_key'] = k 96 | saved_state['prev_value'] = v 97 | utils.set_incremental_state( 98 | self, 99 | incremental_state, 100 | 'attn_state', 101 | saved_state, 102 | ) 103 | 104 | src_len = k.size(0) 105 | 106 | if key_padding_mask is not None: 107 | assert key_padding_mask.size(0) == bsz 108 | assert key_padding_mask.size(1) == src_len 109 | 110 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 111 | k = k.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 112 | v = v.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 113 | 114 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 115 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 116 | 117 | # only apply masking at training time (when incremental state is None) 118 | if mask_future_timesteps and incremental_state is None: 119 | assert query.size() == key.size(), \ 120 | 'mask_future_timesteps only applies to self-attention' 121 | attn_weights += self.buffered_mask(attn_weights.data).detach().unsqueeze(0) 122 | if key_padding_mask is not None: 123 | # don't attend to padding symbols 124 | if key_padding_mask.data.max() > 0: 125 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 126 | attn_weights = attn_weights.masked_fill( 127 | key_padding_mask.unsqueeze(1).unsqueeze(2), 128 | -1e18, 129 | ) 130 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 131 | attn_weights = F.softmax(attn_weights, dim=-1) 132 | attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) 133 | 134 | attn = torch.bmm(attn_weights, v) 135 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 136 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 137 | attn = self.out_proj(attn) 138 | 139 | if need_weights: 140 | # average attention weights over heads 141 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 142 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 143 | else: 144 | attn_weights = None 145 | 146 | return attn, attn_weights 147 | 148 | def in_proj_qkv(self, query): 149 | return self._in_proj(query).chunk(3, dim=-1) 150 | 151 | def in_proj_kv(self, key): 152 | if key.numel() == 0: 153 | return (key, key) 154 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 155 | 156 | def in_proj_q(self, query): 157 | return self._in_proj(query, end=self.embed_dim) 158 | 159 | def in_proj_k(self, key): 160 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 161 | 162 | def in_proj_v(self, value): 163 | return self._in_proj(value, start=2 * self.embed_dim) 164 | 165 | def _in_proj(self, input, start=None, end=None): 166 | weight = self.in_proj_weight 167 | bias = self.in_proj_bias 168 | if end is not None: 169 | weight = weight[:end, :] 170 | if bias is not None: 171 | bias = bias[:end] 172 | if start is not None: 173 | weight = weight[start:, :] 174 | if bias is not None: 175 | bias = bias[start:] 176 | return F.linear(input, weight, bias) 177 | 178 | def buffered_mask(self, tensor): 179 | dim = tensor.size(-1) 180 | if self._mask is None: 181 | self._mask = torch.triu(tensor.new(dim, dim).fill_(-1e18), 1) 182 | if self._mask.size(0) < dim: 183 | self._mask = torch.triu(self._mask.resize_(dim, dim).fill_(-1e18), 1) 184 | return self._mask[:dim, :dim] 185 | 186 | def reorder_incremental_state(self, incremental_state, new_order): 187 | saved_state = utils.get_incremental_state(self, incremental_state, 'attn_state') 188 | if saved_state is not None: 189 | for k in saved_state.keys(): 190 | saved_state[k] = saved_state[k].index_select(1, new_order) 191 | utils.set_incremental_state(self, incremental_state, 'attn_state', saved_state) 192 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import re 2 | import subprocess 3 | from itertools import chain 4 | from logging import getLogger 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from torch.nn.utils import clip_grad_norm_ 9 | from tqdm import tqdm 10 | 11 | from src.utils import get_optimizer 12 | from src.distributed_utils import all_gather_list,all_reduce_and_rescale_tensors 13 | 14 | 15 | logger = getLogger() 16 | 17 | 18 | class TrainerMT(): 19 | def __init__(self, encoder, decoder, data, test_data, params, num_updates): 20 | self.encoder = encoder 21 | self.decoder = decoder 22 | self.data = data 23 | self.test_data = test_data 24 | self.params = params 25 | 26 | self.enc_dec_params = list(self.encoder.parameters())+list(self.decoder.parameters()) 27 | # optimizers 28 | self.optimizer = get_optimizer(self.enc_dec_params, self.params.optim) 29 | self.optimizer._num_updates = num_updates 30 | # training statistics 31 | self.epoch = getattr(params, 'now_epoch', 0) 32 | self.n_iter = 0 33 | self.oom = 0 34 | self.n_sentences = 0 35 | self.stats = { 36 | 'processed_s': 0, 37 | 'processed_w': 0, 38 | 'loss': [] 39 | } 40 | self.sample_sizes = [] 41 | 42 | def train_epoch(self): 43 | self.iterator = self.get_iterator() 44 | for (sent1, len1), (sent2, len2) in tqdm(self.iterator, mininterval=2, desc=' - (Training) ', leave=False, total=self.data.total): 45 | self.train_step(sent1, len1, sent2, len2) 46 | # save only when 1. the main process 2. once in update_freq 3. every save_freq_update updates 47 | if ((self.n_iter+1) % self.params.update_freq == 0) and (self.params.rank == 0 or self.params.gpu_num == 1) and (self.optimizer._num_updates != 0 and self.optimizer._num_updates % self.params.save_freq_update == 0): 48 | checkpoint = { 49 | 'encoder': self.encoder.state_dict(), 50 | 'decoder': self.decoder.state_dict(), 51 | 'params': self.params, 52 | 'epoch': self.epoch, 53 | 'num_updates': self.optimizer._num_updates 54 | } 55 | if self.params.save_optimizer: 56 | checkpoint['optimizer'] = self.optimizer.state_dict() 57 | 58 | self.params.model_name = f'model_epoch{self.epoch}_update{self.optimizer._num_updates}.pt' 59 | torch.save(checkpoint, self.params.checkpoint_dir + '/' + self.params.model_name) 60 | 61 | # do evaluation 62 | if self.params.do_eval: 63 | self.evaluate() 64 | self.encoder.train() 65 | self.decoder.train() 66 | 67 | # save epoch checkpoint 68 | if self.params.gpu_num == 1 or self.params.rank == 0: 69 | checkpoint = { 70 | 'encoder': self.encoder.state_dict(), 71 | 'decoder': self.decoder.state_dict(), 72 | 'params': self.params, 73 | 'epoch': self.epoch, 74 | 'num_updates': self.optimizer._num_updates 75 | } 76 | if (self.epoch == self.params.max_epoch-1) or self.params.save_optimizer: 77 | checkpoint['optimizer'] = self.optimizer.state_dict() 78 | self.params.model_name = f'model_epoch{self.epoch}.pt' 79 | torch.save(checkpoint, self.params.checkpoint_dir + '/' + self.params.model_name) 80 | # do evaluation 81 | if self.params.do_eval: 82 | self.evaluate() 83 | self.encoder.train() 84 | self.decoder.train() 85 | self.epoch += 1 86 | 87 | def train_step(self, sent1, len1, sent2, len2): 88 | if self.params.update_freq == 1: 89 | need_zero = True 90 | need_reduction = True 91 | else: 92 | need_reduction = True if (self.n_iter+1) % self.params.update_freq == 0 else False 93 | need_zero = True if self.n_iter % self.params.update_freq == 0 else False 94 | self.encoder.train() 95 | self.decoder.train() 96 | sent1, sent2 = sent1.cuda(), sent2.cuda() 97 | try: 98 | if need_zero: 99 | self.optimizer.zero_grad() 100 | encoded = self.encoder(sent1, len1) 101 | scores = self.decoder(encoded, sent2[:-1]) 102 | loss,sample_size = self.decoder.loss_fn(scores.view(-1, self.decoder.n_words), sent2[1:].view(-1)) 103 | 104 | # check NaN 105 | if (loss != loss).data.any(): 106 | logger.error("NaN detected") 107 | exit() 108 | # optimizer 109 | loss.backward() 110 | self.sample_sizes.append(sample_size) 111 | # print(f'forward gpu-{self.params.rank},iter-{self.n_iter}{self.enc_dec_params[0].grad.data[0][0:20]}') 112 | 113 | except Exception as e: 114 | logger.error(e) 115 | torch.cuda.empty_cache() 116 | self.n_iter += 1 117 | self.oom += 1 118 | return 119 | 120 | if need_reduction: 121 | try: 122 | # sample_sizes contain gpu_num*update_delay numbers of tokens like [1948, 2013, ..] 123 | # now we get the total token nums of all delay batch and gpus 124 | if self.params.gpu_num > 1: 125 | sample_sizes = all_gather_list(self.sample_sizes) 126 | sample_sizes = list(chain.from_iterable(sample_sizes)) 127 | sample_size = sum(sample_sizes) 128 | grads = [p.grad.data for p in self.enc_dec_params if p.requires_grad and p.grad is not None] 129 | all_reduce_and_rescale_tensors(grads,float(sample_size)/self.params.gpu_num) 130 | else: 131 | sample_size = sum(self.sample_sizes) 132 | for p in self.enc_dec_params: 133 | if p.requires_grad and p.grad is not None: 134 | p.grad.data.mul_(1/float(sample_size)) 135 | clip_grad_norm_(self.enc_dec_params, self.params.clip_grad_norm) 136 | self.optimizer.step() 137 | self.sample_sizes = [] 138 | 139 | except Exception as e: 140 | logger.error(e) 141 | exit(0) 142 | 143 | # number of processed sentences / words 144 | self.stats['processed_s'] += len2.size(0) 145 | self.stats['processed_w'] += len2.sum() 146 | self.n_iter += 1 147 | del loss 148 | torch.cuda.empty_cache() 149 | 150 | def evaluate(self): 151 | test_iterator = self.test_data.get_iterator(shuffle=False, group_by_size=False)() 152 | self.encoder.eval() 153 | self.decoder.eval() 154 | self.params.out_file = '{}/predict_{}'.format(self.params.checkpoint_dir, self.params.model_name[:-3]) 155 | 156 | file = open(self.params.out_file, 'w',encoding='utf-8') 157 | total = 0 158 | out_sents = [] 159 | with torch.no_grad(): 160 | for (sen1, len1) in test_iterator: 161 | len1, bak_order = len1.sort(descending=True) 162 | sen1 = sen1[:,bak_order] 163 | sen1 = sen1.cuda() 164 | encoded = self.encoder(sen1, len1) 165 | sent2, len2, _ = self.decoder.generate(encoded) 166 | total += len2.size(0) 167 | logger.info('Translating %i sentences.' % total) 168 | for j in bak_order.argsort().tolist(): 169 | out1 = self.params.tgt_dico.idx2string(sent2[:, j]).replace('@@ ', '') 170 | out_sents.append(out1) 171 | file.write(out1 + '\n') 172 | file.close() 173 | command = f'perl scripts/multi-bleu.perl {self.params.reference_file} < {self.params.out_file}' 174 | logger.info(command) 175 | p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) 176 | result = p.communicate()[0].decode("utf-8") 177 | bleu_multibleu = re.findall(r"BLEU = (.+?),", result)[0] 178 | logger.info(bleu_multibleu) 179 | with open('{}/bleu.log'.format(self.params.checkpoint_dir),'a+') as f: 180 | f.write(f'{self.params.model_name}\t{bleu_multibleu}\n') 181 | 182 | 183 | def get_iterator(self): 184 | if self.params.gpu_num == 1: 185 | iterator = self.data.get_iterator(shuffle=True, group_by_size=True)() 186 | else: 187 | iterator = self.data.get_iterator(shuffle=True, group_by_size=True, partition=self.params.rank)() 188 | return iterator 189 | -------------------------------------------------------------------------------- /scripts/mosesdecoder/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.lt: -------------------------------------------------------------------------------- 1 | # Anything in this file, followed by a period (and an upper-case word), 2 | # does NOT indicate an end-of-sentence marker. 3 | # Special cases are included for prefixes that ONLY appear before 0-9 numbers. 4 | 5 | # Any single upper case letter followed by a period is not a sentence ender 6 | # (excluding I occasionally, but we leave it in) 7 | # usually upper case letters are initials in a name 8 | A 9 | Ā 10 | B 11 | C 12 | Č 13 | D 14 | E 15 | Ē 16 | F 17 | G 18 | Ģ 19 | H 20 | I 21 | Ī 22 | J 23 | K 24 | Ķ 25 | L 26 | Ļ 27 | M 28 | N 29 | Ņ 30 | O 31 | P 32 | Q 33 | R 34 | S 35 | Š 36 | T 37 | U 38 | Ū 39 | V 40 | W 41 | X 42 | Y 43 | Z 44 | Ž 45 | 46 | # Initialis -- Džonas 47 | Dz 48 | Dž 49 | Just 50 | 51 | # Day and month abbreviations 52 | # m. menesis d. diena g. gimes 53 | m 54 | mėn 55 | d 56 | g 57 | gim 58 | # Pirmadienis Penktadienis 59 | Pr 60 | Pn 61 | Pirm 62 | Antr 63 | Treč 64 | Ketv 65 | Penkt 66 | Šešt 67 | Sekm 68 | Saus 69 | Vas 70 | Kov 71 | Bal 72 | Geg 73 | Birž 74 | Liep 75 | Rugpj 76 | Rugs 77 | Spal 78 | Lapkr 79 | Gruod 80 | 81 | # Business, governmental, geographical terms 82 | a 83 | # aikštė 84 | adv 85 | # advokatas 86 | akad 87 | # akademikas 88 | aklg 89 | # akligatvis 90 | akt 91 | # aktorius 92 | al 93 | # alėja 94 | A.V 95 | # antspaudo vieta 96 | aps 97 | apskr 98 | # apskritis 99 | apyg 100 | # apygarda 101 | aps 102 | apskr 103 | # apskritis 104 | asist 105 | # asistentas 106 | asmv 107 | avd 108 | # asmenvardis 109 | a.k 110 | asm 111 | asm.k 112 | # asmens kodas 113 | atsak 114 | # atsakingasis 115 | atsisk 116 | sąsk 117 | # atsiskaitomoji sąskaita 118 | aut 119 | # autorius 120 | b 121 | k 122 | b.k 123 | # banko kodas 124 | bkl 125 | # bakalauras 126 | bt 127 | # butas 128 | buv 129 | # buvęs, -usi 130 | dail 131 | # dailininkas 132 | dek 133 | # dekanas 134 | dėst 135 | # dėstytojas 136 | dir 137 | # direktorius 138 | dirig 139 | # dirigentas 140 | doc 141 | # docentas 142 | drp 143 | # durpynas 144 | dš 145 | # dešinysis 146 | egz 147 | # egzempliorius 148 | eil 149 | # eilutė 150 | ekon 151 | # ekonomika 152 | el 153 | # elektroninis 154 | etc 155 | ež 156 | # ežeras 157 | faks 158 | # faksas 159 | fak 160 | # fakultetas 161 | gen 162 | # generolas 163 | gyd 164 | # gydytojas 165 | gv 166 | # gyvenvietė 167 | įl 168 | # įlanka 169 | Įn 170 | # įnagininkas 171 | insp 172 | # inspektorius 173 | pan 174 | # ir panašiai 175 | t.t 176 | # ir taip toliau 177 | k.a 178 | # kaip antai 179 | kand 180 | # kandidatas 181 | kat 182 | # katedra 183 | kyš 184 | # kyšulys 185 | kl 186 | # klasė 187 | kln 188 | # kalnas 189 | kn 190 | # knyga 191 | koresp 192 | # korespondentas 193 | kpt 194 | # kapitonas 195 | kr 196 | # kairysis 197 | kt 198 | # kitas 199 | kun 200 | # kunigas 201 | l 202 | e 203 | p 204 | l.e.p 205 | # laikinai einantis pareigas 206 | ltn 207 | # leitenantas 208 | m 209 | mst 210 | # miestas 211 | m.e 212 | # mūsų eros 213 | m.m 214 | # mokslo metai 215 | mot 216 | # moteris 217 | mstl 218 | # miestelis 219 | mgr 220 | # magistras 221 | mgnt 222 | # magistrantas 223 | mjr 224 | # majoras 225 | mln 226 | # milijonas 227 | mlrd 228 | # milijardas 229 | mok 230 | # mokinys 231 | mokyt 232 | # mokytojas 233 | moksl 234 | # mokslinis 235 | nkt 236 | # nekaitomas 237 | ntk 238 | # neteiktinas 239 | Nr 240 | nr 241 | # numeris 242 | p 243 | # ponas 244 | p.d 245 | a.d 246 | # pašto dėžutė, abonentinė dėžutė 247 | p.m.e 248 | # prieš mūsų erą 249 | pan 250 | # ir panašiai 251 | pav 252 | # paveikslas 253 | pavad 254 | # pavaduotojas 255 | pirm 256 | # pirmininkas 257 | pl 258 | # plentas 259 | plg 260 | # palygink 261 | plk 262 | # pulkininkas; pelkė 263 | pr 264 | # prospektas 265 | Kr 266 | pr.Kr 267 | # prieš Kristų 268 | prok 269 | # prokuroras 270 | prot 271 | # protokolas 272 | pss 273 | # pusiasalis 274 | pšt 275 | # paštas 276 | pvz 277 | # pavyzdžiui 278 | r 279 | # rajonas 280 | red 281 | # redaktorius 282 | rš 283 | # raštų kalbos 284 | sąs 285 | # sąsiuvinis 286 | saviv 287 | sav 288 | # savivaldybė 289 | sekr 290 | # sekretorius 291 | sen 292 | # seniūnija, seniūnas 293 | sk 294 | # skaityk; skyrius 295 | skg 296 | # skersgatvis 297 | skyr 298 | sk 299 | # skyrius 300 | skv 301 | # skveras 302 | sp 303 | # spauda; spaustuvė 304 | spec 305 | # specialistas 306 | sr 307 | # sritis 308 | st 309 | # stotis 310 | str 311 | # straipsnis 312 | stud 313 | # studentas 314 | š 315 | š.m 316 | # šių metų 317 | šnek 318 | # šnekamosios 319 | tir 320 | # tiražas 321 | tūkst 322 | # tūkstantis 323 | up 324 | # upė 325 | upl 326 | # upelis 327 | vad 328 | # vadinamasis, -oji 329 | vlsč 330 | # valsčius 331 | ved 332 | # vedėjas 333 | vet 334 | # veterinarija 335 | virš 336 | # viršininkas, viršaitis 337 | vyr 338 | # vyriausiasis, -ioji; vyras 339 | vyresn 340 | # vyresnysis 341 | vlsč 342 | # valsčius 343 | vs 344 | # viensėdis 345 | Vt 346 | vt 347 | # vietininkas 348 | vtv 349 | vv 350 | # vietovardis 351 | žml 352 | # žemėlapis 353 | 354 | # Technical terms, abbreviations used in guidebooks, advertisments, etc. 355 | # Generally lower-case. 356 | air 357 | # airiškai 358 | amer 359 | # amerikanizmas 360 | anat 361 | # anatomija 362 | angl 363 | # angl. angliskai 364 | arab 365 | # arabų 366 | archeol 367 | archit 368 | asm 369 | # asmuo 370 | astr 371 | # astronomija 372 | austral 373 | # australiškai 374 | aut 375 | # automobilis 376 | av 377 | # aviacija 378 | bažn 379 | bdv 380 | # būdvardis 381 | bibl 382 | # Biblija 383 | biol 384 | # biologija 385 | bot 386 | # botanika 387 | brt 388 | # burtai, burtažodis. 389 | brus 390 | # baltarusių 391 | buh 392 | # buhalterija 393 | chem 394 | # chemija 395 | col 396 | # collectivum 397 | con 398 | conj 399 | # conjunctivus, jungtukas 400 | dab 401 | # dab. dabartine 402 | dgs 403 | # daugiskaita 404 | dial 405 | # dialektizmas 406 | dipl 407 | dktv 408 | # daiktavardis 409 | džn 410 | # dažnai 411 | ekon 412 | el 413 | # elektra 414 | esam 415 | # esamasis laikas 416 | euf 417 | # eufemizmas 418 | fam 419 | # familiariai 420 | farm 421 | # farmacija 422 | filol 423 | # filologija 424 | filos 425 | # filosofija 426 | fin 427 | # finansai 428 | fiz 429 | # fizika 430 | fiziol 431 | # fiziologija 432 | flk 433 | # folkloras 434 | fon 435 | # fonetika 436 | fot 437 | # fotografija 438 | geod 439 | # geodezija 440 | geogr 441 | geol 442 | # geologija 443 | geom 444 | # geometrija 445 | glžk 446 | gr 447 | # graikų 448 | gram 449 | her 450 | # heraldika 451 | hidr 452 | # hidrotechnika 453 | ind 454 | # Indų 455 | iron 456 | # ironiškai 457 | isp 458 | # ispanų 459 | ist 460 | istor 461 | # istorija 462 | it 463 | # italų 464 | įv 465 | reikšm 466 | įv.reikšm 467 | # įvairiomis reikšmėmis 468 | jap 469 | # japonų 470 | juok 471 | # juokaujamai 472 | jūr 473 | # jūrininkystė 474 | kalb 475 | # kalbotyra 476 | kar 477 | # karyba 478 | kas 479 | # kasyba 480 | kin 481 | # kinematografija 482 | klaus 483 | # klausiamasis 484 | knyg 485 | # knyginis 486 | kom 487 | # komercija 488 | komp 489 | # kompiuteris 490 | kosm 491 | # kosmonautika 492 | kt 493 | # kitas 494 | kul 495 | # kulinarija 496 | kuop 497 | # kuopine 498 | l 499 | # laikas 500 | lit 501 | # literatūrinis 502 | lingv 503 | # lingvistika 504 | log 505 | # logika 506 | lot 507 | # lotynų 508 | mat 509 | # matematika 510 | maž 511 | # mažybinis 512 | med 513 | # medicina 514 | medž 515 | # medžioklė 516 | men 517 | # menas 518 | menk 519 | # menkinamai 520 | metal 521 | # metalurgija 522 | meteor 523 | min 524 | # mineralogija 525 | mit 526 | # mitologija 527 | mok 528 | # mokyklinis 529 | ms 530 | # mįslė 531 | muz 532 | # muzikinis 533 | n 534 | # naujasis 535 | neig 536 | # neigiamasis 537 | neol 538 | # neologizmas 539 | niek 540 | # niekinamai 541 | ofic 542 | # oficialus 543 | opt 544 | # optika 545 | orig 546 | # original 547 | p 548 | # pietūs 549 | pan 550 | # panašiai 551 | parl 552 | # parlamentas 553 | pat 554 | # patarlė 555 | paž 556 | # pažodžiui 557 | plg 558 | # palygink 559 | poet 560 | # poetizmas 561 | poez 562 | # poezija 563 | poligr 564 | # poligrafija 565 | polit 566 | # politika 567 | ppr 568 | # paprastai 569 | pranc 570 | pr 571 | # prancūzų, prūsų 572 | priet 573 | # prietaras 574 | prek 575 | # prekyba 576 | prk 577 | # perkeltine 578 | prs 579 | # persona, asmuo 580 | psn 581 | # pasenęs žodis 582 | psich 583 | # psichologija 584 | pvz 585 | # pavyzdžiui 586 | r 587 | # rytai 588 | rad 589 | # radiotechnika 590 | rel 591 | # religija 592 | ret 593 | # retai 594 | rus 595 | # rusų 596 | sen 597 | # senasis 598 | sl 599 | # slengas, slavų 600 | sov 601 | # sovietinis 602 | spec 603 | # specialus 604 | sport 605 | stat 606 | # statyba 607 | sudurt 608 | # sudurtinis 609 | sutr 610 | # sutrumpintas 611 | suv 612 | # suvalkiečių 613 | š 614 | # šiaurė 615 | šach 616 | # šachmatai 617 | šiaur 618 | škot 619 | # škotiškai 620 | šnek 621 | # šnekamoji 622 | teatr 623 | tech 624 | techn 625 | # technika 626 | teig 627 | # teigiamas 628 | teis 629 | # teisė 630 | tekst 631 | # tekstilė 632 | tel 633 | # telefonas 634 | teol 635 | # teologija 636 | v 637 | # tik vyriškosios, vakarai 638 | t.p 639 | t 640 | p 641 | # ir taip pat 642 | t.t 643 | # ir taip toliau 644 | t.y 645 | # tai yra 646 | vaik 647 | # vaikų 648 | vart 649 | # vartojama 650 | vet 651 | # veterinarija 652 | vid 653 | # vidurinis 654 | vksm 655 | # veiksmažodis 656 | vns 657 | # vienaskaita 658 | vok 659 | # vokiečių 660 | vulg 661 | # vulgariai 662 | zool 663 | # zoologija 664 | žr 665 | # žiūrėk 666 | ž.ū 667 | ž 668 | ū 669 | # žemės ūkis 670 | 671 | # List of titles. These are often followed by upper-case names, but do 672 | # not indicate sentence breaks 673 | # 674 | # Jo Eminencija 675 | Em. 676 | # Gerbiamasis 677 | Gerb 678 | gerb 679 | # malonus 680 | malon 681 | # profesorius 682 | Prof 683 | prof 684 | # daktaras (mokslų) 685 | Dr 686 | dr 687 | habil 688 | med 689 | # inž inžinierius 690 | inž 691 | Inž 692 | 693 | 694 | #Numbers only. These should only induce breaks when followed by a numeric sequence 695 | # add NUMERIC_ONLY after the word for this function 696 | #This case is mostly for the english "No." which can either be a sentence of its own, or 697 | #if followed by a number, a non-breaking prefix 698 | No #NUMERIC_ONLY# 699 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | 7 | logger = getLogger() 8 | 9 | 10 | class Dataset(object): 11 | def __init__(self, params): 12 | self.eos_index = params.eos_index 13 | self.pad_index = params.pad_index 14 | self.unk_index = params.unk_index 15 | self.bos_index = params.bos_index 16 | self.batch_size = params.batch_size 17 | self.batch_size_tokens = params.batch_size_tokens 18 | self.gpu_num = params.gpu_num 19 | self.seed = params.seed 20 | 21 | def batch_sentences(self, sentences): 22 | """ 23 | Take as input a list of n sentences (torch.LongTensor vectors) and return 24 | a tensor of size (s_len, n) where s_len is the length of the longest 25 | sentence, and a vector lengths containing the length of each sentence. 26 | """ 27 | lengths = torch.LongTensor([len(s) + 2 for s in sentences]) 28 | sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.pad_index) 29 | 30 | sent[0] = self.bos_index 31 | for i, s in enumerate(sentences): 32 | sent[1:lengths[i] - 1, i].copy_(s) 33 | sent[lengths[i] - 1, i] = self.eos_index 34 | 35 | return sent, lengths 36 | 37 | 38 | class MonolingualDataset(Dataset): 39 | def __init__(self, sent1, pos1, dico1, params): 40 | super(MonolingualDataset, self).__init__(params) 41 | self.sent1 = sent1 42 | self.pos1 = pos1 43 | self.dico1 = dico1 44 | self.lengths1 = (self.pos1[:, 1] - self.pos1[:, 0]) 45 | self.is_parallel = False 46 | 47 | # check number of sentences 48 | assert len(self.pos1) == (self.sent1 == -1).sum() 49 | 50 | def __len__(self): 51 | """ 52 | Number of sentences in the dataset. 53 | """ 54 | return len(self.pos1) 55 | 56 | def get_batches_iterator(self, batches): 57 | """ 58 | Return a sentences iterator, given the associated sentence batches. 59 | """ 60 | def iterator(): 61 | for sentence_ids in batches: 62 | pos1 = self.pos1[sentence_ids] 63 | sent1 = [self.sent1[a:b] for a, b in pos1] 64 | yield self.batch_sentences(sent1) 65 | return iterator 66 | 67 | def get_iterator(self, shuffle=False, group_by_size=False, n_sentences=-1): 68 | """ 69 | Return a sentences iterator. 70 | """ 71 | np.random.seed(self.seed) 72 | 73 | n_sentences = len(self.pos1) if n_sentences == -1 else n_sentences 74 | assert 0 < n_sentences <= len(self.pos1) 75 | assert type(shuffle) is bool and type(group_by_size) is bool 76 | 77 | # select sentences to iterate over 78 | if shuffle: 79 | indices = np.random.permutation(len(self.pos1))[:n_sentences] 80 | else: 81 | indices = np.arange(n_sentences) 82 | 83 | # group sentences by lengths 84 | if group_by_size: 85 | indices = indices[np.argsort(self.lengths1[indices], kind='mergesort')] 86 | 87 | # create batches / optionally shuffle them 88 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) 89 | if shuffle: 90 | np.random.shuffle(batches) 91 | 92 | # return the iterator 93 | return self.get_batches_iterator(batches) 94 | 95 | 96 | class ParallelDataset(Dataset): 97 | 98 | def __init__(self, sent1, pos1, dico1, sent2, pos2, dico2, params): 99 | super(ParallelDataset, self).__init__(params) 100 | self.sent1 = sent1 101 | self.sent2 = sent2 102 | self.pos1 = pos1 103 | self.pos2 = pos2 104 | self.dico1 = dico1 105 | self.dico2 = dico2 106 | self.lengths1 = (self.pos1[:, 1] - self.pos1[:, 0]) 107 | self.lengths2 = (self.pos2[:, 1] - self.pos2[:, 0]) 108 | self.is_parallel = True 109 | self.total = 0 110 | 111 | # check number of sentences 112 | assert len(self.pos1) == (self.sent1 == -1).sum() 113 | assert len(self.pos2) == (self.sent2 == -1).sum() 114 | 115 | self.remove_empty_sentences() 116 | 117 | def __len__(self): 118 | """ 119 | Number of sentences in the dataset. 120 | """ 121 | return len(self.pos1) 122 | 123 | def remove_empty_sentences(self): 124 | """ 125 | Remove empty sentences. 126 | """ 127 | init_size = len(self.pos1) 128 | indices = np.arange(len(self.pos1)) 129 | indices = indices[self.lengths1[indices] > 0] 130 | indices = indices[self.lengths2[indices] > 0] 131 | self.pos1 = self.pos1[indices] 132 | self.pos2 = self.pos2[indices] 133 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 134 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 135 | logger.info("Removed %i empty sentences." % (init_size - len(indices))) 136 | 137 | def remove_long_sentences(self, max_len): 138 | """ 139 | Remove sentences exceeding a certain length. 140 | """ 141 | assert max_len > 0 142 | init_size = len(self.pos1) 143 | indices = np.arange(len(self.pos1)) 144 | indices = indices[self.lengths1[indices] <= max_len] # indices[True,False.....] = [0,1,2....] 145 | indices = indices[self.lengths2[indices] <= max_len] 146 | self.pos1 = self.pos1[indices] 147 | self.pos2 = self.pos2[indices] 148 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 149 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 150 | logger.info("Removed %i too long sentences." % (init_size - len(indices))) 151 | 152 | def select_data(self, a, b): 153 | """ 154 | Only retain a subset of the dataset. 155 | """ 156 | assert 0 <= a <= b <= len(self.pos1) 157 | if a < b: 158 | self.pos1 = self.pos1[a:b] 159 | self.pos2 = self.pos2[a:b] 160 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 161 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 162 | else: 163 | self.pos1 = torch.LongTensor() 164 | self.pos2 = torch.LongTensor() 165 | self.lengths1 = torch.LongTensor() 166 | self.lengths2 = torch.LongTensor() 167 | 168 | 169 | def get_batches_iterator(self, batches): 170 | """ 171 | Return a sentences iterator, given the associated sentence batches. 172 | """ 173 | def iterator(): 174 | for sentence_ids in batches: 175 | pos1 = self.pos1[sentence_ids] 176 | pos2 = self.pos2[sentence_ids] 177 | sent1 = [self.sent1[a:b] for a, b in pos1] 178 | sent2 = [self.sent2[a:b] for a, b in pos2] 179 | yield self.batch_sentences(sent1), self.batch_sentences(sent2) 180 | return iterator 181 | 182 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, partition=None): 183 | """ 184 | Return a sentences iterator. 185 | """ 186 | np.random.seed(self.seed) 187 | self.seed += 1 188 | # 可能会影响数据的随机性 189 | # 多gpu可能需要这个 不然每个进程里的数据不一样 190 | 191 | n_sentences = len(self.pos1) if n_sentences == -1 else n_sentences 192 | assert 0 < n_sentences <= len(self.pos1) 193 | assert type(shuffle) is bool and type(group_by_size) is bool 194 | 195 | # select sentences to iterate over 196 | if shuffle: 197 | indices = np.random.permutation(len(self.pos1))[:n_sentences] 198 | else: 199 | indices = np.arange(n_sentences) 200 | 201 | # group sentences by lengths 202 | if group_by_size: 203 | indices = indices[np.argsort(200-self.lengths2[indices], kind='mergesort')] 204 | indices = indices[np.argsort(200-self.lengths1[indices], kind='mergesort')] 205 | # 这里得到了所有句子的id,按照句长从小到大排列的句子id 206 | # just same with ordered indices in fairseq's language_pair_dataset.py 207 | # change to 200-length for reverse because of padding in lstm 208 | 209 | # create batches / optionally shuffle them 210 | if self.batch_size_tokens == -1: 211 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) 212 | else: 213 | # need to split sentences to batch depending on max_batch_size_tokens 214 | max_tokens = self.batch_size_tokens 215 | batches = [] 216 | batch = [] 217 | 218 | def is_batch_full(num_tokens): 219 | if len(batch) == 0: 220 | return False 221 | if num_tokens > max_tokens: 222 | return True 223 | return False 224 | 225 | sample_len = 0 226 | sample_lens = [] 227 | id = 0 228 | while id < len(indices): 229 | idx = indices[id] 230 | sample_lens.append(max(self.lengths1[idx], self.lengths2[idx])) 231 | history = sample_len 232 | sample_len = max(sample_len, sample_lens[-1]) 233 | num_tokens = len(batch) * sample_len 234 | if is_batch_full(num_tokens): 235 | # prevent a sudden increase of num_tokens (Ex. 30*50=1500 -> 31*100=3100) 236 | batch.pop() 237 | id -= 1 238 | batches.append(np.array(batch)) 239 | batch = [] 240 | sample_lens = [] 241 | sample_len = 0 242 | batch.append(idx) 243 | id += 1 244 | 245 | 246 | if len(batch) > 0: 247 | batches.append(np.array(batch)) 248 | batches = np.array(batches) 249 | 250 | if shuffle: 251 | np.random.shuffle(batches) 252 | 253 | self.total = len(batches) 254 | # partition 255 | if partition is not None: 256 | part_len = int((1.0/self.gpu_num)*self.total) 257 | batches = batches[part_len*partition:(partition+1)*part_len] 258 | 259 | self.total = len(batches) 260 | # return the iterator 261 | return self.get_batches_iterator(batches) 262 | -------------------------------------------------------------------------------- /scripts/learn_bpe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author: Rico Sennrich 4 | 5 | """Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text. 6 | Unlike the original BPE, it does not compress the plain text, but can be used to reduce the vocabulary 7 | of a text to a configurable number of symbols, with only a small increase in the number of tokens. 8 | 9 | Reference: 10 | Rico Sennrich, Barry Haddow and Alexandra Birch (2016). Neural Machine Translation of Rare Words with Subword Units. 11 | Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany. 12 | """ 13 | 14 | from __future__ import unicode_literals 15 | 16 | import os 17 | import sys 18 | import inspect 19 | import codecs 20 | import re 21 | import copy 22 | import argparse 23 | import warnings 24 | from collections import defaultdict, Counter 25 | 26 | # hack for python2/3 compatibility 27 | from io import open 28 | argparse.open = open 29 | 30 | def create_parser(subparsers=None): 31 | 32 | if subparsers: 33 | parser = subparsers.add_parser('learn-bpe', 34 | formatter_class=argparse.RawDescriptionHelpFormatter, 35 | description="learn BPE-based word segmentation") 36 | else: 37 | parser = argparse.ArgumentParser( 38 | formatter_class=argparse.RawDescriptionHelpFormatter, 39 | description="learn BPE-based word segmentation") 40 | 41 | parser.add_argument( 42 | '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, 43 | metavar='PATH', 44 | help="Input text (default: standard input).") 45 | 46 | parser.add_argument( 47 | '--output', '-o', type=argparse.FileType('w'), default=sys.stdout, 48 | metavar='PATH', 49 | help="Output file for BPE codes (default: standard output)") 50 | parser.add_argument( 51 | '--symbols', '-s', type=int, default=10000, 52 | help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))") 53 | parser.add_argument( 54 | '--min-frequency', type=int, default=2, metavar='FREQ', 55 | help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))') 56 | parser.add_argument('--dict-input', action="store_true", 57 | help="If set, input file is interpreted as a dictionary where each line contains a word-count pair") 58 | parser.add_argument( 59 | '--total-symbols', '-t', action="store_true", 60 | help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).") 61 | parser.add_argument( 62 | '--verbose', '-v', action="store_true", 63 | help="verbose mode.") 64 | 65 | return parser 66 | 67 | def get_vocabulary(fobj, is_dict=False): 68 | """Read text and return dictionary that encodes vocabulary 69 | """ 70 | vocab = Counter() 71 | for i, line in enumerate(fobj): 72 | if is_dict: 73 | try: 74 | word, count = line.strip('\r\n ').split(' ') 75 | except: 76 | print('Failed reading vocabulary file at line {0}: {1}'.format(i, line)) 77 | sys.exit(1) 78 | vocab[word] += int(count) 79 | else: 80 | for word in line.strip('\r\n ').split(' '): 81 | if word: 82 | vocab[word] += 1 83 | return vocab 84 | 85 | def update_pair_statistics(pair, changed, stats, indices): 86 | """Minimally update the indices and frequency of symbol pairs 87 | 88 | if we merge a pair of symbols, only pairs that overlap with occurrences 89 | of this pair are affected, and need to be updated. 90 | """ 91 | stats[pair] = 0 92 | indices[pair] = defaultdict(int) 93 | first, second = pair 94 | new_pair = first+second 95 | for j, word, old_word, freq in changed: 96 | 97 | # find all instances of pair, and update frequency/indices around it 98 | i = 0 99 | while True: 100 | # find first symbol 101 | try: 102 | i = old_word.index(first, i) 103 | except ValueError: 104 | break 105 | # if first symbol is followed by second symbol, we've found an occurrence of pair (old_word[i:i+2]) 106 | if i < len(old_word)-1 and old_word[i+1] == second: 107 | # assuming a symbol sequence "A B C", if "B C" is merged, reduce the frequency of "A B" 108 | if i: 109 | prev = old_word[i-1:i+1] 110 | stats[prev] -= freq 111 | indices[prev][j] -= 1 112 | if i < len(old_word)-2: 113 | # assuming a symbol sequence "A B C B", if "B C" is merged, reduce the frequency of "C B". 114 | # however, skip this if the sequence is A B C B C, because the frequency of "C B" will be reduced by the previous code block 115 | if old_word[i+2] != first or i >= len(old_word)-3 or old_word[i+3] != second: 116 | nex = old_word[i+1:i+3] 117 | stats[nex] -= freq 118 | indices[nex][j] -= 1 119 | i += 2 120 | else: 121 | i += 1 122 | 123 | i = 0 124 | while True: 125 | try: 126 | # find new pair 127 | i = word.index(new_pair, i) 128 | except ValueError: 129 | break 130 | # assuming a symbol sequence "A BC D", if "B C" is merged, increase the frequency of "A BC" 131 | if i: 132 | prev = word[i-1:i+1] 133 | stats[prev] += freq 134 | indices[prev][j] += 1 135 | # assuming a symbol sequence "A BC B", if "B C" is merged, increase the frequency of "BC B" 136 | # however, if the sequence is A BC BC, skip this step because the count of "BC BC" will be incremented by the previous code block 137 | if i < len(word)-1 and word[i+1] != new_pair: 138 | nex = word[i:i+2] 139 | stats[nex] += freq 140 | indices[nex][j] += 1 141 | i += 1 142 | 143 | 144 | def get_pair_statistics(vocab): 145 | """Count frequency of all symbol pairs, and create index""" 146 | 147 | # data structure of pair frequencies 148 | stats = defaultdict(int) 149 | 150 | #index from pairs to words 151 | indices = defaultdict(lambda: defaultdict(int)) 152 | 153 | for i, (word, freq) in enumerate(vocab): 154 | prev_char = word[0] 155 | for char in word[1:]: 156 | stats[prev_char, char] += freq 157 | indices[prev_char, char][i] += 1 158 | prev_char = char 159 | 160 | return stats, indices 161 | 162 | 163 | def replace_pair(pair, vocab, indices): 164 | """Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'""" 165 | first, second = pair 166 | pair_str = ''.join(pair) 167 | pair_str = pair_str.replace('\\','\\\\') 168 | changes = [] 169 | pattern = re.compile(r'(?'); 208 | # version numbering allows bckward compatibility 209 | outfile.write('#version: 0.2\n') 210 | 211 | vocab = get_vocabulary(infile, is_dict) 212 | vocab = dict([(tuple(x[:-1])+(x[-1]+'',) ,y) for (x,y) in vocab.items()]) 213 | sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True) 214 | 215 | stats, indices = get_pair_statistics(sorted_vocab) 216 | big_stats = copy.deepcopy(stats) 217 | 218 | if total_symbols: 219 | uniq_char_internal = set() 220 | uniq_char_final = set() 221 | for word in vocab: 222 | for char in word[:-1]: 223 | uniq_char_internal.add(char) 224 | uniq_char_final.add(word[-1]) 225 | sys.stderr.write('Number of word-internal characters: {0}\n'.format(len(uniq_char_internal))) 226 | sys.stderr.write('Number of word-final characters: {0}\n'.format(len(uniq_char_final))) 227 | sys.stderr.write('Reducing number of merge operations by {0}\n'.format(len(uniq_char_internal) + len(uniq_char_final))) 228 | num_symbols -= len(uniq_char_internal) + len(uniq_char_final) 229 | 230 | # threshold is inspired by Zipfian assumption, but should only affect speed 231 | threshold = max(stats.values()) / 10 232 | for i in range(num_symbols): 233 | if stats: 234 | most_frequent = max(stats, key=lambda x: (stats[x], x)) 235 | 236 | # we probably missed the best pair because of pruning; go back to full statistics 237 | if not stats or (i and stats[most_frequent] < threshold): 238 | prune_stats(stats, big_stats, threshold) 239 | stats = copy.deepcopy(big_stats) 240 | most_frequent = max(stats, key=lambda x: (stats[x], x)) 241 | # threshold is inspired by Zipfian assumption, but should only affect speed 242 | threshold = stats[most_frequent] * i/(i+10000.0) 243 | prune_stats(stats, big_stats, threshold) 244 | 245 | if stats[most_frequent] < min_frequency: 246 | sys.stderr.write('no pair has frequency >= {0}. Stopping\n'.format(min_frequency)) 247 | break 248 | 249 | if verbose: 250 | sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent])) 251 | outfile.write('{0} {1}\n'.format(*most_frequent)) 252 | changes = replace_pair(most_frequent, sorted_vocab, indices) 253 | update_pair_statistics(most_frequent, changes, stats, indices) 254 | stats[most_frequent] = 0 255 | if not i % 100: 256 | prune_stats(stats, big_stats, threshold) 257 | 258 | 259 | if __name__ == '__main__': 260 | 261 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 262 | newdir = os.path.join(currentdir, 'subword_nmt') 263 | if os.path.isdir(newdir): 264 | warnings.simplefilter('default') 265 | warnings.warn( 266 | "this script's location has moved to {0}. This symbolic link will be removed in a future version. Please point to the new location, or install the package and use the command 'subword-nmt'".format(newdir), 267 | DeprecationWarning 268 | ) 269 | 270 | # python 2/3 compatibility 271 | if sys.version_info < (3, 0): 272 | sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) 273 | sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) 274 | sys.stdin = codecs.getreader('UTF-8')(sys.stdin) 275 | else: 276 | sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer) 277 | sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer) 278 | sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer) 279 | 280 | parser = create_parser() 281 | args = parser.parse_args() 282 | 283 | # read/write files as UTF-8 284 | if args.input.name != '': 285 | args.input = codecs.open(args.input.name, encoding='utf-8') 286 | if args.output.name != '': 287 | args.output = codecs.open(args.output.name, 'w', encoding='utf-8') 288 | 289 | learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols) 290 | --------------------------------------------------------------------------------