├── .gitmodules ├── scripts ├── preprocess-da.sh ├── preprocess.sh ├── extract-lex-giza.sh ├── train-embed.sh ├── wfw_backtranslation.sh ├── translate.sh ├── train.sh ├── train-da-opt.sh ├── train-muse.sh └── train-da.sh ├── .gitignore ├── wfw_backtranslation.py ├── extract_lexicon.py ├── conda-dali-env.txt ├── extract_lexicon_giza.py └── README.md /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "MUSE"] 2 | path = MUSE 3 | url = https://github.com/facebookresearch/MUSE.git 4 | branch = master 5 | [submodule "fastText"] 6 | path = fastText 7 | url = https://github.com/facebookresearch/fastText.git 8 | branch = master 9 | [submodule "fairseq"] 10 | path = fairseq 11 | url = https://github.com/pytorch/fairseq.git 12 | branch = master 13 | [submodule "mosesdecoder"] 14 | path = mosesdecoder 15 | url = git://github.com/moses-smt/mosesdecoder.git 16 | -------------------------------------------------------------------------------- /scripts/preprocess-da.sh: -------------------------------------------------------------------------------- 1 | sl=de 2 | tl=en 3 | 4 | sd='it' 5 | td='emea' 6 | data_dir=${PWD}/dataset/${td}-lex-w2w+${sd} 7 | out_dir=${PWD}/outputs/ 8 | dest_dir=$out_dir/data-bin-join/${sd}2${td}/ 9 | mkdir -p $dest_dir 10 | 11 | fairseq-preprocess --source-lang ${sl} --target-lang $tl \ 12 | --trainpref $data_dir/${td}-w2w-unsup+${sd}-para.train.bpe.clean \ 13 | --validpref $data_dir/${td}-w2w-unsup+${sd}-para.dev.bpe \ 14 | --testpref $data_dir/${td}-test.bpe \ 15 | --destdir $dest_dir \ 16 | --srcdict $out_dir/data-bin-join/${sd}/dict.${sl}.txt \ 17 | --tgtdict $out_dir/data-bin-join/${td}/dict.${tl}.txt 18 | 19 | -------------------------------------------------------------------------------- /scripts/preprocess.sh: -------------------------------------------------------------------------------- 1 | sl=de 2 | tl=en 3 | 4 | d='it' 5 | repo=$PWD 6 | data_dir=$repo/dataset/ 7 | out_dir=$repo/outputs/ 8 | mkdir -p $out_dir/data-bin-join/${d} 9 | 10 | fairseq-preprocess --source-lang ${sl} --target-lang $tl \ 11 | --trainpref $data_dir/${d}-train.bpe.clean \ 12 | --validpref $data_dir/${d}-dev.bpe \ 13 | --testpref $data_dir/${d}-test.bpe,$data_dir/emea-test.bpe,$data_dir/koran-test.bpe,$data_dir/subtitles-test.bpe,$data_dir/acquis-test.bpe \ 14 | --destdir $out_dir/data-bin-join/${d}/ \ 15 | --srcdict $out_dir/data-bin-join/${d}/dict.${sl}.txt \ 16 | --tgtdict $out_dir/data-bin-join/${d}/dict.${tl}.txt 17 | 18 | -------------------------------------------------------------------------------- /scripts/extract-lex-giza.sh: -------------------------------------------------------------------------------- 1 | 2 | REPO=$PWD 3 | mosesdecoder="$REPO/mosesdecoder" 4 | dir="$REPO/data" 5 | 6 | domain='acquis' 7 | out_dir="$REPO/outputs/moses-${domain}" 8 | 9 | mkdir -p "${out_dir}" 10 | cd "${out_dir}" 11 | nohup nice ${mosesdecoder}/scripts/training/train-model.perl -root-dir train \ 12 | -corpus ${dir}/${src}-train.tc.clean \ 13 | -f de -e en -alignment grow-diag-final-and -reordering msd-bidirectional-fe \ 14 | -last-step 6 \ 15 | -external-bin-dir ${mosesdecoder}/tools >& training.out & 16 | echo "Training on ${domain}" 17 | 18 | cd $REPO 19 | 20 | python extract_lexicon_giza.py \ 21 | --s2t_lex_infile $out_dir/lex.f2e \ 22 | --t2s_lex_infile $out_dir/lex.e2f \ 23 | --src_infile $dir/${domain}-train.tc.de.clean \ 24 | --tgt_infile $dir/${domain}-train.tc.en.clean \ 25 | --lex_outfile $out_dir/lex 26 | -------------------------------------------------------------------------------- /scripts/train-embed.sh: -------------------------------------------------------------------------------- 1 | repo=$PWD 2 | data_dir=$repo/dataset/ 3 | 4 | # combine unaligned monolingual data in both languages 5 | echo "combine monolingual data" 6 | cat ${data_dir}/*-train.tc.clean.en.mono > ${data_dir}/all-train.tc.clean.mono.en 7 | cat ${data_dir}/*-train.tc.clean.de.mono > ${data_dir}/all-train.tc.clean.mono.de 8 | prefix=all-train.tc.clean.mono 9 | 10 | # train embeddings on all the unaligned monolingual data in two languages 11 | fasttext="$repo/fastText/build/fasttext" 12 | out_dir="$repo/embed/" 13 | mkdir -p $out_dir 14 | for lang in de en; do 15 | echo "$fasttext skipgram -input ${data_dir}/${prefix}.${lang} -output ${out_dir}/${prefix}.${lang} -ws 10 -dim 512 -neg 10 -t 0.00001 -epoch 10" 16 | $fasttext skipgram -input ${data_dir}/${prefix}.${lang} -output ${out_dir}/${prefix}.${lang} -ws 10 -dim 512 -neg 10 -t 0.00001 -epoch 10 17 | done -------------------------------------------------------------------------------- /scripts/wfw_backtranslation.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=$1 2 | sl=de 3 | tl=en 4 | 5 | sd='it' 6 | td='emea' 7 | 8 | data_dir=${PWD}/dataset/ 9 | output_dir=${PWD}/dataset/${td}-lex-w2w+${sd} 10 | mkdir -p $output_dir 11 | 12 | for sp in "${td}-train.bpe.clean" "${td}-dev.bpe"; do 13 | python wfw_backtranslation.py \ 14 | --lexicon_infile ${PWD}/outputs/unsupervised-muse/debug/v1/S2T+T2S-de-en.lex \ 15 | --tgt_infile ${data_dir}/${sp}.en.mono \ 16 | --src_outfile ${output_dir}/${sp}.de.mono 17 | ln -s ${data_dir}/${sp}.en.mono ${output_dir}/${sp}.en.mono 18 | 19 | num_td=$(< $data_dir/$sp.en.mono wc -l) 20 | src_sp=${sp/$td/$sd} 21 | head -n $num_td ${data_dir}/${src_sp}.en > ${output_dir}/${src_sp}.en 22 | head -n $num_td ${data_dir}/${src_sp}.de > ${output_dir}/${src_sp}.de 23 | 24 | cat $output_dir/$sp.en.mono $output_dir/$src_sp.en > $output_dir/${sp/$td/${td}-w2w-unsup+${sd}-para}.en 25 | cat $output_dir/$sp.de.mono $output_dir/$src_sp.de > $output_dir/${sp/$td/${td}-w2w-unsup+${sd}-para}.de 26 | done -------------------------------------------------------------------------------- /scripts/translate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sl=de 3 | tl=en 4 | 5 | sd="it" 6 | td="emea" 7 | 8 | gpu=yes 9 | path=$1 # checkpoint path, e.g., $PWD/outputs/${sd}-${sl}-${tl}-epoch40/checkpoint_best.pt 10 | out_file=$2 # decode file, e.g., $PWD/outputs/${sd}-${sl}-${tl}-epoch40/decode-best-beam5.txt 11 | data_dir=$3 # binarized data folder, e.g., $PWD/datasets/${sd}2${td}/ 12 | split=$4 # prefix of test set, e.g., test, test1, test2 13 | 14 | if [[ $gpu == yes ]]; then 15 | fairseq-generate \ 16 | $data_dir \ 17 | --source-lang ${sl} --target-lang ${tl} \ 18 | --path $path \ 19 | --beam 5 --lenpen 1.2 \ 20 | --gen-subset $split \ 21 | --batch-size 1 \ 22 | --remove-bpe="@@ " > $out_file 23 | else 24 | fairseq-generate \ 25 | $data_dir \ 26 | --source-lang ${sl} --target-lang ${tl} \ 27 | --path $path \ 28 | --beam 5 --lenpen 1.2 \ 29 | --gen-subset ${split} \ 30 | --remove-bpe="@@ " \ 31 | --batch-size 1 \ 32 | --cpu > $out_file 33 | fi 34 | 35 | grep ^H- $out_file | cut -d' ' -f2- > ${out_file}.out 36 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=$1 2 | sl=de 3 | tl=en 4 | dataset='it' 5 | data_dir=${PWD}/outputs/data-bin-join/${dataset}/ 6 | 7 | epoch=40 8 | save_dir=${PWD}/outputs/${dataset}-${sl}-${tl}-epoch${epoch}/ 9 | mkdir -p $save_dir 10 | fairseq-train $data_dir \ 11 | --save-dir $save_dir \ 12 | --arch transformer \ 13 | --source-lang ${sl} --target-lang ${tl} \ 14 | --encoder-layers 6 --decoder-layers 6 \ 15 | --encoder-embed-dim 512 --decoder-embed-dim 512 \ 16 | --encoder-ffn-embed-dim 2048 --decoder-ffn-embed-dim 2048 \ 17 | --encoder-attention-heads 8 --decoder-attention-heads 8 \ 18 | --encoder-normalize-before --decoder-normalize-before \ 19 | --dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \ 20 | --weight-decay 0.0001 \ 21 | --label-smoothing 0.2 --criterion label_smoothed_cross_entropy \ 22 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0 \ 23 | --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-7 \ 24 | --lr 1e-3 --min-lr 1e-9 \ 25 | --max-tokens 2000 \ 26 | --update-freq 8 \ 27 | --max-epoch ${epoch} --save-interval 1 \ 28 | --fp16 \ 29 | --save-interval-updates 5000 1> $save_dir/log 2> $save_dir/err 30 | 31 | -------------------------------------------------------------------------------- /scripts/train-da-opt.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=$1 2 | sl=de 3 | tl=en 4 | sd='it' 5 | td="emea" 6 | data_dir=${PWD}/outputs/data-bin-join/${sd}2${td}/ 7 | epoch=40 8 | out_dir=${PWD}/outputs/${sd}-${sl}-${tl}-epoch${epoch}/ 9 | save_dir=${PWD}/outputs/${sd}2${td}-${sl}-${tl}-epoch${epoch}-opt/ 10 | mkdir -p $save_dir 11 | fairseq-train $data_dir \ 12 | --save-dir $save_dir \ 13 | --restore-file $out_dir/checkpoint_best.pt \ 14 | --arch transformer \ 15 | --source-lang ${sl} --target-lang ${tl} \ 16 | --encoder-layers 6 --decoder-layers 6 \ 17 | --encoder-embed-dim 512 --decoder-embed-dim 512 \ 18 | --encoder-ffn-embed-dim 2048 --decoder-ffn-embed-dim 2048 \ 19 | --encoder-attention-heads 8 --decoder-attention-heads 8 \ 20 | --encoder-normalize-before --decoder-normalize-before \ 21 | --dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \ 22 | --weight-decay 0.0001 \ 23 | --label-smoothing 0.2 --criterion label_smoothed_cross_entropy \ 24 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0 \ 25 | --lr 5e-4 --min-lr 1e-9 \ 26 | --max-tokens 2000 \ 27 | --update-freq 8 \ 28 | --max-epoch ${epoch} --save-interval 1 \ 29 | --fp16 \ 30 | --reset-dataloader \ 31 | --save-interval-updates 5000 1> $save_dir/log 2> $save_dir/err 32 | 33 | -------------------------------------------------------------------------------- /scripts/train-muse.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=$1 # [GPU id] 2 | sup_dict_path=$2 # [path to supervised seed lexicon], e.g., "$PWD/word-align/mono.fbw.lex" 3 | repo=$PWD 4 | 5 | # Install MUSE 6 | muse=$repo/MUSE/ 7 | muse_dict=$muse/data/crosslingual/dictionaries/de-en.0-5000.txt 8 | if [ ! -f $muse_dict ]; then 9 | cd $muse/data/ 10 | bash get_evaluation.sh 11 | cd $repo 12 | fi 13 | 14 | dir="$repo/data" 15 | src_emb="$repo/embed/all-train.tc.clean.mono.de.vec" 16 | tgt_emb="$repo/embed/all-train.tc.clean.mono.en.vec" 17 | 18 | 19 | db="S2T|T2S" 20 | if [ -f $sup_dict_path ]; then 21 | out_dir="$repo/outputs/supervised-muse/" 22 | mkdir -p $out_dir 23 | echo "supervised lexicon induction" 24 | python $muse/supervised.py \ 25 | --src_emb $src_emb \ 26 | --tgt_emb $tgt_emb \ 27 | --emb_dim 512 \ 28 | --dico_train $sup_dict_path \ 29 | --save_dict_path $save_dic \ 30 | --normalize_embeddings center \ 31 | --dico_build $db \ 32 | --exp_path $out_dir \ 33 | --eval_file $muse_dict 34 | else 35 | out_dir="$repo/outputs/unsupervised-muse/" 36 | mkdir -p $out_dir 37 | echo "unsupervised lexicon induction" 38 | python $muse/unsupervised.py \ 39 | --src_lang de \ 40 | --tgt_lang en \ 41 | --emb_dim 512 \ 42 | --src_emb $src_emb\ 43 | --tgt_emb $tgt_emb \ 44 | --normalize_embeddings center \ 45 | --exp_path $out_dir \ 46 | --n_refinement 5 \ 47 | --dis_most_frequent 0\ 48 | --exp_id v1 \ 49 | --dico_eval $muse_dict 50 | fi 51 | 52 | -------------------------------------------------------------------------------- /scripts/train-da.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=$1 2 | sl=de 3 | tl=en 4 | sd='it' 5 | td="emea" 6 | data_dir=${PWD}/outputs/data-bin-join/${sd}2${td}/ 7 | epoch=40 8 | out_dir=${PWD}/outputs/${sd}-${sl}-${tl}-epoch${epoch}/ 9 | save_dir=${PWD}/outputs/${sd}2${td}-${sl}-${tl}-epoch${epoch}/ 10 | mkdir -p $save_dir 11 | fairseq-train $data_dir \ 12 | --save-dir $save_dir \ 13 | --restore-file $out_dir/checkpoint_best.pt \ 14 | --arch transformer \ 15 | --source-lang ${sl} --target-lang ${tl} \ 16 | --encoder-layers 6 --decoder-layers 6 \ 17 | --encoder-embed-dim 512 --decoder-embed-dim 512 \ 18 | --encoder-ffn-embed-dim 2048 --decoder-ffn-embed-dim 2048 \ 19 | --encoder-attention-heads 8 --decoder-attention-heads 8 \ 20 | --encoder-normalize-before --decoder-normalize-before \ 21 | --dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \ 22 | --weight-decay 0.0001 \ 23 | --label-smoothing 0.2 --criterion label_smoothed_cross_entropy \ 24 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0 \ 25 | --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-7 \ 26 | --lr 5e-4 --min-lr 1e-9 \ 27 | --max-tokens 2000 \ 28 | --update-freq 8 \ 29 | --max-epoch ${epoch} --save-interval 1 \ 30 | --fp16 \ 31 | --reset-dataloader \ 32 | --reset-optimizer \ 33 | --save-interval-updates 5000 1> $save_dir/log 2> $save_dir/err 34 | 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /wfw_backtranslation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import random 4 | 5 | def read_lexicons(lexicon_infile, delimiter=' ', src2tgt=True): 6 | lexicons = set() 7 | for line in open(lexicon_infile, 'r'): 8 | items = tuple(line.strip().split(delimiter)) 9 | if len(items) == 2: 10 | lex = items if src2tgt else (items[1], items[0]) 11 | lexicons.add(lex) 12 | print(f'Finish reading {len(lex)} lexicon entries from {lexicon_infile}') 13 | return lexicons 14 | 15 | def read_text(file, delimiter=' '): 16 | return [l.strip().split(delimiter) for l in open(file)] 17 | 18 | def write_text(sents, file): 19 | with open(file, 'w') as f: 20 | for sent in sents: 21 | f.write(sent + '\n') 22 | 23 | def word2word_backtranslate(lexicon_infile, tgt_infile, src_outfile): 24 | """ 25 | Args: 26 | lexicon_infile: path to the input lexicons 27 | tgt_infile: path to the input target sentences 28 | src_outfile: path to save the w2w-backtranslated source sentences 29 | """ 30 | lexicons = read_lexicons(lexicon_infile) 31 | tgts = read_text(tgt_infile) 32 | tgt_dict = defaultdict(list) 33 | for (s, t) in lexicons: 34 | tgt_dict[t].append(s) 35 | 36 | srcs = [] 37 | cnt = total = 0.0 38 | for tgt in tgts: 39 | src = [] 40 | for t in tgt: 41 | if t in tgt_dict: 42 | ridx = random.randint(0, len(tgt_dict[t]) - 1) 43 | src.append(tgt_dict[t][ridx]) 44 | cnt += 1 45 | else: 46 | src.append(t) 47 | total += 1 48 | srcs.append(" ".join(src)) 49 | print(f'Transalte {cnt} out of {total} words ({cnt / total * 100}\%)') 50 | 51 | write_text(srcs, src_outfile) 52 | return srcs 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='Word-for-word backtranslation') 57 | parser.add_argument("--lexicon_infile", type=str, default="", help="lexicon_input file") 58 | parser.add_argument("--tgt_infile", type=str, default="", help="monolingual target input file") 59 | parser.add_argument("--src_outfile", type=str, default="", help="augmented source output file") 60 | args = parser.parse_args() 61 | 62 | word2word_backtranslate(args.lexicon_infile, args.tgt_infile, args.src_outfile) 63 | 64 | -------------------------------------------------------------------------------- /extract_lexicon.py: -------------------------------------------------------------------------------- 1 | """ 2 | python3 extract_lexicon.py \ 3 | --src_emb [path to src embeddings] \ 4 | --tgt_emb [path to tgt embeddings] \ 5 | --output [output lexicon] \ 6 | --dico_build ["S2T&T2S", "S2T", "T2S", "S2T|T2S"] 7 | """ 8 | import sys 9 | sys.path.append('./MUSE') 10 | from src.dico_builder import build_dictionary 11 | import numpy as np 12 | import io 13 | import argparse 14 | import torch 15 | import pickle 16 | from src.utils import bool_flag, initialize_exp 17 | import codecs 18 | 19 | 20 | def load_vec(emb_path, nmax=500000): 21 | vectors = [] 22 | word2id = {} 23 | with open(emb_path, 'r', encoding="utf-8") as f: 24 | next(f) 25 | for i, line in enumerate(f): 26 | word, vect = line.rstrip().split(' ', 1) 27 | vect = np.fromstring(vect, sep=' ') 28 | assert word not in word2id, 'word found twice' 29 | vectors.append(vect) 30 | word2id[word] = len(word2id) 31 | if len(word2id) == nmax: 32 | break 33 | id2word = {v: k for k, v in word2id.items()} 34 | embeddings = np.vstack(vectors) 35 | return embeddings, id2word, word2id 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser(description='Extract lexicon by nearest neighbor search') 39 | parser.add_argument("--src_emb", type=str, default="", help="Reload source embeddings") 40 | parser.add_argument("--tgt_emb", type=str, default="", help="Reload target embeddings") 41 | parser.add_argument("--dico_method", type=str, default='csls_knn_10', help="Method used for dictionary generation (nn/invsm_beta_30/csls_knn_10)") 42 | parser.add_argument("--dico_build", type=str, default='S2T&T2S', help="S2T,T2S,S2T|T2S,S2T&T2S") 43 | parser.add_argument("--dico_threshold", type=float, default=0, help="Threshold confidence for dictionary generation") 44 | parser.add_argument("--dico_max_rank", type=int, default=0, help="Maximum dictionary words rank (0 to disable)") 45 | parser.add_argument("--dico_min_size", type=int, default=0, help="Minimum generated dictionary size (0 to disable)") 46 | parser.add_argument("--dico_max_size", type=int, default=0, help="Maximum generated dictionary size (0 to disable)") 47 | parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU") 48 | parser.add_argument("--output", type=str, default="", help="output path of the dictionary") 49 | params = parser.parse_args() 50 | 51 | src_word_embs, src_id2word, src_word2id = load_vec(params.src_emb) 52 | tgt_word_embs, tgt_id2word, tgt_word2id = load_vec(params.tgt_emb) 53 | src_word_embs = torch.FloatTensor(src_word_embs).cuda() 54 | tgt_word_embs = torch.FloatTensor(tgt_word_embs).cuda() 55 | dictionary = build_dictionary(src_emb=src_word_embs, tgt_emb=tgt_word_embs, params=params) 56 | dictionary = dictionary.cpu().numpy() 57 | f = codecs.open(params.output, 'w', encoding='utf8') 58 | for k, (i, j) in enumerate(dictionary): 59 | s_word = src_id2word[i] 60 | t_word = tgt_id2word[j] 61 | f.write(s_word + " " + t_word + "\n") 62 | print(k) 63 | print(dictionary.shape) 64 | print(dictionary[0]) 65 | -------------------------------------------------------------------------------- /conda-dali-env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.tar.bz2 6 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.tar.bz2 7 | https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2019.5.15-1.tar.bz2 8 | https://repo.anaconda.com/pkgs/main/linux-64/cudatoolkit-10.0.130-0.tar.bz2 9 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2019.4-243.tar.bz2 10 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-7.3.0-hdf63c60_0.tar.bz2 11 | https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-9.1.0-hdf63c60_0.tar.bz2 12 | https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-9.1.0-hdf63c60_0.tar.bz2 13 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2019.4-243.tar.bz2 14 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.tar.bz2 15 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.2.1-hd88cf55_4.tar.bz2 16 | https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.1-he6710b0_1.tar.bz2 17 | https://repo.anaconda.com/pkgs/main/linux-64/openssl-1.1.1d-h7b6447c_1.tar.bz2 18 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.2.4-h14c3975_4.tar.bz2 19 | https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.11-h7b6447c_3.tar.bz2 20 | https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20181209-hc058e9b_0.tar.bz2 21 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.37-hbc83047_0.tar.bz2 22 | https://repo.anaconda.com/pkgs/main/linux-64/readline-7.0-h7b6447c_5.tar.bz2 23 | https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.8-hbc83047_0.tar.bz2 24 | https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.3.7-h0b5b093_0.tar.bz2 25 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.9.1-h8a8886c_1.tar.bz2 26 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.0.10-h2733197_2.tar.bz2 27 | https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.28.0-h7b6447c_0.tar.bz2 28 | https://repo.anaconda.com/pkgs/main/linux-64/python-3.6.8-h0371630_0.tar.bz2 29 | https://repo.anaconda.com/pkgs/main/linux-64/certifi-2019.6.16-py36_1.tar.bz2 30 | https://repo.anaconda.com/pkgs/main/linux-64/ninja-1.9.0-py36hfd86e86_0.tar.bz2 31 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.16.4-py36hde5b4d6_0.tar.bz2 32 | https://repo.anaconda.com/pkgs/main/linux-64/olefile-0.46-py36_0.tar.bz2 33 | https://repo.anaconda.com/pkgs/main/linux-64/pycparser-2.19-py36_0.tar.bz2 34 | https://repo.anaconda.com/pkgs/main/linux-64/six-1.12.0-py36_0.tar.bz2 35 | https://repo.anaconda.com/pkgs/main/noarch/tqdm-4.32.1-py_0.tar.bz2 36 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.12.3-py36h2e261b9_0.tar.bz2 37 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.3.0-py36he904b0f_0.tar.bz2 38 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.0.2-py36hd81dba3_0.tar.bz2 39 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-6.0.0-py36h34e0f95_0.tar.bz2 40 | https://repo.anaconda.com/pkgs/main/linux-64/setuptools-41.0.1-py36_0.tar.bz2 41 | https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.33.4-py36_0.tar.bz2 42 | https://repo.anaconda.com/pkgs/main/linux-64/pip-19.1.1-py36_0.tar.bz2 43 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.0.12-py36ha843d7b_0.tar.bz2 44 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.16.4-py36h7e9f1db_0.tar.bz2 45 | https://conda.anaconda.org/pytorch/linux-64/pytorch-1.1.0-py3.6_cuda10.0.130_cudnn7.5.1_0.tar.bz2 46 | https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.3.1-py36h7c811a0_0.tar.bz2 47 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.3.0-py36_cu10.0.130_1.tar.bz2 48 | -------------------------------------------------------------------------------- /extract_lexicon_giza.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict, Counter 3 | import numpy as np 4 | import sys 5 | 6 | def read_lex(file, lex, s2t): 7 | punc = set(["?", ",", ".", "!", "$", "%", "^", "&", "*", "@", "~", "`" "-", "+", "_", "=", "{", "}", "[", "]", "<", ">", "/", "'", '"', "(", ")", ":", ";" ]) 8 | for l in open(file): 9 | items = l.strip().split(' ') 10 | if len(items) != 3: 11 | print('Skip line={} with length={}'.format(l.strip(), len(items))) 12 | continue 13 | s = items[0].strip() 14 | t = items[1].strip() 15 | p = 1 if len(items) == 2 else float(items[2]) 16 | if len(s) == 0 or len(t) == 0: 17 | continue 18 | if (s in punc or t in punc) and s != t: 19 | continue 20 | if s2t: 21 | lex[(s,t)].append(p) 22 | else: 23 | lex[(t,s)].append(p) 24 | return lex 25 | 26 | def extract_lexicon(s2t_lex_infile, t2s_lex_infile, src_infile, tgt_infile, lex_outfile, filter_low_prob=False): 27 | # Read giza++ lexicons from both directions 28 | lex = defaultdict(list) 29 | lex = read_lex(s2t_lex_infile, lex, s2t=True) 30 | lex = read_lex(t2s_lex_infile, lex, s2t=False) 31 | comb_lex = {k: float(np.mean(v)) for k,v in lex.items()} 32 | print('No. of lex', len(comb_lex)) 33 | 34 | # filter low-probability lexicion 35 | if filter_low_prob: 36 | mean_prob = float(np.mean(comb_lex.values())) 37 | print('Filter lexicions with prob < {}'.format(mean_prob * 0.05)) 38 | comb_lex = {k:v for k,v in comb_lex.items() if v < mean_prob * 0.05} 39 | print('No. of lex', len(comb_lex)) 40 | 41 | # read parallel text 42 | src_sents = [l.strip().split(' ') for l in open(src_infile)] 43 | tgt_sents = [l.strip().split(' ') for l in open(tgt_infile)] 44 | cnt_lex = Counter() 45 | for src, tgt in zip(src_sents, tgt_sents): 46 | for sw in src: 47 | for tw in tgt: 48 | pair = (sw, tw) 49 | if pair in comb_lex: 50 | cnt_lex[pair] += 1 51 | print('Count {} lexicons in {}/{}'.format(len(cnt_lex.values()), src_infile, tgt_infile)) 52 | 53 | final_lex = {} 54 | src_words, tgt_words = set(), set() 55 | with open(lex_outfile, 'w') as writer: 56 | for (sw, tw), cnt in cnt_lex.most_common(): 57 | if sw not in src_words and tw not in tgt_words: 58 | final_lex[(sw, tw)] = cnt 59 | writer.write('{}\t{}\t{}\n'.format(sw, tw, cnt)) 60 | src_words.add(sw) 61 | tgt_words.add(tw) 62 | print('No. of extracted lex', len(final_lex)) 63 | return final_lex 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser(description='GIZA++ lexicon extraction') 67 | parser.add_argument("--s2t_lex_infile", type=str, default="", help="giza++ source to target lexicon input file") 68 | parser.add_argument("--t2s_lex_infile", type=str, default="", help="giza++ target to source lexicon input file") 69 | parser.add_argument("--src_infile", type=str, default="", help="parallel source input file") 70 | parser.add_argument("--tgt_infile", type=str, default="", help="parallel target input file") 71 | parser.add_argument("--lex_outfile", type=str, default="", help="giza++ lexicon file") 72 | parser.add_argument("--filter_low_prob", action='store_true', help="whether to filter lexicons with low probability") 73 | args = parser.parse_args() 74 | 75 | extract_lexicon(args.s2t_lex_infile, args.t2s_lex_infile, args.src_infile, args.tgt_infile, args.lex_outfile, args.filter_low_prob) 76 | 77 | 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Domain Adaptation of Neural Machine Translation by Lexicon Induction 2 | === 3 | Implemented by [Junjie Hu](http://www.cs.cmu.edu/~junjieh/) 4 | 5 | Contact: junjieh@cs.cmu.edu 6 | 7 | If you use the codes in this repo, please cite our [ACL2019 paper](https://www.aclweb.org/anthology/P19-1286). 8 | 9 | @inproceedings{hu-etal-2019-domain, 10 | title = "Domain Adaptation of Neural Machine Translation by Lexicon Induction", 11 | author = "Hu, Junjie and Xia, Mengzhou and Neubig, Graham and Carbonell, Jaime", 12 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", 13 | month = jul, 14 | year = "2019", 15 | address = "Florence, Italy", 16 | publisher = "Association for Computational Linguistics", 17 | url = "https://www.aclweb.org/anthology/P19-1286", 18 | doi = "10.18653/v1/P19-1286", 19 | pages = "2989--3001", 20 | } 21 | 22 | 23 | Installation 24 | == 25 | - Anaconda environment 26 | ``` 27 | conda env create --file conda-dali-env.txt 28 | ``` 29 | 30 | - Install fairseq 31 | ``` 32 | cd fairseq && pip install --editable . && cd .. 33 | ``` 34 | 35 | - Install fastText 36 | ``` 37 | cd fastText && mkdir build && cd build && cmake .. && make && cd ../.. 38 | ``` 39 | 40 | - Download MUSE's dictionary 41 | ``` 42 | cd MUSE/data/ && bash get_evaluation.sh && cd ../.. 43 | ``` 44 | 45 | Downloads 46 | == 47 | The preprocessed data and pre-trained models can be found [here](https://drive.google.com/drive/folders/1PmlmLg8ZgR4MVLb0svP2j5oE0jQ1G3Z4?usp=sharing). Extract ***dataset.tar.gz*** under the ***dali*** directory. Extract ***{data-bin, it-de-en-epoch40, it2emea-de-en}.tar.gz*** under the ***dali/outputs*** directory. 48 | 49 | - ***dataset.tar.gz***: train/dev/test data in five domains: it, emea, acquis, koran, subtitles. 50 | - ***data-bin.tar.gz***: fairseq's binarized data. 51 | - ***it-de-en-epoch40.tar.gz***: fairseq's transformer model pre-trained on data in the it domain. 52 | - ***it2emea-de-en.tar.gz***: fairseq's transformer model adapted from it domain to emea domain using DALI-U. 53 | - ***S2T+T2S-de-en.lex***: the lexicon induced by DALI-U. 54 | - ***embed-vec.tar.gz***: {it,emea,acquis,koran,subtitles}.{de,en}.vec embeddings trained in five domains respectively, and {de,en}.vec embeddings trained in the combination of five domains. 55 | 56 | The pre-trained model in the it domain can obtain the BLEU scores in the five domains as follows. After adaptation, the BLEU in the emea test set can be raised to *18.25* from *8.23*. The BLEU scores are slightly different from those in the paper since we used different NMT toolkits (fairseq v.s. OpenNMT), but we observed similar improvements as we found in the paper. 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 |
Out-of-domain In-domain
itemeakoransubtitlesacquis
it58.948.232.506.264.34
78 | 79 | Demo 80 | == 81 | - Preprocess the data in the source (it) domain 82 | ``` 83 | bash scripts/preprocess.sh 84 | ``` 85 | 86 | - Train the transformer model in the source (it) domain 87 | ``` 88 | bash scripts/train.sh [GPU id] 89 | ``` 90 | 91 | - Perform DALI's data augmentation 92 | 1.1 (Unupervised Lexicon Induction) Train the word embeddings 93 | ``` 94 | bash scripts/train-embed.sh 95 | ``` 96 | 1.2 (Unupervised Lexicon Induction) Train the crosslingual embeddings by supervised lexicon induction 97 | ``` 98 | bash scripts/train-muse.sh [path to supervised seed lexicon] 99 | ``` 100 | 1.3 (Unupervised Lexicon Induction) Train the crosslingual embeddings by unsupervised lexicon induction 101 | ``` 102 | bash scripts/train-muse.sh 103 | ``` 104 | 1.4 (Unupervised Lexicon Induction) Obtain the word translation by nearest neighbor search 105 | ``` 106 | python3 extract_lexicon.py \ 107 | --src_emb $PWD/outputs/unsupervised-muse/debug/v1/vectors-de.txt \ 108 | --tgt_emb $PWD/outputs/unsupervised-muse/debug/v1/vectors-en.txt \ 109 | --output $PWD/outputs/unsupervised-muse/debug/v1/S2T+T2S-de-en.lex \ 110 | --dico_build "S2T&T2S" 111 | ``` 112 | 2.1 (Supervised Lexicon Induction) Obtain the word translation by GIZA++ 113 | ``` 114 | bash scripts/extract_lex_giza.sh 115 | ``` 116 | 3. Perform word-for-word back-translation 117 | ``` 118 | bash scripts/wfw_backtranslation.sh 119 | ``` 120 | 121 | 122 | - Preprocess the data in the target (emea) domain 123 | ``` 124 | bash scripts/preprocess-da.sh 125 | ``` 126 | 127 | - Adapt the pre-train model to the target (emea) domain 128 | ``` 129 | bash scripts/train-da-opt.sh 130 | ``` 131 | 132 | - Translate the test1 set in the emea domain 133 | ``` 134 | bash scripts/translate.sh \ 135 | outputs/it-de-en-epoch40/checkpoint_best.pt \ 136 | outputs/it-de-en-epoch40/decode-test1-best.txt \ 137 | outputs/data-bin-join/it \ 138 | test1 139 | ``` 140 | --------------------------------------------------------------------------------