├── laser ├── __init__.py ├── laser_task.py ├── laser_dataset.py └── laser_lstm.py ├── LICENSE ├── prepare-europarl.sh ├── README.md ├── embed.py └── bucc.sh /laser/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | # hack to prevent ModuleNotFoundError with multiprocessing 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 4 | from . import laser_lstm 5 | from . import laser_task 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Raymond Hendy Susanto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /prepare-europarl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d mosesdecoder ] ; then 4 | echo 'Cloning Moses github repository (for tokenization scripts)...' 5 | git clone https://github.com/moses-smt/mosesdecoder.git 6 | fi 7 | if [ ! -d fastBPE ] ; then 8 | git clone https://github.com/glample/fastBPE.git 9 | pushd fastBPE 10 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 11 | popd 12 | fi 13 | 14 | SCRIPTS=mosesdecoder/scripts 15 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 16 | LC=$SCRIPTS/tokenizer/lowercase.perl 17 | NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 18 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 19 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 20 | 21 | prep=europarl_en_de_es_fr 22 | tmp=$prep/tmp 23 | orig=$prep/downloaded 24 | codes=40000 25 | bpe=$prep/bpe.40k 26 | 27 | mkdir -p $orig $tmp $bpe 28 | 29 | urlpref=http://opus.nlpl.eu/download.php?f=Europarl/v8/moses 30 | mkdir -p $orig 31 | for f in en-es.txt.zip de-en.txt.zip de-es.txt.zip de-fr.txt.zip en-fr.txt.zip es-fr.txt.zip de-fr.txt.zip; do 32 | if [ ! -f $orig/$f ] ; then 33 | wget $urlpref/$f -O $orig/$f 34 | rm $orig/{README,LICENSE} 35 | unzip $orig/$f -d $orig 36 | fi 37 | done 38 | 39 | echo "pre-processing train data..." 40 | for lang_pair in de-en de-es de-fr en-es en-fr es-fr ; do 41 | src=`echo $lang_pair | cut -d'-' -f1` 42 | tgt=`echo $lang_pair | cut -d'-' -f2` 43 | lang=$src-$tgt 44 | for l in $src $tgt; do 45 | rm -rf $tmp/train.tags.$lang.tok.$l 46 | for f in Europarl ; do 47 | cat $orig/$f.$lang.$l | \ 48 | perl $REM_NON_PRINT_CHAR | \ 49 | perl $NORM_PUNC $l | \ 50 | perl $TOKENIZER -threads 20 -l $l -q -no-escape | \ 51 | perl $LC >> $tmp/train.tags.$lang.tok.$l 52 | done 53 | done 54 | done 55 | 56 | rm -f $tmp/train.all 57 | # apply length filtering before BPE 58 | for lang_pair in de-en de-es de-fr en-es en-fr es-fr ; do 59 | src=`echo $lang_pair | cut -d'-' -f1` 60 | tgt=`echo $lang_pair | cut -d'-' -f2` 61 | lang=$src-$tgt 62 | perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.$lang 1 100 63 | cat $tmp/train.$lang.{$src,$tgt} >> $tmp/train.all 64 | done 65 | 66 | 67 | #BPE 68 | fastBPE/fast learnbpe $codes $tmp/train.all > $bpe/codes 69 | for lang_pair in de-en de-es de-fr en-es en-fr es-fr ; do 70 | src=`echo $lang_pair | cut -d'-' -f1` 71 | tgt=`echo $lang_pair | cut -d'-' -f2` 72 | lang=$src-$tgt 73 | for l in $src $tgt; do 74 | fastBPE/fast applybpe $bpe/train.$lang.$l $tmp/train.$lang.$l $bpe/codes 75 | done 76 | done 77 | 78 | cat $bpe/train.*.* | fastBPE/fast getvocab - > $bpe/vocab 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fairseq-laser 2 | 3 | Implementing LASER architecture using fairseq library based on my understanding of the original papers by Artexte and Schwenk (2018, [1] and [2]). 4 | 5 | ## Requirements 6 | 7 | * Python version >= 3.6 8 | * [PyTorch](https://pytorch.org/) (tested on version 1.3.1) 9 | * [fairseq](https://github.com/pytorch/fairseq) (tested on version 0.9.0) 10 | * [Faiss](https://github.com/facebookresearch/faiss), for bitext mining 11 | 12 | ## Training 13 | 14 | This example shows how to train a LASER model on 4 languages from Europarl v7 (English/French/Spanish/German) with a similar architecture in [1]. 15 | 16 | ```bash 17 | # Download and preprocess the data 18 | bash prepare-europarl.sh 19 | 20 | # Binarize datasets for each language pair 21 | bpe=europarl_en_de_es_fr/bpe.40k 22 | data_bin=data-bin/europarl.de_en_es_fr.bpe40k 23 | for lang_pair in de-en de-es de-fr en-es en-fr es-fr; do 24 | src=`echo $lang_pair | cut -d'-' -f1` 25 | tgt=`echo $lang_pair | cut -d'-' -f2` 26 | rm $data_bin/dict.$src.txt $data_bin/dict.$tgt.txt 27 | fairseq-preprocess --source-lang $src --target-lang $tgt \ 28 | --trainpref $bpe/train.$src-$tgt \ 29 | --joined-dictionary --tgtdict $bpe/vocab \ 30 | --destdir $data_bin \ 31 | --workers 20 32 | done 33 | 34 | # Train a LASER model. To speed up, we only use 2 target languages 35 | # (English and Spanish) and train for 10 epochs. 36 | checkpoint=checkpoints/laser_lstm 37 | mkdir -p $checkpoint 38 | fairseq-train $data_bin \ 39 | --max-epoch 10 \ 40 | --ddp-backend=no_c10d \ 41 | --task translation_laser --arch laser \ 42 | --lang-pairs de-en,de-es,en-es,es-en,fr-en,fr-es \ 43 | --optimizer adam --adam-betas '(0.9, 0.98)' \ 44 | --lr 0.001 --criterion cross_entropy \ 45 | --dropout 0.1 --save-dir $checkpoint \ 46 | --max-tokens 12000 --fp16 \ 47 | --valid-subset train --disable-validation \ 48 | --no-progress-bar --log-interval 1000 \ 49 | --user-dir laser/ 50 | ``` 51 | 52 | ## Bitext mining 53 | 54 | Here are some results on running the above model on [BUCC 2018 shared task data](https://comparable.limsi.fr/bucc2017/cgi-bin/download-data-2018.cgi) (see `bucc.sh`). The scores are on training set since the gold standard for the test set is not released (refer to Table 2 of [1] for comparison to a similar model). 55 | 56 | | Languages | Threshold | Precision | Recall | F1 score | 57 | |-----------|-----------|-----------|--------|---------| 58 | | fr-en | 1.102786 | 91.63 | 91.37 | 91.50 | 59 | | de-en | 1.095823 | 95.12 | 94.57 | 94.84 | 60 | 61 | ## References 62 | 63 | [1] Mikel Artetxe and Holger Schwenk, [*Margin-based Parallel Corpus Mining with Multilingual Sentence Embeddings*](https://arxiv.org/abs/1811.01136) arXiv, Nov 3 2018. 64 | 65 | [2] Mikel Artetxe and Holger Schwenk, [*Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond*](https://arxiv.org/abs/1812.10464) arXiv, 26 Dec 2018. 66 | -------------------------------------------------------------------------------- /laser/laser_task.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import logging 3 | import os 4 | 5 | from fairseq import options, utils 6 | from fairseq.data import ( 7 | Dictionary, 8 | LanguagePairDataset, 9 | ) 10 | from fairseq.tasks import FairseqTask, register_task 11 | from fairseq.tasks.multilingual_translation import MultilingualTranslationTask, load_langpair_dataset 12 | 13 | from .laser_dataset import LaserDataset 14 | 15 | @register_task('translation_laser') 16 | class TranslationLaserTask(FairseqTask): 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | """Add task-specific arguments to the parser.""" 21 | MultilingualTranslationTask.add_args(parser) 22 | 23 | def __init__(self, args, dicts, training): 24 | super().__init__(args) 25 | self.dicts = dicts 26 | self.training = training 27 | if training: 28 | self.lang_pairs = args.lang_pairs 29 | args.source_lang, args.target_lang = args.lang_pairs[0].split('-') 30 | else: 31 | self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)] 32 | self.langs = list(dicts.keys()) 33 | 34 | @classmethod 35 | def setup_task(cls, args, **kwargs): 36 | dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) 37 | return cls(args, dicts, training) 38 | 39 | def load_dataset(self, split, epoch=0, **kwargs): 40 | """Load a dataset split.""" 41 | 42 | paths = self.args.data.split(':') 43 | assert len(paths) > 0 44 | data_path = paths[epoch % len(paths)] 45 | 46 | def language_pair_dataset(lang_pair): 47 | src, tgt = lang_pair.split('-') 48 | langpair_dataset = load_langpair_dataset( 49 | data_path, split, src, self.dicts[src], tgt, self.dicts[tgt], 50 | combine=True, dataset_impl=self.args.dataset_impl, 51 | upsample_primary=self.args.upsample_primary, 52 | left_pad_source=self.args.left_pad_source, 53 | left_pad_target=self.args.left_pad_target, 54 | max_source_positions=self.args.max_source_positions, 55 | max_target_positions=self.args.max_target_positions, 56 | ) 57 | return langpair_dataset 58 | 59 | self.datasets[split] = LaserDataset( 60 | OrderedDict([ 61 | (lang_pair, language_pair_dataset(lang_pair)) 62 | for lang_pair in self.lang_pairs 63 | ]), 64 | eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), 65 | ) 66 | 67 | def build_dataset_for_inference(self, src_tokens, src_lengths): 68 | lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang) 69 | return LaserDataset( 70 | OrderedDict([( 71 | lang_pair, 72 | LanguagePairDataset( 73 | src_tokens, src_lengths, 74 | self.source_dictionary 75 | ), 76 | )]), 77 | eval_key=lang_pair, 78 | ) 79 | 80 | def build_model(self, args): 81 | # Check if task args are consistant with model args 82 | if len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs)) != 0: 83 | raise ValueError('--lang-pairs should include all the language pairs {}.'.format(args.lang_pairs)) 84 | 85 | from fairseq import models 86 | model = models.build_model(args, self) 87 | from .laser_lstm import LaserModel 88 | if not isinstance(model, LaserModel): 89 | raise ValueError('TranslationLaserTask requires a LaserModel architecture') 90 | return model 91 | 92 | @property 93 | def source_dictionary(self): 94 | return self.dicts[self.args.source_lang] 95 | 96 | @property 97 | def target_dictionary(self): 98 | return self.dicts[self.args.target_lang] 99 | -------------------------------------------------------------------------------- /embed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | """ 3 | Embed sentences with a trained model. Batches data on-the-fly. 4 | """ 5 | 6 | from collections import namedtuple 7 | import fileinput 8 | 9 | import torch 10 | import numpy as np 11 | 12 | from fairseq import checkpoint_utils, options, tasks, utils 13 | from fairseq.data import encoders 14 | 15 | 16 | Batch = namedtuple('Batch', 'ids src_tokens src_lengths') 17 | Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') 18 | 19 | 20 | def buffered_read(input, buffer_size): 21 | buffer = [] 22 | with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h: 23 | for src_str in h: 24 | buffer.append(src_str.strip()) 25 | if len(buffer) >= buffer_size: 26 | yield buffer 27 | buffer = [] 28 | 29 | if len(buffer) > 0: 30 | yield buffer 31 | 32 | 33 | def make_batches(lines, args, task, max_positions, encode_fn): 34 | tokens = [ 35 | task.source_dictionary.encode_line( 36 | encode_fn(src_str), add_if_not_exist=False 37 | ).long() 38 | for src_str in lines 39 | ] 40 | lengths = torch.LongTensor([t.numel() for t in tokens]) 41 | itr = task.get_batch_iterator( 42 | dataset=task.build_dataset_for_inference(tokens, lengths), 43 | max_tokens=args.max_tokens, 44 | max_sentences=args.max_sentences, 45 | max_positions=max_positions, 46 | ).next_epoch_itr(shuffle=False) 47 | for batch in itr: 48 | yield Batch( 49 | ids=batch['id'], 50 | src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], 51 | ) 52 | 53 | 54 | def main(args): 55 | utils.import_user_module(args) 56 | 57 | if args.buffer_size < 1: 58 | args.buffer_size = 1 59 | if args.max_tokens is None and args.max_sentences is None: 60 | args.max_sentences = 1 61 | 62 | assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ 63 | '--max-sentences/--batch-size cannot be larger than --buffer-size' 64 | 65 | print(args) 66 | 67 | use_cuda = torch.cuda.is_available() and not args.cpu 68 | 69 | # Setup task, e.g., translation 70 | task = tasks.setup_task(args) 71 | 72 | # Load ensemble 73 | print('| loading model(s) from {}'.format(args.path)) 74 | models, _model_args = checkpoint_utils.load_model_ensemble( 75 | args.path.split(':'), 76 | arg_overrides=eval(args.model_overrides), 77 | task=task, 78 | ) 79 | 80 | # Set dictionaries 81 | src_dict = task.source_dictionary 82 | tgt_dict = task.target_dictionary 83 | 84 | # Optimize ensemble for generation 85 | for model in models: 86 | model.make_generation_fast_( 87 | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, 88 | need_attn=args.print_alignment, 89 | ) 90 | if args.fp16: 91 | model.half() 92 | if use_cuda: 93 | model.cuda() 94 | 95 | model = models[0] # Just a single model 96 | 97 | # Handle tokenization and BPE 98 | tokenizer = encoders.build_tokenizer(args) 99 | bpe = encoders.build_bpe(args) 100 | 101 | def encode_fn(x): 102 | if tokenizer is not None: 103 | x = tokenizer.encode(x) 104 | if bpe is not None: 105 | x = bpe.encode(x) 106 | return x 107 | 108 | def decode_fn(x): 109 | if bpe is not None: 110 | x = bpe.decode(x) 111 | if tokenizer is not None: 112 | x = tokenizer.decode(x) 113 | return x 114 | 115 | max_positions = utils.resolve_max_positions( 116 | task.max_positions(), 117 | *[model.max_positions() for model in models] 118 | ) 119 | 120 | fout = open(args.output_file, mode='wb') 121 | if args.buffer_size > 1: 122 | print('| Sentence buffer size:', args.buffer_size) 123 | start_id = 0 124 | for inputs in buffered_read(args.input, args.buffer_size): 125 | indices = [] 126 | results = [] 127 | for batch in make_batches(inputs, args, task, max_positions, encode_fn): 128 | src_tokens = batch.src_tokens 129 | src_lengths = batch.src_lengths 130 | if use_cuda: 131 | src_tokens = src_tokens.cuda() 132 | src_lengths = src_lengths.cuda() 133 | 134 | model.eval() 135 | embeddings = model.encoder(src_tokens, src_lengths)['sentemb'] 136 | embeddings = embeddings.detach().cpu().numpy() 137 | for i, (id, emb) in enumerate(zip(batch.ids.tolist(), embeddings)): 138 | indices.append(id) 139 | results.append(emb) 140 | np.vstack(results)[np.argsort(indices)].tofile(fout) 141 | 142 | # update running id counter 143 | start_id += len(inputs) 144 | fout.close() 145 | 146 | 147 | def cli_main(): 148 | parser = options.get_generation_parser(interactive=True) 149 | parser.add_argument('--output-file', required=True, 150 | help='Output sentence embeddings') 151 | args = options.parse_args_and_arch(parser) 152 | main(args) 153 | 154 | 155 | if __name__ == '__main__': 156 | cli_main() 157 | -------------------------------------------------------------------------------- /bucc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # bash script from LASER repo to mine for bitexts in the BUCC corpus 5 | # (https://github.com/facebookresearch/LASER/blob/master/tasks/bucc/bucc.sh) 6 | # modified to use fairseq to generate sentence embeddings 7 | 8 | 9 | if [ ! -d mosesdecoder ] ; then 10 | echo 'Cloning Moses github repository (for tokenization scripts)...' 11 | git clone https://github.com/moses-smt/mosesdecoder.git 12 | fi 13 | if [ ! -d fastBPE ] ; then 14 | git clone https://github.com/glample/fastBPE.git 15 | pushd fastBPE 16 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 17 | popd 18 | fi 19 | if [ ! -d LASER ] ; then 20 | echo 'Cloning LASER github repository...' 21 | git clone https://github.com/facebookresearch/LASER.git 22 | fi 23 | export LASER=$PWD/LASER 24 | 25 | SCRIPTS=mosesdecoder/scripts 26 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 27 | LC=$SCRIPTS/tokenizer/lowercase.perl 28 | NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 29 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 30 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 31 | 32 | # general config 33 | bucc="bucc2018" 34 | data="." 35 | xdir=${data}/downloaded # tar files as distrubuted by the BUCC evaluation 36 | ddir=${data}/${bucc} # raw texts of BUCC 37 | edir=${data}/embed # normalized texts and embeddings 38 | langs=("fr" "de") 39 | ltrg="en" # English is always the 2nd language 40 | 41 | # encoder 42 | data_bin=data-bin/europarl.de_en_es_fr.bpe40k/ 43 | bpe=europarl_en_de_es_fr/bpe.40k 44 | checkpoint=checkpoints/laser_lstm/checkpoint_last.pt 45 | 46 | # delete all generated files to re-run 47 | rerun=true 48 | if [ $rerun = true ] ; then 49 | rm -f $edir/*.enc.* $edir/*.candidates.tsv bucc2018.*.train.log 50 | fi 51 | 52 | ################################################################### 53 | # 54 | # Extract files with labels and texts from the BUCC corpus 55 | # 56 | ################################################################### 57 | 58 | GetData () { 59 | fn1=$1; fn2=$2; lang=$3 60 | outf="${edir}/${bucc}.${lang}-${ltrg}.${fn2}" 61 | for ll in ${ltrg} ${lang} ; do 62 | inf="${ddir}/${fn1}.${ll}" 63 | if [ ! -f ${outf}.txt.${ll} ] ; then 64 | echo " - extract files ${outf} in ${ll}" 65 | cat ${inf} | cut -f1 > ${outf}.id.${ll} 66 | cat ${inf} | cut -f2 > ${outf}.txt.${ll} 67 | fi 68 | done 69 | } 70 | 71 | ExtractBUCC () { 72 | slang=$1 73 | tlang=${ltrg} 74 | 75 | pushd ${data} > /dev/null 76 | if [ ! -d ${ddir}/${slang}-${tlang} ] ; then 77 | for tf in ${xdir}/${bucc}-${slang}-${tlang}.*.tar.bz2 ; do 78 | echo " - extract from tar `basename ${tf}`" 79 | tar jxf $tf 80 | done 81 | fi 82 | 83 | GetData "${slang}-${tlang}/${slang}-${tlang}.sample" "dev" ${slang} 84 | GetData "${slang}-${tlang}/${slang}-${tlang}.training" "train" ${slang} 85 | GetData "${slang}-${tlang}/${slang}-${tlang}.test" "test" ${slang} 86 | popd > /dev/null 87 | } 88 | 89 | 90 | ################################################################### 91 | # 92 | # Tokenize and Embed 93 | # 94 | ################################################################### 95 | 96 | Embed () { 97 | ll=$2 98 | txt="$1.txt.${ll}" 99 | enc="$1.enc.${ll}" 100 | tl="en" 101 | if [ $ll = "en" ]; then tl="es" ; fi 102 | if [ ! -s ${enc} ] ; then 103 | cat ${txt} | \ 104 | perl $REM_NON_PRINT_CHAR | \ 105 | perl $NORM_PUNC $l | \ 106 | perl $TOKENIZER -threads 20 -l $ll -q -no-escape | \ 107 | perl $LC | \ 108 | fastBPE/fast applybpe_stream $bpe/codes $bpe/vocab | \ 109 | python3 embed.py $data_bin \ 110 | --task translation_laser \ 111 | --lang-pairs de-en,de-es,en-es,es-en,fr-en,fr-es \ 112 | --source-lang $ll --target-lang $tl \ 113 | --path $checkpoint \ 114 | --buffer-size 2000 --batch-size 128 \ 115 | --output-file ${enc} \ 116 | --user-dir laser/ 117 | fi 118 | } 119 | 120 | 121 | ################################################################### 122 | # 123 | # Mine for bitexts 124 | # 125 | ################################################################### 126 | 127 | Mine () { 128 | bn=$1 129 | l1=$2 130 | l2=$3 131 | cand="${bn}.candidates.tsv" 132 | if [ ! -s ${cand} ] ; then 133 | python3 ${LASER}/source/mine_bitexts.py \ 134 | ${bn}.txt.${l1} ${bn}.txt.${l2} \ 135 | --src-lang ${l1} --trg-lang ${l2} \ 136 | --src-embeddings ${bn}.enc.${l1} --trg-embeddings ${bn}.enc.${l2} \ 137 | --unify --mode mine --retrieval max --margin ratio -k 4 \ 138 | --output ${cand} \ 139 | --verbose --gpu 140 | fi 141 | } 142 | 143 | 144 | ################################################################### 145 | # 146 | # Main loop 147 | # 148 | ################################################################### 149 | 150 | echo -e "\nProcessing BUCC data in ${data}" 151 | 152 | # create output directories 153 | for d in ${ddir} ${edir} ; do 154 | mkdir -p ${d} 155 | done 156 | 157 | for lsrc in ${langs[@]} ; do 158 | ExtractBUCC ${lsrc} 159 | 160 | # Tokenize and embed train 161 | bname="${bucc}.${lsrc}-${ltrg}" 162 | part="${bname}.train" 163 | Embed ${edir}/${part} ${lsrc} ${encoder} ${bpe_codes} 164 | Embed ${edir}/${part} ${ltrg} ${encoder} ${bpe_codes} 165 | 166 | # mine for texts in train 167 | Mine ${edir}/${part} ${lsrc} ${ltrg} 168 | 169 | # optimize threshold on BUCC training data and provided gold alignments 170 | if [ ! -s ${part}.log ] ; then 171 | python3 ${LASER}/tasks/bucc/bucc.py \ 172 | --src-lang ${lsrc} --trg-lang ${ltrg} \ 173 | --bucc-texts ${edir}/${part}.txt \ 174 | --bucc-ids ${edir}/${part}.id \ 175 | --candidates ${edir}/${part}.candidates.tsv \ 176 | --gold ${ddir}/${lsrc}-${ltrg}/${lsrc}-${ltrg}.training.gold \ 177 | --verbose \ 178 | | tee ${part}.log 179 | fi 180 | done 181 | -------------------------------------------------------------------------------- /laser/laser_dataset.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Callable, Dict, List 3 | 4 | import numpy as np 5 | 6 | from fairseq.data import FairseqDataset 7 | 8 | 9 | def uniform_sampler(x): 10 | # Sample from uniform distribution 11 | return np.random.choice(x, 1).item() 12 | 13 | 14 | def add_decoder_language(batch, lang_pair): 15 | tgt = lang_pair.split('-')[1] 16 | batch['net_input']['decoder_lang'] = tgt 17 | return batch 18 | 19 | 20 | class LaserDataset(FairseqDataset): 21 | """ 22 | Stores multiple instances of FairseqDataset together and in every iteration 23 | creates a batch by first sampling a dataset according to a specified 24 | probability distribution and then getting instances from that dataset. 25 | Adapted from: https://github.com/pytorch/fairseq/blob/master/fairseq/data/multi_corpus_sampled_dataset.py. 26 | 27 | Args: 28 | datasets: an OrderedDict of FairseqDataset instances. 29 | sampling_func: A function for sampling over list of dataset keys. 30 | Default strategy is to sample uniformly. 31 | eval_key: a dataset key for evaluation 32 | """ 33 | 34 | def __init__( 35 | self, 36 | datasets: Dict[str, FairseqDataset], 37 | sampling_func: Callable[[List], int] = None, 38 | eval_key: str = None, 39 | ): 40 | super().__init__() 41 | assert isinstance(datasets, OrderedDict) 42 | self.datasets = datasets 43 | if sampling_func is None: 44 | sampling_func = uniform_sampler 45 | self.sampling_func = sampling_func 46 | 47 | self.total_num_instances = 0 48 | for _, dataset in datasets.items(): 49 | assert isinstance(dataset, FairseqDataset) 50 | self.total_num_instances += dataset.__len__() 51 | 52 | self._ordered_indices = None 53 | self.eval_key = eval_key 54 | 55 | def __len__(self): 56 | """ 57 | Length of this dataset is the sum of individual datasets 58 | """ 59 | return self.total_num_instances 60 | 61 | def ordered_indices(self): 62 | """ 63 | Ordered indices for batching. Here we call the underlying 64 | dataset's ordered_indices() so that we get the same random ordering 65 | as we would have from using the underlying dataset directly. 66 | """ 67 | if self._ordered_indices is None: 68 | self._ordered_indices = OrderedDict( 69 | [ 70 | (key, dataset.ordered_indices()) 71 | for key, dataset in self.datasets.items() 72 | ] 73 | ) 74 | return np.arange(len(self)) 75 | 76 | def _map_index_to_dataset(self, key: int, index: int): 77 | """ 78 | Different underlying datasets have different lengths. In order to ensure 79 | we are not accessing an index outside the range of the current dataset 80 | size, we wrap around. This function should be called after we have 81 | created an ordering for this and all underlying datasets. 82 | """ 83 | assert ( 84 | self._ordered_indices is not None 85 | ), "Must call MultiCorpusSampledDataset.ordered_indices() first" 86 | mapped_index = index % len(self.datasets[key]) 87 | return self._ordered_indices[key][mapped_index] 88 | 89 | def __getitem__(self, index: int): 90 | """ 91 | Get the item associated with index from each underlying dataset. 92 | Since index is in the range of [0, TotalNumInstances], we need to 93 | map the index to the dataset before retrieving the item. 94 | """ 95 | if self.eval_key is None: 96 | return OrderedDict( 97 | [ 98 | (key, dataset[self._map_index_to_dataset(key, index)]) 99 | for key, dataset in self.datasets.items() 100 | ] 101 | ) 102 | else: 103 | return self.datasets[self.eval_key][self._map_index_to_dataset(self.eval_key, index)] 104 | 105 | def collater(self, samples: List[Dict]): 106 | """ 107 | Generate a mini-batch for this dataset. 108 | To convert this into a regular mini-batch we use the following 109 | logic: 110 | 1. Select a dataset using the specified probability distribution. 111 | 2. Call the collater function of the selected dataset. 112 | """ 113 | if len(samples) == 0: 114 | return None 115 | 116 | if self.eval_key is None: 117 | selected_key = self.sampling_func(list(self.datasets.keys())) 118 | selected_samples = [sample[selected_key] for sample in samples] 119 | return add_decoder_language( 120 | self.datasets[selected_key].collater(selected_samples), 121 | selected_key 122 | ) 123 | else: 124 | return add_decoder_language( 125 | self.datasets[self.eval_key].collater(samples), 126 | self.eval_key 127 | ) 128 | 129 | def num_tokens(self, index: int): 130 | """ 131 | Return an example's length (number of tokens), used for batching. Here 132 | we return the max across all examples at index across all underlying 133 | datasets. 134 | """ 135 | return max( 136 | dataset.num_tokens(self._map_index_to_dataset(key, index)) 137 | for key, dataset in self.datasets.items() 138 | ) 139 | 140 | def size(self, index: int): 141 | """ 142 | Return an example's size as a float or tuple. Here we return the max 143 | across all underlying datasets. This value is used when filtering a 144 | dataset with max-positions. 145 | """ 146 | return max( 147 | dataset.size(self._map_index_to_dataset(key, index)) 148 | for key, dataset in self.datasets.items() 149 | ) 150 | 151 | @property 152 | def supports_prefetch(self): 153 | return all( 154 | getattr(dataset, "supports_prefetch", False) 155 | for dataset in self.datasets.values() 156 | ) 157 | 158 | def prefetch(self, indices): 159 | for key, dataset in self.datasets.items(): 160 | dataset.prefetch( 161 | [self._map_index_to_dataset(key, index) for index in indices] 162 | ) 163 | -------------------------------------------------------------------------------- /laser/laser_lstm.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from fairseq import utils 8 | from fairseq.models import ( 9 | FairseqEncoder, 10 | FairseqIncrementalDecoder, 11 | FairseqEncoderDecoderModel, 12 | register_model, 13 | register_model_architecture, 14 | ) 15 | from fairseq.models.lstm import ( 16 | Embedding, 17 | Linear, 18 | LSTM, 19 | LSTMCell, 20 | LSTMModel, 21 | ) 22 | 23 | @register_model('laser') 24 | class LaserModel(FairseqEncoderDecoderModel): 25 | """Laser Encoder-Decoder implementation adapted from: 26 | https://github.com/pytorch/fairseq/blob/master/fairseq/models/lstm.py 27 | https://github.com/facebookresearch/LASER/blob/master/source/embed.py 28 | https://github.com/transducens/LASERtrain/blob/master/fairseq-modules/multilingual_lstm_laser.py 29 | """ 30 | def __init__(self, encoder, decoder): 31 | super().__init__(encoder, decoder) 32 | 33 | @staticmethod 34 | def add_args(parser): 35 | """Add model-specific arguments to the parser.""" 36 | LSTMModel.add_args(parser) 37 | parser.add_argument('--lang-embedding-size', type=int, default=32, 38 | help='language embedding dimension') 39 | 40 | @classmethod 41 | def build_model(cls, args, task): 42 | """Build a new model instance.""" 43 | from .laser_task import TranslationLaserTask 44 | assert isinstance(task, TranslationLaserTask) 45 | 46 | shared_dict = task.dicts[task.langs[0]] 47 | if any(task.dicts[lang] != shared_dict for lang in task.langs): 48 | raise ValueError('This model requires a joined dictionary.') 49 | 50 | # make sure that all args are properly defaulted (in case there are any new ones) 51 | base_architecture(args) 52 | 53 | # Languages index: lang codes into integers 54 | lang_dictionary = { 55 | task.langs[i] : i for i in range(len(task.langs)) 56 | } 57 | 58 | encoder = LaserEncoder( 59 | dictionary=task.source_dictionary, 60 | embed_dim=args.encoder_embed_dim, 61 | hidden_size=args.encoder_hidden_size, 62 | num_layers=args.encoder_layers, 63 | bidirectional=args.encoder_bidirectional, 64 | dropout_in=args.encoder_dropout_in, 65 | dropout_out=args.encoder_dropout_out, 66 | ) 67 | decoder = LaserDecoder( 68 | dictionary=task.target_dictionary, 69 | lang_dictionary=lang_dictionary, 70 | embed_dim=args.decoder_embed_dim, 71 | hidden_size=args.decoder_hidden_size, 72 | out_embed_dim=args.decoder_out_embed_dim, 73 | num_layers=args.decoder_layers, 74 | dropout_in=args.decoder_dropout_in, 75 | dropout_out=args.decoder_dropout_out, 76 | attention=False, 77 | encoder_output_units=int(args.encoder_hidden_size)*2, 78 | lang_embedding_size=args.lang_embedding_size, 79 | ) 80 | return cls(encoder, decoder) 81 | 82 | def forward(self, src_tokens, src_lengths, prev_output_tokens, decoder_lang, **kwargs): 83 | encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) 84 | decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, lang=decoder_lang, **kwargs) 85 | return decoder_out 86 | 87 | 88 | class LaserEncoder(FairseqEncoder): 89 | 90 | def __init__( 91 | self, dictionary, embed_dim=320, hidden_size=512, num_layers=1, bidirectional=False, 92 | left_pad=True, padding_value=0., dropout_in=0.1, dropout_out=0.1 93 | ): 94 | super().__init__(dictionary) 95 | 96 | self.num_layers = num_layers 97 | self.dropout_in = dropout_in 98 | self.dropout_out = dropout_out 99 | self.bidirectional = bidirectional 100 | self.hidden_size = hidden_size 101 | 102 | num_embeddings = len(dictionary) 103 | self.padding_idx = dictionary.pad() 104 | self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) 105 | 106 | self.lstm = nn.LSTM( 107 | input_size=embed_dim, 108 | hidden_size=hidden_size, 109 | num_layers=num_layers, 110 | dropout=self.dropout_out if num_layers > 1 else 0., 111 | bidirectional=bidirectional, 112 | ) 113 | self.left_pad = left_pad 114 | self.padding_value = padding_value 115 | 116 | self.output_units = hidden_size 117 | if bidirectional: 118 | self.output_units *= 2 119 | 120 | def forward(self, src_tokens, src_lengths): 121 | if self.left_pad: 122 | # convert left-padding to right-padding 123 | src_tokens = utils.convert_padding_direction( 124 | src_tokens, 125 | self.padding_idx, 126 | left_to_right=True, 127 | ) 128 | 129 | bsz, seqlen = src_tokens.size() 130 | 131 | # embed tokens 132 | x = self.embed_tokens(src_tokens) 133 | x = F.dropout(x, p=self.dropout_in, training=self.training) 134 | 135 | # B x T x C -> T x B x C 136 | x = x.transpose(0, 1) 137 | 138 | # pack embedded source tokens into a PackedSequence 139 | packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) 140 | 141 | # apply LSTM 142 | if self.bidirectional: 143 | state_size = 2 * self.num_layers, bsz, self.hidden_size 144 | else: 145 | state_size = self.num_layers, bsz, self.hidden_size 146 | h0 = x.data.new(*state_size).zero_() 147 | c0 = x.data.new(*state_size).zero_() 148 | packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) 149 | 150 | # unpack outputs and apply dropout 151 | x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value) 152 | x = F.dropout(x, p=self.dropout_out, training=self.training) 153 | assert list(x.size()) == [seqlen, bsz, self.output_units] 154 | 155 | if self.bidirectional: 156 | def combine_bidir(outs): 157 | return torch.cat([ 158 | torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units) 159 | for i in range(self.num_layers) 160 | ], dim=0) 161 | 162 | final_hiddens = combine_bidir(final_hiddens) 163 | final_cells = combine_bidir(final_cells) 164 | 165 | encoder_padding_mask = src_tokens.eq(self.padding_idx).t() 166 | 167 | # Set padded outputs to -inf so they are not selected by max-pooling 168 | padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) 169 | if padding_mask.any(): 170 | x = x.float().masked_fill_(padding_mask, float('-inf')).type_as(x) 171 | 172 | # Build the sentence embedding by max-pooling over the encoder outputs 173 | sentemb = x.max(dim=0)[0] 174 | 175 | return { 176 | 'sentemb': sentemb, 177 | 'encoder_out': (x, final_hiddens, final_cells), 178 | 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None 179 | } 180 | 181 | 182 | class LaserDecoder(FairseqIncrementalDecoder): 183 | 184 | def __init__( 185 | self, dictionary, lang_dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, 186 | num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, 187 | encoder_output_units=512, pretrained_embed=None, 188 | share_input_output_embed=False, adaptive_softmax_cutoff=None, 189 | lang_embedding_size=32 190 | ): 191 | super().__init__(dictionary) 192 | self.dropout_in = dropout_in 193 | self.dropout_out = dropout_out 194 | self.hidden_size = hidden_size 195 | self.share_input_output_embed = share_input_output_embed 196 | self.lang_embedding_size = lang_embedding_size 197 | self.lang_dictionary = lang_dictionary 198 | self.embed_langs = nn.Embedding(len(lang_dictionary), lang_embedding_size) 199 | self.need_attn = False 200 | 201 | self.adaptive_softmax = None 202 | num_embeddings = len(dictionary) 203 | padding_idx = dictionary.pad() 204 | self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) 205 | 206 | self.encoder_output_units = encoder_output_units 207 | self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size) 208 | self.encoder_cell_proj = Linear(encoder_output_units, hidden_size) 209 | 210 | input_size = hidden_size + embed_dim + lang_embedding_size + encoder_output_units 211 | self.layers = nn.ModuleList([ 212 | LSTMCell( 213 | input_size=input_size if layer == 0 else hidden_size, 214 | hidden_size=hidden_size, 215 | ) 216 | for layer in range(num_layers) 217 | ]) 218 | 219 | self.attention = None 220 | if hidden_size != out_embed_dim: 221 | self.additional_fc = Linear(hidden_size, out_embed_dim) 222 | if not self.share_input_output_embed: 223 | self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) 224 | 225 | def forward(self, prev_output_tokens, encoder_out, lang, incremental_state=None): 226 | x, attn_scores = self.extract_features( 227 | prev_output_tokens, encoder_out, lang, incremental_state 228 | ) 229 | return self.output_layer(x), attn_scores 230 | 231 | def extract_features( 232 | self, prev_output_tokens, encoder_out, lang, incremental_state=None 233 | ): 234 | """ 235 | Similar to *forward* but only return features. 236 | """ 237 | encoder_sentemb = encoder_out['sentemb'] 238 | encoder_padding_mask = encoder_out['encoder_padding_mask'] 239 | encoder_out = encoder_out['encoder_out'] 240 | 241 | if incremental_state is not None: 242 | prev_output_tokens = prev_output_tokens[:, -1:] 243 | bsz, seqlen = prev_output_tokens.size() 244 | 245 | # get outputs from encoder 246 | encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] 247 | srclen = encoder_outs.size(0) 248 | 249 | # embed tokens 250 | x = self.embed_tokens(prev_output_tokens) 251 | x = F.dropout(x, p=self.dropout_in, training=self.training) 252 | 253 | # B x T x C -> T x B x C 254 | x = x.transpose(0, 1) 255 | 256 | # embed language 257 | lang_tensor = torch.LongTensor( 258 | [self.lang_dictionary[lang]] * bsz 259 | ).to(device=prev_output_tokens.device) 260 | l = self.embed_langs(lang_tensor) 261 | 262 | # initialize previous states (or get from cache during incremental generation) 263 | cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') 264 | if cached_state is not None: 265 | prev_hiddens, prev_cells, input_feed = cached_state 266 | else: 267 | num_layers = len(self.layers) 268 | prev_hiddens = [encoder_sentemb for i in range(num_layers)] 269 | prev_cells = [encoder_sentemb for i in range(num_layers)] 270 | prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens] 271 | prev_cells = [self.encoder_cell_proj(x) for x in prev_cells] 272 | input_feed = x.new_zeros(bsz, self.hidden_size) 273 | 274 | outs = [] 275 | for j in range(seqlen): 276 | # input feeding: concatenate context vector from previous time step 277 | input = torch.cat((x[j, :, :], encoder_sentemb, input_feed, l), dim=1) 278 | 279 | for i, rnn in enumerate(self.layers): 280 | # recurrent cell 281 | hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) 282 | 283 | # hidden state becomes the input to the next layer 284 | input = F.dropout(hidden, p=self.dropout_out, training=self.training) 285 | 286 | # save state for next time step 287 | prev_hiddens[i] = hidden 288 | prev_cells[i] = cell 289 | 290 | out = hidden 291 | out = F.dropout(out, p=self.dropout_out, training=self.training) 292 | 293 | # input feeding 294 | input_feed = out 295 | 296 | # save final output 297 | outs.append(out) 298 | 299 | # cache previous states (no-op except during incremental generation) 300 | utils.set_incremental_state( 301 | self, incremental_state, 'cached_state', 302 | (prev_hiddens, prev_cells, input_feed), 303 | ) 304 | 305 | # collect outputs across time steps 306 | x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) 307 | 308 | # T x B x C -> B x T x C 309 | x = x.transpose(1, 0) 310 | 311 | if hasattr(self, 'additional_fc') and self.adaptive_softmax is None: 312 | x = self.additional_fc(x) 313 | x = F.dropout(x, p=self.dropout_out, training=self.training) 314 | 315 | return x, None 316 | 317 | def output_layer(self, x): 318 | """Project features to the vocabulary size.""" 319 | if self.adaptive_softmax is None: 320 | if self.share_input_output_embed: 321 | x = F.linear(x, self.embed_tokens.weight) 322 | else: 323 | x = self.fc_out(x) 324 | return x 325 | 326 | def reorder_incremental_state(self, incremental_state, new_order): 327 | super().reorder_incremental_state(incremental_state, new_order) 328 | cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') 329 | if cached_state is None: 330 | return 331 | 332 | def reorder_state(state): 333 | if isinstance(state, list): 334 | return [reorder_state(state_i) for state_i in state] 335 | return state.index_select(0, new_order) 336 | 337 | new_state = tuple(map(reorder_state, cached_state)) 338 | utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) 339 | 340 | 341 | @register_model_architecture('laser', 'laser') 342 | def base_architecture(args): 343 | args.dropout = getattr(args, 'dropout', 0.1) 344 | args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) 345 | args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) 346 | args.encoder_freeze_embed = getattr(args, 'encoder_freeze_embed', False) 347 | args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', 512) 348 | args.encoder_layers = getattr(args, 'encoder_layers', 1) 349 | args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', True) 350 | args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout) 351 | args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout) 352 | args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) 353 | args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) 354 | args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) 355 | args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 2048) 356 | args.decoder_layers = getattr(args, 'decoder_layers', 1) 357 | args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 2048) 358 | args.decoder_attention = getattr(args, 'decoder_attention', False) 359 | args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) 360 | args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) 361 | args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) 362 | args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) 363 | args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) 364 | --------------------------------------------------------------------------------