├── third_party ├── __init__.py ├── ud-conversion-tools │ ├── lib │ │ ├── __init__.py │ │ └── conll.py │ └── conllu_to_conll.py ├── xlm.py ├── processors │ ├── pawsx.py │ ├── xnli.py │ ├── utils.py │ └── squad.py ├── utils_tag.py ├── utils_retrieve.py ├── run_retrieval.py └── xlm_roberta.py ├── xtreme_score.png ├── install_tools.sh ├── LICENSE ├── scripts ├── preprocess_panx.sh ├── preprocess_udpos.sh ├── run_tatoeba.sh ├── train.sh ├── predict_qa.sh ├── eval_qa.sh ├── train_xnli.sh ├── train_pawsx.sh ├── run_bucc2018.sh ├── train_udpos.sh ├── train_panx.sh ├── train_qa.sh └── download_data.sh ├── .gitignore ├── evaluate.py ├── README.md ├── conda-env.txt └── utils_preprocess.py /third_party/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /third_party/ud-conversion-tools/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xtreme_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunjieHu/xtreme-dev/HEAD/xtreme_score.png -------------------------------------------------------------------------------- /install_tools.sh: -------------------------------------------------------------------------------- 1 | REPO=$PWD 2 | LIB=$REPO/third_party 3 | mkdir -p $LIB 4 | 5 | # install conda env 6 | conda create --name xtreme --file conda-env.txt 7 | conda init bash 8 | source activate xtreme 9 | 10 | # install latest transformer 11 | cd $LIB 12 | git clone https://github.com/huggingface/transformers 13 | cd transformers 14 | pip install . 15 | cd $LIB 16 | 17 | pip install seqeval 18 | pip install tensorboardx 19 | 20 | # install XLM tokenizer 21 | pip install sacremoses 22 | pip install pythainlp 23 | pip install jieba 24 | 25 | git clone https://github.com/neubig/kytea.git && cd kytea 26 | autoreconf -i 27 | ./configure --prefix=$HOME/local 28 | make && make install 29 | pip install kytea 30 | cd $LIB 31 | 32 | wget -O $LIB/evaluate_squad.py https://raw.githubusercontent.com/allenai/bi-att-flow/master/squad/evaluate-v1.1.py 33 | wget -O $LIB/evaluate_mlqa.py https://raw.githubusercontent.com/facebookresearch/MLQA/master/mlqa_evaluation_v1.py 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 JunjieHu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/preprocess_panx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | DATA_DIR=${2:-"$REPO/download/"} 5 | 6 | TASK='panx' 7 | MAXL=128 8 | LANGS="ar,he,vi,id,jv,ms,tl,eu,ml,ta,te,af,nl,en,de,el,bn,hi,mr,ur,fa,fr,it,pt,es,bg,ru,ja,ka,ko,th,sw,yo,my,zh,kk,tr,et,fi,hu" 9 | LC="" 10 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 11 | MODEL_TYPE="bert" 12 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024" ]; then 13 | MODEL_TYPE="xlm" 14 | LC=" --do_lower_case" 15 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xom-roberta-base" ]; then 16 | MODEL_TYPE="xlmr" 17 | fi 18 | 19 | SAVE_DIR="$DATA_DIR/$TASK/${TASK}_processed_maxlen${MAXL}" 20 | mkdir -p $SAVE_DIR 21 | python3 $REPO/utils_preprocess.py \ 22 | --data_dir $DATA_DIR/$TASK/ \ 23 | --task panx_tokenize \ 24 | --model_name_or_path $MODEL \ 25 | --model_type $MODEL_TYPE \ 26 | --max_len $MAXL \ 27 | --output_dir $SAVE_DIR \ 28 | --languages $LANGS $LC 29 | if [ ! -f $SAVE_DIR/labels.txt ]; then 30 | cat $SAVE_DIR/*/*.${MODEL} | cut -f 2 | grep -v "^$" | sort | uniq > $SAVE_DIR/labels.txt 31 | fi 32 | -------------------------------------------------------------------------------- /scripts/preprocess_udpos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | DATA_DIR=${2:-"$REPO/download/"} 5 | 6 | TASK='udpos' 7 | MAXL=128 8 | LANGS='af,ar,bg,de,el,en,es,et,eu,fa,fi,fr,he,hi,hu,id,it,ja,kk,ko,mr,nl,pt,ru,ta,te,th,tl,tr,ur,vi,yo,zh' 9 | LC="" 10 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 11 | MODEL_TYPE="bert" 12 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024" ]; then 13 | MODEL_TYPE="xlm" 14 | LC=" --do_lower_case" 15 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-roberta-base" ]; then 16 | MODEL_TYPE="xlmr" 17 | fi 18 | 19 | SAVE_DIR="$DATA_DIR/${TASK}/udpos_processed_maxlen${MAXL}" 20 | mkdir -p $SAVE_DIR 21 | python3 $REPO/utils_preprocess.py \ 22 | --data_dir $DATA_DIR/${TASK}/ \ 23 | --task udpos_tokenize \ 24 | --model_name_or_path $MODEL \ 25 | --model_type $MODEL_TYPE \ 26 | --max_len $MAXL \ 27 | --output_dir $SAVE_DIR \ 28 | --languages $LANGS $LC #>> $SAVE_DIR/process.log 29 | if [ ! -f $SAVE_DIR/labels.txt ]; then 30 | echo "create label" 31 | cat $SAVE_DIR/*/*.${MODEL} | cut -f 2 | grep -v "^$" | sort | uniq > $SAVE_DIR/labels.txt 32 | fi 33 | -------------------------------------------------------------------------------- /scripts/run_tatoeba.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | GPU=${2:-0} 5 | DATA_DIR=${3:-"$REPO/download/"} 6 | OUT_DIR=${4:-"$REPO/outputs/"} 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU 9 | 10 | TASK='tatoeba' 11 | TL='en' 12 | MAXL=512 13 | LC="" 14 | LAYER=7 15 | NLAYER=12 16 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 17 | MODEL_TYPE="bert" 18 | DIM=768 19 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024" ]; then 20 | MODEL_TYPE="xlm" 21 | LC=" --do_lower_case" 22 | DIM=1280 23 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-roberta-base" ]; then 24 | MODEL_TYPE="xlmr" 25 | DIM=1024 26 | NLAYER=24 27 | LAYER=13 28 | fi 29 | 30 | 31 | OUT=$OUT_DIR/$TASK/${MODEL}_${MAXL}/ 32 | mkdir -p $OUT 33 | for SL in ar he vi id jv tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw zh kk tr et fi hu; do 34 | python $REPO/third_party/run_retrieval.py \ 35 | --model_type $MODEL_TYPE \ 36 | --model_name_or_path $MODEL \ 37 | --embed_size $DIM \ 38 | --batch_size 100 \ 39 | --task_name $TASK \ 40 | --src_language $SL \ 41 | --tgt_language en \ 42 | --data_dir $DATA_DIR/$TASK/ \ 43 | --max_seq_length $MAXL \ 44 | --output_dir $OUT \ 45 | --log_file embed-cosine \ 46 | --num_layers $NLAYER \ 47 | --dist cosine $LC \ 48 | --specific_layer $LAYER 49 | done 50 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | TASK=${2:-pawsx} 5 | GPU=${3:-0} 6 | DATA_DIR=${4:-"$REPO/download/"} 7 | OUT_DIR=${5:-"$REPO/outputs-temp/"} 8 | echo "Fine-tuning $MODEL on $TASK using GPU $GPU" 9 | echo "Load data from $DATA_DIR, and save models to $OUT_DIR" 10 | 11 | if [ $TASK == 'pawsx' ]; then 12 | bash $REPO/scripts/train_pawsx.sh $MODEL $GPU $DATA_DIR $OUT_DIR 13 | elif [ $TASK == 'xnli' ]; then 14 | bash $REPO/scripts/train_xnli.sh $MODEL $GPU $DATA_DIR $OUT_DIR 15 | elif [ $TASK == 'udpos' ]; then 16 | bash $REPO/scripts/preprocess_udpos.sh $MODEL $DATA_DIR 17 | bash $REPO/scripts/train_udpos.sh $MODEL $GPU $DATA_DIR $OUT_DIR 18 | elif [ $TASK == 'panx' ]; then 19 | bash $REPO/scripts/preprocess_panx.sh $MODEL $DATA_DIR 20 | bash $REPO/scripts/train_panx.sh $MODEL $GPU $DATA_DIR $OUT_DIR 21 | elif [ $TASK == 'xquad' ]; then 22 | bash $REPO/scripts/train_qa.sh $MODEL squad $TASK $GPU $DATA_DIR $OUT_DIR 23 | elif [ $TASK == 'mlqa' ]; then 24 | bash $REPO/scripts/train_qa.sh $MODEL squad $TASK $GPU $DATA_DIR $OUT_DIR 25 | elif [ $TASK == 'tydiqa' ]; then 26 | bash $REPO/scripts/train_qa.sh $MODEL tydiqa $TASK $GPU $DATA_DIR $OUT_DIR 27 | elif [ $TASK == 'bucc2018' ]; then 28 | bash $REPO/scripts/run_bucc2018.sh $MODEL $GPU $DATA_DIR $OUT_DIR 29 | elif [ $TASK == 'tatoeba' ]; then 30 | bash $REPO/scripts/run_tatoeba.sh $MODEL $GPU $DATA_DIR $OUT_DIR 31 | fi 32 | 33 | -------------------------------------------------------------------------------- /scripts/predict_qa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script to obtain predictions using a trained model on XQuAD, TyDi QA, and MLQA. 3 | REPO=$PWD 4 | MODEL=${1:-bert-base-multilingual-cased} 5 | MODEL_TYPE=${2:-bert} 6 | MODEL_PATH=${3} 7 | TGT=${4:-xquad} 8 | GPU=${5:-0} 9 | DATA_DIR=${6:-"$REPO/download/"} 10 | 11 | if [ ! -d "${MODEL_PATH}" ]; then 12 | echo "Model path ${MODEL_PATH} does not exist." 13 | exit 14 | fi 15 | 16 | DIR=${DATA_DIR}/${TGT}/ 17 | PREDICTIONS_DIR=${REPO}/predictions 18 | PRED_DIR=${PREDICTIONS_DIR}/$TGT/ 19 | mkdir -p "${PRED_DIR}" 20 | 21 | if [ $TGT == 'xquad' ]; then 22 | langs=( en es de el ru tr ar vi th zh hi ) 23 | elif [ $TGT == 'mlqa' ]; then 24 | langs=( en es de ar hi vi zh ) 25 | elif [ $TGT == 'tydiqa' ]; then 26 | langs=( en ar bn fi id ko ru sw te ) 27 | fi 28 | 29 | echo "************************" 30 | echo ${MODEL} 31 | echo "************************" 32 | 33 | echo 34 | echo "Predictions on $TGT" 35 | for lang in ${langs[@]}; do 36 | echo " $lang " 37 | if [ $TGT == 'xquad' ]; then 38 | TEST_FILE=${DIR}/xquad.$lang.json 39 | elif [ $TGT == 'mlqa' ]; then 40 | TEST_FILE=${DIR}/MLQA_V1/test/test-context-$lang-question-$lang.json 41 | elif [ $TGT == 'tydiqa' ]; then 42 | TEST_FILE=${DIR}/tydiqa-goldp-v1.1-dev/tydiqa.$lang.dev.json 43 | fi 44 | 45 | CUDA_VISIBLE_DEVICES=${CUDA} python third_party/run_squad.py \ 46 | --model_type ${MODEL_TYPE} \ 47 | --model_name_or_path ${MODEL_PATH} \ 48 | --do_eval \ 49 | --do_lower_case \ 50 | --eval_lang ${lang} \ 51 | --predict_file "${TEST_FILE}" \ 52 | --output_dir "${PRED_DIR}" &> /dev/null 53 | done 54 | 55 | -------------------------------------------------------------------------------- /scripts/eval_qa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script to evaluate the predictions of a trained model on XQuAD, TyDi QA, and MLQA. 3 | REPO=$PWD 4 | DIR=${REPO}/download 5 | XQUAD_DIR=${DIR}/xquad 6 | MLQA_DIR=${DIR}/mlqa 7 | TYDIQA_DIR=${DIR}/tydiqa 8 | 9 | EVAL_SQUAD=${DIR}/squad/evaluate-v1.1.py 10 | EVAL_MLQA=${MLQA_DIR}/mlqa_evaluation_v1.py 11 | 12 | PREDICTIONS_DIR=${REPO}/predictions 13 | XQUAD_PRED_DIR=${PREDICTIONS_DIR}/xquad 14 | MLQA_PRED_DIR=${PREDICTIONS_DIR}/mlqa 15 | TYDIQA_PRED_DIR=${PREDICTIONS_DIR}/tydiqa 16 | 17 | for pred_path in ${PREDICTIONS_DIR} ${XQUAD_PRED_DIR} ${MLQA_PRED_DIR} ${TYDIQA_PRED_DIR}; do 18 | if [ ! -d ${pred_path} ] 19 | then 20 | echo "Predictions path ${pred_path} does not exist." 21 | exit 22 | fi 23 | done 24 | 25 | echo 26 | echo "XQuAD" 27 | for lang in en es de el ru tr ar vi th zh hi; do 28 | echo -n " $lang " 29 | TEST_FILE=${XQUAD_DIR}/xquad.$lang.json 30 | PRED_FILE=${XQUAD_PRED_DIR}/predictions_${lang}_.json 31 | python "${EVAL_SQUAD}" "${TEST_FILE}" "${PRED_FILE}" 32 | done 33 | 34 | echo 35 | echo "MLQA" 36 | for lang in en es de ar hi vi zh; do 37 | echo -n " $lang " 38 | TEST_FILE=${MLQA_DIR}/MLQA_V1/test/test-context-$lang-question-$lang.json 39 | PRED_FILE=${MLQA_PRED_DIR}/predictions_${lang}_.json 40 | python "${EVAL_MLQA}" "${TEST_FILE}" "${PRED_FILE}" ${lang} 41 | done 42 | 43 | echo "TyDi QA Gold Passage" 44 | for lang in en ar bn fi id ko ru sw te; do 45 | echo -n " $lang " 46 | TEST_FILE=${TYDIQA_DIR}/tydiqa-goldp-v1.1-dev/tydiqa.$lang.dev.json 47 | PRED_FILE=${TYDIQA_PRED_DIR}/predictions_${lang}_.json 48 | python "${EVAL_SQUAD}" "${TEST_FILE}" "${PRED_FILE}" 49 | done 50 | -------------------------------------------------------------------------------- /scripts/train_xnli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | GPU=${2:-0} 5 | DATA_DIR=${3:-"$REPO/download/"} 6 | OUT_DIR=${4:-"$REPO/outputs/"} 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU 9 | 10 | TASK='xnli' 11 | LR=2e-5 12 | EPOCH=5 13 | MAXL=128 14 | LANGS="ar,bg,de,el,en,es,fr,hi,ru,sw,th,tr,ur,vi,zh" 15 | LC="" 16 | 17 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 18 | MODEL_TYPE="bert" 19 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024" ]; then 20 | MODEL_TYPE="xlm" 21 | LC=" --do_lower_case" 22 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-roberta-base" ]; then 23 | MODEL_TYPE="xlmr" 24 | fi 25 | 26 | if [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-roberta-large" ]; then 27 | BATCH_SIZE=2 28 | GRAD_ACC=16 29 | else 30 | BATCH_SIZE=8 31 | GRAD_ACC=4 32 | fi 33 | 34 | SAVE_DIR="$OUT_DIR/$TASK/${MODEL}-LR${LR}-epoch${EPOCH}-MaxLen${MAXL}/" 35 | mkdir -p $SAVE_DIR 36 | 37 | python $PWD/third_party/run_classify.py \ 38 | --model_type $MODEL_TYPE \ 39 | --model_name_or_path $MODEL \ 40 | --train_language en \ 41 | --task_name $TASK \ 42 | --do_train \ 43 | --do_eval \ 44 | --do_predict \ 45 | --data_dir $DATA_DIR/${TASK} \ 46 | --gradient_accumulation_steps $BATCH_SIZE \ 47 | --per_gpu_train_batch_size $GRAD_ACC \ 48 | --learning_rate $LR \ 49 | --num_train_epochs $EPOCH \ 50 | --max_seq_length $MAXL \ 51 | --output_dir $SAVE_DIR/ \ 52 | --save_steps 100 \ 53 | --eval_all_checkpoints \ 54 | --log_file 'train' \ 55 | --predict_languages $LANGS \ 56 | --save_only_best_checkpoint \ 57 | --overwrite_output_dir \ 58 | --eval_test_set $LC 59 | -------------------------------------------------------------------------------- /scripts/train_pawsx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | GPU=${2:-0} 5 | DATA_DIR=${3:-"$REPO/download/"} 6 | OUT_DIR=${4:-"$REPO/outputs/"} 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU 9 | 10 | TASK='pawsx' 11 | LR=2e-5 12 | EPOCH=5 13 | MAXL=128 14 | LANGS="de,en,es,fr,ja,ko,zh" 15 | LC="" 16 | 17 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 18 | MODEL_TYPE="bert" 19 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024" ]; then 20 | MODEL_TYPE="xlm" 21 | LC=" --do_lower_case" 22 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-roberta-base" ]; then 23 | MODEL_TYPE="xlmr" 24 | fi 25 | 26 | if [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-roberta-large" ]; then 27 | BATCH_SIZE=2 28 | GRAD_ACC=16 29 | else 30 | BATCH_SIZE=8 31 | GRAD_ACC=4 32 | fi 33 | 34 | SAVE_DIR="${OUT_DIR}/${TASK}/${MODEL}-LR${LR}-epoch${EPOCH}-MaxLen${MAXL}/" 35 | mkdir -p $SAVE_DIR 36 | 37 | python $PWD/third_party/run_classify.py \ 38 | --model_type $MODEL_TYPE \ 39 | --model_name_or_path $MODEL \ 40 | --train_language en \ 41 | --task_name $TASK \ 42 | --do_train \ 43 | --do_eval \ 44 | --do_predict \ 45 | --train_split train \ 46 | --test_split test \ 47 | --data_dir $DATA_DIR/$TASK/ \ 48 | --gradient_accumulation_steps $GRAD_ACC \ 49 | --save_steps 200 \ 50 | --per_gpu_train_batch_size $BATCH_SIZE \ 51 | --learning_rate $LR \ 52 | --num_train_epochs $EPOCH \ 53 | --max_seq_length $MAXL \ 54 | --output_dir $SAVE_DIR \ 55 | --eval_all_checkpoints \ 56 | --overwrite_output_dir \ 57 | --overwrite_cache \ 58 | --log_file 'train.log' \ 59 | --predict_languages $LANGS \ 60 | --save_only_best_checkpoint $LC \ 61 | --eval_test_set 62 | -------------------------------------------------------------------------------- /scripts/run_bucc2018.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | GPU=${2:-0} 5 | DATA_DIR=${3:-"$REPO/download/"} 6 | OUT_DIR=${4:-"$REPO/outputs/"} 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU 9 | 10 | TASK='bucc2018' 11 | DATA_DIR=$DATA_DIR/$TASK/ 12 | MAXL=512 13 | TL='en' 14 | 15 | NLAYER=12 16 | LC="" 17 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 18 | MODEL_TYPE="bert" 19 | DIM=768 20 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024" ]; then 21 | MODEL_TYPE="xlm" 22 | DIM=1280 23 | LC=" --do_lower_case" 24 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-roberta-base" ]; then 25 | MODEL_TYPE="xlmr" 26 | DIM=1024 27 | NLAYER=24 28 | fi 29 | 30 | SP='test' 31 | for SL in fr ru zh de; do 32 | PRED_DIR=$REPO/predictions/ 33 | OUT=$OUT_DIR/$TASK/$MODEL-${SL} 34 | mkdir -p $OUT 35 | for sp in 'test' 'dev'; do 36 | for lg in "$SL" "$TL"; do 37 | FILE=$DATA_DIR/${SL}-${TL}.${sp}.${lg} 38 | cut -f2 $FILE > $OUT/${SL}-${TL}.${sp}.${lg}.txt 39 | cut -f1 $FILE > $OUT/${SL}-${TL}.${sp}.${lg}.id 40 | done 41 | done 42 | 43 | CP="candidates" 44 | python $REPO/third_party/run_retrieval.py \ 45 | --model_type $MODEL_TYPE \ 46 | --model_name_or_path $MODEL \ 47 | --embed_size $DIM \ 48 | --batch_size 100 \ 49 | --task_name $TASK \ 50 | --src_language $SL \ 51 | --tgt_language $TL \ 52 | --pool_type cls \ 53 | --max_seq_length $MAXL \ 54 | --data_dir $DATA_DIR \ 55 | --output_dir $OUT \ 56 | --predict_dir $PRED_DIR \ 57 | --candidate_prefix $CP \ 58 | --log_file mine-bitext-${SL}.log \ 59 | --extract_embeds \ 60 | --mine_bitext \ 61 | --specific_layer 7 \ 62 | --dist cosine $LC 63 | 64 | done 65 | -------------------------------------------------------------------------------- /scripts/train_udpos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | GPU=${2:-0} 5 | DATA_DIR=${3:-"$REPO/download/"} 6 | OUT_DIR=${4:-"$REPO/outputs/"} 7 | 8 | TASK='udpos' 9 | export CUDA_VISIBLE_DEVICES=$GPU 10 | LANGS='af,ar,bg,de,el,en,es,et,eu,fa,fi,fr,he,hi,hu,id,it,ja,kk,ko,mr,nl,pt,ru,ta,te,th,tl,tr,ur,vi,yo,zh' 11 | NUM_EPOCHS=10 12 | MAX_LENGTH=128 13 | LR=2e-5 14 | 15 | LC="" 16 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 17 | MODEL_TYPE="bert" 18 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024"]; then 19 | MODEL_TYPE="xlm" 20 | LC=" --do_lower_case" 21 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-roberta-base" ]; then 22 | MODEL_TYPE="xlmr" 23 | fi 24 | 25 | if [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-mlm-100-1280" ]; then 26 | BATCH_SIZE=2 27 | GRAD_ACC=16 28 | else 29 | BATCH_SIZE=8 30 | GRAD_ACC=4 31 | fi 32 | 33 | DATA_DIR=$DATA_DIR/$TASK/${TASK}_processed_maxlen${MAX_LENGTH}/ 34 | OUTPUT_DIR="$OUT_DIR/$TASK/${MODEL}-LR${LR}-epoch${NUM_EPOCH}-MaxLen${MAX_LENGTH}/" 35 | mkdir -p $OUTPUT_DIR 36 | python3 $REPO/third_party/run_tag.py \ 37 | --data_dir $DATA_DIR \ 38 | --model_type $MODEL_TYPE \ 39 | --labels $DATA_DIR/labels.txt \ 40 | --model_name_or_path $MODEL \ 41 | --output_dir $OUTPUT_DIR \ 42 | --max_seq_length $MAX_LENGTH \ 43 | --num_train_epochs $NUM_EPOCHS \ 44 | --per_gpu_train_batch_size $BATCH_SIZE \ 45 | --save_steps 500 \ 46 | --seed 1 \ 47 | --learning_rate $LR \ 48 | --do_train \ 49 | --do_eval \ 50 | --do_predict \ 51 | --do_predict_dev \ 52 | --evaluate_during_training \ 53 | --predict_langs $LANGS \ 54 | --gradient_accumulation_steps $GRAD_ACC \ 55 | --log_file $OUTPUT_DIR/train.log \ 56 | --eval_all_checkpoints \ 57 | --overwrite_output_dir \ 58 | --save_only_best_checkpoint $LC 59 | 60 | -------------------------------------------------------------------------------- /scripts/train_panx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO=$PWD 3 | MODEL=${1:-bert-base-multilingual-cased} 4 | GPU=${2:-0} 5 | DATA_DIR=${3:-"$REPO/download/"} 6 | OUT_DIR=${4:-"$REPO/outputs/"} 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU 9 | TASK='panx' 10 | LANGS="ar,he,vi,id,jv,ms,tl,eu,ml,ta,te,af,nl,en,de,el,bn,hi,mr,ur,fa,fr,it,pt,es,bg,ru,ja,ka,ko,th,sw,yo,my,zh,kk,tr,et,fi,hu" 11 | NUM_EPOCHS=10 12 | MAX_LENGTH=128 13 | LR=2e-5 14 | 15 | LC="" 16 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 17 | MODEL_TYPE="bert" 18 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024" ]; then 19 | MODEL_TYPE="xlm" 20 | LC=" --do_lower_case" 21 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-roberta-base" ]; then 22 | MODEL_TYPE="xlmr" 23 | fi 24 | 25 | if [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-roberta-large" ]; then 26 | BATCH_SIZE=2 27 | GRAD_ACC=16 28 | else 29 | BATCH_SIZE=8 30 | GRAD_ACC=4 31 | fi 32 | 33 | DATA_DIR=$DATA_DIR/${TASK}/${TASK}_processed_maxlen${MAX_LENGTH}/ 34 | OUTPUT_DIR="$OUT_DIR/$TASK/${MODEL}-LR${LR}-epoch${NUM_EPOCH}-MaxLen${MAX_LENGTH}/" 35 | mkdir -p $OUTPUT_DIR 36 | python $REPO/third_party/run_tag.py \ 37 | --data_dir $DATA_DIR \ 38 | --model_type $MODEL_TYPE \ 39 | --labels $DATA_DIR/labels.txt \ 40 | --model_name_or_path $MODEL \ 41 | --output_dir $OUTPUT_DIR \ 42 | --max_seq_length $MAX_LENGTH \ 43 | --num_train_epochs $NUM_EPOCHS \ 44 | --per_gpu_train_batch_size $BATCH_SIZE \ 45 | --per_gpu_eval_batch_size 32 \ 46 | --save_steps 1000 \ 47 | --seed 1 \ 48 | --learning_rate $LR \ 49 | --do_train \ 50 | --do_eval \ 51 | --do_predict \ 52 | --predict_langs $LANGS \ 53 | --train_langs en \ 54 | --gradient_accumulation_steps $GRAD_ACC \ 55 | --log_file $OUTPUT_DIR/train.log \ 56 | --eval_all_checkpoints \ 57 | --eval_patience -1 \ 58 | --overwrite_output_dir \ 59 | --save_only_best_checkpoint $LC 60 | 61 | 62 | -------------------------------------------------------------------------------- /scripts/train_qa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script to train a model on SQuAD v1.1 or the English TyDiQA-GoldP train data. 3 | 4 | REPO=$PWD 5 | MODEL=${1:-bert-base-multilingual-cased} 6 | SRC=${2:-squad} 7 | TGT=${3:-xquad} 8 | GPU=${4:-0} 9 | DATA_DIR=${5:-"$REPO/download/"} 10 | OUT_DIR=${6:-"$REPO/outputs/"} 11 | 12 | MAXL=384 13 | LR=3e-5 14 | NUM_EPOCHS=2.0 15 | if [ $MODEL == "bert-base-multilingual-cased" ]; then 16 | MODEL_TYPE="bert" 17 | elif [ $MODEL == "xlm-mlm-100-1280" ] || [ $MODEL == "xlm-mlm-tlm-xnli15-1024" ]; then 18 | MODEL_TYPE="xlm" 19 | LC=" --do_lower_case" 20 | elif [ $MODEL == "xlm-roberta-large" ] || [ $MODEL == "xlm-roberta-base" ]; then 21 | MODEL_TYPE="xlmr" 22 | fi 23 | 24 | # Model path where trained model should be stored 25 | MODEL_PATH=$OUT_DIR/$SRC/${MODEL}_LR${LR}_EPOCH${NUM_EPOCHS}_maxlen${MAXL} 26 | mkdir -p $MODEL_PATH 27 | # Train either on the SQuAD or TyDiQa-GoldP English train file 28 | if [ $SRC == 'squad' ]; then 29 | TRAIN_FILE=${DATA_DIR}/squad/train-v1.1.json 30 | PREDICT_FILE=${DATA_DIR}/squad/dev-v1.1.json 31 | else 32 | TRAIN_FILE=${DATA_DIR}/tydiqa/tydiqa-goldp-v1.1-train/tydiqa.goldp.en.train.json 33 | PREDICT_FILE=${DATA_DIR}/tydiqa/tydiqa-goldp-v1.1-dev/tydiqa.en.dev.json 34 | fi 35 | 36 | # train 37 | CUDA_VISIBLE_DEVICES=$GPU python third_party/run_squad.py \ 38 | --model_type ${MODEL_TYPE} \ 39 | --model_name_or_path ${MODEL} \ 40 | --do_lower_case \ 41 | --do_train \ 42 | --do_eval \ 43 | --train_file ${TRAIN_FILE} \ 44 | --predict_file ${PREDICT_FILE} \ 45 | --per_gpu_train_batch_size 4 \ 46 | --learning_rate ${LR} \ 47 | --num_train_epochs ${NUM_EPOCHS} \ 48 | --max_seq_length $MAXL \ 49 | --doc_stride 128 \ 50 | --save_steps -1 \ 51 | --overwrite_output_dir \ 52 | --gradient_accumulation_steps 4 \ 53 | --warmup_steps 500 \ 54 | --output_dir ${MODEL_PATH} \ 55 | --weight_decay 0.0001 \ 56 | --threads 8 \ 57 | --train_lang en \ 58 | --eval_lang en 59 | 60 | 61 | # predict 62 | bash scripts/predict_qa.sh $MODEL $MODEL_TYPE $MODEL_PATH $TGT $GPU $DATA_DIR 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | #lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /third_party/ud-conversion-tools/conllu_to_conll.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import islice 3 | from pathlib import Path 4 | import argparse 5 | import sys, copy 6 | 7 | from lib.conll import CoNLLReader 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description="""Convert conllu to conll format""") 11 | parser.add_argument('input', help="conllu file") 12 | parser.add_argument('output', help="target file", type=Path) 13 | parser.add_argument('--replace_subtokens_with_fused_forms', help="By default removes fused tokens", default=False, action="store_true") 14 | parser.add_argument('--remove_deprel_suffixes', help="Restrict deprels to the common universal subset, e.g. nmod:tmod becomes nmod", default=False, action="store_true") 15 | parser.add_argument('--remove_node_properties', help="space-separated list of node properties to remove: form, lemma, cpostag, postag, feats", choices=['form', 'lemma', 'cpostag','postag','feats'], metavar='prop', type=str, nargs='+') 16 | parser.add_argument('--lang', help="specify a language 2-letter code", default="default") 17 | parser.add_argument('--output_format', choices=['conll2006', 'conll2009', 'conllu'], default="conll2006") 18 | parser.add_argument('--remove_arabic_diacritics', help="remove Arabic short vowels", default=False, action="store_true") 19 | parser.add_argument('--print_comments',default=False,action="store_true") 20 | parser.add_argument('--print_fused_forms',default=False,action="store_true") 21 | 22 | args = parser.parse_args() 23 | 24 | if sys.version_info < (3,0): 25 | print("Sorry, requires Python 3.x.") #suggestion: install anaconda python 26 | sys.exit(1) 27 | 28 | POSRANKPRECEDENCEDICT = defaultdict(list) 29 | POSRANKPRECEDENCEDICT["default"] = "VERB NOUN PROPN PRON ADJ NUM ADV INTJ AUX ADP DET PART CCONJ SCONJ X PUNCT ".split(" ") 30 | # POSRANKPRECEDENCEDICT["de"] = "PROPN ADP DET ".split(" ") 31 | POSRANKPRECEDENCEDICT["es"] = "VERB AUX PRON ADP DET".split(" ") 32 | POSRANKPRECEDENCEDICT["fr"] = "VERB AUX PRON NOUN ADJ ADV ADP DET PART SCONJ CONJ".split(" ") 33 | POSRANKPRECEDENCEDICT["it"] = "VERB AUX ADV PRON ADP DET INTJ".split(" ") 34 | 35 | if args.lang in POSRANKPRECEDENCEDICT: 36 | current_pos_precedence_list = POSRANKPRECEDENCEDICT[args.lang] 37 | else: 38 | current_pos_precedence_list = POSRANKPRECEDENCEDICT["default"] 39 | 40 | cio = CoNLLReader() 41 | orig_treebank = cio.read_conll_u(args.input)#, args.keep_fused_forms, args.lang, POSRANKPRECEDENCEDICT) 42 | modif_treebank = copy.copy(orig_treebank) 43 | 44 | # As per Dec 2015 the args.lang variable is redundant once you have current_pos_precedence_list 45 | # We keep it for future modifications, i.e. any language-specific modules 46 | for s in modif_treebank: 47 | # print('sentence', s.get_sentence_as_string(printid=True)) 48 | s.filter_sentence_content(args.replace_subtokens_with_fused_forms, args.lang, current_pos_precedence_list,args.remove_node_properties,args.remove_deprel_suffixes,args.remove_arabic_diacritics) 49 | 50 | cio.write_conll(modif_treebank,args.output, args.output_format,print_fused_forms=args.print_fused_forms, print_comments=args.print_comments) 51 | 52 | if __name__ == "__main__": 53 | main() -------------------------------------------------------------------------------- /third_party/xlm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | 7 | from transformers.modeling_xlm import XLMModel, XLMPreTrainedModel, XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING 8 | from transformers import XLMConfig, add_start_docstrings 9 | 10 | @add_start_docstrings("""XLM Model with a token classification head on top (a linear layer on top of 11 | the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, 12 | XLM_START_DOCSTRING, 13 | XLM_INPUTS_DOCSTRING) 14 | class XLMForTokenClassification(XLMPreTrainedModel): 15 | r""" 16 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 17 | Labels for computing the token classification loss. 18 | Indices should be in ``[0, ..., config.num_labels - 1]``. 19 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 20 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 21 | Classification loss. 22 | **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)`` 23 | Classification scores (before SoftMax). 24 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 25 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 26 | of shape ``(batch_size, sequence_length, hidden_size)``: 27 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 28 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 29 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 30 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 31 | Examples:: 32 | tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-100-1280') 33 | model = XLMForTokenClassification.from_pretrained('xlm-mlm-100-1280') 34 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 35 | labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 36 | outputs = model(input_ids, labels=labels) 37 | loss, scores = outputs[:2] 38 | """ 39 | def __init__(self, config): 40 | super(XLMForTokenClassification, self).__init__(config) 41 | self.num_labels = config.num_labels 42 | self.transformer = XLMModel(config) 43 | self.dropout = nn.Dropout(config.dropout) 44 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 45 | 46 | self.init_weights() 47 | 48 | 49 | def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, 50 | position_ids=None, head_mask=None, inputs_embeds=None, labels=None): 51 | 52 | outputs = self.transformer(input_ids, 53 | attention_mask=attention_mask, 54 | langs=langs, 55 | token_type_ids=token_type_ids, 56 | position_ids=position_ids, 57 | head_mask=head_mask, 58 | inputs_embeds=inputs_embeds) 59 | 60 | sequence_output = outputs[0] 61 | 62 | sequence_output = self.dropout(sequence_output) 63 | logits = self.classifier(sequence_output) 64 | 65 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 66 | if labels is not None: 67 | loss_fct = CrossEntropyLoss() 68 | # Only keep active parts of the loss 69 | if attention_mask is not None: 70 | active_loss = attention_mask.view(-1) == 1 71 | active_logits = logits.view(-1, self.num_labels)[active_loss] 72 | active_labels = labels.view(-1)[active_loss] 73 | loss = loss_fct(active_logits, active_labels) 74 | else: 75 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 76 | outputs = (loss,) + outputs 77 | 78 | return outputs # (loss), scores, (hidden_states), (attentions) 79 | -------------------------------------------------------------------------------- /third_party/processors/pawsx.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ PAWS-X utils (dataset loading and evaluation) """ 17 | 18 | 19 | import logging 20 | import os 21 | 22 | from transformers import DataProcessor 23 | from .utils import InputExample 24 | 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class PawsxProcessor(DataProcessor): 30 | """Processor for the PAWS-X dataset.""" 31 | 32 | def __init__(self): 33 | pass 34 | 35 | def get_examples(self, data_dir, language='en', split='train'): 36 | """See base class.""" 37 | examples = [] 38 | for lg in language.split(','): 39 | lines = self._read_tsv(os.path.join(data_dir, "{}-{}.tsv".format(split, lg))) 40 | 41 | for (i, line) in enumerate(lines): 42 | if i == 0: 43 | continue 44 | guid = "%s-%s-%s" % (split, lg, i) 45 | text_a = line[0] 46 | text_b = line[1] 47 | if split == 'test' and len(line) != 3: 48 | label = "0" 49 | else: 50 | label = str(line[2].strip()) 51 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 52 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, language=lg)) 53 | return examples 54 | 55 | def get_translate_examples(self, data_dir, language='en', split='train'): 56 | """See base class.""" 57 | languages = language.split(',') 58 | examples = [] 59 | for language in languages: 60 | if split == 'train': 61 | file_path = os.path.join(data_dir, "translated/en-{}-translated.tsv".format(language)) 62 | else: 63 | file_path = os.path.join(data_dir, "translated/test-{}-en-translated.tsv".format(language)) 64 | logger.info("reading from " + file_path) 65 | lines = self._read_tsv(file_path) 66 | for (i, line) in enumerate(lines): 67 | guid = "%s-%s-%s" % (split, language, i) 68 | text_a = line[0] 69 | text_b = line[1] 70 | label = str(line[2].strip()) 71 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 72 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, language=language)) 73 | return examples 74 | 75 | def get_train_examples(self, data_dir, language='en'): 76 | """See base class.""" 77 | return self.get_examples(data_dir, language, split='train') 78 | 79 | def get_translate_train_examples(self, data_dir, language='en'): 80 | """See base class.""" 81 | return self.get_translate_examples(data_dir, language, split='train') 82 | 83 | def get_translate_test_examples(self, data_dir, language='en'): 84 | """See base class.""" 85 | return self.get_translate_examples(data_dir, language, split='test') 86 | 87 | def get_test_examples(self, data_dir, language='en'): 88 | """See base class.""" 89 | return self.get_examples(data_dir, language, split='test') 90 | 91 | def get_dev_examples(self, data_dir, language='en'): 92 | """See base class.""" 93 | return self.get_examples(data_dir, language, split='dev') 94 | 95 | def get_labels(self): 96 | """See base class.""" 97 | return ["0", "1"] 98 | 99 | 100 | pawsx_processors = { 101 | "pawsx": PawsxProcessor, 102 | } 103 | 104 | pawsx_output_modes = { 105 | "pawsx": "classification", 106 | } 107 | 108 | pawsx_tasks_num_labels = { 109 | "pawsx": 2, 110 | } 111 | -------------------------------------------------------------------------------- /third_party/processors/xnli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XNLI utils (dataset loading and evaluation) """ 17 | 18 | 19 | import logging 20 | import os 21 | 22 | from transformers import DataProcessor 23 | from .utils import InputExample 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class XnliProcessor(DataProcessor): 29 | """Processor for the XNLI dataset. 30 | Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207""" 31 | 32 | def __init__(self): 33 | pass 34 | 35 | def get_examples(self, data_dir, language='en', split='train'): 36 | """See base class.""" 37 | examples = [] 38 | for lg in language.split(','): 39 | lines = self._read_tsv(os.path.join(data_dir, "{}-{}.tsv".format(split, lg))) 40 | 41 | for (i, line) in enumerate(lines): 42 | if i == 0: 43 | continue 44 | guid = "%s-%s-%s" % (split, lg, i) 45 | text_a = line[0] 46 | text_b = line[1] 47 | if split == 'test' and len(line) != 3: 48 | label = "neutral" 49 | else: 50 | label = str(line[2].strip()) 51 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 52 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, language=lg)) 53 | return examples 54 | 55 | def get_train_examples(self, data_dir, language='en'): 56 | return self.get_examples(data_dir, language, split='train') 57 | 58 | def get_dev_examples(self, data_dir, language='en'): 59 | return self.get_examples(data_dir, language, split='dev') 60 | 61 | def get_test_examples(self, data_dir, language='en'): 62 | return self.get_examples(data_dir, language, split='test') 63 | 64 | def get_translate_train_examples(self, data_dir, language='en'): 65 | """See base class.""" 66 | examples = [] 67 | for lg in language.split(','): 68 | file_path = os.path.join(data_dir, "XNLI-Translated/en-{}-translated.tsv".format(lg)) 69 | logger.info("reading file from " + file_path) 70 | lines = self._read_tsv(file_path) 71 | for (i, line) in enumerate(lines): 72 | guid = "%s-%s-%s" % ("translate-train", lg, i) 73 | text_a = line[0] 74 | text_b = line[1] 75 | label = "contradiction" if line[2].strip() == "contradictory" else line[2].strip() 76 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 77 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, language=lg)) 78 | return examples 79 | 80 | def get_translate_test_examples(self, data_dir, language='en'): 81 | lg = language 82 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-Translated/test-{}-en-translated.tsv".format(lg))) 83 | examples = [] 84 | for (i, line) in enumerate(lines): 85 | guid = "%s-%s-%s" % ("translate-test", language, i) 86 | text_a = line[0] 87 | text_b = line[1] 88 | label = "contradiction" if line[2].strip() == "contradictory" else line[2].strip() 89 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 90 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, language=language)) 91 | return examples 92 | 93 | def get_pseudo_test_examples(self, data_dir, language='en'): 94 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-Translated/pseudo-test-set/en-{}-pseudo-translated.csv".format(language))) 95 | examples = [] 96 | for (i, line) in enumerate(lines): 97 | guid = "%s-%s-%s" % ("pseudo-test", language, i) 98 | text_a = line[0] 99 | text_b = line[1] 100 | label = "contradiction" if line[2].strip() == "contradictory" else line[2].strip() 101 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 102 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, language=language)) 103 | return examples 104 | 105 | def get_labels(self): 106 | """See base class.""" 107 | return ["contradiction", "entailment", "neutral"] 108 | 109 | 110 | xnli_processors = { 111 | "xnli": XnliProcessor, 112 | } 113 | 114 | xnli_output_modes = { 115 | "xnli": "classification", 116 | } 117 | 118 | xnli_tasks_num_labels = { 119 | "xnli": 3, 120 | } 121 | -------------------------------------------------------------------------------- /third_party/processors/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import csv 3 | import json 4 | import logging 5 | from transformers import XLMTokenizer 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class InputExample(object): 11 | """ 12 | A single training/test example for simple sequence classification. 13 | Args: 14 | guid: Unique id for the example. 15 | text_a: string. The untokenized text of the first sequence. For single 16 | sequence tasks, only this sequence must be specified. 17 | text_b: (Optional) string. The untokenized text of the second sequence. 18 | Only must be specified for sequence pair tasks. 19 | label: (Optional) string. The label of the example. This should be 20 | specified for train and dev examples, but not for test examples. 21 | """ 22 | 23 | def __init__(self, guid, text_a, text_b=None, label=None, language=None): 24 | self.guid = guid 25 | self.text_a = text_a 26 | self.text_b = text_b 27 | self.label = label 28 | self.language = language 29 | 30 | def __repr__(self): 31 | return str(self.to_json_string()) 32 | 33 | def to_dict(self): 34 | """Serializes this instance to a Python dictionary.""" 35 | output = copy.deepcopy(self.__dict__) 36 | return output 37 | 38 | def to_json_string(self): 39 | """Serializes this instance to a JSON string.""" 40 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 41 | 42 | 43 | class InputFeatures(object): 44 | """ 45 | A single set of features of data. 46 | Args: 47 | input_ids: Indices of input sequence tokens in the vocabulary. 48 | attention_mask: Mask to avoid performing attention on padding token indices. 49 | Mask values selected in ``[0, 1]``: 50 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 51 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 52 | label: Label corresponding to the input 53 | """ 54 | 55 | def __init__(self, input_ids, attention_mask=None, token_type_ids=None, langs=None, label=None): 56 | self.input_ids = input_ids 57 | self.attention_mask = attention_mask 58 | self.token_type_ids = token_type_ids 59 | self.label = label 60 | self.langs = langs 61 | 62 | def __repr__(self): 63 | return str(self.to_json_string()) 64 | 65 | def to_dict(self): 66 | """Serializes this instance to a Python dictionary.""" 67 | output = copy.deepcopy(self.__dict__) 68 | return output 69 | 70 | def to_json_string(self): 71 | """Serializes this instance to a JSON string.""" 72 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 73 | 74 | 75 | def convert_examples_to_features( 76 | examples, 77 | tokenizer, 78 | max_length=512, 79 | label_list=None, 80 | output_mode=None, 81 | pad_on_left=False, 82 | pad_token=0, 83 | pad_token_segment_id=0, 84 | mask_padding_with_zero=True, 85 | lang2id=None, 86 | ): 87 | """ 88 | Loads a data file into a list of ``InputFeatures`` 89 | Args: 90 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples. 91 | tokenizer: Instance of a tokenizer that will tokenize the examples 92 | max_length: Maximum example length 93 | task: GLUE task 94 | label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method 95 | output_mode: String indicating the output mode. Either ``regression`` or ``classification`` 96 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) 97 | pad_token: Padding token 98 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4) 99 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values 100 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for 101 | actual values) 102 | Returns: 103 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset`` 104 | containing the task-specific features. If the input is a list of ``InputExamples``, will return 105 | a list of task-specific ``InputFeatures`` which can be fed to the model. 106 | """ 107 | # is_tf_dataset = False 108 | # if is_tf_available() and isinstance(examples, tf.data.Dataset): 109 | # is_tf_dataset = True 110 | 111 | label_map = {label: i for i, label in enumerate(label_list)} 112 | 113 | features = [] 114 | for (ex_index, example) in enumerate(examples): 115 | if ex_index % 10000 == 0: 116 | logger.info("Writing example %d" % (ex_index)) 117 | # if is_tf_dataset: 118 | # example = processor.get_example_from_tensor_dict(example) 119 | # example = processor.tfds_map(example) 120 | 121 | if isinstance(tokenizer, XLMTokenizer): 122 | inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length, lang=example.language) 123 | else: 124 | inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length) 125 | 126 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] 127 | 128 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 129 | # tokens are attended to. 130 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 131 | 132 | # Zero-pad up to the sequence length. 133 | padding_length = max_length - len(input_ids) 134 | if pad_on_left: 135 | input_ids = ([pad_token] * padding_length) + input_ids 136 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 137 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 138 | else: 139 | input_ids = input_ids + ([pad_token] * padding_length) 140 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 141 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 142 | 143 | if lang2id is not None: 144 | lid = lang2id.get(example.language, lang2id["en"]) 145 | else: 146 | lid = 0 147 | langs = [lid] * max_length 148 | 149 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) 150 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format( 151 | len(attention_mask), max_length 152 | ) 153 | assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format( 154 | len(token_type_ids), max_length 155 | ) 156 | 157 | if output_mode == "classification": 158 | label = label_map[example.label] 159 | elif output_mode == "regression": 160 | label = float(example.label) 161 | else: 162 | raise KeyError(output_mode) 163 | 164 | if ex_index < 5: 165 | logger.info("*** Example ***") 166 | logger.info("guid: %s" % (example.guid)) 167 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 168 | logger.info("sentence: %s" % " ".join(tokenizer.convert_ids_to_tokens(input_ids))) 169 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 170 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) 171 | logger.info("label: %s (id = %d)" % (example.label, label)) 172 | logger.info("language: %s, (lid = %d)" % (example.language, lid)) 173 | 174 | features.append( 175 | InputFeatures( 176 | input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, langs=langs, label=label 177 | ) 178 | ) 179 | return features 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright The XTREME Benchmark Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Evaluation.""" 16 | 17 | import argparse 18 | from seqeval.metrics import precision_score, recall_score, f1_score 19 | import sys 20 | 21 | from third_party.evaluate_squad import evaluate as squad_eval 22 | from third_party.evaluate_mlqa import evaluate as mlqa_eval 23 | 24 | def read_tag(file): 25 | labels = [] 26 | example = [] 27 | for line in open(file, 'r'): 28 | items = line.strip().split('\t') 29 | if len(items) == 2: 30 | example.append(items[1].strip()) 31 | else: 32 | labels.append(example) 33 | example = [] 34 | return labels 35 | 36 | 37 | def read_label(file): 38 | return [l.strip() for l in open(file)] 39 | 40 | 41 | def read_squad(file): 42 | expected_version = '1.1' 43 | with open(file) as dataset_file: 44 | dataset_json = json.load(dataset_file) 45 | if 'version' in dataset_json and dataset_json['version'] != expected_version: 46 | print('Evaluation expects v-' + expected_version, 47 | ', but got dataset with v-' + dataset_json['version'], 48 | file=sys.stderr) 49 | if 'data' in dataset_json: 50 | return dataset_json['data'] 51 | else: 52 | return dataset_json 53 | 54 | 55 | def f1(labels, predictions, language=None): 56 | f1 = f1_score(labels, predictions) 57 | precision = precision_score(labels, predictions) 58 | recall = recall_score(labels, predictions) 59 | return {'f1': f1, 'precision': precision, 'recall': recall} 60 | 61 | 62 | def accuracy(labels, predictions, language=None): 63 | correct = sum([int(p == l) for p, l in zip(predictions, labels)]) 64 | accuracy = float(correct) / len(predictions) 65 | return {'accuracy': accuracy} 66 | 67 | 68 | def squad_em_f1(labels, predictions, language=None): 69 | return squad_eval(labels, predictions) 70 | 71 | def mlqa_em_f1(labels, predictions, language): 72 | if language is None: 73 | print('required 2-char language code for the argument `language`') 74 | exit(0) 75 | return mlqa_eval(labels, predictions) 76 | 77 | 78 | GROUP2TASK = { 79 | "classification": ["pawsx", "xnli"], 80 | "tagging": ["udpos", "panx"], 81 | "qa": ["xquad", "mlqa", "tydiqa"], 82 | "retrieval": ["bucc2018", "tatoeba"], 83 | } 84 | 85 | 86 | TASK2LANGS = { 87 | "pawsx": "de,en,es,fr,ja,ko,zh".split(","), 88 | "xnli": "ar,bg,de,el,en,es,fr,hi,ru,sw,th,tr,ur,vi,zh".split(","), 89 | "panx": "ar,he,vi,id,jv,ms,tl,eu,ml,ta,te,af,nl,en,de,el,bn,hi,mr,ur,fa,fr,it,pt,es,bg,ru,ja,ka,ko,th,sw,yo,my,zh,kk,tr,et,fi,hu".split(","), 90 | "udpos": "af,ar,bg,de,el,en,es,et,eu,fa,fi,fr,he,hi,hu,id,it,ja,kk,ko,mr,nl,pt,ru,ta,te,th,tl,tr,ur,vi,yo,zh".split(","), 91 | "bucc2018": "de,fr,ru,zh".split(","), 92 | "tatoeba": "ar,he,vi,id,jv,tl,eu,ml,ta,te,af,nl,en,de,el,bn,hi,mr,ur,fa,fr,it,pt,es,bg,ru,ja,ka,ko,th,sw,zh,kk,tr,et,fi,hu".split(","), 93 | "xquad": "en,es,de,el,ru,tr,ar,vi,th,zh,hi".split(","), 94 | "mlqa": "en,es,de,ar,hi,vi,zh".split(","), 95 | "tydiqa": "en,ar,bn,fi,id,ko,ru,sw,te".split(","), 96 | } 97 | 98 | 99 | READER_FUNCTION = { 100 | 'pawsx': read_label, 101 | 'xnli': read_label, 102 | 'panx': read_tag, 103 | 'udpos': read_tag, 104 | 'bucc2018': read_label, 105 | 'tatoeba': read_label, 106 | 'xquad': read_squad, 107 | 'mlqa': read_squad, 108 | 'tydiqa': read_squad, 109 | } 110 | 111 | 112 | METRIC_FUNCTION = { 113 | 'pawsx': accuracy, 114 | 'xnli': accuracy, 115 | 'panx': f1, 116 | 'udpos': f1, 117 | 'bucc2018': f1, 118 | 'tatoeba': accuracy, 119 | 'xquad': squad_em_f1, 120 | 'mlqa': mlqa_em_f1, 121 | 'tydiqa': squad_em_f1, 122 | } 123 | 124 | 125 | def evaluate_one_task(prediction_file, label_file, task, language=None): 126 | """Evalute the classification tasks by accuracy. 127 | Args: 128 | prediction_file (string): path to the prediction tsv file. 129 | label_file (string): path to the grouth truth tsv file. 130 | Return: 131 | result (dict): a dictionary with accuracy. 132 | 133 | Both input files contain one example per line as follows: 134 | ``[label]\t[sentence1]\t[sentence2]`` 135 | """ 136 | predictions = READER_FUNCTION[task](prediction_file) 137 | labels = READER_FUNCTION[task](label_file) 138 | assert len(predictions) == len(labels), 'Number of examples in {} and {} not matched'.format(prediction_file, label_file) 139 | result = METRIC_FUNCTION[task](labels, predictions) 140 | return result 141 | 142 | 143 | def evaluate(prediction_folder, label_folder): 144 | """Evaluate on all tasks if available. 145 | Args: 146 | prediction_folder (string): prediction folder that contains each task's prediction in each subfolder. 147 | label_file (string): label folder that contains each task's ground-truth label in each subfolder. 148 | Return: 149 | overall_scores (dict): a dictionary with sub-group scores. key: group label. 150 | detailed_scores (dict): a dictionary with all detailed scores. key: task label. 151 | """ 152 | prediction_tasks = next(os.walk(prediction_folder))[1] 153 | label_tasks = next(os.walk(label_folder))[1] 154 | 155 | detailed_scores = {} 156 | for task, langs in TASK2LANGS.items(): 157 | if task in prediction_tasks and task in label_tasks: 158 | suffix = "json" if task in GROUP2TASK["qa"] else "tsv" 159 | # collect scores over all languages 160 | score = defaultdict(dict) 161 | for lg in langs: 162 | prediction_file = os.path.join(prediction_folder, task, f"test-{lg}.{suffix}") 163 | label_file = os.path.join(label_folder, task, f"test-{lg}.{suffix}") 164 | score_lg = evaluate_one_task(prediction_file, label_file, task, language=lg) 165 | for metric in score_lg.items(): 166 | score[metric][lg] = score_lg[metric] 167 | # average over all languages 168 | for m in score: 169 | score[f'avg_{m}'] = sum(score[m].values()) / len(score[m]) 170 | if task in GROUP2TASK["qa"]: 171 | score['avg_metric'] = (score['avg_exact_match'] + score['avg_f1']) / 2 172 | elif 'avg_f1' in score: 173 | score['avg_metric'] = score['avg_f1'] 174 | elif 'avg_accuracy' in score: 175 | score['avg_metric'] = score['avg_accuracy'] 176 | detailed_scores[task] = score 177 | 178 | # Display logic: 179 | # If scores of all tasks in a sub group are available, show the score in the sub table 180 | overall_scores = {} 181 | for group in GROUP2TASK: 182 | if all(task in detailed_scores for task in group): 183 | overall_scores[group] = sum(detailed_scores[task]['avg_metric'] for task in group) / len(group) 184 | 185 | # If scores of all tasks are available, show the overall score in the main table 186 | all_tasks = list(TASK2LANGS.keys()) 187 | if all(task in all_tasks for task in detailed_scores): 188 | overall_scores['all_task'] = sum(detailed_scores[task]['avg_metric'] for task in all_tasks) / len(all_tasks) 189 | 190 | return overall_scores, detailed_scores 191 | 192 | 193 | if __name__ == '__main__': 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument('--prediction_folder', default=None, type=str, required=True, 196 | help='the predictions of one model') 197 | parser.add_argument('--label_folder', default=None, type=str, required=True, 198 | help='the grouth truth file') 199 | args = parser.parse_args() 200 | overall_scores, detailed_scores = evaluate(args.prediction_folder, args.label_folder) 201 | print(overall_scores) 202 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | REPO=$PWD 2 | DIR=$REPO/download/ 3 | mkdir -p $DIR 4 | 5 | # download XNLI dataset 6 | function download_xnli { 7 | OUTPATH=$DIR/xnli-tmp/ 8 | if [ ! -d $OUTPATH/XNLI-MT-1.0 ]; then 9 | if [ ! -f $OUTPATH/XNLI-MT-1.0.zip ]; then 10 | wget -c https://dl.fbaipublicfiles.com/XNLI/XNLI-MT-1.0.zip -P $OUTPATH -q --show-progress 11 | fi 12 | unzip -qq $OUTPATH/XNLI-MT-1.0.zip -d $OUTPATH 13 | fi 14 | if [ ! -d $OUTPATH/XNLI-1.0 ]; then 15 | if [ ! -f $OUTPATH/XNLI-1.0.zip ]; then 16 | wget -c https://dl.fbaipublicfiles.com/XNLI/XNLI-1.0.zip -P $OUTPATH -q --show-progress 17 | fi 18 | unzip -qq $OUTPATH/XNLI-1.0.zip -d $OUTPATH 19 | fi 20 | python $REPO/utils_preprocess.py \ 21 | --data_dir $OUTPATH \ 22 | --output_dir $DIR/xnli/ \ 23 | --task xnli 24 | rm -rf $OUTPATH 25 | echo "Successfully ownload data at $DIR/xnli" >> $DIR/download.log 26 | } 27 | 28 | # download PAWS-X dataset 29 | function download_pawsx { 30 | cd $DIR 31 | wget https://storage.googleapis.com/paws/pawsx/x-final.tar.gz -q --show-progress 32 | tar xzf x-final.tar.gz -C $DIR/ 33 | python $REPO/utils_preprocess.py \ 34 | --data_dir $DIR/x-final \ 35 | --output_dir $DIR/pawsx/ \ 36 | --task pawsx 37 | rm -rf x-final x-final.tar.gz 38 | echo "Successfully download data at $DIR/pawsx" >> $DIR/download.log 39 | } 40 | 41 | # download UD-POS dataset 42 | function download_udpos { 43 | base_dir=$DIR/udpos-tmp 44 | out_dir=$base_dir/conll/ 45 | mkdir -p $out_dir 46 | cd $base_dir 47 | curl -s --remote-name-all https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3105/ud-treebanks-v2.5.tgz 48 | tar -xzf $base_dir/ud-treebanks-v2.5.tgz 49 | 50 | langs=(af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh) 51 | for x in $base_dir/ud-treebanks-v2.5/*/*.conllu; do 52 | file="$(basename $x)" 53 | IFS='_' read -r -a array <<< "$file" 54 | lang=${array[0]} 55 | if [[ " ${langs[@]} " =~ " ${lang} " ]]; then 56 | lang_dir=$out_dir/$lang/ 57 | mkdir -p $lang_dir 58 | y=$lang_dir/${file/conllu/conll} 59 | if [ ! -f "$y" ]; then 60 | echo "python $REPO/third_party/ud-conversion-tools/conllu_to_conll.py $x $y --lang $lang --replace_subtokens_with_fused_forms --print_fused_forms" 61 | python $REPO/third_party/ud-conversion-tools/conllu_to_conll.py $x $y --lang $lang --replace_subtokens_with_fused_forms --print_fused_forms 62 | else 63 | echo "${y} exists" 64 | fi 65 | fi 66 | done 67 | 68 | python $REPO/utils_preprocess.py --data_dir $out_dir/ --output_dir $DIR/udpos/ --task udpos 69 | rm -rf $out_dir ud-treebanks-v2.tgz $DIR/udpos-tmp 70 | echo "Successfully ownload data at $DIR/udpos" >> $DIR/download.log 71 | } 72 | 73 | function download_panx { 74 | echo "Download panx NER dataset" 75 | if [ -f $DIR/AmazonPhotos.zip ]; then 76 | unzip -qq $DIR/AmazonPhotos.zip -d $DIR/ 77 | base_dir=$DIR/panx_dataset/ && cd $base_dir 78 | langs=(ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu) 79 | for lg in ${langs[@]}; do 80 | tar xzf $base_dir/${lg}.tar.gz 81 | for f in dev test train; do mv $base_dir/$f $base_dir/${lg}-${f}; done 82 | done 83 | python $REPO/utils_preprocess.py \ 84 | --data_dir $base_dir \ 85 | --output_dir $DIR/panx \ 86 | --task panx 87 | rm -rf $base_dir 88 | echo "Successfully download data at $DIR/panx" >> $DIR/download.log 89 | else 90 | echo "Please download the AmazonPhotos.zip file on Amazon Cloud Drive mannually and save it to $DIR/AmazonPhotos.zip" 91 | echo "https://www.amazon.com/clouddrive/share/d3KGCRCIYwhKJF0H3eWA26hjg2ZCRhjpEQtDL70FSBN" 92 | fi 93 | } 94 | 95 | function download_tatoeba { 96 | base_dir=$DIR/tatoeba-tmp/ 97 | wget https://github.com/facebookresearch/LASER/archive/master.zip 98 | unzip -qq -o master.zip -d $base_dir/ 99 | mv $base_dir/LASER-master/data/tatoeba/v1/* $base_dir/ 100 | python $REPO/utils_preprocess.py \ 101 | --data_dir $base_dir \ 102 | --output_dir $DIR/tatoeba \ 103 | --task tatoeba 104 | rm -rf $base_dir master.zip 105 | echo "Successfully ownload data at $DIR/tatoeba" >> $DIR/download.log 106 | } 107 | 108 | function download_bucc18 { 109 | base_dir=$DIR/bucc2018/ 110 | cd $DIR 111 | for lg in zh ru de fr; do 112 | wget https://comparable.limsi.fr/bucc2018/bucc2018-${lg}-en.training-gold.tar.bz2 -q --show-progress 113 | tar -xjf bucc2018-${lg}-en.training-gold.tar.bz2 114 | wget https://comparable.limsi.fr/bucc2018/bucc2018-${lg}-en.sample-gold.tar.bz2 -q --show-progress 115 | tar -xjf bucc2018-${lg}-en.sample-gold.tar.bz2 116 | done 117 | mv $base_dir/*/* $base_dir/ 118 | for f in $base_dir/*training*; do mv $f ${f/training/test}; done 119 | for f in $base_dir/*sample*; do mv $f ${f/sample/dev}; done 120 | rm -rf $base_dir/*test.gold $DIR/bucc2018*tar.bz2 $base_dir/{zh,ru,de,fr}-en/ 121 | echo "Successfully download data at $DIR/bucc2018" >> $DIR/download.log 122 | } 123 | 124 | 125 | function download_squad { 126 | echo "download squad" 127 | base_dir=$DIR/squad/ 128 | mkdir -p $base_dir && cd $base_dir 129 | wget https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-v1.1.json -q --show-progress 130 | wget https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v1.1.json -q --show-progress 131 | # Download the SQuAD evaluation script (used for XQuAD and TyDiQA-GoldP) 132 | wget https://raw.githubusercontent.com/allenai/bi-att-flow/master/squad/evaluate-v1.1.py -q --show-progress 133 | echo "Successfully ownload data at $DIR/squad" >> $DIR/download.log 134 | } 135 | 136 | function download_xquad { 137 | echo "download xquad" 138 | base_dir=$DIR/xquad/ 139 | mkdir -p $base_dir && cd $base_dir 140 | for lang in ar de el en es hi ru th tr vi zh; do 141 | wget https://raw.githubusercontent.com/deepmind/xquad/master/xquad.${lang}.json -q --show-progress 142 | done 143 | python $REPO/utils_preprocess.py --data_dir $base_dir --output_dir $base_dir --task xquad 144 | echo "Successfully download data at $DIR/xquad" >> $DIR/download.log 145 | } 146 | 147 | function download_mlqa { 148 | echo "download mlqa" 149 | base_dir=$DIR/mlqa/ 150 | mkdir -p $base_dir && cd $base_dir 151 | zip_file=MLQA_V1.zip 152 | wget https://dl.fbaipublicfiles.com/MLQA/${zip_file} -q --show-progress 153 | unzip -qq ${zip_file} 154 | rm ${zip_file} 155 | # Download the MLQA evaluation script 156 | wget https://raw.githubusercontent.com/facebookresearch/MLQA/master/mlqa_evaluation_v1.py -q --show-progress 157 | python $REPO/utils_preprocess.py --data_dir $base_dir/MLQA_V1/test --output_dir $base_dir --task mlqa 158 | echo "Successfully download data at $DIR/mlqa" >> $DIR/download.log 159 | } 160 | 161 | function download_tydiqa { 162 | echo "download tydiqa-goldp" 163 | base_dir=$DIR/tydiqa/ 164 | mkdir -p $base_dir && cd $base_dir 165 | tydiqa_train_file=tydiqa-goldp-v1.1-train.json 166 | tydiqa_dev_file=tydiqa-goldp-v1.1-dev.tgz 167 | wget https://storage.googleapis.com/tydiqa/v1.1/${tydiqa_train_file} -q --show-progress 168 | wget https://storage.googleapis.com/tydiqa/v1.1/${tydiqa_dev_file} -q --show-progress 169 | tar -xf ${tydiqa_dev_file} 170 | rm ${tydiqa_dev_file} 171 | out_dir=$base_dir/tydiqa-goldp-v1.1-train 172 | python $REPO/utils_preprocess.py --data_dir $base_dir --output_dir $out_dir --task tydiqa 173 | mv $base_dir/$tydiqa_train_file $out_dir/ 174 | echo "Successfully Download data at $DIR/tydiqa" >> $DIR/download.log 175 | } 176 | 177 | download_xnli 178 | download_pawsx 179 | download_tatoeba 180 | download_bucc18 181 | download_squad 182 | download_xquad 183 | download_mlqa 184 | download_tydiqa 185 | download_udpos 186 | download_panx 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XTREME: A Massively Multilingual Multi-task Benchmark for Evaluating Cross-lingual Generalization 2 | 3 | [**Tasks**](#tasks-and-languages) | [**Download**](#download-the-data) | 4 | [**Baselines**](#build-a-baseline-system) | 5 | [**Leaderboard**](#leaderboard-submission) | 6 | [**Website**](https://ai.google.com/research/xtreme) | 7 | [**Paper**](https://arxiv.org/pdf/2003.11080.pdf) 8 | 9 | This repository contains information about XTREME, code for downloading data, and 10 | implementations of baseline systems for the benchmark. 11 | 12 | # Introduction 13 | 14 | The Cross-lingual TRansfer Evaluation of Multilingual Encoders (XTREME) benchmark is a benchmark for the evaluation of the cross-lingual generalization ability of pre-trained multilingual models. It covers 40 typologically diverse languages (spanning 12 language families) and includes nine tasks that collectively require reasoning about different levels of syntax and semantics. The languages in XTREME are selected to maximize language diversity, coverage in existing tasks, and availability of training data. Among these are many under-studied languages, such as the Dravidian languages Tamil (spoken in southern India, Sri Lanka, and Singapore), Telugu and Malayalam (spoken mainly in southern India), and the Niger-Congo languages Swahili and Yoruba, spoken in Africa. 15 | 16 | For a full description of the benchmark, see [the paper](https://arxiv.org/abs/2003.11080). 17 | 18 | # Tasks and Languages 19 | 20 | The tasks included in XTREME cover a range of standard paradigms in natural language processing, including sentence classification, structured prediction, sentence retrieval and question answering. The full list of tasks can be seen in the image below. 21 | 22 | ![The datasets used in XTREME](xtreme_score.png) 23 | 24 | In order for models to be successful on the XTREME benchmark, they must learn representations that generalize across many tasks and languages. Each of the tasks covers a subset of the 40 languages included in XTREME (shown here with their ISO 639-1 codes): af, ar, bg, bn, de, el, en, es, et, eu, fa, fi, fr, he, hi, hu, id, it, ja, jv, ka, kk, ko, ml, mr, ms, my, nl, pt, ru, sw, ta, te, th, tl, tr, ur, vi, yo, and zh. The languages were selected among the top 100 languages with the [most Wikipedia articles](https://meta.wikimedia.org/wiki/List_of_Wikipedias) to maximize language diversity, task coverage, and availability of training data. They include members of the Afro-Asiatic, Austro-Asiatic, Austronesian, Dravidian, Indo-European, Japonic, Kartvelian, Kra-Dai, Niger-Congo, Sino-Tibetan, Turkic, and Uralic language families as well as of two isolates, Basque and Korean. 25 | 26 | # Download the data 27 | 28 | In order to run experiments on XTREME, the first step is to download the dependencies. We assumed you have installed [`anaconda`](https://www.anaconda.com/) and use Python 3.7+. The additional requirements including `transformers`, `seqeval` (for sequence labelling evaluation), `tensorboardx`, `jieba`, `kytea`, and `pythainlp` (for text segmentation in Chinese, Japanese, and Thai), and `sacremoses` can be installed by running the following script: 29 | ``` 30 | bash install_tools.sh 31 | ``` 32 | 33 | The next step is to download the data. To this end, first create a `download` folder with ```mkdir -p download``` in the root of this project. You then need to manually download `panx_dataset` (for NER) from [here](https://www.amazon.com/clouddrive/share/d3KGCRCIYwhKJF0H3eWA26hjg2ZCRhjpEQtDL70FSBN) (note that it will download as `AmazonPhotos.zip`) to the `download` directory. Finally, run the following command to download the remaining datasets: 34 | ``` 35 | bash scripts/download_data.sh 36 | ``` 37 | 38 | # Build a baseline system 39 | 40 | The evaluation setting in XTREME is zero-shot cross-lingual transfer from English. We fine-tune models that were pre-trained on multilingual data on the labelled data of each XTREME task in English. Each fine-tuned model is then applied to the test data of the same task in other languages to obtain predictions. 41 | 42 | For every task, we provide a single script `scripts/train.sh` that fine-tunes pre-trained models implemented in the [Transformers] repo. To fine-tune a different model, simply pass different `MODEL` argument to the script with the corresponding model, where the current supported models are `bert-base-multilingual-cased`, `xlm-mlm-100-1280` and `xlm-roberta-large`. 43 | 44 | ## Universal dependencies part-of-speech tagging 45 | 46 | For part-of-speech tagging, we use data from the Universal Dependencies v2.5. You can fine-tune a pre-trained multilingual model on the English POS tagging data with the following command: 47 | ``` 48 | bash scripts/train.sh [MODEL] udpos 49 | ``` 50 | 51 | ## Wikiann named entity recognition 52 | 53 | For named entity recognition (NER), we use data from the Wikiann (panx) dataset. You can fine-tune a pre-trained multilingual model on the English NER data with the following command: 54 | ``` 55 | bash scripts/train.sh [MODEL] panx 56 | ``` 57 | 58 | ## PAXS-X sentence classification 59 | 60 | For sentence classification, we use the Cross-lingual Paraphrase Adversaries from Word Scrambling (PAWS-X) dataset. You can fine-tune a pre-trained multilingual model on the English PAWS data with the following command: 61 | ``` 62 | bash scripts/train.sh [MODEL] pawsx 63 | ``` 64 | 65 | ## XNLI sentence classification 66 | 67 | The second sentence classification dataset is the Cross-lingual Natural Language Inference (XNLI) dataset. You can fine-tune a pre-trained multilingual model on the English MNLI data with the following command: 68 | ``` 69 | bash scripts/train.sh [MODEL] xnli 70 | ``` 71 | 72 | ## XQuAD, MLQA, TyDiQA-GoldP question answering 73 | 74 | For question answering, we use the data from the XQuAD, MLQA, and TyDiQA-Gold Passage datasets. 75 | For XQuAD and MLQA, the model should be trained on the English SQuAD training set. For TyDiQA-Gold Passage, the model is trained on the English TyDiQA-GoldP training set. Using the following command, you can first fine-tune a pre-trained multilingual model on the corresponding English training data, and then you can obtain predictions on the test data of all tasks. 76 | ``` 77 | bash scripts/train.sh [MODEL] [xquad,mlqa,tydiqa] 78 | ``` 79 | 80 | ## BUCC sentence retrieval 81 | 82 | For cross-lingual sentence retrieval, we use the data from the Building and Using Parallel Corpora (BUCC) shared task. As the models are not trained for this task but the representations of the pre-trained models are directly used to obtain similarity judgements, you can directly apply the model to obtain predictions on the test data of the task: 83 | ``` 84 | bash scripts/train.sh [MODEL] bucc2018 85 | ``` 86 | 87 | ## Tatoeba sentence retrieval 88 | 89 | The second cross-lingual sentence retrieval dataset we use, is the Tatoeba dataset. Similarly to BUCC, you can directly apply the model to obtain predictions on the test data of the task: 90 | ``` 91 | bash scripts/train.sh [MODEL] tatoeba 92 | ``` 93 | 94 | # Leaderboard Submission 95 | 96 | ## Submissions 97 | To submit your predicitons to [**XTREME**](https://ai.google.com/research/xtreme), please create one single folder that contains 9 sub-folders named after all the tasks, i.e., `udpos`, `panx`, `xnli`, `pawsx`, `xquad`, `mlqa`, `tydiqa`, `bucc2018`, `tatoeba`. Inside each sub-folder, create a file containing the prediction label of the test set for all languages, and name the file using the format `test-{language}.{extension}` where `language` indicates the 2-character language code, and `extension` is `json` for QA tasks and `tsv` for other tasks. 98 | 99 | ## Evaluation 100 | We will compare your submissions with our label files using the following command: 101 | ``` 102 | python evaluate.py --prediction_folder [path] --label_folder [path] 103 | ``` 104 | 105 | # Paper 106 | 107 | If you use our benchmark or the code in this repo, please cite our paper. 108 | ``` 109 | @article{hu2020xtreme, 110 | author = {Junjie Hu and Sebastian Ruder and Aditya Siddhant and Graham Neubig and Orhan Firat and Melvin Johnson}, 111 | title = {XTREME: A Massively Multilingual Multi-task Benchmark for Evaluating Cross-lingual Generalization}, 112 | journal = {CoRR}, 113 | volume = {abs/2003.11080}, 114 | year = {2020}, 115 | archivePrefix = {arXiv}, 116 | eprint = {2003.11080} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /conda-env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.tar.bz2 6 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.tar.bz2 7 | https://conda.anaconda.org/anaconda/linux-64/ca-certificates-2020.1.1-0.tar.bz2 8 | https://repo.anaconda.com/pkgs/main/linux-64/cudatoolkit-10.0.130-0.tar.bz2 9 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2019.4-243.tar.bz2 10 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-7.3.0-hdf63c60_0.tar.bz2 11 | https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-9.1.0-hdf63c60_0.tar.bz2 12 | https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-9.1.0-hdf63c60_0.tar.bz2 13 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2019.4-243.tar.bz2 14 | https://conda.anaconda.org/anaconda/linux-64/expat-2.2.6-he6710b0_0.tar.bz2 15 | https://conda.anaconda.org/anaconda/linux-64/gmp-6.1.2-hb3b607b_0.tar.bz2 16 | https://conda.anaconda.org/anaconda/linux-64/icu-58.2-h211956c_0.tar.bz2 17 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.tar.bz2 18 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.2.1-hd88cf55_4.tar.bz2 19 | https://conda.anaconda.org/anaconda/linux-64/libsodium-1.0.16-h1bed415_0.tar.bz2 20 | https://conda.anaconda.org/anaconda/linux-64/libuuid-1.0.3-h1bed415_2.tar.bz2 21 | https://conda.anaconda.org/anaconda/linux-64/libxcb-1.13-h1bed415_1.tar.bz2 22 | https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.1-he6710b0_1.tar.bz2 23 | https://conda.anaconda.org/anaconda/linux-64/openssl-1.1.1-h7b6447c_0.tar.bz2 24 | https://conda.anaconda.org/anaconda/linux-64/pcre-8.43-he6710b0_0.tar.bz2 25 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.2.4-h14c3975_4.tar.bz2 26 | https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.11-h7b6447c_3.tar.bz2 27 | https://conda.anaconda.org/anaconda/linux-64/glib-2.56.2-hd408876_0.tar.bz2 28 | https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20181209-hc058e9b_0.tar.bz2 29 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.37-hbc83047_0.tar.bz2 30 | https://conda.anaconda.org/anaconda/linux-64/libxml2-2.9.9-hea5a465_1.tar.bz2 31 | https://conda.anaconda.org/anaconda/linux-64/pandoc-2.2.3.2-0.tar.bz2 32 | https://repo.anaconda.com/pkgs/main/linux-64/readline-7.0-h7b6447c_5.tar.bz2 33 | https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.8-hbc83047_0.tar.bz2 34 | https://conda.anaconda.org/anaconda/linux-64/zeromq-4.3.1-he6710b0_3.tar.bz2 35 | https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.3.7-h0b5b093_0.tar.bz2 36 | https://conda.anaconda.org/anaconda/linux-64/dbus-1.13.12-h746ee38_0.tar.bz2 37 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.9.1-h8a8886c_1.tar.bz2 38 | https://conda.anaconda.org/anaconda/linux-64/gstreamer-1.14.0-hb453b48_1.tar.bz2 39 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.1.0-h2733197_0.tar.bz2 40 | https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.30.1-h7b6447c_0.tar.bz2 41 | https://conda.anaconda.org/anaconda/linux-64/fontconfig-2.13.0-h9420a91_0.tar.bz2 42 | https://conda.anaconda.org/anaconda/linux-64/gst-plugins-base-1.14.0-hbbd80ab_1.tar.bz2 43 | https://repo.anaconda.com/pkgs/main/linux-64/python-3.7.5-h0371630_0.tar.bz2 44 | https://conda.anaconda.org/anaconda/noarch/attrs-19.3.0-py_0.tar.bz2 45 | https://conda.anaconda.org/anaconda/linux-64/backcall-0.1.0-py37_0.tar.bz2 46 | https://conda.anaconda.org/anaconda/linux-64/certifi-2019.11.28-py37_0.tar.bz2 47 | https://conda.anaconda.org/anaconda/noarch/decorator-4.4.1-py_0.tar.bz2 48 | https://conda.anaconda.org/anaconda/noarch/defusedxml-0.6.0-py_0.tar.bz2 49 | https://conda.anaconda.org/anaconda/linux-64/entrypoints-0.3-py37_0.tar.bz2 50 | https://conda.anaconda.org/anaconda/linux-64/ipython_genutils-0.2.0-py37_0.tar.bz2 51 | https://conda.anaconda.org/anaconda/linux-64/markupsafe-1.1.1-py37h7b6447c_0.tar.bz2 52 | https://conda.anaconda.org/anaconda/linux-64/mistune-0.8.4-py37h7b6447c_0.tar.bz2 53 | https://conda.anaconda.org/anaconda/noarch/more-itertools-8.0.2-py_0.tar.bz2 54 | https://repo.anaconda.com/pkgs/main/linux-64/ninja-1.9.0-py37hfd86e86_0.tar.bz2 55 | https://repo.anaconda.com/pkgs/main/noarch/olefile-0.46-py_0.tar.bz2 56 | https://conda.anaconda.org/anaconda/linux-64/pandocfilters-1.4.2-py37_1.tar.bz2 57 | https://conda.anaconda.org/anaconda/noarch/parso-0.5.2-py_0.tar.bz2 58 | https://conda.anaconda.org/anaconda/linux-64/pickleshare-0.7.5-py37_0.tar.bz2 59 | https://conda.anaconda.org/anaconda/noarch/prometheus_client-0.7.1-py_0.tar.bz2 60 | https://conda.anaconda.org/anaconda/linux-64/ptyprocess-0.6.0-py37_0.tar.bz2 61 | https://repo.anaconda.com/pkgs/main/noarch/pycparser-2.19-py_0.tar.bz2 62 | https://repo.anaconda.com/pkgs/main/noarch/pytz-2019.3-py_0.tar.bz2 63 | https://conda.anaconda.org/anaconda/linux-64/pyzmq-18.1.0-py37he6710b0_0.tar.bz2 64 | https://conda.anaconda.org/anaconda/linux-64/qt-5.9.7-h5867ecd_1.tar.bz2 65 | https://conda.anaconda.org/anaconda/linux-64/send2trash-1.5.0-py37_0.tar.bz2 66 | https://conda.anaconda.org/anaconda/linux-64/sip-4.19.13-py37he6710b0_0.tar.bz2 67 | https://repo.anaconda.com/pkgs/main/linux-64/six-1.13.0-py37_0.tar.bz2 68 | https://conda.anaconda.org/anaconda/noarch/testpath-0.4.4-py_0.tar.bz2 69 | https://conda.anaconda.org/anaconda/linux-64/tornado-6.0.3-py37h7b6447c_0.tar.bz2 70 | https://conda.anaconda.org/anaconda/linux-64/wcwidth-0.1.7-py37_0.tar.bz2 71 | https://conda.anaconda.org/anaconda/linux-64/webencodings-0.5.1-py37_1.tar.bz2 72 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.13.2-py37h2e261b9_0.tar.bz2 73 | https://conda.anaconda.org/anaconda/linux-64/jedi-0.15.1-py37_0.tar.bz2 74 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.3.0-py37he904b0f_0.tar.bz2 75 | https://conda.anaconda.org/anaconda/linux-64/networkx-1.11-py37_1.tar.bz2 76 | https://conda.anaconda.org/anaconda/linux-64/pexpect-4.7.0-py37_0.tar.bz2 77 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-6.2.1-py37h34e0f95_0.tar.bz2 78 | https://conda.anaconda.org/anaconda/linux-64/pyqt-5.9.2-py37h22d08a2_1.tar.bz2 79 | https://conda.anaconda.org/anaconda/linux-64/pyrsistent-0.15.6-py37h7b6447c_0.tar.bz2 80 | https://conda.anaconda.org/anaconda/noarch/python-dateutil-2.8.1-py_0.tar.bz2 81 | https://repo.anaconda.com/pkgs/main/linux-64/setuptools-42.0.2-py37_0.tar.bz2 82 | https://conda.anaconda.org/anaconda/linux-64/terminado-0.8.3-py37_0.tar.bz2 83 | https://conda.anaconda.org/anaconda/linux-64/traitlets-4.3.3-py37_0.tar.bz2 84 | https://conda.anaconda.org/anaconda/noarch/zipp-0.6.0-py_0.tar.bz2 85 | https://conda.anaconda.org/anaconda/noarch/bleach-3.1.0-py_0.tar.bz2 86 | https://conda.anaconda.org/anaconda/linux-64/importlib_metadata-1.3.0-py37_0.tar.bz2 87 | https://conda.anaconda.org/anaconda/noarch/jinja2-2.10.3-py_0.tar.bz2 88 | https://conda.anaconda.org/anaconda/linux-64/jupyter_core-4.6.1-py37_0.tar.bz2 89 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.17.4-py37hde5b4d6_0.tar.bz2 90 | https://conda.anaconda.org/anaconda/noarch/pygments-2.5.2-py_0.tar.bz2 91 | https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.33.6-py37_0.tar.bz2 92 | https://conda.anaconda.org/anaconda/linux-64/jsonschema-3.2.0-py37_0.tar.bz2 93 | https://conda.anaconda.org/anaconda/linux-64/jupyter_client-5.3.4-py37_0.tar.bz2 94 | https://repo.anaconda.com/pkgs/main/linux-64/pip-19.3.1-py37_0.tar.bz2 95 | https://conda.anaconda.org/anaconda/noarch/prompt_toolkit-3.0.2-py_0.tar.bz2 96 | https://conda.anaconda.org/anaconda/linux-64/ipython-7.10.2-py37h39e3cac_0.tar.bz2 97 | https://conda.anaconda.org/anaconda/linux-64/nbformat-4.4.0-py37_0.tar.bz2 98 | https://conda.anaconda.org/anaconda/linux-64/ipykernel-5.1.3-py37h39e3cac_0.tar.bz2 99 | https://conda.anaconda.org/anaconda/linux-64/nbconvert-5.6.1-py37_0.tar.bz2 100 | https://conda.anaconda.org/anaconda/linux-64/jupyter_console-5.2.0-py37_1.tar.bz2 101 | https://conda.anaconda.org/anaconda/linux-64/notebook-6.0.2-py37_0.tar.bz2 102 | https://conda.anaconda.org/anaconda/noarch/qtconsole-4.6.0-py_0.tar.bz2 103 | https://conda.anaconda.org/anaconda/linux-64/widgetsnbextension-3.5.1-py37_0.tar.bz2 104 | https://conda.anaconda.org/anaconda/noarch/ipywidgets-7.5.1-py_0.tar.bz2 105 | https://conda.anaconda.org/anaconda/linux-64/jupyter-1.0.0-py37_7.tar.bz2 106 | https://conda.anaconda.org/pytorch/linux-64/faiss-gpu-1.6.0-py37h1a5d453_0.tar.bz2 107 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.0.15-py37ha843d7b_0.tar.bz2 108 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.1.0-py37hd6b4f25_0.tar.bz2 109 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.17.4-py37hc1035e2_0.tar.bz2 110 | https://repo.anaconda.com/pkgs/main/linux-64/pandas-0.25.3-py37he6710b0_0.tar.bz2 111 | https://conda.anaconda.org/pytorch/linux-64/pytorch-1.3.1-py3.7_cuda10.0.130_cudnn7.6.3_0.tar.bz2 112 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.4.2-py37_cu100.tar.bz2 113 | -------------------------------------------------------------------------------- /third_party/utils_tag.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors, 3 | # The HuggingFace Inc. team, and The XTREME Benchmark Authors. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Utility functions for NER/POS tagging tasks.""" 18 | 19 | from __future__ import absolute_import, division, print_function 20 | 21 | import logging 22 | import os 23 | from io import open 24 | from transformers import XLMTokenizer 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class InputExample(object): 30 | """A single training/test example for token classification.""" 31 | 32 | def __init__(self, guid, words, labels, langs=None): 33 | """Constructs a InputExample. 34 | 35 | Args: 36 | guid: Unique id for the example. 37 | words: list. The words of the sequence. 38 | labels: (Optional) list. The labels for each word of the sequence. This should be 39 | specified for train and dev examples, but not for test examples. 40 | """ 41 | self.guid = guid 42 | self.words = words 43 | self.labels = labels 44 | self.langs = langs 45 | 46 | 47 | class InputFeatures(object): 48 | """A single set of features of data.""" 49 | 50 | def __init__(self, input_ids, input_mask, segment_ids, label_ids, langs=None): 51 | self.input_ids = input_ids 52 | self.input_mask = input_mask 53 | self.segment_ids = segment_ids 54 | self.label_ids = label_ids 55 | self.langs = langs 56 | 57 | 58 | def read_examples_from_file(file_path, lang, lang2id=None): 59 | if not os.path.exists(file_path): 60 | logger.info("[Warming] file {} not exists".format(file_path)) 61 | return [] 62 | guid_index = 1 63 | examples = [] 64 | subword_len_counter = 0 65 | if lang2id: 66 | lang_id = lang2id.get(lang, lang2id['en']) 67 | else: 68 | lang_id = 0 69 | logger.info("lang_id={}, lang={}, lang2id={}".format(lang_id, lang, lang2id)) 70 | with open(file_path, encoding="utf-8") as f: 71 | words = [] 72 | labels = [] 73 | langs = [] 74 | for line in f: 75 | if line.startswith("-DOCSTART-") or line == "" or line == "\n": 76 | if word: 77 | examples.append(InputExample(guid="{}-{}".format(lang, guid_index), 78 | words=words, 79 | labels=labels, 80 | langs=langs)) 81 | guid_index += 1 82 | words = [] 83 | labels = [] 84 | langs = [] 85 | subword_len_counter = 0 86 | else: 87 | print(f'guid_index', guid_index, words, langs, labels, subword_len_counter) 88 | else: 89 | splits = line.split("\t") 90 | word = splits[0] 91 | 92 | words.append(splits[0]) 93 | langs.append(lang_id) 94 | if len(splits) > 1: 95 | labels.append(splits[-1].replace("\n", "")) 96 | else: 97 | # Examples could have no label for mode = "test" 98 | labels.append("O") 99 | if words: 100 | examples.append(InputExample(guid="%s-%d".format(lang, guid_index), 101 | words=words, 102 | labels=labels, 103 | langs=langs)) 104 | return examples 105 | 106 | def convert_examples_to_features(examples, 107 | label_list, 108 | max_seq_length, 109 | tokenizer, 110 | cls_token_at_end=False, 111 | cls_token="[CLS]", 112 | cls_token_segment_id=1, 113 | sep_token="[SEP]", 114 | sep_token_extra=False, 115 | pad_on_left=False, 116 | pad_token=0, 117 | pad_token_segment_id=0, 118 | pad_token_label_id=-1, 119 | sequence_a_segment_id=0, 120 | mask_padding_with_zero=True, 121 | lang='en'): 122 | """ Loads a data file into a list of `InputBatch`s 123 | `cls_token_at_end` define the location of the CLS token: 124 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 125 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 126 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 127 | """ 128 | 129 | label_map = {label: i for i, label in enumerate(label_list)} 130 | 131 | features = [] 132 | for (ex_index, example) in enumerate(examples): 133 | if ex_index % 10000 == 0: 134 | logger.info("Writing example %d of %d", ex_index, len(examples)) 135 | 136 | tokens = [] 137 | label_ids = [] 138 | for word, label in zip(example.words, example.labels): 139 | if isinstance(tokenizer, XLMTokenizer): 140 | word_tokens = tokenizer.tokenize(word, lang=lang) 141 | else: 142 | word_tokens = tokenizer.tokenize(word) 143 | if len(word) != 0 and len(word_tokens) == 0: 144 | word_tokens = [tokenizer.unk_token] 145 | tokens.extend(word_tokens) 146 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 147 | label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1)) 148 | 149 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 150 | special_tokens_count = 3 if sep_token_extra else 2 151 | if len(tokens) > max_seq_length - special_tokens_count: 152 | print('truncate token', len(tokens), max_seq_length, special_tokens_count) 153 | tokens = tokens[:(max_seq_length - special_tokens_count)] 154 | label_ids = label_ids[:(max_seq_length - special_tokens_count)] 155 | 156 | # The convention in BERT is: 157 | # (a) For sequence pairs: 158 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 159 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 160 | # (b) For single sequences: 161 | # tokens: [CLS] the dog is hairy . [SEP] 162 | # type_ids: 0 0 0 0 0 0 0 163 | # 164 | # Where "type_ids" are used to indicate whether this is the first 165 | # sequence or the second sequence. The embedding vectors for `type=0` and 166 | # `type=1` were learned during pre-training and are added to the wordpiece 167 | # embedding vector (and position vector). This is not *strictly* necessary 168 | # since the [SEP] token unambiguously separates the sequences, but it makes 169 | # it easier for the model to learn the concept of sequences. 170 | 171 | tokens += [sep_token] 172 | label_ids += [pad_token_label_id] 173 | if sep_token_extra: 174 | # roberta uses an extra separator b/w pairs of sentences 175 | tokens += [sep_token] 176 | label_ids += [pad_token_label_id] 177 | segment_ids = [sequence_a_segment_id] * len(tokens) 178 | 179 | if cls_token_at_end: 180 | tokens += [cls_token] 181 | label_ids += [pad_token_label_id] 182 | segment_ids += [cls_token_segment_id] 183 | else: 184 | tokens = [cls_token] + tokens 185 | label_ids = [pad_token_label_id] + label_ids 186 | segment_ids = [cls_token_segment_id] + segment_ids 187 | 188 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 189 | 190 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 191 | # tokens are attended to. 192 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 193 | 194 | # Zero-pad up to the sequence length. 195 | padding_length = max_seq_length - len(input_ids) 196 | if pad_on_left: 197 | input_ids = ([pad_token] * padding_length) + input_ids 198 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 199 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 200 | label_ids = ([pad_token_label_id] * padding_length) + label_ids 201 | else: 202 | input_ids += ([pad_token] * padding_length) 203 | input_mask += ([0 if mask_padding_with_zero else 1] * padding_length) 204 | segment_ids += ([pad_token_segment_id] * padding_length) 205 | label_ids += ([pad_token_label_id] * padding_length) 206 | 207 | if example.langs and len(example.langs) > 0: 208 | langs = [example.langs[0]] * max_seq_length 209 | else: 210 | print('example.langs', example.langs, example.words, len(example.langs)) 211 | print('ex_index', ex_index, len(examples)) 212 | langs = None 213 | 214 | assert len(input_ids) == max_seq_length 215 | assert len(input_mask) == max_seq_length 216 | assert len(segment_ids) == max_seq_length 217 | assert len(label_ids) == max_seq_length 218 | assert len(langs) == max_seq_length 219 | 220 | if ex_index < 5: 221 | logger.info("*** Example ***") 222 | logger.info("guid: %s", example.guid) 223 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 224 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 225 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 226 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 227 | logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) 228 | logger.info("langs: {}".format(langs)) 229 | 230 | features.append( 231 | InputFeatures(input_ids=input_ids, 232 | input_mask=input_mask, 233 | segment_ids=segment_ids, 234 | label_ids=label_ids, 235 | langs=langs)) 236 | return features 237 | 238 | 239 | def get_labels(path): 240 | with open(path, "r") as f: 241 | labels = f.read().splitlines() 242 | if "O" not in labels: 243 | labels = ["O"] + labels 244 | return labels 245 | -------------------------------------------------------------------------------- /third_party/utils_retrieve.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This repository is modified based on the LASER repository. 3 | # https://github.com/facebookresearch/LASER 4 | # Copyright The LASER Team Authors, and The XTREME Benchmark Authors. 5 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | """Utility functions for retrieval tasks.""" 19 | 20 | 21 | import os 22 | import sys 23 | import faiss 24 | import tempfile 25 | import numpy as np 26 | import faiss 27 | 28 | 29 | def knn(x, y, k, use_gpu, dist='cosine'): 30 | return knnGPU(x, y, k) if use_gpu else knnCPU(x, y, k, dist) 31 | 32 | 33 | def knnGPU(x, y, k, mem=5*1024*1024*1024): 34 | dim = x.shape[1] 35 | batch_size = mem // (dim*4) 36 | sim = np.zeros((x.shape[0], k), dtype=np.float32) 37 | ind = np.zeros((x.shape[0], k), dtype=np.int64) 38 | for xfrom in range(0, x.shape[0], batch_size): 39 | xto = min(xfrom + batch_size, x.shape[0]) 40 | bsims, binds = [], [] 41 | for yfrom in range(0, y.shape[0], batch_size): 42 | yto = min(yfrom + batch_size, y.shape[0]) 43 | print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto)) 44 | idx = faiss.IndexFlatIP(dim) 45 | idx = faiss.index_cpu_to_all_gpus(idx) 46 | idx.add(y[yfrom:yto]) 47 | bsim, bind = idx.search(x[xfrom:xto], min(k, yto-yfrom)) 48 | bsims.append(bsim) 49 | binds.append(bind + yfrom) 50 | del idx 51 | bsims = np.concatenate(bsims, axis=1) 52 | binds = np.concatenate(binds, axis=1) 53 | aux = np.argsort(-bsims, axis=1) 54 | for i in range(xfrom, xto): 55 | for j in range(k): 56 | sim[i, j] = bsims[i-xfrom, aux[i-xfrom, j]] 57 | ind[i, j] = binds[i-xfrom, aux[i-xfrom, j]] 58 | return sim, ind 59 | 60 | 61 | def knnCPU(x, y, k, dist='cosine'): 62 | # x: query, y: database 63 | dim = x.shape[1] 64 | if dist == 'cosine': 65 | idx = faiss.IndexFlatIP(dim) 66 | else: 67 | idx = faiss.IndexFlatL2(dim) 68 | idx.add(y) 69 | sim, ind = idx.search(x, k) 70 | 71 | if dist != 'cosine': 72 | sim = 1 / (1 + sim) 73 | return sim, ind 74 | 75 | 76 | def score(x, y, fwd_mean, bwd_mean, margin, dist='cosine'): 77 | if dist == 'cosine': 78 | return margin(x.dot(y), (fwd_mean + bwd_mean) / 2) 79 | else: 80 | l2 = ((x - y) ** 2).sum() 81 | sim = 1 / (1 + l2) 82 | return margin(sim, (fwd_mean + bwd_mean) / 2) 83 | 84 | 85 | def score_candidates(x, y, candidate_inds, fwd_mean, bwd_mean, margin, dist='cosine'): 86 | print(' - scoring {:d} candidates using {}'.format(x.shape[0], dist)) 87 | scores = np.zeros(candidate_inds.shape) 88 | for i in range(scores.shape[0]): 89 | for j in range(scores.shape[1]): 90 | k = candidate_inds[i, j] 91 | scores[i, j] = score(x[i], y[k], fwd_mean[i], bwd_mean[k], margin, dist) 92 | return scores 93 | 94 | 95 | def text_load_unify(fname, encoding, unify=True): 96 | print(' - loading texts {:s}: '.format(fname), end='') 97 | fin = open(fname, encoding=encoding, errors='surrogateescape') 98 | inds = [] 99 | sents = [] 100 | sent2ind = {} 101 | n = 0 102 | nu = 0 103 | for line in fin: 104 | new_ind = len(sent2ind) 105 | inds.append(sent2ind.setdefault(line, new_ind)) 106 | if unify: 107 | if inds[-1] == new_ind: 108 | sents.append(line[:-1]) 109 | nu += 1 110 | else: 111 | sents.append(line[:-1]) 112 | nu += 1 113 | n += 1 114 | print('{:d} lines, {:d} unique'.format(n, nu)) 115 | del sent2ind 116 | return inds, sents 117 | 118 | 119 | def unique_embeddings(emb, ind): 120 | aux = {j: i for i, j in enumerate(ind)} 121 | print(' - unify embeddings: {:d} -> {:d}'.format(len(emb), len(aux))) 122 | return emb[[aux[i] for i in range(len(aux))]] 123 | 124 | 125 | def shift_embeddings(x, y): 126 | print(' - shift embeddings') 127 | delta = x.mean(axis=0) - y.mean(axis=0) 128 | x2y = x - delta 129 | y2x = y + delta 130 | return x2y, y2x 131 | 132 | 133 | def mine_bitext(x, y, src_text_file, trg_text_file, output_file, mode='mine', 134 | retrieval='max', margin='ratio', threshold=0, 135 | neighborhood=4, use_gpu=False, encoding='utf-8', dist='cosine', use_shift_embeds=False): 136 | src_inds, src_sents = text_load_unify(src_text_file, encoding, True) 137 | trg_inds, trg_sents = text_load_unify(trg_text_file, encoding, True) 138 | 139 | x = unique_embeddings(x, src_inds) 140 | y = unique_embeddings(y, trg_inds) 141 | if dist == 'cosine': 142 | faiss.normalize_L2(x) 143 | faiss.normalize_L2(y) 144 | 145 | if use_shift_embeds: 146 | x2y, y2x = shift_embeddings(x, y) 147 | 148 | # calculate knn in both directions 149 | if retrieval is not 'bwd': 150 | print(' - perform {:d}-nn source against target, dist={}'.format(neighborhood, dist)) 151 | if use_shift_embeds: 152 | # project x to y space, and search k-nn ys for each x 153 | x2y_sim, x2y_ind = knn(x2y, y, min(y.shape[0], neighborhood), use_gpu, dist) 154 | x2y_mean = x2y_sim.mean(axis=1) 155 | else: 156 | x2y_sim, x2y_ind = knn(x, y, min(y.shape[0], neighborhood), use_gpu, dist) 157 | x2y_mean = x2y_sim.mean(axis=1) 158 | 159 | if retrieval is not 'fwd': 160 | print(' - perform {:d}-nn target against source, dist={}'.format(neighborhood, dist)) 161 | if use_shift_embeds: 162 | y2x_sim, y2x_ind = knn(y2x, x, min(x.shape[0], neighborhood), use_gpu, dist) 163 | y2x_mean = y2x_sim.mean(axis=1) 164 | else: 165 | y2x_sim, y2x_ind = knn(y, x, min(x.shape[0], neighborhood), use_gpu, dist) 166 | y2x_mean = y2x_sim.mean(axis=1) 167 | 168 | # margin function 169 | if margin == 'absolute': 170 | margin = lambda a, b: a 171 | elif margin == 'distance': 172 | margin = lambda a, b: a - b 173 | else: # margin == 'ratio': 174 | margin = lambda a, b: a / b 175 | 176 | fout = open(output_file, mode='w', encoding=encoding, errors='surrogateescape') 177 | 178 | if mode == 'search': 179 | print(' - Searching for closest sentences in target') 180 | print(' - writing alignments to {:s}'.format(output_file)) 181 | scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin) 182 | best = x2y_ind[np.arange(x.shape[0]), scores.argmax(axis=1)] 183 | 184 | nbex = x.shape[0] 185 | ref = np.linspace(0, nbex-1, nbex).astype(int) # [0, nbex) 186 | err = nbex - np.equal(best.reshape(nbex), ref).astype(int).sum() 187 | print(' - errors: {:d}={:.2f}%'.format(err, 100*err/nbex)) 188 | for i in src_inds: 189 | print(trg_sents[best[i]], file=fout) 190 | 191 | elif mode == 'score': 192 | for i, j in zip(src_inds, trg_inds): 193 | s = score(x[i], y[j], x2y_mean[i], y2x_mean[j], margin) 194 | print(s, src_sents[i], trg_sents[j], sep='\t', file=fout) 195 | 196 | elif mode == 'mine': 197 | print(' - mining for parallel data') 198 | if use_shift_embeds: 199 | fwd_scores = score_candidates(x2y, y, x2y_ind, x2y_mean, y2x_mean, margin) 200 | bwd_scores = score_candidates(y2x, x, y2x_ind, y2x_mean, x2y_mean, margin) 201 | else: 202 | fwd_scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin) 203 | bwd_scores = score_candidates(y, x, y2x_ind, y2x_mean, x2y_mean, margin) 204 | fwd_best = x2y_ind[np.arange(x.shape[0]), fwd_scores.argmax(axis=1)] 205 | bwd_best = y2x_ind[np.arange(y.shape[0]), bwd_scores.argmax(axis=1)] 206 | print(' - writing alignments to {:s}'.format(output_file)) 207 | if threshold > 0: 208 | print(' - with threshold of {:f}'.format(threshold)) 209 | if retrieval == 'fwd': 210 | for i, j in enumerate(fwd_best): 211 | print(fwd_scores[i].max(), src_sents[i], trg_sents[j], sep='\t', file=fout) 212 | if retrieval == 'bwd': 213 | for j, i in enumerate(bwd_best): 214 | print(bwd_scores[j].max(), src_sents[i], trg_sents[j], sep='\t', file=fout) 215 | if retrieval == 'intersect': 216 | for i, j in enumerate(fwd_best): 217 | if bwd_best[j] == i: 218 | print(fwd_scores[i].max(), src_sents[i], trg_sents[j], sep='\t', file=fout) 219 | if retrieval == 'max': 220 | indices = np.stack((np.concatenate((np.arange(x.shape[0]), bwd_best)), 221 | np.concatenate((fwd_best, np.arange(y.shape[0])))), axis=1) 222 | scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1))) 223 | seen_src, seen_trg = set(), set() 224 | for i in np.argsort(-scores): 225 | src_ind, trg_ind = indices[i] 226 | if not src_ind in seen_src and not trg_ind in seen_trg: 227 | seen_src.add(src_ind) 228 | seen_trg.add(trg_ind) 229 | if scores[i] > threshold: 230 | print(scores[i], src_sents[src_ind], trg_sents[trg_ind], sep='\t', file=fout) 231 | fout.close() 232 | 233 | 234 | def bucc_optimize(candidate2score, gold): 235 | items = sorted(candidate2score.items(), key=lambda x: -x[1]) 236 | ngold = len(gold) 237 | nextract = ncorrect = 0 238 | threshold = 0 239 | best_f1 = 0 240 | for i in range(len(items)): 241 | nextract += 1 242 | if '\t'.join(items[i][0]) in gold: 243 | ncorrect += 1 244 | if ncorrect > 0: 245 | precision = ncorrect / nextract 246 | recall = ncorrect / ngold 247 | f1 = 2 * precision * recall / (precision + recall) 248 | if f1 > best_f1: 249 | best_f1 = f1 250 | threshold = (items[i][1] + items[i + 1][1]) / 2 251 | return threshold 252 | 253 | 254 | def bucc_extract(cand2score, th, fname): 255 | if fname: 256 | of = open(fname, 'w', encoding=args.encoding) 257 | bitexts = [] 258 | for (src, trg), score in cand2score.items(): 259 | if score >= th: 260 | bitexts.append(src + '\t' + trg) 261 | if fname: 262 | of.write(src + '\t' + trg + '\n') 263 | if fname: 264 | of.close() 265 | return bitexts 266 | 267 | 268 | def read_sent2id(text_file, id_file, encoding='utf-8'): 269 | repeated = set() 270 | sent2id = {} 271 | with open(id_file, encoding=encoding, errors='surrogateescape') as f: 272 | ids = [l.strip() for l in f] 273 | with open(text_file, encoding=encoding, errors='surrogateescape') as f: 274 | sentences = [l.strip() for l in f] 275 | for id, sent in zip(ids, sentences): 276 | if sent in sent2id: 277 | repeated.add(sent) 278 | else: 279 | sent2id[sent] = id 280 | for sent in repeated: 281 | del sent2id[sent] 282 | return sent2id 283 | 284 | 285 | def read_candidate2score(candidates_file, src_text_file, trg_text_file, src_id_file, trg_id_file, encoding='utf-8'): 286 | print(' - reading sentences {}'.format(candidates_file)) 287 | src_sent2id = read_sent2id(src_text_file, src_id_file, encoding) 288 | trg_sent2id = read_sent2id(trg_text_file, trg_id_file, encoding) 289 | 290 | print(' - reading candidates {}'.format(candidates_file)) 291 | candidate2score = {} 292 | with open(candidates_file, encoding=encoding, errors='surrogateescape') as f: 293 | for line in f: 294 | score, src, trg = line.split('\t') 295 | score = float(score) 296 | src = src.strip() 297 | trg = trg.strip() 298 | if src in src_sent2id and trg in trg_sent2id: 299 | src_id = src_sent2id[src] 300 | trg_id = trg_sent2id[trg] 301 | score = max(score, candidate2score.get((src_id, trg_id), score)) 302 | candidate2score[(src_id, trg_id)] = score 303 | return candidate2score 304 | 305 | 306 | def bucc_eval(candidates_file, gold_file, src_file, trg_file, src_id_file, trg_id_file, predict_file, threshold=None, encoding='utf-8'): 307 | candidate2score = read_candidate2score(candidates_file, src_file, trg_file, src_id_file, trg_id_file, encoding) 308 | 309 | if threshold is not None and gold_file is None: 310 | print(' - using threshold {}'.format(threshold)) 311 | else: 312 | print(' - optimizing threshold on gold alignments {}'.format(gold_file)) 313 | gold = {line.strip() for line in open(gold_file)} 314 | threshold = bucc_optimize(candidate2score, gold) 315 | 316 | bitexts = bucc_extract(candidate2score, threshold, predict_file) 317 | if gold_file is not None: 318 | ncorrect = len(gold.intersection(bitexts)) 319 | if ncorrect > 0: 320 | precision = ncorrect / len(bitexts) 321 | recall = ncorrect / len(gold) 322 | f1 = 2*precision*recall / (precision + recall) 323 | else: 324 | precision = recall = f1 = 0 325 | 326 | print(' - best threshold={:f}: precision={:.2f}, recall={:.2f}, F1={:.2f}' 327 | .format(threshold, 100*precision, 100*recall, 100*f1)) 328 | return {'best-threshold': threshold, 'precision': 100*precision, 'recall': 100*recall, 'F1': 100*f1} 329 | else: 330 | return None 331 | 332 | 333 | def similarity_search(x, y, dim, normalize=False): 334 | num = x.shape[0] 335 | idx = faiss.IndexFlatL2(dim) 336 | if normalize: 337 | faiss.normalize_L2(x) 338 | faiss.normalize_L2(y) 339 | idx.add(x) 340 | scores, prediction = idx.search(y, 1) 341 | return prediction 342 | 343 | -------------------------------------------------------------------------------- /third_party/ud-conversion-tools/lib/conll.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from collections import Counter 3 | import re 4 | 5 | 6 | #TODO make these parse functions static methods of ConllReder 7 | def parse_id(id_str): 8 | if id_str == '_': 9 | return None 10 | if "." in id_str: 11 | return None 12 | ids = tuple(map(int, id_str.split("-"))) 13 | if len(ids) == 1: 14 | return ids[0] 15 | else: 16 | return ids 17 | 18 | def parse_feats(feats_str): 19 | if feats_str == '_': 20 | return {} 21 | feat_pairs = [pair.split("=") for pair in feats_str.split("|")] 22 | return {k: v for k, v in feat_pairs} 23 | 24 | def parse_deps(dep_str): 25 | if dep_str == '_': 26 | return [] 27 | dep_pairs = [pair.split(":") for pair in dep_str.split("|")] 28 | return [(int(pair[0]), pair[1]) for pair in dep_pairs if pair[0].isdigit()] 29 | 30 | 31 | 32 | 33 | class DependencyTree(nx.DiGraph): 34 | """ 35 | A DependencyTree as networkx graph: 36 | nodes store information about tokens 37 | edges store edge related info, e.g. dependency relations 38 | """ 39 | 40 | def __init__(self): 41 | nx.DiGraph.__init__(self) 42 | 43 | def pathtoroot(self, child): 44 | path = [] 45 | newhead = self.head_of(self, child) 46 | while newhead: 47 | path.append(newhead) 48 | newhead = self.head_of(self, newhead) 49 | return path 50 | 51 | def head_of(self, n): 52 | for u, v in self.edges(): 53 | if v == n: 54 | return u 55 | return None 56 | 57 | def get_sentence_as_string(self,printid=False): 58 | out = [] 59 | for token_i in range(1, max(self.nodes()) + 1): 60 | if printid: 61 | out.append(str(token_i)+":"+self.node[token_i]['form']) 62 | else: 63 | out.append(self.node[token_i]['form']) 64 | return u" ".join(out) 65 | 66 | def subsumes(self, head, child): 67 | if head in self.pathtoroot(self, child): 68 | return True 69 | 70 | def remove_arabic_diacritics(self): 71 | # The following code is based on nltk.stem.isri 72 | # It is equivalent to an interative application of isri.norm(word,num=1) 73 | # i.e. we do not remove any hamza characters 74 | 75 | re_short_vowels = re.compile(r'[\u064B-\u0652]') 76 | for n in self.nodes(): 77 | self.node[n]["form"] = re_short_vowels.sub('', self.node[n]["form"]) 78 | 79 | 80 | def get_highest_index_of_span(self, span): # retrieves the node index that is closest to root 81 | #TODO: CANDIDATE FOR DEPRECATION 82 | distancestoroot = [len(self.pathtoroot(self, x)) for x in span] 83 | shortestdistancetoroot = min(distancestoroot) 84 | spanhead = span[distancestoroot.index(shortestdistancetoroot)] 85 | return spanhead 86 | 87 | def get_deepest_index_of_span(self, span): # retrieves the node index that is farthest from root 88 | #TODO: CANDIDATE FOR DEPRECATION 89 | distancestoroot = [len(self.pathtoroot(self, x)) for x in span] 90 | longestdistancetoroot = max(distancestoroot) 91 | lownode = span[distancestoroot.index(longestdistancetoroot)] 92 | return lownode 93 | 94 | def span_makes_subtree(self, initidx, endidx): 95 | G = nx.DiGraph() 96 | span_nodes = list(range(initidx,endidx+1)) 97 | span_words = [self.node[x]["form"] for x in span_nodes] 98 | G.add_nodes_from(span_nodes) 99 | for h,d in self.edges(): 100 | if h in span_nodes and d in span_nodes: 101 | G.add_edge(h,d) 102 | return nx.is_tree(G) 103 | 104 | def _choose_spanhead_from_heuristics(self,span_nodes,pos_precedence_list): 105 | distancestoroot = [len(nx.ancestors(self,x)) for x in span_nodes] 106 | shortestdistancetoroot = min(distancestoroot) 107 | distance_counter = Counter(distancestoroot) 108 | 109 | highest_nodes_in_span = [] 110 | # Heuristic Nr 1: If there is one single highest node in the span, it becomes the head 111 | # N.B. no need for the subspan to be a tree if there is one single highest element 112 | if distance_counter[shortestdistancetoroot] == 1: 113 | spanhead = span_nodes[distancestoroot.index(shortestdistancetoroot)] 114 | return spanhead 115 | 116 | # Heuristic Nr 2: Choose by POS ranking the best head out of the highest nodes 117 | for x in span_nodes: 118 | if len(nx.ancestors(self,x)) == shortestdistancetoroot: 119 | highest_nodes_in_span.append(x) 120 | 121 | best_rank = len(pos_precedence_list) + 1 122 | candidate_head = - 1 123 | span_upos = [self.node[x]["cpostag"]for x in highest_nodes_in_span] 124 | for upos, idx in zip(span_upos,highest_nodes_in_span): 125 | if pos_precedence_list.index(upos) < best_rank: 126 | best_rank = pos_precedence_list.index(upos) 127 | candidate_head = idx 128 | return candidate_head 129 | 130 | def _remove_node_properties(self,fields): 131 | for n in sorted(self.nodes()): 132 | for fieldname in self.node[n].keys(): 133 | if fieldname in fields: 134 | self.node[n][fieldname]="_" 135 | 136 | def _remove_deprel_suffixes(self): 137 | for h,d in self.edges(): 138 | if ":" in self[h][d]["deprel"]: 139 | self[h][d]["deprel"]=self[h][d]["deprel"].split(":")[0] 140 | 141 | def _keep_fused_form(self,posPreferenceDicts): 142 | # For a span A,B and external tokens C, such as A > B > C, we have to 143 | # Make A the head of the span 144 | # Attach C-level tokens to A 145 | #Remove B-level tokens, which are the subtokens of the fused form della: de la 146 | 147 | if self.graph["multi_tokens"] == {}: 148 | return 149 | 150 | spanheads = [] 151 | spanhead_fused_token_dict = {} 152 | # This double iteration is overkill, one could skip the spanhead identification 153 | # but in this way we avoid modifying the tree as we read it 154 | for fusedform_idx in sorted(self.graph["multi_tokens"]): 155 | fusedform_start, fusedform_end = self.graph["multi_tokens"][fusedform_idx]["id"] 156 | fuseform_span = list(range(fusedform_start,fusedform_end+1)) 157 | spanhead = self._choose_spanhead_from_heuristics(fuseform_span,posPreferenceDicts) 158 | #if not spanhead: 159 | # spanhead = self._choose_spanhead_from_heuristics(fuseform_span,posPreferenceDicts) 160 | spanheads.append(spanhead) 161 | spanhead_fused_token_dict[spanhead] = fusedform_idx 162 | 163 | # try: 164 | # order = list(nx.topological_sort(self)) 165 | # except nx.NetworkXUnfeasible: 166 | # msg = 'Circular dependency detected between hooks' 167 | # problem_graph = ', '.join(f'{a} -> {b}' 168 | # for a, b in nx.find_cycle(self)) 169 | # print('nx.simple_cycles', list(nx.simple_cycles(self))) 170 | # print(problem_graph) 171 | # exit(0) 172 | # for edge in list(nx.simple_cycles(self)): 173 | # self.remove_edge(edge[0], edge[1]) 174 | self = remove_all_cycle(self) 175 | bottom_up_order = [x for x in nx.topological_sort(self) if x in spanheads] 176 | for spanhead in bottom_up_order: 177 | fusedform_idx = spanhead_fused_token_dict[spanhead] 178 | fusedform = self.graph["multi_tokens"][fusedform_idx]["form"] 179 | fusedform_start, fusedform_end = self.graph["multi_tokens"][fusedform_idx]["id"] 180 | fuseform_span = list(range(fusedform_start,fusedform_end+1)) 181 | 182 | if spanhead: 183 | #Step 1: Replace form of head span (A) with fusedtoken form -- in this way we keep the lemma and features if any 184 | self.node[spanhead]["form"] = fusedform 185 | # 2- Reattach C-level (external dependents) to A 186 | #print(fuseform_span,spanhead) 187 | 188 | internal_dependents = set(fuseform_span) - set([spanhead]) 189 | external_dependents = [nx.bfs_successors(self,x) for x in internal_dependents] 190 | for depdict in external_dependents: 191 | for localhead in depdict: 192 | for ext_dep in depdict[localhead]: 193 | if ext_dep in self[localhead]: 194 | deprel = self[localhead][ext_dep]["deprel"] 195 | self.remove_edge(localhead,ext_dep) 196 | self.add_edge(spanhead,ext_dep,deprel=deprel) 197 | 198 | #3- Remove B-level tokens 199 | for int_dep in internal_dependents: 200 | self.remove_edge(self.head_of(int_dep),int_dep) 201 | self.remove_node(int_dep) 202 | 203 | #4 reconstruct tree at the very end 204 | new_index_dict = {} 205 | for new_node_index, old_node_idex in enumerate(sorted(self.nodes())): 206 | new_index_dict[old_node_idex] = new_node_index 207 | 208 | T = DependencyTree() # Transfer DiGraph, to replace self 209 | 210 | for n in sorted(self.nodes()): 211 | T.add_node(new_index_dict[n],self.node[n]) 212 | 213 | for h, d in self.edges(): 214 | T.add_edge(new_index_dict[h],new_index_dict[d],deprel=self[h][d]["deprel"]) 215 | #4A Quick removal of edges and nodes 216 | self.__init__() 217 | 218 | #4B Rewriting the Deptree in Self 219 | # TODO There must a more elegant way to rewrite self -- self= T for instance? 220 | for n in sorted(T.nodes()): 221 | self.add_node(n,T.node[n]) 222 | 223 | for h,d in T.edges(): 224 | self.add_edge(h,d,T[h][d]) 225 | 226 | # 5. remove all fused forms form the multi_tokens field 227 | self.graph["multi_tokens"] = {} 228 | 229 | # if not nx.is_tree(self): 230 | # print("Not a tree after fused-form heuristics:",self.get_sentence_as_string()) 231 | 232 | def filter_sentence_content(self,replace_subtokens_with_fused_forms=False, lang=None, posPreferenceDict=None,node_properties_to_remove=None,remove_deprel_suffixes=False,remove_arabic_diacritics=False): 233 | if replace_subtokens_with_fused_forms: 234 | self._keep_fused_form(posPreferenceDict) 235 | if remove_deprel_suffixes: 236 | self._remove_deprel_suffixes() 237 | if node_properties_to_remove: 238 | self._remove_node_properties(node_properties_to_remove) 239 | if remove_arabic_diacritics: 240 | self.remove_arabic_diacritics() 241 | 242 | def remove_all_cycle(G): 243 | GC = nx.DiGraph(G.edges()) 244 | edges = list(nx.simple_cycles(GC)) 245 | for edge in edges: 246 | for i in range(len(edge)-1): 247 | for j in range(i+1, len(edge)): 248 | a, b = edge[i], edge[j] 249 | if G.has_edge(a, b): 250 | # print('remove {} - {}'.format(a, b)) 251 | G.remove_edge(a, b) 252 | return G 253 | 254 | 255 | class CoNLLReader(object): 256 | """ 257 | conll input/output 258 | """ 259 | 260 | "" "Static properties""" 261 | CONLL06_COLUMNS = [('id',int), ('form',str), ('lemma',str), ('cpostag',str), ('postag',str), ('feats',str), ('head',int), ('deprel',str), ('phead', str), ('pdeprel',str)] 262 | #CONLL06_COLUMNS = ['id', 'form', 'lemma', 'cpostag', 'postag', 'feats', 'head', 'deprel', 'phead', 'pdeprel'] 263 | CONLL06DENSE_COLUMNS = [('id',int), ('form',str), ('lemma',str), ('cpostag',str), ('postag',str), ('feats',str), ('head',int), ('deprel',str), ('edgew',str)] 264 | CONLL_U_COLUMNS = [('id', parse_id), ('form', str), ('lemma', str), ('cpostag', str), 265 | ('postag', str), ('feats', str), ('head', parse_id), ('deprel', str), 266 | ('deps', parse_deps), ('misc', str)] 267 | #CONLL09_COLUMNS = ['id','form','lemma','plemma','cpostag','pcpostag','feats','pfeats','head','phead','deprel','pdeprel'] 268 | 269 | 270 | 271 | def __init__(self): 272 | pass 273 | 274 | def read_conll_2006(self, filename): 275 | sentences = [] 276 | sent = DependencyTree() 277 | for line_num, conll_line in enumerate(open(filename)): 278 | parts = conll_line.strip().split("\t") 279 | if len(parts) in (8, 10): 280 | token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL06_COLUMNS, parts)} 281 | 282 | sent.add_node(token_dict['id'], token_dict) 283 | sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) 284 | elif len(parts) == 0 or (len(parts)==1 and parts[0]==""): 285 | sentences.append(sent) 286 | sent = DependencyTree() 287 | else: 288 | raise Exception("Invalid input format in line nr: ", line_num, conll_line, filename) 289 | 290 | return sentences 291 | 292 | def read_conll_2006_dense(self, filename): 293 | sentences = [] 294 | sent = DependencyTree() 295 | for conll_line in open(filename): 296 | parts = conll_line.strip().split("\t") 297 | if len(parts) == 9: 298 | token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL06DENSE_COLUMNS, parts)} 299 | 300 | sent.add_node(token_dict['id'], token_dict) 301 | sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) 302 | elif len(parts) == 0 or (len(parts)==1 and parts[0]==""): 303 | sentences.append(sent) 304 | sent = DependencyTree() 305 | else: 306 | raise Exception("Invalid input format in line: ", conll_line, filename) 307 | 308 | return sentences 309 | 310 | 311 | 312 | def write_conll(self, list_of_graphs, conll_path,conllformat, print_fused_forms=False,print_comments=False): 313 | # TODO add comment writing 314 | if conllformat == "conllu": 315 | columns = [colname for colname, fname in self.CONLL_U_COLUMNS] 316 | else: 317 | columns = [colname for colname, fname in self.CONLL06_COLUMNS] 318 | 319 | with conll_path.open('w') as out: 320 | for sent_i, sent in enumerate(list_of_graphs): 321 | if sent_i > 0: 322 | print("", file=out) 323 | if print_comments: 324 | for c in sent.graph["comment"]: 325 | print(c, file=out) 326 | for token_i in range(1, max(sent.nodes()) + 1): 327 | token_dict = dict(sent.node[token_i]) 328 | head_i = sent.head_of(token_i) 329 | if head_i is None: 330 | token_dict['head'] = 0 331 | token_dict['deprel'] = '' 332 | else: 333 | token_dict['head'] = head_i 334 | token_dict['deprel'] = sent[head_i][token_i]['deprel'] 335 | token_dict['id'] = token_i 336 | row = [str(token_dict.get(col, '_')) for col in columns] 337 | if print_fused_forms and token_i in sent.graph["multi_tokens"]: 338 | currentmulti = sent.graph["multi_tokens"][token_i] 339 | currentmulti["id"]=str(currentmulti["id"][0])+"-"+str(currentmulti["id"][1]) 340 | currentmulti["feats"]="_" 341 | currentmulti["head"]="_" 342 | rowmulti = [str(currentmulti.get(col, '_')) for col in columns] 343 | print(u"\t".join(rowmulti),file=out) 344 | print(u"\t".join(row), file=out) 345 | 346 | # emtpy line afterwards 347 | print(u"", file=out) 348 | 349 | 350 | def read_conll_u(self,filename,keepFusedForm=False, lang=None, posPreferenceDict=None): 351 | sentences = [] 352 | sent = DependencyTree() 353 | multi_tokens = {} 354 | 355 | for line_no, line in enumerate(open(filename).readlines()): 356 | line = line.strip("\n") 357 | if not line: 358 | # Add extra properties to ROOT node if exists 359 | if 0 in sent: 360 | for key in ('form', 'lemma', 'cpostag', 'postag'): 361 | sent.node[0][key] = 'ROOT' 362 | 363 | # Handle multi-tokens 364 | sent.graph['multi_tokens'] = multi_tokens 365 | multi_tokens = {} 366 | sentences.append(sent) 367 | sent = DependencyTree() 368 | elif line.startswith("#"): 369 | if 'comment' not in sent.graph: 370 | sent.graph['comment'] = [line] 371 | else: 372 | sent.graph['comment'].append(line) 373 | else: 374 | parts = line.split("\t") 375 | if len(parts) != len(self.CONLL_U_COLUMNS): 376 | error_msg = 'Invalid number of columns in line {} (found {}, expected {})'.format(line_no, len(parts), len(CONLL_U_COLUMNS)) 377 | raise Exception(error_msg) 378 | 379 | token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL_U_COLUMNS, parts)} 380 | if isinstance(token_dict['id'], int): 381 | sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) 382 | sent.node[token_dict['id']].update({k: v for (k, v) in token_dict.items() 383 | if k not in ('head', 'id', 'deprel', 'deps')}) 384 | for head, deprel in token_dict['deps']: 385 | sent.add_edge(head, token_dict['id'], deprel=deprel, secondary=True) 386 | elif token_dict['id'] is not None: 387 | #print(token_dict['id']) 388 | first_token_id = int(token_dict['id'][0]) 389 | multi_tokens[first_token_id] = token_dict 390 | return sentences 391 | -------------------------------------------------------------------------------- /third_party/run_retrieval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright The XTREME Benchmark Authors. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Use pre-trained models for retrieval tasks.""" 17 | 18 | 19 | import argparse 20 | 21 | import glob 22 | import logging 23 | import os 24 | import random 25 | 26 | import numpy as np 27 | import torch 28 | from torch.utils.data import DataLoader, TensorDataset 29 | from torch.utils.data import RandomSampler, SequentialSampler 30 | from tqdm import tqdm, trange 31 | 32 | from transformers import ( 33 | BertConfig, 34 | BertModel, 35 | BertTokenizer, 36 | XLMConfig, 37 | XLMModel, 38 | XLMTokenizer, 39 | XLMRobertaConfig, 40 | XLMRobertaTokenizer, 41 | XLMRobertaModel, 42 | ) 43 | from processors.utils import InputFeatures 44 | from utils_retrieve import mine_bitext, bucc_eval, similarity_search 45 | 46 | logger = logging.getLogger(__name__) 47 | 48 | ALL_MODELS = sum( 49 | (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLMConfig, XLMRobertaConfig)), () 50 | ) 51 | 52 | 53 | MODEL_CLASSES = { 54 | "bert": (BertConfig, BertModel, BertTokenizer), 55 | "xlm": (XLMConfig, XLMModel, XLMTokenizer), 56 | "xlmr": (XLMRobertaConfig, XLMRobertaModel, XLMRobertaTokenizer), 57 | } 58 | 59 | 60 | def load_embeddings(embed_file, num_sentences=None): 61 | logger.info(' loading from {}'.format(embed_file)) 62 | embeds = np.load(embed_file) 63 | return embeds 64 | 65 | 66 | def prepare_batch(sentences, tokenizer, model_type, device="cuda", max_length=512, lang='en', langid=None, use_local_max_length=True, pool_skip_special_token=False): 67 | pad_token = tokenizer.pad_token 68 | cls_token = tokenizer.cls_token 69 | sep_token = tokenizer.sep_token 70 | 71 | pad_token_id = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0] 72 | pad_token_segment_id = 0 73 | 74 | batch_input_ids = [] 75 | batch_token_type_ids = [] 76 | batch_attention_mask = [] 77 | batch_size = len(sentences) 78 | batch_pool_mask = [] 79 | 80 | local_max_length = min(max([len(s) for s in sentences]) + 2, max_length) 81 | if use_local_max_length: 82 | max_length = local_max_length 83 | 84 | for sent in sentences: 85 | 86 | if len(sent) > max_length - 2: 87 | sent = sent[: (max_length - 2)] 88 | input_ids = tokenizer.convert_tokens_to_ids([cls_token] + sent + [sep_token]) 89 | 90 | padding_length = max_length - len(input_ids) 91 | attention_mask = [1] * len(input_ids) + [0] * padding_length 92 | pool_mask = [0] + [1] * (len(input_ids) - 2) + [0] * (padding_length + 1) 93 | input_ids = input_ids + ([pad_token_id] * padding_length) 94 | 95 | batch_input_ids.append(input_ids) 96 | batch_attention_mask.append(attention_mask) 97 | batch_pool_mask.append(pool_mask) 98 | 99 | input_ids = torch.LongTensor(batch_input_ids).to(device) 100 | attention_mask = torch.LongTensor(batch_attention_mask).to(device) 101 | 102 | if pool_skip_special_token: 103 | pool_mask = torch.LongTensor(batch_pool_mask).to(device) 104 | else: 105 | pool_mask = attention_mask 106 | 107 | 108 | if model_type == "xlm": 109 | langs = torch.LongTensor([[langid] * max_length for _ in range(len(sentences))]).to(device) 110 | return {"input_ids": input_ids, "attention_mask": attention_mask, "langs": langs}, pool_mask 111 | elif model_type == 'bert' or model_type == 'xlmr': 112 | token_type_ids = torch.LongTensor([[0] * max_length for _ in range(len(sentences))]).to(device) 113 | return {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}, pool_mask 114 | 115 | 116 | def tokenize_text(text_file, tok_file, tokenizer, lang=None): 117 | if os.path.exists(tok_file): 118 | tok_sentences = [l.strip().split(' ') for l in open(tok_file)] 119 | logger.info(' -- loading from existing tok_file {}'.format(tok_file)) 120 | return tok_sentences 121 | 122 | tok_sentences = [] 123 | sents = [l.strip() for l in open(text_file)] 124 | with open(tok_file, 'w') as writer: 125 | for sent in tqdm(sents, desc='tokenize'): 126 | if isinstance(tokenizer, XLMTokenizer): 127 | tok_sent = tokenizer.tokenize(sent, lang=lang) 128 | else: 129 | tok_sent = tokenizer.tokenize(sent) 130 | tok_sentences.append(tok_sent) 131 | writer.write(' '.join(tok_sent) + '\n') 132 | logger.info(' -- save tokenized sentences to {}'.format(tok_file)) 133 | 134 | logger.info('============ First 5 tokenized sentences ===============') 135 | for i in range(5): 136 | logger.info('S{}: {}'.format(i, ' '.join(tok_sentences[i]))) 137 | logger.info('==================================') 138 | return tok_sentences 139 | 140 | 141 | def load_model(args): 142 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 143 | config = config_class.from_pretrained(args.model_name_or_path) 144 | config.output_hidden_states = True 145 | langid = config.lang2id.get(lang, config.lang2id["en"]) if args.model_type == 'xlm' else 0 146 | logger.info("langid={}, lang={}".format(langid, lang)) 147 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, do_lower_case=args.do_lower_case) 148 | logger.info("tokenizer.pad_token={}, pad_token_id={}".format(tokenizer.pad_token, tokenizer.pad_token_id)) 149 | model = model_class.from_pretrained(args.model_name_or_path, config=config) 150 | model.to(args.device) 151 | model.eval() 152 | return config, model, tokenizer 153 | 154 | 155 | def extract_embeddings(args, text_file, tok_file, embed_file, lang='en', pool_type='mean'): 156 | num_embeds = args.num_layers 157 | all_embed_files = ["{}_{}.npy".format(embed_file, i) for i in range(num_embeds)] 158 | if all(os.path.exists(f) for f in all_embed_files): 159 | logger.info('loading files from {}'.format(all_embed_files)) 160 | return [load_embeddings(f) for f in all_embed_files] 161 | 162 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 163 | config = config_class.from_pretrained(args.model_name_or_path) 164 | config.output_hidden_states = True 165 | langid = config.lang2id.get(lang, config.lang2id["en"]) if args.model_type == 'xlm' else 0 166 | logger.info("langid={}, lang={}".format(langid, lang)) 167 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, do_lower_case=args.do_lower_case) 168 | logger.info("tokenizer.pad_token={}, pad_token_id={}".format(tokenizer.pad_token, tokenizer.pad_token_id)) 169 | if args.init_checkpoint: 170 | model = model_class.from_pretrained(args.init_checkpoint, config=config, cache_dir=args.init_checkpoint) 171 | else: 172 | model = model_class.from_pretrained(args.model_name_or_path, config=config) 173 | model.to(args.device) 174 | model.eval() 175 | 176 | sent_toks = tokenize_text(text_file, tok_file, tokenizer, lang) 177 | max_length = max([len(s) for s in sent_toks]) 178 | logger.info('max length of tokenized text = {}'.format(max_length)) 179 | 180 | batch_size = args.batch_size 181 | num_batch = int(np.ceil(len(sent_toks) * 1.0 / batch_size)) 182 | num_sents = len(sent_toks) 183 | 184 | all_embeds = [np.zeros(shape=(num_sents, args.embed_size), dtype=np.float32) for _ in range(num_embeds)] 185 | for i in tqdm(range(num_batch), desc='Batch'): 186 | start_index = i * batch_size 187 | end_index = min((i + 1) * batch_size, num_sents) 188 | batch, pool_mask = prepare_batch(sent_toks[start_index: end_index], 189 | tokenizer, 190 | args.model_type, 191 | args.device, 192 | args.max_seq_length, 193 | lang=lang, 194 | langid=langid, 195 | pool_skip_special_token=args.pool_skip_special_token) 196 | 197 | with torch.no_grad(): 198 | outputs = model(**batch) 199 | 200 | if args.model_type == 'bert' or args.model_type == 'xlmr': 201 | last_layer_outputs, first_token_outputs, all_layer_outputs = outputs 202 | elif args.model_type == 'xlm': 203 | last_layer_outputs, all_layer_outputs = outputs 204 | first_token_outputs = last_layer_outputs[:,0] # first element of the last layer 205 | 206 | # get the pool embedding 207 | if pool_type == 'cls': 208 | all_batch_embeds = cls_pool_embedding(all_layer_outputs[-args.num_layers:]) 209 | else: 210 | all_batch_embeds = [] 211 | all_layer_outputs = all_layer_outputs[-args.num_layers:] 212 | all_batch_embeds.extend(mean_pool_embedding(all_layer_outputs, pool_mask)) 213 | 214 | for embeds, batch_embeds in zip(all_embeds, all_batch_embeds): 215 | embeds[start_index: end_index] = batch_embeds.cpu().numpy().astype(np.float32) 216 | del last_layer_outputs, first_token_outputs, all_layer_outputs 217 | torch.cuda.empty_cache() 218 | 219 | if embed_file is not None: 220 | for file, embeds in zip(all_embed_files, all_embeds): 221 | logger.info('save embed {} to file {}'.format(embeds.shape, file)) 222 | np.save(file, embeds) 223 | return all_embeds 224 | 225 | 226 | def mean_pool_embedding(all_layer_outputs, masks): 227 | """ 228 | Args: 229 | embeds: list of torch.FloatTensor, (B, L, D) 230 | masks: torch.FloatTensor, (B, L) 231 | Return: 232 | sent_emb: list of torch.FloatTensor, (B, D) 233 | """ 234 | sent_embeds = [] 235 | for embeds in all_layer_outputs: 236 | embeds = (embeds * masks.unsqueeze(2)).sum(dim=1) / masks.sum(dim=1).view(-1, 1) 237 | sent_embeds.append(embeds) 238 | return sent_embeds 239 | 240 | 241 | def cls_pool_embedding(all_layer_outputs): 242 | sent_embeds = [] 243 | for embeds in all_layer_outputs: 244 | embeds = embeds[:, 0, :] 245 | sent_embeds.append(embeds) 246 | return sent_embeds 247 | 248 | 249 | def concate_embedding(all_embeds, last_k): 250 | if last_k == 1: 251 | return all_embeds[-1] 252 | else: 253 | embeds = np.hstack(all_embeds[-last_k:]) # (B,D) 254 | return embeds 255 | 256 | 257 | def main(): 258 | parser = argparse.ArgumentParser(description='BUCC bitext mining') 259 | parser.add_argument('--encoding', default='utf-8', 260 | help='character encoding for input/output') 261 | parser.add_argument('--src_file', default=None, help='src file') 262 | parser.add_argument('--tgt_file', default=None, help='tgt file') 263 | parser.add_argument('--gold', default=None, 264 | help='File name of gold alignments') 265 | parser.add_argument('--threshold', type=float, default=-1, 266 | help='Threshold (used with --output)') 267 | parser.add_argument('--embed_size', type=int, default=768, 268 | help='Threshold (used with --output)') 269 | parser.add_argument('--pool_type', type=str, default='mean', 270 | help='pooling over work embeddings') 271 | 272 | # Required parameters 273 | parser.add_argument( 274 | "--data_dir", 275 | default=None, 276 | type=str, 277 | required=True, 278 | help="The input data dir. Should contain the input files for the task.", 279 | ) 280 | parser.add_argument( 281 | "--model_type", 282 | default=None, 283 | type=str, 284 | required=True, 285 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 286 | ) 287 | parser.add_argument( 288 | "--model_name_or_path", 289 | default=None, 290 | type=str, 291 | required=True, 292 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), 293 | ) 294 | parser.add_argument( 295 | "--init_checkpoint", 296 | default=None, 297 | type=str, 298 | help="Path to pre-trained model or shortcut name selected in the list" 299 | ) 300 | parser.add_argument("--src_language", type=str, default="en", help="source language.") 301 | parser.add_argument("--tgt_language", type=str, default="de", help="target language.") 302 | parser.add_argument("--batch_size", type=int, default=100, help="batch size.") 303 | parser.add_argument("--tgt_text_file", type=str, default=None, help="tgt_text_file.") 304 | parser.add_argument("--src_text_file", type=str, default=None, help="src_text_file.") 305 | parser.add_argument("--tgt_embed_file", type=str, default=None, help="tgt_embed_file") 306 | parser.add_argument("--src_embed_file", type=str, default=None, help="src_embed_file") 307 | parser.add_argument("--tgt_tok_file", type=str, default=None, help="tgt_tok_file") 308 | parser.add_argument("--src_tok_file", type=str, default=None, help="src_tok_file") 309 | parser.add_argument("--tgt_id_file", type=str, default=None, help="tgt_id_file") 310 | parser.add_argument("--src_id_file", type=str, default=None, help="src_id_file") 311 | parser.add_argument("--num_layers", type=int, default=12, help="num layers") 312 | parser.add_argument("--candidate_prefix", type=str, default="candidates") 313 | parser.add_argument("--pool_skip_special_token", action="store_true") 314 | parser.add_argument("--dist", type=str, default='cosine') 315 | parser.add_argument("--use_shift_embeds", action="store_true") 316 | parser.add_argument("--extract_embeds", action="store_true") 317 | parser.add_argument("--mine_bitext", action="store_true") 318 | parser.add_argument("--predict_dir", type=str, default=None, help="prediction folder") 319 | 320 | 321 | parser.add_argument( 322 | "--output_dir", 323 | default=None, 324 | type=str, 325 | required=True, 326 | help="The output directory", 327 | ) 328 | parser.add_argument("--log_file", default="train", type=str, help="log file") 329 | 330 | parser.add_argument( 331 | "--task_name", 332 | default="bucc2018", 333 | type=str, 334 | required=True, 335 | help="The task name", 336 | ) 337 | 338 | # Other parameters 339 | parser.add_argument( 340 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 341 | ) 342 | parser.add_argument( 343 | "--tokenizer_name", 344 | default="", 345 | type=str, 346 | help="Pretrained tokenizer name or path if not the same as model_name", 347 | ) 348 | parser.add_argument( 349 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." 350 | ) 351 | parser.add_argument( 352 | "--cache_dir", 353 | default="", 354 | type=str, 355 | help="Where do you want to store the pre-trained models downloaded from s3", 356 | ) 357 | parser.add_argument( 358 | "--max_seq_length", 359 | default=128, 360 | type=int, 361 | help="The maximum total input sequence length after tokenization. Sequences longer " 362 | "than this will be truncated, sequences shorter will be padded.", 363 | ) 364 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 365 | parser.add_argument( 366 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" 367 | ) 368 | parser.add_argument( 369 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 370 | ) 371 | parser.add_argument( 372 | "--unify", action="store_true", help="unify sentences" 373 | ) 374 | parser.add_argument("--split", type=str, default='training', help='split of the bucc dataset') 375 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 376 | parser.add_argument("--concate_layers", action="store_true", help="concate_layers") 377 | parser.add_argument("--specific_layer", type=int, default=7, help="use specific layer") 378 | args = parser.parse_args() 379 | 380 | logging.basicConfig(handlers=[logging.FileHandler(os.path.join(args.output_dir, args.log_file)), logging.StreamHandler()], 381 | format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 382 | datefmt = '%m/%d/%Y %H:%M:%S', 383 | level = logging.INFO) 384 | logging.info("Input args: %r" % args) 385 | 386 | # Setup CUDA, GPU 387 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 388 | args.n_gpu = torch.cuda.device_count() 389 | args.device = device 390 | 391 | # Setup logging 392 | logging.basicConfig( 393 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 394 | datefmt="%m/%d/%Y %H:%M:%S", 395 | level=logging.INFO, 396 | ) 397 | 398 | if args.task_name == 'bucc2018': 399 | best_threshold = None 400 | SL, TL = args.src_language, args.tgt_language 401 | for split in ['dev', 'test']: 402 | prefix = os.path.join(args.output_dir, f'{SL}-{TL}.{split}') 403 | if args.extract_embeds: 404 | for lang in [SL, TL]: 405 | extract_embeddings(args, f'{prefix}.{lang}.txt', f'{prefix}.{lang}.tok', f'{prefix}.{lang}.emb', lang=lang) 406 | 407 | if args.mine_bitext: 408 | num_layers = args.num_layers 409 | if args.specific_layer != -1: 410 | indices = [args.specific_layer] 411 | else: 412 | indices = list(range(num_layers)) 413 | 414 | for idx in indices: 415 | suf = str(idx) 416 | cand2score_file = os.path.join(args.output_dir, '{}_{}.tsv'.format(args.candidate_prefix, suf)) 417 | if os.path.exists(cand2score_file): 418 | logger.info('cand2score_file {} exists'.format(cand2score_file)) 419 | else: 420 | x = load_embeddings(f'{prefix}.{SL}.emb_{idx}.npy') 421 | y = load_embeddings(f'{prefix}.{TL}.emb_{idx}.npy') 422 | mine_bitext(x, y, f'{prefix}.{SL}.txt', f'{prefix}.{TL}.txt', cand2score_file, dist=args.dist, use_shift_embeds=args.use_shift_embeds) 423 | gold_file = f'{prefix}.gold' 424 | if os.path.exists(gold_file): 425 | predict_file = os.path.join(args.predict_dir, f'test-{SL}.tsv') 426 | results = bucc_eval(cand2score_file, gold_file, f'{prefix}.{SL}.txt', f'{prefix}.{TL}.txt', f'{prefix}.{SL}.id', f'{prefix}.id', predict_file, threshold) 427 | best_threshold = results['best-threshold'] 428 | logger.info('--Candidates: {}'.format(cand2score_file)) 429 | logger.info('index={} '.format(suf) + ' '.join('{}={:.4f}'.format(k,v) for k,v in results.items())) 430 | 431 | elif args.task_name == 'tatoeba': 432 | lang3_dict = {'ara':'ar', 'heb':'he', 'vie':'vi', 'ind':'id', 433 | 'jav':'jv', 'tgl':'tl', 'eus':'eu', 'mal':'ml', 'tam':'ta', 434 | 'tel':'te', 'afr':'af', 'nld':'nl', 'eng':'en', 'deu':'de', 435 | 'ell':'el', 'ben':'bn', 'hin':'hi', 'mar':'mr', 'urd':'ur', 436 | 'tam':'ta', 'fra':'fr', 'ita':'it', 'por':'pt', 'spa':'es', 437 | 'bul':'bg', 'rus':'ru', 'jpn':'ja', 'kat':'ka', 'kor':'ko', 438 | 'tha':'th', 'swh':'sw', 'cmn':'zh', 'kaz':'kk', 'tur':'tr', 439 | 'est':'et', 'fin':'fi', 'hun':'hu', 'pes':'fa'} 440 | lang2_dict = {l2: l3 for l3, l2 in lang3_dict.items()} 441 | 442 | src_lang2 = args.src_language 443 | tgt_lang2 = args.tgt_language 444 | src_lang3 = lang2_dict[args.src_language] 445 | tgt_lang3 = lang2_dict[args.tgt_language] 446 | src_text_file = os.path.join(args.data_dir, 'tatoeba.{}-eng.{}'.format(src_lang3, src_lang3)) 447 | tgt_text_file = os.path.join(args.data_dir, 'tatoeba.{}-eng.eng'.format(src_lang3)) 448 | src_tok_file = os.path.join(args.output_dir, 'tatoeba.{}-eng.tok.{}'.format(src_lang3, src_lang3)) 449 | tgt_tok_file = os.path.join(args.output_dir, 'tatoeba.{}-eng.tok.eng'.format(src_lang3)) 450 | 451 | all_src_embeds = extract_embeddings(args, src_text_file, src_tok_file, None, lang=src_lang2) 452 | all_tgt_embeds = extract_embeddings(args, tgt_text_file, tgt_tok_file, None, lang=tgt_lang2) 453 | 454 | idx = list(range(1, len(all_src_embeds) + 1, 4)) 455 | best_score = 0 456 | best_rep = None 457 | num_layers = len(all_src_embeds) 458 | for i in [args.specific_layer]: 459 | x, y = all_src_embeds[i], all_tgt_embeds[i] 460 | predictions = similarity_search(x, y, args.embed_size, normalize=(args.dist == 'cosine')) 461 | with open(os.path.join(args.output_dir, f'test_{src_lang2}_predictions.txt')) as fout: 462 | for p in predictions: 463 | fout.write(str(p) + '\n') 464 | 465 | 466 | main() 467 | 468 | -------------------------------------------------------------------------------- /utils_preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | from transformers import BertTokenizer, XLMTokenizer, XLMRobertaTokenizer 5 | import os 6 | from collections import defaultdict 7 | import csv 8 | import random 9 | import os 10 | import shutil 11 | import json 12 | 13 | 14 | TOKENIZERS = { 15 | 'bert': BertTokenizer, 16 | 'xlm': XLMTokenizer, 17 | 'xlmr': XLMRobertaTokenizer, 18 | } 19 | 20 | 21 | def panx_tokenize_preprocess(args): 22 | def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): 23 | if not os.path.exists(infile): 24 | print(f'{infile} not exists') 25 | return 0 26 | special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 27 | max_seq_len = max_len - special_tokens_count 28 | subword_len_counter = idx = 0 29 | with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx: 30 | for line in fin: 31 | line = line.strip() 32 | if not line: 33 | fout.write('\n') 34 | fidx.write('\n') 35 | idx += 1 36 | subword_len_counter = 0 37 | continue 38 | 39 | items = line.split() 40 | token = items[0].strip() 41 | if len(items) == 2: 42 | label = items[1].strip() 43 | else: 44 | label = 'O' 45 | current_subwords_len = len(tokenizer.tokenize(token)) 46 | 47 | if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0: 48 | token = tokenizer.unk_token 49 | current_subwords_len = 1 50 | 51 | if (subword_len_counter + current_subwords_len) > max_seq_len: 52 | fout.write(f"\n{token}\t{label}\n") 53 | fidx.write(f"\n{idx}\n") 54 | subword_len_counter = current_subwords_len 55 | else: 56 | fout.write(f"{token}\t{label}\n") 57 | fidx.write(f"{idx}\n") 58 | subword_len_counter += current_subwords_len 59 | return 1 60 | 61 | model_type = args.model_type 62 | tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path, 63 | do_lower_case=args.do_lower_case, 64 | cache_dir=args.cache_dir if args.cache_dir else None) 65 | for lang in args.languages.split(','): 66 | out_dir = os.path.join(args.output_dir, lang) 67 | if not os.path.exists(out_dir): 68 | os.makedirs(out_dir) 69 | if lang == 'en': 70 | files = ['dev', 'test', 'train'] 71 | else: 72 | files = ['dev', 'test'] 73 | for file in files: 74 | infile = os.path.join(args.data_dir, f'{file}-{lang}.tsv') 75 | outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path)) 76 | idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path)) 77 | if os.path.exists(outfile) and os.path.exists(idxfile): 78 | print(f'{outfile} and {idxfile} exist') 79 | else: 80 | code = _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) 81 | if code > 0: 82 | print(f'finish preprocessing {outfile}') 83 | 84 | 85 | def panx_preprocess(args): 86 | def _process_one_file(infile, outfile): 87 | with open(infile, 'r') as fin, open(outfile, 'w') as fout: 88 | for l in fin: 89 | items = l.strip().split('\t') 90 | if len(items) == 2: 91 | label = items[1].strip() 92 | token = items[0].split(':')[1].strip() 93 | if 'test' in infile: 94 | fout.write(f'{token}\n') 95 | else: 96 | fout.write(f'{token}\t{label}\n') 97 | else: 98 | fout.write('\n') 99 | if not os.path.exists(args.output_dir): 100 | os.makedirs(args.output_dir) 101 | langs = 'ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu'.split(' ') 102 | for lg in langs: 103 | for split in ['train', 'test', 'dev']: 104 | infile = os.path.join(args.data_dir, f'{lg}-{split}') 105 | outfile = os.path.join(args.output_dir, f'{split}-{lg}.tsv') 106 | _process_one_file(infile, outfile) 107 | 108 | 109 | def udpos_tokenize_preprocess(args): 110 | def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): 111 | if not os.path.exists(infile): 112 | print(f'{infile} does not exist') 113 | return 114 | subword_len_counter = idx = 0 115 | special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 116 | max_seq_len = max_len - special_tokens_count 117 | with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx: 118 | for line in fin: 119 | line = line.strip() 120 | if len(line) == 0 or line == '': 121 | fout.write('\n') 122 | fidx.write('\n') 123 | idx += 1 124 | subword_len_counter = 0 125 | continue 126 | 127 | items = line.split() 128 | if len(items) == 2: 129 | label = items[1].strip() 130 | else: 131 | label = "X" 132 | token = items[0].strip() 133 | current_subwords_len = len(tokenizer.tokenize(token)) 134 | 135 | if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0: 136 | token = tokenizer.unk_token 137 | current_subwords_len = 1 138 | 139 | if (subword_len_counter + current_subwords_len) > max_seq_len: 140 | fout.write(f"\n{token}\t{label}\n") 141 | fidx.write(f"\n{idx}\n") 142 | subword_len_counter = current_subwords_len 143 | else: 144 | fout.write(f"{token}\t{label}\n") 145 | fidx.write(f"{idx}\n") 146 | subword_len_counter += current_subwords_len 147 | 148 | model_type = args.model_type 149 | tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path, 150 | do_lower_case=args.do_lower_case, 151 | cache_dir=args.cache_dir if args.cache_dir else None) 152 | for lang in args.languages.split(','): 153 | out_dir = os.path.join(args.output_dir, lang) 154 | if not os.path.exists(out_dir): 155 | os.makedirs(out_dir) 156 | if lang == 'en': 157 | files = ['dev', 'test', 'train'] 158 | else: 159 | files = ['dev', 'test'] 160 | for file in files: 161 | infile = os.path.join(args.data_dir, "{}-{}.tsv".format(file, lang)) 162 | outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path)) 163 | idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path)) 164 | if os.path.exists(outfile) and os.path.exists(idxfile): 165 | print(f'{outfile} and {idxfile} exist') 166 | else: 167 | _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) 168 | print(f'finish preprocessing {outfile}') 169 | 170 | 171 | def udpos_preprocess(args): 172 | def _read_one_file(file): 173 | data = [] 174 | sent, tag, lines = [], [], [] 175 | for line in open(file, 'r'): 176 | items = line.strip().split('\t') 177 | if len(items) != 10: 178 | empty = all(w == '_' for w in sent) 179 | if not empty: 180 | data.append((sent, tag, lines)) 181 | sent, tag, lines = [], [], [] 182 | else: 183 | sent.append(items[1].strip()) 184 | tag.append(items[3].strip()) 185 | lines.append(line.strip()) 186 | assert len(sent) == int(items[0]), 'line={}, sent={}, tag={}'.format(line, sent, tag) 187 | return data 188 | 189 | def isfloat(value): 190 | try: 191 | float(value) 192 | return True 193 | except ValueError: 194 | return False 195 | 196 | def remove_empty_space(data): 197 | new_data = {} 198 | for split in data: 199 | new_data[split] = [] 200 | for sent, tag, lines in data[split]: 201 | new_sent = [''.join(w.replace('\u200c', '').split(' ')) for w in sent] 202 | lines = [line.replace('\u200c', '') for line in lines] 203 | assert len(" ".join(new_sent).split(' ')) == len(tag) 204 | new_data[split].append((new_sent, tag, lines)) 205 | return new_data 206 | 207 | def check_file(file): 208 | for i, l in enumerate(open(file)): 209 | items = l.strip().split('\t') 210 | assert len(items[0].split(' ')) == len(items[1].split(' ')), 'idx={}, line={}'.format(i, l) 211 | 212 | def _write_files(data, output_dir, lang, suffix): 213 | for split in data: 214 | if len(data[split]) > 0: 215 | prefix = os.path.join(output_dir, f'{split}-{lang}') 216 | if suffix == 'mt': 217 | with open(prefix + '.mt.tsv', 'w') as fout: 218 | for idx, (sent, tag, _) in enumerate(data[split]): 219 | newline = '\n' if idx != len(data[split]) - 1 else '' 220 | if split == 'test': 221 | fout.write('{}{}'.format(' '.join(sent, newline))) 222 | else: 223 | fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline)) 224 | check_file(prefix + '.mt.tsv') 225 | print(' - finish checking ' + prefix + '.mt.tsv') 226 | elif suffix == 'tsv': 227 | with open(prefix + '.tsv', 'w') as fout: 228 | for sidx, (sent, tag, _) in enumerate(data[split]): 229 | for widx, (w, t) in enumerate(zip(sent, tag)): 230 | newline = '' if (sidx == len(data[split]) - 1) and (widx == len(sent) - 1) else '\n' 231 | if split == 'test': 232 | fout.write('{}{}'.format(w, newline)) 233 | else: 234 | fout.write('{}\t{}{}'.format(w, t, newline)) 235 | fout.write('\n') 236 | elif suffix == 'conll': 237 | with open(prefix + '.conll', 'w') as fout: 238 | for _, _, lines in data[split]: 239 | for l in lines: 240 | fout.write(l.strip() + '\n') 241 | fout.write('\n') 242 | print(f'finish writing file to {prefix}.{suffix}') 243 | 244 | if not os.path.exists(args.output_dir): 245 | os.makedirs(args.output_dir) 246 | 247 | languages = 'af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh'.split(' ') 248 | for root, dirs, files in os.walk(args.data_dir): 249 | lg = root.strip().split('/')[-1] 250 | if root == args.data_dir or lg not in languages: 251 | continue 252 | 253 | data = {k: [] for k in ['train', 'dev', 'test']} 254 | for f in files: 255 | if f.endswith('conll'): 256 | file = os.path.join(root, f) 257 | examples = _read_one_file(file) 258 | if 'train' in f: 259 | data['train'].extend(examples) 260 | elif 'dev' in f: 261 | data['dev'].extend(examples) 262 | elif 'test' in f: 263 | data['test'].extend(examples) 264 | else: 265 | print('split not found: ', file) 266 | print(' - finish reading {}, {}'.format(file, [(k, len(v)) for k,v in data.items()])) 267 | 268 | data = remove_empty_space(data) 269 | for sub in ['tsv']: 270 | _write_files(data, args.output_dir, lg, sub) 271 | 272 | 273 | def pawsx_preprocess(args): 274 | def _preprocess_one_file(infile, outfile, remove_label=False): 275 | data = [] 276 | for i, line in enumerate(open(infile, 'r')): 277 | if i == 0: 278 | continue 279 | items = line.strip().split('\t') 280 | sent1 = ' '.join(items[1].strip().split(' ')) 281 | sent2 = ' '.join(items[2].strip().split(' ')) 282 | label = items[3] 283 | data.append([sent1, sent2, label]) 284 | 285 | with open(outfile, 'w') as fout: 286 | writer = csv.writer(fout, delimiter='\t') 287 | for sent1, sent2, label in data: 288 | if remove_label: 289 | writer.writerow([sent1, sent2]) 290 | else: 291 | writer.writerow([sent1, sent2, label]) 292 | 293 | if not os.path.exists(args.output_dir): 294 | os.makedirs(args.output_dir) 295 | 296 | split2file = {'train': 'train', 'test': 'test_2k', 'dev': 'dev_2k'} 297 | for lang in ['en', 'de', 'es', 'fr', 'ja', 'ko', 'zh']: 298 | for split in ['train', 'test', 'dev']: 299 | if split == 'train' and lang != 'en': 300 | continue 301 | file = split2file[split] 302 | infile = os.path.join(args.data_dir, lang, "{}.tsv".format(file)) 303 | outfile = os.path.join(args.output_dir, "{}-{}.tsv".format(lang, split)) 304 | _preprocess_one_file(infile, outfile, remove_label=(split == 'test')) 305 | print(f'finish preprocessing {outfile}') 306 | 307 | 308 | def xnli_preprocess(args): 309 | def _preprocess_file(infile, output_dir, split): 310 | all_langs = defaultdict(list) 311 | for i, line in enumerate(open(infile, 'r')): 312 | if i == 0: 313 | continue 314 | 315 | items = line.strip().split('\t') 316 | lang = items[0].strip() 317 | label = "contradiction" if items[1].strip() == "contradictory" else items[1].strip() 318 | sent1 = ' '.join(items[6].strip().split(' ')) 319 | sent2 = ' '.join(items[7].strip().split(' ')) 320 | all_langs[lang].append((sent1, sent2, label)) 321 | print(f'# langs={len(all_langs)}') 322 | for lang, pairs in all_langs.items(): 323 | outfile = os.path.join(output_dir, '{}-{}.tsv'.format(lang, split)) 324 | with open(outfile, 'w') as fout: 325 | writer = csv.writer(fout, delimiter='\t') 326 | for (sent1, sent2, label) in pairs: 327 | if split == 'test': 328 | writer.writerow([sent1, sent2]) 329 | else: 330 | writer.writerow([sent1, sent2, label]) 331 | print(f'finish preprocess {outfile}') 332 | 333 | def _preprocess_train_file(infile, outfile): 334 | with open(outfile, 'w') as fout: 335 | writer = csv.writer(fout, delimiter='\t') 336 | for i, line in enumerate(open(infile, 'r')): 337 | if i == 0: 338 | continue 339 | 340 | items = line.strip().split('\t') 341 | sent1 = ' '.join(items[0].strip().split(' ')) 342 | sent2 = ' '.join(items[1].strip().split(' ')) 343 | label = "contradiction" if items[2].strip() == "contradictory" else items[2].strip() 344 | writer.writerow([sent1, sent2, label]) 345 | print(f'finish preprocess {outfile}') 346 | 347 | infile = os.path.join(args.data_dir, 'XNLI-MT-1.0/multinli/multinli.train.en.tsv') 348 | if not os.path.exists(args.output_dir): 349 | os.makedirs(args.output_dir) 350 | outfile = os.path.join(args.output_dir, 'train-en.tsv') 351 | _preprocess_train_file(infile, outfile) 352 | 353 | for split in ['test', 'dev']: 354 | infile = os.path.join(args.data_dir, 'XNLI-1.0/xnli.{}.tsv'.format(split)) 355 | print(f'reading file {infile}') 356 | _preprocess_file(infile, args.output_dir, split) 357 | 358 | 359 | def tatoeba_preprocess(args): 360 | lang3_dict = { 361 | 'afr':'af', 'ara':'ar', 'bul':'bg', 'ben':'bn', 362 | 'deu':'de', 'ell':'el', 'spa':'es', 'est':'et', 363 | 'eus':'eu', 'pes':'fa', 'fin':'fi', 'fra':'fr', 364 | 'heb':'he', 'hin':'hi', 'hun':'hu', 'ind':'id', 365 | 'ita':'it', 'jpn':'ja', 'jav':'jv', 'kat':'ka', 366 | 'kaz':'kk', 'kor':'ko', 'mal':'ml', 'mar':'mr', 367 | 'nld':'nl', 'por':'pt', 'rus':'ru', 'swh':'sw', 368 | 'tam':'ta', 'tel':'te', 'tha':'th', 'tgl':'tl', 369 | 'tur':'tr', 'urd':'ur', 'vie':'vi', 'cmn':'zh', 370 | 'eng':'en', 371 | } 372 | if not os.path.exists(args.output_dir): 373 | os.makedirs(args.output_dir) 374 | for sl3, sl2 in lang3_dict.items(): 375 | if sl3 != 'eng': 376 | src_file = f'{args.data_dir}/tatoeba.{sl3}-eng.{sl3}' 377 | tgt_file = f'{args.data_dir}/tatoeba.{sl3}-eng.eng' 378 | src_out = f'{args.output_dir}/{sl2}-en.{sl2}' 379 | tgt_out = f'{args.output_dir}/{sl2}-en.en' 380 | shutil.copy(src_file, src_out) 381 | tgts = [l.strip() for l in open(tgt_file)] 382 | idx = range(len(tgts)) 383 | data = zip(tgts, idx) 384 | with open(tgt_out, 'w') as ftgt: 385 | for t, i in sorted(data, key=lambda x: x[0]): 386 | ftgt.write(f'{t}\n') 387 | 388 | 389 | def xquad_preprocess(args): 390 | # Remove the test annotations to prevent accidental cheating 391 | remove_qa_test_annotations(args.data_dir) 392 | 393 | 394 | def mlqa_preprocess(args): 395 | # Remove the test annotations to prevent accidental cheating 396 | remove_qa_test_annotations(args.data_dir) 397 | 398 | 399 | def tydiqa_preprocess(args): 400 | LANG2ISO = {'arabic': 'ar', 'bengali': 'bn', 'english': 'en', 'finnish': 'fi', 401 | 'indonesian': 'id', 'korean': 'ko', 'russian': 'ru', 402 | 'swahili': 'sw', 'telugu': 'te'} 403 | assert os.path.exists(args.data_dir) 404 | train_file = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-train.json') 405 | os.makedirs(args.output_dir, exist_ok=True) 406 | 407 | # Split the training file into language-specific files 408 | lang2data = defaultdict(list) 409 | with open(train_file, 'r') as f_in: 410 | data = json.load(f_in) 411 | version = data['version'] 412 | for doc in data['data']: 413 | for par in doc['paragraphs']: 414 | context = par['context'] 415 | for qa in par['qas']: 416 | question = qa['question'] 417 | question_id = qa['id'] 418 | example_lang = question_id.split('-')[0] 419 | q_id = question_id.split('-')[-1] 420 | for answer in qa['answers']: 421 | a_start, a_text = answer['answer_start'], answer['text'] 422 | a_end = a_start + len(a_text) 423 | assert context[a_start:a_end] == a_text 424 | lang2data[example_lang].append({'paragraphs': [{ 425 | 'context': context, 426 | 'qas': [{'answers': qa['answers'], 427 | 'question': question, 428 | 'id': q_id}]}]}) 429 | 430 | for lang, data in lang2data.items(): 431 | out_file = os.path.join( 432 | args.output_dir, 'tydiqa.%s.train.json' % LANG2ISO[lang]) 433 | with open(out_file, 'w') as f: 434 | json.dump({'data': data, 'version': version}, f) 435 | 436 | # Rename the dev files 437 | dev_dir = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-dev') 438 | assert os.path.exists(dev_dir) 439 | for lang, iso in LANG2ISO.items(): 440 | src_file = os.path.join(dev_dir, 'tydiqa-goldp-dev-%s.json' % lang) 441 | dst_file = os.path.join(dev_dir, 'tydiqa.%s.dev.json' % iso) 442 | os.rename(src_file, dst_file) 443 | 444 | # Remove the test annotations to prevent accidental cheating 445 | remove_qa_test_annotations(dev_dir) 446 | 447 | 448 | def remove_qa_test_annotations(test_dir): 449 | assert os.path.exists(test_dir) 450 | for file_name in os.listdir(test_dir): 451 | new_data = [] 452 | test_file = os.path.join(test_dir, file_name) 453 | with open(test_file, 'r') as f: 454 | data = json.load(f) 455 | version = data['version'] 456 | for doc in data['data']: 457 | for par in doc['paragraphs']: 458 | context = par['context'] 459 | for qa in par['qas']: 460 | question = qa['question'] 461 | question_id = qa['id'] 462 | for answer in qa['answers']: 463 | a_start, a_text = answer['answer_start'], answer['text'] 464 | a_end = a_start + len(a_text) 465 | assert context[a_start:a_end] == a_text 466 | new_data.append({'paragraphs': [{ 467 | 'context': context, 468 | 'qas': [{'answers': [{'answer_start': 0, 'text': ''}], 469 | 'question': question, 470 | 'id': question_id}]}]}) 471 | with open(test_file, 'w') as f: 472 | json.dump({'data': new_data, 'version': version}, f) 473 | 474 | 475 | if __name__ == "__main__": 476 | parser = argparse.ArgumentParser() 477 | 478 | ## Required parameters 479 | parser.add_argument("--data_dir", default=None, type=str, required=True, 480 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 481 | parser.add_argument("--output_dir", default=None, type=str, required=True, 482 | help="The output data dir where any processed files will be written to.") 483 | parser.add_argument("--task", default="panx", type=str, required=True, 484 | help="The task name") 485 | parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str, 486 | help="The pre-trained model") 487 | parser.add_argument("--model_type", default="bert", type=str, 488 | help="model type") 489 | parser.add_argument("--max_len", default=512, type=int, 490 | help="the maximum length of sentences") 491 | parser.add_argument("--do_lower_case", action='store_true', 492 | help="whether to do lower case") 493 | parser.add_argument("--cache_dir", default=None, type=str, 494 | help="cache directory") 495 | parser.add_argument("--languages", default="en", type=str, 496 | help="process language") 497 | parser.add_argument("--remove_last_token", action='store_true', 498 | help="whether to remove the last token") 499 | parser.add_argument("--remove_test_label", action='store_true', 500 | help="whether to remove test set label") 501 | args = parser.parse_args() 502 | 503 | if args.task == 'panx_tokenize': 504 | panx_tokenize_preprocess(args) 505 | if args.task == 'panx': 506 | panx_preprocess(args) 507 | if args.task == 'udpos_tokenize': 508 | udpos_tokenize_preprocess(args) 509 | if args.task == 'udpos': 510 | udpos_preprocess(args) 511 | if args.task == 'pawsx': 512 | pawsx_preprocess(args) 513 | if args.task == 'xnli': 514 | xnli_preprocess(args) 515 | if args.task == 'tatoeba': 516 | tatoeba_preprocess(args) 517 | if args.task == 'xquad': 518 | xquad_preprocess(args) 519 | if args.task == 'mlqa': 520 | mlqa_preprocess(args) 521 | if args.task == 'tydiqa': 522 | tydiqa_preprocess(args) 523 | -------------------------------------------------------------------------------- /third_party/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch XLM-RoBERTa model. """ 17 | 18 | 19 | import logging 20 | 21 | from transformers.configuration_roberta import RobertaConfig 22 | from transformers.file_utils import add_start_docstrings 23 | from roberta import ( 24 | RobertaForMaskedLM, 25 | RobertaForMultipleChoice, 26 | RobertaForSequenceClassification, 27 | RobertaForTokenClassification, 28 | RobertaForQuestionAnswering, 29 | RobertaModel, 30 | ) 31 | 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 37 | "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json", 38 | "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json", 39 | "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json", 40 | "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json", 41 | "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json", 42 | "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json", 43 | } 44 | 45 | 46 | class XLMRobertaConfig(RobertaConfig): 47 | pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 48 | 49 | XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { 50 | "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-pytorch_model.bin", 51 | "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-pytorch_model.bin", 52 | "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-pytorch_model.bin", 53 | "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-pytorch_model.bin", 54 | "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-pytorch_model.bin", 55 | "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-pytorch_model.bin", 56 | } 57 | 58 | 59 | XLM_ROBERTA_START_DOCSTRING = r""" The XLM-RoBERTa model was proposed in 60 | `Unsupervised Cross-lingual Representation Learning at Scale`_ 61 | by Alexis Conneau, Kartikay Khandelwal, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. It is based on Facebook's RoBERTa model released in 2019. 62 | 63 | It is a large multi-lingual language model, trained on 2.5TB of filtered CommonCrawl data. 64 | 65 | This implementation is the same as RoBERTa. 66 | 67 | This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and 68 | refer to the PyTorch documentation for all matter related to general usage and behavior. 69 | 70 | .. _`Unsupervised Cross-lingual Representation Learning at Scale`: 71 | https://arxiv.org/abs/1911.02116 72 | 73 | .. _`torch.nn.Module`: 74 | https://pytorch.org/docs/stable/nn.html#module 75 | 76 | Parameters: 77 | config (:class:`~transformers.XLMRobertaConfig`): Model configuration class with all the parameters of the 78 | model. Initializing with a config file does not load the weights associated with the model, only the configuration. 79 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 80 | """ 81 | 82 | XLM_ROBERTA_INPUTS_DOCSTRING = r""" 83 | Inputs: 84 | **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 85 | Indices of input sequence tokens in the vocabulary. 86 | To match pre-training, XLM-RoBERTa input sequence should be formatted with and tokens as follows: 87 | 88 | (a) For sequence pairs: 89 | 90 | ``tokens: Is this Jacksonville ? No it is not . `` 91 | 92 | (b) For single sequences: 93 | 94 | ``tokens: the dog is hairy . `` 95 | 96 | Fully encoded sequences or sequence pairs can be obtained using the XLMRobertaTokenizer.encode function with 97 | the ``add_special_tokens`` parameter set to ``True``. 98 | 99 | XLM-RoBERTa is a model with absolute position embeddings so it's usually advised to pad the inputs on 100 | the right rather than the left. 101 | 102 | See :func:`transformers.PreTrainedTokenizer.encode` and 103 | :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. 104 | **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: 105 | Mask to avoid performing attention on padding token indices. 106 | Mask values selected in ``[0, 1]``: 107 | ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. 108 | **token_type_ids**: (`optional` need to be trained) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 109 | Optional segment token indices to indicate first and second portions of the inputs. 110 | This embedding matrice is not trained (not pretrained during XLM-RoBERTa pretraining), you will have to train it 111 | during finetuning. 112 | Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` 113 | corresponds to a `sentence B` token 114 | (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details). 115 | **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 116 | Indices of positions of each input sequence tokens in the position embeddings. 117 | Selected in the range ``[0, config.max_position_embeddings - 1[``. 118 | **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: 119 | Mask to nullify selected heads of the self-attention modules. 120 | Mask values selected in ``[0, 1]``: 121 | ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. 122 | **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``: 123 | Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation. 124 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 125 | than the model's internal embedding lookup matrix. 126 | """ 127 | 128 | 129 | @add_start_docstrings( 130 | "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", 131 | XLM_ROBERTA_START_DOCSTRING, 132 | XLM_ROBERTA_INPUTS_DOCSTRING, 133 | ) 134 | class XLMRobertaModel(RobertaModel): 135 | r""" 136 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 137 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 138 | Sequence of hidden-states at the output of the last layer of the model. 139 | **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` 140 | Last layer hidden-state of the first token of the sequence (classification token) 141 | further processed by a Linear layer and a Tanh activation function. The Linear 142 | layer weights are trained from the next sentence prediction (classification) 143 | eo match pre-training, XLM-RoBERTa input sequence should be formatted with and tokens as follows: 144 | 145 | (a) For sequence pairs: 146 | 147 | ``tokens: is this jack ##son ##ville ? no it is not . `` 148 | 149 | ``token_type_ids: 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` 150 | 151 | (b) For single sequences: 152 | 153 | ``tokens: the dog is hairy . `` 154 | 155 | ``token_type_ids: 0 0 0 0 0 0 0`` 156 | 157 | objective during Bert pretraining. This output is usually *not* a good summary 158 | of the semantic content of the input, you're often better with averaging or pooling 159 | the sequence of hidden-states for the whole input sequence. 160 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 161 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 162 | of shape ``(batch_size, sequence_length, hidden_size)``: 163 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 164 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 165 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 166 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 167 | 168 | Examples:: 169 | 170 | tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') 171 | model = XLMRobertaModel.from_pretrained('xlm-roberta-large') 172 | input_ids = torch.tensor(tokenizer.encode("Schloß Nymphenburg ist sehr schön .")).unsqueeze(0) # Batch size 1 173 | outputs = model(input_ids) 174 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 175 | 176 | """ 177 | config_class = XLMRobertaConfig 178 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 179 | 180 | 181 | @add_start_docstrings( 182 | """XLM-RoBERTa Model with a `language modeling` head on top. """, 183 | XLM_ROBERTA_START_DOCSTRING, 184 | XLM_ROBERTA_INPUTS_DOCSTRING, 185 | ) 186 | class XLMRobertaForMaskedLM(RobertaForMaskedLM): 187 | r""" 188 | **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 189 | Labels for computing the masked language modeling loss. 190 | Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) 191 | Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels 192 | in ``[0, ..., config.vocab_size]`` 193 | 194 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 195 | **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 196 | Masked language modeling loss. 197 | **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` 198 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 199 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 200 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 201 | of shape ``(batch_size, sequence_length, hidden_size)``: 202 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 203 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 204 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 205 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 206 | 207 | Examples:: 208 | 209 | tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') 210 | model = XLMRobertaForMaskedLM.from_pretrained('xlm-roberta-large') 211 | input_ids = torch.tensor(tokenizer.encode("Schloß Nymphenburg ist sehr schön .")).unsqueeze(0) # Batch size 1 212 | outputs = model(input_ids, masked_lm_labels=input_ids) 213 | loss, prediction_scores = outputs[:2] 214 | 215 | """ 216 | config_class = XLMRobertaConfig 217 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 218 | 219 | 220 | @add_start_docstrings( 221 | """XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer 222 | on top of the pooled output) e.g. for GLUE tasks. """, 223 | XLM_ROBERTA_START_DOCSTRING, 224 | XLM_ROBERTA_INPUTS_DOCSTRING, 225 | ) 226 | class XLMRobertaForSequenceClassification(RobertaForSequenceClassification): 227 | r""" 228 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 229 | Labels for computing the sequence classification/regression loss. 230 | Indices should be in ``[0, ..., config.num_labels]``. 231 | If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), 232 | If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). 233 | 234 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 235 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 236 | Classification (or regression if config.num_labels==1) loss. 237 | **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` 238 | Classification (or regression if config.num_labels==1) scores (before SoftMax). 239 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 240 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 241 | of shape ``(batch_size, sequence_length, hidden_size)``: 242 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 243 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 244 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 245 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 246 | 247 | Examples:: 248 | 249 | tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') 250 | model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large') 251 | input_ids = torch.tensor(tokenizer.encode("Schloß Nymphenburg ist sehr schön .")).unsqueeze(0) # Batch size 1 252 | labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 253 | outputs = model(input_ids, labels=labels) 254 | loss, logits = outputs[:2] 255 | 256 | """ 257 | config_class = XLMRobertaConfig 258 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 259 | 260 | 261 | @add_start_docstrings( 262 | """XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of 263 | the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, 264 | XLM_ROBERTA_START_DOCSTRING, 265 | XLM_ROBERTA_INPUTS_DOCSTRING, 266 | ) 267 | class XLMRobertaForMultipleChoice(RobertaForMultipleChoice): 268 | r""" 269 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 270 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 271 | Classification loss. 272 | **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension 273 | of the input tensors. (see `input_ids` above). 274 | Classification scores (before SoftMax). 275 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 276 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 277 | of shape ``(batch_size, sequence_length, hidden_size)``: 278 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 279 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 280 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 281 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 282 | 283 | Examples:: 284 | 285 | tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') 286 | model = XLMRobertaForMultipleChoice.from_pretrained('xlm-roberta-large') 287 | choices = ["Schloß Nymphenburg ist sehr schön .", "Der Schloßkanal auch !"] 288 | input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices 289 | labels = torch.tensor(1).unsqueeze(0) # Batch size 1 290 | outputs = model(input_ids, labels=labels) 291 | loss, classification_scores = outputs[:2] 292 | 293 | """ 294 | config_class = XLMRobertaConfig 295 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 296 | 297 | 298 | @add_start_docstrings( 299 | """XLM-RoBERTa Model with a token classification head on top (a linear layer on top of 300 | the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, 301 | XLM_ROBERTA_START_DOCSTRING, 302 | XLM_ROBERTA_INPUTS_DOCSTRING, 303 | ) 304 | class XLMRobertaForTokenClassification(RobertaForTokenClassification): 305 | r""" 306 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 307 | Labels for computing the token classification loss. 308 | Indices should be in ``[0, ..., config.num_labels - 1]``. 309 | 310 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 311 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 312 | Classification loss. 313 | **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)`` 314 | Classification scores (before SoftMax). 315 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 316 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 317 | of shape ``(batch_size, sequence_length, hidden_size)``: 318 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 319 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 320 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 321 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 322 | 323 | Examples:: 324 | 325 | tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') 326 | model = XLMRobertaForTokenClassification.from_pretrained('xlm-roberta-large') 327 | input_ids = torch.tensor(tokenizer.encode("Schloß Nymphenburg ist sehr schön .", add_special_tokens=True)).unsqueeze(0) # Batch size 1 328 | labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 329 | outputs = model(input_ids, labels=labels) 330 | loss, scores = outputs[:2] 331 | 332 | """ 333 | config_class = XLMRobertaConfig 334 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 335 | 336 | @add_start_docstrings( 337 | """XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of 338 | the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, 339 | XLM_ROBERTA_START_DOCSTRING, 340 | XLM_ROBERTA_INPUTS_DOCSTRING, 341 | ) 342 | class XLMRobertaForQuestionAnswering(RobertaForQuestionAnswering): 343 | r""" 344 | **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 345 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 346 | Positions are clamped to the length of the sequence (`sequence_length`). 347 | Position outside of the sequence are not taken into account for computing the loss. 348 | **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 349 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 350 | Positions are clamped to the length of the sequence (`sequence_length`). 351 | Position outside of the sequence are not taken into account for computing the loss. 352 | **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 353 | Labels whether a question has an answer or no answer (SQuAD 2.0) 354 | **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 355 | Labels for position (index) of the classification token to use as input for computing plausibility of the answer. 356 | **p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 357 | Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...) 358 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 359 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 360 | Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. 361 | **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` 362 | Span-start scores (before SoftMax). 363 | **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` 364 | Span-end scores (before SoftMax). 365 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 366 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 367 | of shape ``(batch_size, sequence_length, hidden_size)``: 368 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 369 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 370 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 371 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 372 | Examples:: 373 | tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') 374 | model = XLMForQuestionAnswering.from_pretrained('xlm-mlm-en-2048') 375 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 376 | start_positions = torch.tensor([1]) 377 | end_positions = torch.tensor([3]) 378 | outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) 379 | loss, start_scores, end_scores = outputs[:2] 380 | """ 381 | config_class = XLMRobertaConfig 382 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP -------------------------------------------------------------------------------- /third_party/processors/squad.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from multiprocessing import Pool, cpu_count 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from transformers.file_utils import is_tf_available, is_torch_available 11 | from transformers.tokenization_bert import whitespace_tokenize 12 | from transformers import DataProcessor 13 | 14 | if is_torch_available(): 15 | import torch 16 | from torch.utils.data import TensorDataset 17 | 18 | if is_tf_available(): 19 | import tensorflow as tf 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text, lang='en', lang2id=None): 25 | """Returns tokenized answer spans that better match the annotated answer.""" 26 | if lang2id is None: 27 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 28 | else: 29 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, lang=lang)) 30 | 31 | for new_start in range(input_start, input_end + 1): 32 | for new_end in range(input_end, new_start - 1, -1): 33 | text_span = " ".join(doc_tokens[new_start : (new_end + 1)]) 34 | if text_span == tok_answer_text: 35 | return (new_start, new_end) 36 | 37 | return (input_start, input_end) 38 | 39 | 40 | def _check_is_max_context(doc_spans, cur_span_index, position): 41 | """Check if this is the 'max context' doc span for the token.""" 42 | best_score = None 43 | best_span_index = None 44 | for (span_index, doc_span) in enumerate(doc_spans): 45 | end = doc_span.start + doc_span.length - 1 46 | if position < doc_span.start: 47 | continue 48 | if position > end: 49 | continue 50 | num_left_context = position - doc_span.start 51 | num_right_context = end - position 52 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 53 | if best_score is None or score > best_score: 54 | best_score = score 55 | best_span_index = span_index 56 | 57 | return cur_span_index == best_span_index 58 | 59 | 60 | def _new_check_is_max_context(doc_spans, cur_span_index, position): 61 | """Check if this is the 'max context' doc span for the token.""" 62 | # if len(doc_spans) == 1: 63 | # return True 64 | best_score = None 65 | best_span_index = None 66 | for (span_index, doc_span) in enumerate(doc_spans): 67 | end = doc_span["start"] + doc_span["length"] - 1 68 | if position < doc_span["start"]: 69 | continue 70 | if position > end: 71 | continue 72 | num_left_context = position - doc_span["start"] 73 | num_right_context = end - position 74 | score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"] 75 | if best_score is None or score > best_score: 76 | best_score = score 77 | best_span_index = span_index 78 | 79 | return cur_span_index == best_span_index 80 | 81 | 82 | def _is_whitespace(c): 83 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 84 | return True 85 | return False 86 | 87 | 88 | def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training, lang2id): 89 | features = [] 90 | if is_training and not example.is_impossible: 91 | # Get start and end position 92 | start_position = example.start_position 93 | end_position = example.end_position 94 | 95 | # If the answer cannot be found in the text, then skip this example. 96 | actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)]) 97 | cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) 98 | if actual_text.find(cleaned_answer_text) == -1: 99 | logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) 100 | return [] 101 | 102 | tok_to_orig_index = [] 103 | orig_to_tok_index = [] 104 | all_doc_tokens = [] 105 | for (i, token) in enumerate(example.doc_tokens): 106 | orig_to_tok_index.append(len(all_doc_tokens)) 107 | if lang2id is None: 108 | sub_tokens = tokenizer.tokenize(token) 109 | else: 110 | sub_tokens = tokenizer.tokenize(token, lang=example.language) 111 | for sub_token in sub_tokens: 112 | tok_to_orig_index.append(i) 113 | all_doc_tokens.append(sub_token) 114 | 115 | if is_training and not example.is_impossible: 116 | tok_start_position = orig_to_tok_index[example.start_position] 117 | if example.end_position < len(example.doc_tokens) - 1: 118 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 119 | else: 120 | tok_end_position = len(all_doc_tokens) - 1 121 | 122 | (tok_start_position, tok_end_position) = _improve_answer_span( 123 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text, 124 | lang=example.language, lang2id=lang2id 125 | ) 126 | 127 | spans = [] 128 | 129 | truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) 130 | sequence_added_tokens = ( 131 | tokenizer.max_len - tokenizer.max_len_single_sentence + 1 132 | if "roberta" in str(type(tokenizer)) 133 | else tokenizer.max_len - tokenizer.max_len_single_sentence 134 | ) 135 | sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair 136 | 137 | span_doc_tokens = all_doc_tokens 138 | while len(spans) * doc_stride < len(all_doc_tokens): 139 | 140 | encoded_dict = tokenizer.encode_plus( 141 | truncated_query if tokenizer.padding_side == "right" else span_doc_tokens, 142 | span_doc_tokens if tokenizer.padding_side == "right" else truncated_query, 143 | max_length=max_seq_length, 144 | return_overflowing_tokens=True, 145 | pad_to_max_length=True, 146 | stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, 147 | truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first", 148 | ) 149 | 150 | paragraph_len = min( 151 | len(all_doc_tokens) - len(spans) * doc_stride, 152 | max_seq_length - len(truncated_query) - sequence_pair_added_tokens, 153 | ) 154 | 155 | if tokenizer.pad_token_id in encoded_dict["input_ids"]: 156 | non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)] 157 | else: 158 | non_padded_ids = encoded_dict["input_ids"] 159 | 160 | tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) 161 | 162 | token_to_orig_map = {} 163 | for i in range(paragraph_len): 164 | index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i 165 | token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i] 166 | 167 | encoded_dict["paragraph_len"] = paragraph_len 168 | encoded_dict["tokens"] = tokens 169 | encoded_dict["token_to_orig_map"] = token_to_orig_map 170 | encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens 171 | encoded_dict["token_is_max_context"] = {} 172 | encoded_dict["start"] = len(spans) * doc_stride 173 | encoded_dict["length"] = paragraph_len 174 | 175 | spans.append(encoded_dict) 176 | 177 | if "overflowing_tokens" not in encoded_dict: 178 | break 179 | span_doc_tokens = encoded_dict["overflowing_tokens"] 180 | 181 | for doc_span_index in range(len(spans)): 182 | for j in range(spans[doc_span_index]["paragraph_len"]): 183 | is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) 184 | index = ( 185 | j 186 | if tokenizer.padding_side == "left" 187 | else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j 188 | ) 189 | spans[doc_span_index]["token_is_max_context"][index] = is_max_context 190 | 191 | for span in spans: 192 | # Identify the position of the CLS token 193 | cls_index = span["input_ids"].index(tokenizer.cls_token_id) 194 | 195 | # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) 196 | # Original TF implem also keep the classification token (set to 0) (not sure why...) 197 | p_mask = np.array(span["token_type_ids"]) 198 | 199 | p_mask = np.minimum(p_mask, 1) 200 | 201 | if tokenizer.padding_side == "right": 202 | # Limit positive values to one 203 | p_mask = 1 - p_mask 204 | 205 | p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1 206 | 207 | # Set the CLS index to '0' 208 | p_mask[cls_index] = 0 209 | 210 | span_is_impossible = example.is_impossible 211 | start_position = 0 212 | end_position = 0 213 | if is_training and not span_is_impossible: 214 | # For training, if our document chunk does not contain an annotation 215 | # we throw it out, since there is nothing to predict. 216 | doc_start = span["start"] 217 | doc_end = span["start"] + span["length"] - 1 218 | out_of_span = False 219 | 220 | if not (tok_start_position >= doc_start and tok_end_position <= doc_end): 221 | out_of_span = True 222 | 223 | if out_of_span: 224 | start_position = cls_index 225 | end_position = cls_index 226 | span_is_impossible = True 227 | else: 228 | if tokenizer.padding_side == "left": 229 | doc_offset = 0 230 | else: 231 | doc_offset = len(truncated_query) + sequence_added_tokens 232 | 233 | start_position = tok_start_position - doc_start + doc_offset 234 | end_position = tok_end_position - doc_start + doc_offset 235 | 236 | if lang2id is not None: 237 | lid = lang2id.get(example.language, lang2id["en"]) 238 | else: 239 | lid = 0 240 | langs = [lid] * max_seq_length 241 | 242 | features.append( 243 | SquadFeatures( 244 | span["input_ids"], 245 | span["attention_mask"], 246 | span["token_type_ids"], 247 | cls_index, 248 | p_mask.tolist(), 249 | example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing. 250 | unique_id=0, 251 | paragraph_len=span["paragraph_len"], 252 | token_is_max_context=span["token_is_max_context"], 253 | tokens=span["tokens"], 254 | token_to_orig_map=span["token_to_orig_map"], 255 | start_position=start_position, 256 | end_position=end_position, 257 | langs=langs 258 | ) 259 | ) 260 | return features 261 | 262 | 263 | def squad_convert_example_to_features_init(tokenizer_for_convert): 264 | global tokenizer 265 | tokenizer = tokenizer_for_convert 266 | 267 | 268 | def squad_convert_examples_to_features( 269 | examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False, threads=1, 270 | lang2id=None 271 | ): 272 | """ 273 | Converts a list of examples into a list of features that can be directly given as input to a model. 274 | It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. 275 | 276 | Args: 277 | examples: list of :class:`~transformers.data.processors.squad.SquadExample` 278 | tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer` 279 | max_seq_length: The maximum sequence length of the inputs. 280 | doc_stride: The stride used when the context is too large and is split across several features. 281 | max_query_length: The maximum length of the query. 282 | is_training: whether to create features for model evaluation or model training. 283 | return_dataset: Default False. Either 'pt' or 'tf'. 284 | if 'pt': returns a torch.data.TensorDataset, 285 | if 'tf': returns a tf.data.Dataset 286 | threads: multiple processing threadsa-smi 287 | 288 | 289 | Returns: 290 | list of :class:`~transformers.data.processors.squad.SquadFeatures` 291 | 292 | Example:: 293 | 294 | processor = SquadV2Processor() 295 | examples = processor.get_dev_examples(data_dir) 296 | 297 | features = squad_convert_examples_to_features( 298 | examples=examples, 299 | tokenizer=tokenizer, 300 | max_seq_length=args.max_seq_length, 301 | doc_stride=args.doc_stride, 302 | max_query_length=args.max_query_length, 303 | is_training=not evaluate, 304 | ) 305 | """ 306 | 307 | # Defining helper methods 308 | features = [] 309 | threads = min(threads, cpu_count()) 310 | with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p: 311 | annotate_ = partial( 312 | squad_convert_example_to_features, 313 | max_seq_length=max_seq_length, 314 | doc_stride=doc_stride, 315 | max_query_length=max_query_length, 316 | is_training=is_training, 317 | lang2id=lang2id 318 | ) 319 | features = list( 320 | tqdm( 321 | p.imap(annotate_, examples, chunksize=32), 322 | total=len(examples), 323 | desc="convert squad examples to features", 324 | ) 325 | ) 326 | new_features = [] 327 | unique_id = 1000000000 328 | example_index = 0 329 | for example_features in tqdm(features, total=len(features), desc="add example index and unique id"): 330 | if not example_features: 331 | continue 332 | for example_feature in example_features: 333 | example_feature.example_index = example_index 334 | example_feature.unique_id = unique_id 335 | new_features.append(example_feature) 336 | unique_id += 1 337 | example_index += 1 338 | features = new_features 339 | del new_features 340 | if return_dataset == "pt": 341 | if not is_torch_available(): 342 | raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.") 343 | 344 | # Convert to Tensors and build dataset 345 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 346 | all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 347 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 348 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) 349 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) 350 | all_langs = torch.tensor([f.langs for f in features], dtype=torch.long) 351 | 352 | if not is_training: 353 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 354 | dataset = TensorDataset( 355 | all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask, all_langs 356 | ) 357 | else: 358 | all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) 359 | all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) 360 | dataset = TensorDataset( 361 | all_input_ids, 362 | all_attention_masks, 363 | all_token_type_ids, 364 | all_start_positions, 365 | all_end_positions, 366 | all_cls_index, 367 | all_p_mask, 368 | all_langs 369 | ) 370 | 371 | return features, dataset 372 | elif return_dataset == "tf": 373 | if not is_tf_available(): 374 | raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.") 375 | 376 | def gen(): 377 | for ex in features: 378 | yield ( 379 | { 380 | "input_ids": ex.input_ids, 381 | "attention_mask": ex.attention_mask, 382 | "token_type_ids": ex.token_type_ids, 383 | }, 384 | { 385 | "start_position": ex.start_position, 386 | "end_position": ex.end_position, 387 | "cls_index": ex.cls_index, 388 | "p_mask": ex.p_mask, 389 | }, 390 | ) 391 | 392 | return tf.data.Dataset.from_generator( 393 | gen, 394 | ( 395 | {"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, 396 | {"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32}, 397 | ), 398 | ( 399 | { 400 | "input_ids": tf.TensorShape([None]), 401 | "attention_mask": tf.TensorShape([None]), 402 | "token_type_ids": tf.TensorShape([None]), 403 | }, 404 | { 405 | "start_position": tf.TensorShape([]), 406 | "end_position": tf.TensorShape([]), 407 | "cls_index": tf.TensorShape([]), 408 | "p_mask": tf.TensorShape([None]), 409 | }, 410 | ), 411 | ) 412 | 413 | return features 414 | 415 | 416 | class SquadProcessor(DataProcessor): 417 | """ 418 | Processor for the SQuAD data set. 419 | Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively. 420 | """ 421 | 422 | train_file = None 423 | dev_file = None 424 | 425 | def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False): 426 | if not evaluate: 427 | answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8") 428 | answer_start = tensor_dict["answers"]["answer_start"][0].numpy() 429 | answers = [] 430 | else: 431 | answers = [ 432 | {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")} 433 | for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"]) 434 | ] 435 | 436 | answer = None 437 | answer_start = None 438 | 439 | return SquadExample( 440 | qas_id=tensor_dict["id"].numpy().decode("utf-8"), 441 | question_text=tensor_dict["question"].numpy().decode("utf-8"), 442 | context_text=tensor_dict["context"].numpy().decode("utf-8"), 443 | answer_text=answer, 444 | start_position_character=answer_start, 445 | title=tensor_dict["title"].numpy().decode("utf-8"), 446 | answers=answers, 447 | ) 448 | 449 | def get_examples_from_dataset(self, dataset, evaluate=False): 450 | """ 451 | Creates a list of :class:`~transformers.data.processors.squad.SquadExample` using a TFDS dataset. 452 | 453 | Args: 454 | dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")` 455 | evaluate: boolean specifying if in evaluation mode or in training mode 456 | 457 | Returns: 458 | List of SquadExample 459 | 460 | Examples:: 461 | 462 | import tensorflow_datasets as tfds 463 | dataset = tfds.load("squad") 464 | 465 | training_examples = get_examples_from_dataset(dataset, evaluate=False) 466 | evaluation_examples = get_examples_from_dataset(dataset, evaluate=True) 467 | """ 468 | 469 | if evaluate: 470 | dataset = dataset["validation"] 471 | else: 472 | dataset = dataset["train"] 473 | 474 | examples = [] 475 | for tensor_dict in tqdm(dataset): 476 | examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate)) 477 | 478 | return examples 479 | 480 | def get_train_examples(self, data_dir, filename=None, language='en'): 481 | """ 482 | Returns the training examples from the data directory. 483 | 484 | Args: 485 | data_dir: Directory containing the data files used for training and evaluating. 486 | filename: None by default, specify this if the training file has a different name than the original one 487 | which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. 488 | 489 | """ 490 | if data_dir is None: 491 | data_dir = "" 492 | 493 | if self.train_file is None: 494 | raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") 495 | 496 | with open( 497 | os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8" 498 | ) as reader: 499 | input_data = json.load(reader)["data"] 500 | return self._create_examples(input_data, "train", language) 501 | 502 | def get_dev_examples(self, data_dir, filename=None, language='en'): 503 | """ 504 | Returns the evaluation example from the data directory. 505 | 506 | Args: 507 | data_dir: Directory containing the data files used for training and evaluating. 508 | filename: None by default, specify this if the evaluation file has a different name than the original one 509 | which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. 510 | """ 511 | if data_dir is None: 512 | data_dir = "" 513 | 514 | if self.dev_file is None: 515 | raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") 516 | 517 | with open( 518 | os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8" 519 | ) as reader: 520 | input_data = json.load(reader)["data"] 521 | return self._create_examples(input_data, "dev", language) 522 | 523 | def _create_examples(self, input_data, set_type, language): 524 | is_training = set_type == "train" 525 | examples = [] 526 | for entry in tqdm(input_data): 527 | title = entry["title"] if "title" in entry else "" 528 | for paragraph in entry["paragraphs"]: 529 | context_text = paragraph["context"] 530 | for qa in paragraph["qas"]: 531 | qas_id = qa["id"] 532 | question_text = qa["question"] 533 | start_position_character = None 534 | answer_text = None 535 | answers = [] 536 | 537 | if "is_impossible" in qa: 538 | is_impossible = qa["is_impossible"] 539 | else: 540 | is_impossible = False 541 | 542 | if not is_impossible: 543 | if is_training: 544 | answer = qa["answers"][0] 545 | answer_text = answer["text"] 546 | start_position_character = answer["answer_start"] 547 | else: 548 | answers = qa["answers"] 549 | 550 | example = SquadExample( 551 | qas_id=qas_id, 552 | question_text=question_text, 553 | context_text=context_text, 554 | answer_text=answer_text, 555 | start_position_character=start_position_character, 556 | title=title, 557 | is_impossible=is_impossible, 558 | answers=answers, 559 | language=language 560 | ) 561 | 562 | examples.append(example) 563 | return examples 564 | 565 | 566 | class SquadV1Processor(SquadProcessor): 567 | train_file = "train-v1.1.json" 568 | dev_file = "dev-v1.1.json" 569 | 570 | 571 | class SquadV2Processor(SquadProcessor): 572 | train_file = "train-v2.0.json" 573 | dev_file = "dev-v2.0.json" 574 | 575 | 576 | class SquadExample(object): 577 | """ 578 | A single training/test example for the Squad dataset, as loaded from disk. 579 | 580 | Args: 581 | qas_id: The example's unique identifier 582 | question_text: The question string 583 | context_text: The context string 584 | answer_text: The answer string 585 | start_position_character: The character position of the start of the answer 586 | title: The title of the example 587 | answers: None by default, this is used during evaluation. Holds answers as well as their start positions. 588 | is_impossible: False by default, set to True if the example has no possible answer. 589 | """ 590 | 591 | def __init__( 592 | self, 593 | qas_id, 594 | question_text, 595 | context_text, 596 | answer_text, 597 | start_position_character, 598 | title, 599 | answers=[], 600 | is_impossible=False, 601 | language='en' 602 | ): 603 | self.qas_id = qas_id 604 | self.question_text = question_text 605 | self.context_text = context_text 606 | self.answer_text = answer_text 607 | self.title = title 608 | self.is_impossible = is_impossible 609 | self.answers = answers 610 | 611 | self.start_position, self.end_position = 0, 0 612 | 613 | self.language = language 614 | 615 | doc_tokens = [] 616 | char_to_word_offset = [] 617 | prev_is_whitespace = True 618 | 619 | # Split on whitespace so that different tokens may be attributed to their original position. 620 | for c in self.context_text: 621 | if _is_whitespace(c): 622 | prev_is_whitespace = True 623 | else: 624 | if prev_is_whitespace: 625 | doc_tokens.append(c) 626 | else: 627 | doc_tokens[-1] += c 628 | prev_is_whitespace = False 629 | char_to_word_offset.append(len(doc_tokens) - 1) 630 | 631 | self.doc_tokens = doc_tokens 632 | self.char_to_word_offset = char_to_word_offset 633 | 634 | # Start end end positions only has a value during evaluation. 635 | if start_position_character is not None and not is_impossible: 636 | self.start_position = char_to_word_offset[start_position_character] 637 | self.end_position = char_to_word_offset[ 638 | min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1) 639 | ] 640 | 641 | 642 | class SquadFeatures(object): 643 | """ 644 | Single squad example features to be fed to a model. 645 | Those features are model-specific and can be crafted from :class:`~transformers.data.processors.squad.SquadExample` 646 | using the :method:`~transformers.data.processors.squad.squad_convert_examples_to_features` method. 647 | 648 | Args: 649 | input_ids: Indices of input sequence tokens in the vocabulary. 650 | attention_mask: Mask to avoid performing attention on padding token indices. 651 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 652 | cls_index: the index of the CLS token. 653 | p_mask: Mask identifying tokens that can be answers vs. tokens that cannot. 654 | Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer 655 | example_index: the index of the example 656 | unique_id: The unique Feature identifier 657 | paragraph_len: The length of the context 658 | token_is_max_context: List of booleans identifying which tokens have their maximum context in this feature object. 659 | If a token does not have their maximum context in this feature object, it means that another feature object 660 | has more information related to that token and should be prioritized over this feature for that token. 661 | tokens: list of tokens corresponding to the input ids 662 | token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer. 663 | start_position: start of the answer token index 664 | end_position: end of the answer token index 665 | """ 666 | 667 | def __init__( 668 | self, 669 | input_ids, 670 | attention_mask, 671 | token_type_ids, 672 | cls_index, 673 | p_mask, 674 | example_index, 675 | unique_id, 676 | paragraph_len, 677 | token_is_max_context, 678 | tokens, 679 | token_to_orig_map, 680 | start_position, 681 | end_position, 682 | langs 683 | ): 684 | self.input_ids = input_ids 685 | self.attention_mask = attention_mask 686 | self.token_type_ids = token_type_ids 687 | self.cls_index = cls_index 688 | self.p_mask = p_mask 689 | 690 | self.example_index = example_index 691 | self.unique_id = unique_id 692 | self.paragraph_len = paragraph_len 693 | self.token_is_max_context = token_is_max_context 694 | self.tokens = tokens 695 | self.token_to_orig_map = token_to_orig_map 696 | 697 | self.start_position = start_position 698 | self.end_position = end_position 699 | self.langs = langs 700 | 701 | 702 | class SquadResult(object): 703 | """ 704 | Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset. 705 | 706 | Args: 707 | unique_id: The unique identifier corresponding to that example. 708 | start_logits: The logits corresponding to the start of the answer 709 | end_logits: The logits corresponding to the end of the answer 710 | """ 711 | 712 | def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None): 713 | self.start_logits = start_logits 714 | self.end_logits = end_logits 715 | self.unique_id = unique_id 716 | 717 | if start_top_index: 718 | self.start_top_index = start_top_index 719 | self.end_top_index = end_top_index 720 | self.cls_logits = cls_logits 721 | --------------------------------------------------------------------------------