├── images ├── lap_yelp.pdf ├── mr_yelp.pdf ├── absa_amazon.pdf ├── absa_yelp.pdf ├── lap_amazon.pdf ├── mr_amazon.pdf ├── time_mr_yelp.pdf ├── time_absa_amazon.pdf ├── time_absa_yelp.pdf └── time_mr_amazon.pdf ├── scripts ├── finetune_ckpt_all_seed.sh ├── finetune_origin_all_seed.sh ├── run_mask_model.sh ├── finetune_ckpt.sh ├── finetune_origin.sh ├── run_all_pipeline.sh └── run_pretraining.sh ├── data ├── create_data_rule │ ├── run.sh │ ├── xarg_wrapper.sh │ ├── config.sh │ └── create_mask_dataset.sh ├── create_data_model │ ├── run.sh │ ├── xarg_wrapper.sh │ ├── config.sh │ └── create_mask_dataset.sh ├── merge_pkl.py ├── rand_mask_gen.py ├── merge_hdf5.py ├── create_data.py └── sc_mask_gen.py ├── requirements.txt ├── gather_results.py ├── convert_config.py ├── plot ├── plot_time_absa_yelp.py ├── plot_time_mr_amazon.py ├── plot_time_absa_amazon.py ├── plot_time_mr_yelp.py ├── plot_mr_yelp.py ├── plot_mr_amazon.py ├── plot_absa_amazon.py ├── plot_lap_amazon.py ├── plot_lap_yelp.py └── plot_absa_yelp.py ├── sig-test └── res-yelp.py ├── config └── test.json ├── README.md ├── .gitignore ├── model ├── schedulers.py ├── file_utils.py ├── fused_adam_local.py ├── optimization.py └── tokenization.py ├── run_pretraining.py ├── finetune.py └── mask_model_pretrain.py /images/lap_yelp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/lap_yelp.pdf -------------------------------------------------------------------------------- /images/mr_yelp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/mr_yelp.pdf -------------------------------------------------------------------------------- /images/absa_amazon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/absa_amazon.pdf -------------------------------------------------------------------------------- /images/absa_yelp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/absa_yelp.pdf -------------------------------------------------------------------------------- /images/lap_amazon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/lap_amazon.pdf -------------------------------------------------------------------------------- /images/mr_amazon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/mr_amazon.pdf -------------------------------------------------------------------------------- /images/time_mr_yelp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/time_mr_yelp.pdf -------------------------------------------------------------------------------- /images/time_absa_amazon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/time_absa_amazon.pdf -------------------------------------------------------------------------------- /images/time_absa_yelp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/time_absa_yelp.pdf -------------------------------------------------------------------------------- /images/time_mr_amazon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/SelectiveMasking/HEAD/images/time_mr_amazon.pdf -------------------------------------------------------------------------------- /scripts/finetune_ckpt_all_seed.sh: -------------------------------------------------------------------------------- 1 | SEED_ARRAY=(13 43 83 181 271 347 433 659 727 859) 2 | 3 | for SEED in ${SEED_ARRAY[@]} 4 | do 5 | bash scripts/finetune_ckpt.sh ${SEED} 6 | done -------------------------------------------------------------------------------- /scripts/finetune_origin_all_seed.sh: -------------------------------------------------------------------------------- 1 | SEED_ARRAY=(13 43 83 181 271 347 433 659 727 859) 2 | 3 | for SEED in ${SEED_ARRAY[@]} 4 | do 5 | bash scripts/finetune_origin.sh ${SEED} 6 | done 7 | -------------------------------------------------------------------------------- /data/create_data_rule/run.sh: -------------------------------------------------------------------------------- 1 | source data/create_data_rule/config.sh 2 | 3 | mkdir -p ${OUTPUT_DIR} 4 | mkdir -p ${OUTPUT_DIR}/merged/ 5 | 6 | bash data/create_data_rule/xarg_wrapper.sh 7 | 8 | python3 data/merge_pkl.py ${OUTPUT_DIR} ${MAX_PROC} 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # progress bars in model download and training scripts 2 | tqdm 3 | # Accessing files from S3 directly. 4 | boto3 5 | # Used for downloading models over HTTP 6 | requests 7 | six 8 | ipdb 9 | #Data processing 10 | h5py 11 | html2text 12 | nltk 13 | progressbar -------------------------------------------------------------------------------- /data/create_data_model/run.sh: -------------------------------------------------------------------------------- 1 | source data/create_data_model/config.sh 2 | 3 | mkdir -p ${OUTPUT_DIR}/model/merged/dev 4 | mkdir -p ${OUTPUT_DIR}/rand/merged/dev 5 | 6 | bash data/create_data_model/xarg_wrapper.sh 7 | 8 | python3 data/merge_hdf5.py ${OUTPUT_DIR}/model/ ${MAX_PROC} 9 | python3 data/merge_hdf5.py ${OUTPUT_DIR}/rand/ ${MAX_PROC} 10 | -------------------------------------------------------------------------------- /data/create_data_rule/xarg_wrapper.sh: -------------------------------------------------------------------------------- 1 | source data/create_data_rule/config.sh 2 | SHARD_COUNT=0 3 | rm xarg_list.txt 4 | touch xarg_list.txt 5 | PART=0 6 | for GPU_ID in ${GPU_LIST[@]}; do 7 | echo "${GPU_ID} ${PART}">> xarg_list.txt 8 | ((PART++)) 9 | done 10 | chmod 777 data/create_data_rule/create_mask_dataset.sh 11 | xargs -n 2 --max-procs=${MAX_PROC} --arg-file=xarg_list.txt data/create_data_rule/create_mask_dataset.sh 12 | rm xarg_list.txt 13 | -------------------------------------------------------------------------------- /data/create_data_model/xarg_wrapper.sh: -------------------------------------------------------------------------------- 1 | source data/create_data_model/config.sh 2 | 3 | SHARD_COUNT=0 4 | rm xarg_list.txt 5 | touch xarg_list.txt 6 | PART=0 7 | for GPU_ID in ${GPU_LIST[@]}; do 8 | echo "${GPU_ID} ${PART}">> xarg_list.txt 9 | ((PART++)) 10 | done 11 | chmod 777 data/create_data_model/create_mask_dataset.sh 12 | xargs -n 2 --max-procs=${MAX_PROC} --arg-file=xarg_list.txt data/create_data_model/create_mask_dataset.sh 13 | rm xarg_list.txt 14 | -------------------------------------------------------------------------------- /gather_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | dir_path = sys.argv[1] 6 | 7 | tot = 0 8 | seeds = [13, 43, 83, 181, 271, 347, 433, 659, 727, 859] 9 | for seed in seeds: 10 | path = os.path.join(dir_path, str(seed), "test_results.txt") 11 | with open(path, "r") as f: 12 | acc = float(f.readlines()[0].strip().split()[-1]) 13 | for l in f.readlines(): 14 | print(l) 15 | tot += acc 16 | 17 | print("Gathered results on different random seeds. Average Acc. : ") 18 | print(tot / len(seeds)) 19 | -------------------------------------------------------------------------------- /data/merge_pkl.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | import os 4 | import random 5 | 6 | origin_dir = sys.argv[1] 7 | num_files = int(sys.argv[2]) 8 | 9 | L = [] 10 | 11 | dev_rate = 0.1 12 | 13 | for i in range(num_files): 14 | filename = os.path.join(origin_dir, "{}.pkl".format(i)) 15 | print(filename) 16 | with open(filename, "rb") as f: 17 | L.extend(pickle.load(f)) 18 | 19 | all_data_size = len(L) 20 | 21 | train_data = L[0: int((1 - dev_rate) * all_data_size)] 22 | dev_data = L[int((1 - dev_rate) * all_data_size):] 23 | 24 | with open(os.path.join(origin_dir, "merged", "train.pkl"), "wb") as f: 25 | pickle.dump(train_data, f) 26 | 27 | with open(os.path.join(origin_dir, "merged", "valid.pkl"), "wb") as f: 28 | pickle.dump(dev_data, f) 29 | -------------------------------------------------------------------------------- /convert_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | 5 | config_path = sys.argv[1] 6 | 7 | with open(config_path) as f: 8 | config = json.load(f) 9 | 10 | os.makedirs("config/", exist_ok=True) 11 | 12 | f = open("config/bash_config.sh", "w") 13 | 14 | def set_env(d, prefix): 15 | for name in d: 16 | if isinstance(d[name], dict): 17 | set_env(d[name], prefix + "_" + name) 18 | else: 19 | # print("export {}={}".format((prefix + "_" + name).upper(), d[name])) 20 | # os.system("export {}={}".format((prefix + "_" + name).upper(), d[name])) 21 | # os.system("echo ${}".format((prefix + "_" + name).upper())) 22 | f.write("{}={}\n".format((prefix + "_" + name).upper(), d[name])) 23 | 24 | set_env(config, "E") 25 | 26 | f.close() 27 | -------------------------------------------------------------------------------- /scripts/run_mask_model.sh: -------------------------------------------------------------------------------- 1 | source config/bash_config.sh 2 | 3 | DATA_DIR=${E_SELECTIVE_MASKING_TRAIN_NN_DATA_DIR} 4 | OUTPUT_DIR=${E_SELECTIVE_MASKING_TRAIN_NN_OUTPUT_DIR} 5 | BERT_MODEL=${E_GENEPT_BERT_MODEL} 6 | 7 | CMD="mask_model_pretrain.py" 8 | CMD+=" --bert_model=${BERT_MODEL}" 9 | CMD+=" --task_name=MaskGen" 10 | CMD+=" --data_dir=${DATA_DIR}" 11 | CMD+=" --output_dir=${OUTPUT_DIR}" 12 | CMD+=" --max_seq_length=128 " 13 | CMD+=" --train_batch_size=32" 14 | CMD+=" --num_train_epochs=10" 15 | CMD+=" --learning_rate=1e-5" 16 | CMD+=" --do_lower_case" 17 | CMD+=" --gradient_accumulation_steps 2" 18 | CMD+=" --sample_weight=3" 19 | CMD+=" --do_train" 20 | CMD+=" --do_eval" 21 | # CMD+=" --save_all" 22 | 23 | export CUDA_VISIBLE_DEVICES=${E_SELECTIVE_MASKING_TRAIN_NN_GPU_LIST} 24 | CMD="python3 ${CMD}" 25 | 26 | echo ${CMD} 27 | 28 | ${CMD} -------------------------------------------------------------------------------- /scripts/finetune_ckpt.sh: -------------------------------------------------------------------------------- 1 | 2 | SEED=$1 3 | 4 | source config/bash_config.sh 5 | 6 | DATA_DIR=${E_FINE_TUNING_DATA_DIR} 7 | BERT_MODEL=${E_GENEPT_BERT_MODEL} 8 | OUTPUT_DIR=$E_FINE_TUNING_OUTPUT_DIR}/${SEED} 9 | CKPT=${E_FINE_TUNING_CKPT} 10 | 11 | CMD="finetune.py" 12 | CMD+=" --bert_model=${BERT_MODEL}" 13 | CMD+=" --do_train" 14 | CMD+=" --do_eval" 15 | CMD+=" --task_name=absa_term" 16 | CMD+=" --data_dir=${DATA_DIR}" 17 | CMD+=" --output_dir=${OUTPUT_DIR} " 18 | CMD+=" --max_seq_length=256 " 19 | CMD+=" --train_batch_size=32" 20 | CMD+=" --num_train_epochs=10" 21 | CMD+=" --learning_rate=2e-5" 22 | CMD+=" --do_lower_case" 23 | CMD+=" --gradient_accumulation_steps 2" 24 | CMD+=" --seed=${SEED}" 25 | CMD+=" --fp16" 26 | CMD+=" --ckpt ${CKPT}" 27 | 28 | export CUDA_VISIBLE_DEVICES=${E_FINE_TUNING_GPU_LIST} 29 | 30 | CMD="python3 ${CMD}" 31 | 32 | echo ${CMD} 33 | 34 | ${CMD} -------------------------------------------------------------------------------- /scripts/finetune_origin.sh: -------------------------------------------------------------------------------- 1 | SEED=${1:-42} 2 | 3 | source config/bash_config.sh 4 | 5 | DATA_DIR=${E_SELECTIVE_MASKING_FINETUNE_BERT_DATA_DIR} 6 | BERT_MODEL=${E_GENEPT_BERT_MODEL} 7 | OUTPUT_DIR=${E_SELECTIVE_MASKING_FINETUNE_BERT_OUTPUT_DIR}/${SEED} 8 | 9 | CMD="finetune.py" 10 | CMD+=" --bert_model=${BERT_MODEL}" 11 | CMD+=" --do_train" 12 | CMD+=" --do_eval" 13 | CMD+=" --task_name=absa_term" 14 | CMD+=" --data_dir=${DATA_DIR}" 15 | CMD+=" --output_dir=${OUTPUT_DIR} " 16 | CMD+=" --max_seq_length=256 " 17 | CMD+=" --train_batch_size=32" 18 | CMD+=" --num_train_epochs=10" 19 | CMD+=" --learning_rate=2e-5" 20 | CMD+=" --do_lower_case" 21 | CMD+=" --gradient_accumulation_steps 2" 22 | CMD+=" --seed=${SEED}" 23 | CMD+=" --fp16" 24 | 25 | export CUDA_VISIBLE_DEVICES=${E_SELECTIVE_MASKING_FINETUNE_BERT_GPU_LIST} 26 | 27 | CMD="python3 ${CMD}" 28 | 29 | echo ${CMD} 30 | 31 | ${CMD} -------------------------------------------------------------------------------- /data/create_data_model/config.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | source config/bash_config.sh 3 | 4 | MAX_SEQUENCE_LENGTH=256 5 | MAX_PREDICTIONS_PER_SEQUENCE=40 6 | MASKED_LM_PROB=0.15 7 | SEED=12345 8 | DUPE_FACTOR=1 9 | DO_WITH_RAND=true 10 | N_LINES_PER_SHARD_APPROX=99000 11 | 12 | GPU_LIST=(${E_SELECTIVE_MASKING_IN_DOMAIN_MASK_GPU_LIST[@]}) # Adjust this based on memory requirements and available number of cores 13 | MAX_PROC=${#GPU_LIST[@]} 14 | 15 | MODE=model 16 | TASK_NAME=${E_SELECTIVE_MASKING_IN_DOMAIN_MASK_TASK_NAME} 17 | 18 | INPUT_DIR=${E_SELECTIVE_MASKING_IN_DOMAIN_MASK_DATA_DIR} 19 | 20 | OUTPUT_DIR=${E_SELECTIVE_MASKING_IN_DOMAIN_MASK_OUTPUT_DIR} 21 | 22 | # model to generate mask training sets 23 | BERT_MODEL=${E_SELECTIVE_MASKING_IN_DOMAIN_MASK_BERT_MODEL} 24 | 25 | TOP_SEN_RATE=1 26 | THRESHOLD=0.01 27 | 28 | WITH_RAND="" 29 | if [ "$DO_WITH_RAND" = true ] ; then 30 | WITH_RAND="--with_rand" 31 | fi -------------------------------------------------------------------------------- /data/create_data_rule/config.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | source config/bash_config.sh 3 | 4 | MAX_SEQUENCE_LENGTH=128 5 | MAX_PREDICTIONS_PER_SEQUENCE=20 6 | MASKED_LM_PROB=0.15 7 | SEED=12345 8 | DUPE_FACTOR=1 9 | DO_WITH_RAND=false 10 | N_LINES_PER_SHARD_APPROX=99000 11 | 12 | GPU_LIST=(${E_SELECTIVE_MASKING_DOWNSTREAM_MASK_GPU_LIST[@]}) # Adjust this based on memory requirements and available number of cores 13 | MAX_PROC=${#GPU_LIST[@]} 14 | 15 | MODE=rule 16 | TASK_NAME=${E_SELECTIVE_MASKING_DOWNSTREAM_MASK_TASK_NAME} 17 | 18 | INPUT_DIR=${E_SELECTIVE_MASKING_DOWNSTREAM_MASK_DATA_DIR} 19 | 20 | OUTPUT_DIR=${E_SELECTIVE_MASKING_DOWNSTREAM_MASK_OUTPUT_DIR} 21 | 22 | # model to generate mask training sets 23 | BERT_MODEL=${E_SELECTIVE_MASKING_DOWNSTREAM_MASK_BERT_MODEL} 24 | 25 | TOP_SEN_RATE=1 26 | THRESHOLD=0.01 27 | 28 | WITH_RAND="" 29 | if [ "$DO_WITH_RAND" = true ] ; then 30 | WITH_RAND="--with_rand" 31 | fi -------------------------------------------------------------------------------- /data/create_data_model/create_mask_dataset.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source data/create_data_model/config.sh 4 | 5 | echo "Bert model: ${BERT_MODEL}" 6 | echo "Input dir: ${INPUT_DIR}" 7 | echo "Output dir: ${OUTPUT_DIR}" 8 | 9 | GPU_ID=$1 10 | PART=$2 11 | 12 | echo "GPU id: ${GPU_ID} PART: ${PART}" 13 | CMD="data/create_data.py" 14 | CMD+=" --input_dir=${INPUT_DIR}" 15 | CMD+=" --output_dir=${OUTPUT_DIR}" 16 | CMD+=" --max_seq_length=${MAX_SEQUENCE_LENGTH}" 17 | CMD+=" --max_predictions_per_seq=${MAX_PREDICTIONS_PER_SEQUENCE}" 18 | CMD+=" --masked_lm_prob=${MASKED_LM_PROB}" 19 | CMD+=" --random_seed=${SEED}" 20 | CMD+=" --dupe_factor=${DUPE_FACTOR}" 21 | CMD+=" --bert_model=${BERT_MODEL}" 22 | CMD+=" --task_name=${TASK_NAME}" 23 | CMD+=" --top_sen_rate=${TOP_SEN_RATE}" 24 | CMD+=" --part ${PART}" 25 | CMD+=" --threshold=${THRESHOLD}" 26 | CMD+=" --max_proc=${MAX_PROC}" 27 | CMD+=" --mode=${MODE}" 28 | CMD+=" --do_lower_case" 29 | CMD+=" ${WITH_RAND}" 30 | 31 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 32 | CMD="python3 ${CMD}" 33 | 34 | echo ${CMD} 35 | 36 | ${CMD} -------------------------------------------------------------------------------- /data/create_data_rule/create_mask_dataset.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source data/create_data_rule/config.sh 4 | 5 | echo "Bert model: ${BERT_MODEL}" 6 | echo "Input dir: ${INPUT_DIR}" 7 | echo "Output dir: ${OUTPUT_DIR}" 8 | 9 | GPU_ID=$1 10 | PART=$2 11 | 12 | echo "GPU id: ${GPU_ID} PART: ${PART}" 13 | CMD="data/create_data.py" 14 | CMD+=" --input_dir=${INPUT_DIR}" 15 | CMD+=" --output_dir=${OUTPUT_DIR}" 16 | CMD+=" --max_seq_length=${MAX_SEQUENCE_LENGTH}" 17 | CMD+=" --max_predictions_per_seq=${MAX_PREDICTIONS_PER_SEQUENCE}" 18 | CMD+=" --masked_lm_prob=${MASKED_LM_PROB}" 19 | CMD+=" --random_seed=${SEED}" 20 | CMD+=" --dupe_factor=${DUPE_FACTOR}" 21 | CMD+=" --bert_model=${BERT_MODEL}" 22 | CMD+=" --task_name=${TASK_NAME}" 23 | CMD+=" --top_sen_rate=${TOP_SEN_RATE}" 24 | CMD+=" --part ${PART}" 25 | CMD+=" --threshold=${THRESHOLD}" 26 | CMD+=" --max_proc=${MAX_PROC}" 27 | CMD+=" --mode=${MODE}" 28 | CMD+=" --do_lower_case" 29 | CMD+=" ${WITH_RAND}" 30 | 31 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 32 | CMD="python3 ${CMD}" 33 | 34 | echo ${CMD} 35 | 36 | ${CMD} -------------------------------------------------------------------------------- /plot/plot_time_absa_yelp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | fracs = [1244, 86, 829, 2160] 5 | labels = ["GenePT", "Selective Masking", "TaskPT", "Saved Cost"] 6 | 7 | plt.figure(figsize=(9, 9)) 8 | 9 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 10 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 11 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 12 | 13 | plt.axes(aspect=1) 14 | plt.pie( 15 | x=fracs, 16 | labels=labels, 17 | startangle=90, 18 | colors=["lightskyblue", "gold", "lightgreen", "white"], 19 | wedgeprops={'linewidth': 0.5, 'edgecolor': "black"}, 20 | explode=[0, 0, 0, 0.06], 21 | shadow=True, 22 | labeldistance=10, 23 | radius=1, 24 | autopct='%3.1f %%', 25 | pctdistance=0.8, 26 | textprops={ 27 | 'size': 16 28 | } 29 | ) 30 | 31 | font1 = {'family': 'Times New Roman', 32 | 'weight': 'normal', 33 | 'size': 11.5, 34 | } 35 | plt.legend(loc="upper right", prop=font1) 36 | plt.savefig("../images/time_absa_yelp.pdf", format="pdf") 37 | -------------------------------------------------------------------------------- /plot/plot_time_mr_amazon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | fracs = [1180, 160, 648, 1912] 5 | labels = ["GenePT", "Selective Masking", "TaskPT", "Saved Cost"] 6 | 7 | plt.figure(figsize=(9, 9)) 8 | 9 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 10 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 11 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 12 | 13 | plt.axes(aspect=1) 14 | plt.pie( 15 | x=fracs, 16 | labels=labels, 17 | startangle=90, 18 | colors=["lightskyblue", "gold", "lightgreen", "white"], 19 | wedgeprops={'linewidth': 0.5, 'edgecolor': "black"}, 20 | explode=[0, 0, 0, 0.06], 21 | shadow=True, 22 | labeldistance=10, 23 | radius=1, 24 | autopct='%3.1f %%', 25 | pctdistance=0.8, 26 | textprops={ 27 | 'size': 16 28 | } 29 | ) 30 | 31 | font1 = {'family': 'Times New Roman', 32 | 'weight': 'normal', 33 | 'size': 11.5, 34 | } 35 | plt.legend(loc="upper right", prop=font1) 36 | plt.savefig("../images/time_mr_amazon.pdf", format="pdf") 37 | -------------------------------------------------------------------------------- /scripts/run_all_pipeline.sh: -------------------------------------------------------------------------------- 1 | # setup envs 2 | echo -e "\033[42;37m Set Up Enviroments\033[0m" 3 | python3 convert_config.py $1 4 | cat config/bash_config.sh 5 | source config/bash_config.sh 6 | 7 | # Selective Masking - finetune BERT 8 | echo -e "\033[42;37m Selective Masking - Finetune BERT\033[0m" 9 | bash scripts/finetune_origin.sh 10 | 11 | # Selective Masking - downstream mask 12 | echo -e "\033[42;37m Selective Masking - Downstream Mask\033[0m" 13 | bash data/create_data_rule/run.sh 14 | 15 | # Selective Masking - train nn 16 | echo -e "\033[42;37m Selective Masking - Train NN\033[0m" 17 | bash scripts/run_mask_model.sh 18 | 19 | # Selective Masking - in-domain mask 20 | echo -e "\033[42;37m Selective Masking - In-domain Mask\033[0m" 21 | bash data/create_data_model/run.sh 22 | 23 | # TaskPT 24 | echo -e "\033[42;37m TaskPT\033[0m" 25 | bash scripts/run_pretraining.sh 26 | 27 | # Fine-tuning 28 | echo -e "\033[42;37m Fine-tuning\033[0m" 29 | bash scripts/finetune_ckpt_all_seed.sh 30 | 31 | # Gather results of different seed 32 | echo -e "\033[42;37m Gather results of different seed\033[0m" 33 | python3 gather_results.py ${E_FINE_TUNING_OUTPUT_DIR} 34 | -------------------------------------------------------------------------------- /plot/plot_time_absa_amazon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | fracs = [1244, 136, 829, 2160] 5 | labels = ["GenePT", "Selective Masking", "TaskPT", "Saved Cost"] 6 | 7 | plt.figure(figsize=(9, 9)) 8 | 9 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 10 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 11 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 12 | 13 | plt.axes(aspect=1) 14 | plt.pie( 15 | x=fracs, 16 | labels=labels, 17 | startangle=90, 18 | # colors=["lightskyblue", "orange", "cyan", "lightcoral", "gold", "lightgreen", "white"], 19 | colors=["lightskyblue", "gold", "lightgreen", "white"], 20 | wedgeprops={'linewidth': 0.5, 'edgecolor': "black"}, 21 | explode=[0, 0, 0, 0.06], 22 | shadow=True, 23 | labeldistance=10, 24 | radius=1, 25 | autopct='%3.1f %%', 26 | pctdistance=0.8, 27 | textprops={ 28 | 'size': 16 29 | } 30 | ) 31 | 32 | font1 = {'family': 'Times New Roman', 33 | 'weight': 'normal', 34 | 'size': 11.5, 35 | } 36 | plt.legend(loc="upper right", prop=font1) 37 | plt.savefig("../images/time_absa_amazon.pdf", format="pdf") 38 | -------------------------------------------------------------------------------- /scripts/run_pretraining.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source config/bash_config.sh 3 | 4 | PER_GPU_TRN_BS=28 5 | LR=2e-5 6 | NUM_GPUS=4 7 | WARM_UP=0.2 8 | TRN_STEPS=50 9 | SAVE_STEPS=4 10 | GRAD_ACC=1 11 | SEED=88 12 | 13 | DATA_DIR=${E_TASKPT_DATA_DIR} 14 | BERT_MODEL=${E_GENEPT_BERT_MODEL} 15 | RESULTS_DIR=${E_TASKPT_OUTPUT_DIR} 16 | CHECKPOINTS_DIR=${RESULTS_DIR}/checkpoints/ 17 | 18 | mkdir -p $CHECKPOINTS_DIR 19 | 20 | echo $DATA_DIR 21 | INPUT_DIR=$DATA_DIR 22 | CMD="run_pretraining.py" 23 | CMD+=" --input_dir=$DATA_DIR" 24 | CMD+=" --output_dir=$CHECKPOINTS_DIR" 25 | CMD+=" --ckpt=${CKPT}" 26 | CMD+=" --bert_model=${BERT_MODEL}" 27 | CMD+=" --train_batch_size=${PER_GPU_TRN_BS}" 28 | CMD+=" --max_seq_length=256" 29 | CMD+=" --max_predictions_per_seq=80" 30 | CMD+=" --max_steps=$TRN_STEPS" 31 | CMD+=" --warmup_proportion=$WARM_UP" 32 | CMD+=" --num_steps_per_checkpoint=$SAVE_STEPS" 33 | CMD+=" --learning_rate=$LR" 34 | CMD+=" --seed=$SEED" 35 | CMD+=" --fp16" 36 | 37 | export CUDA_VISIBLE_DEVICES=${E_TASKPT_GPU_LIST} 38 | if [ "$NUM_GPUS" -gt 1 ] ; then 39 | CMD="python3 -m torch.distributed.launch --nproc_per_node=$NUM_GPUS $CMD" 40 | else 41 | CMD="python3 $CMD" 42 | fi 43 | 44 | echo ${CMD} 45 | 46 | ${CMD} -------------------------------------------------------------------------------- /sig-test/res-yelp.py: -------------------------------------------------------------------------------- 1 | import scipy.stats as stats 2 | 3 | D = { 4 | "semres-yelp": { 5 | "model": [93.2, 91.4, 91.8, 90.1, 92.5, 92.4, 91.3, 92.7, 91.5, 91.5], 6 | "rand": [88.1, 91.2, 88.9, 91.7, 90.5, 91.6, 91.0, 92.4, 92.1, 89.7] 7 | }, 8 | "semres-amazon": { 9 | "model": [91.9, 91.5, 91.2, 92.2, 91.2, 92.3, 91.9, 90.5, 91.9, 91.1], 10 | "rand": [89.8, 90.7, 89.0, 89.7, 90.5, 90.6, 90.8, 91.4, 90.2, 91.2] 11 | }, 12 | "mr-yelp": { 13 | "model": [88.4, 87.2, 87.5, 87.8, 87.8, 88.6, 87.7, 87.7, 88.8, 87.6], 14 | "rand": [86.8, 87.0, 87.0, 86.6, 86.3, 87.2, 87.1, 87.5, 86.9, 87.9] 15 | }, 16 | "mr-amazon": { 17 | "model": [89.3, 89.5, 90.1, 88.9, 89.0, 90.0, 89.6, 89.5, 89.3, 89.6], 18 | "rand": [87.6, 88.5, 88.2, 88.0, 88.7, 88.9, 87.9, 89.3, 87.9, 88.5] 19 | }, 20 | "lap-yelp": { 21 | "model": [76.5, 75.1, 74.0, 76.2, 75.9, 75.1, 75.7, 76.6, 72.7, 74.6], 22 | "rand": [73.3, 71.8, 76.6, 76.2, 67.9, 73.4, 71.8, 75.4, 77.0, 74.0] 23 | } 24 | } 25 | 26 | res = {} 27 | 28 | for k, v in D.items(): 29 | print(sum(v["model"]) / 10) 30 | print(sum(v["rand"]) / 10) 31 | _, p = stats.f_oneway(v["model"], v["rand"]) 32 | res[k] = p 33 | 34 | print(res) 35 | -------------------------------------------------------------------------------- /plot/plot_time_mr_yelp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | # fracs = [1180, 10, 20, 10, 70, 648, 1912] 5 | fracs = [1180, 110, 648, 1912] 6 | 7 | # labels = ["GenePT", "Finetune BERT", "Rule Base", "Train NN", "NN base", "TaskPT", "Saved Cost"] 8 | labels = ["GenePT", "Selective Masking", "TaskPT", "Saved Cost"] 9 | 10 | plt.figure(figsize=(9, 9)) 11 | 12 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 13 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 14 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 15 | 16 | plt.axes(aspect=1) 17 | plt.pie( 18 | x=fracs, 19 | labels=labels, 20 | startangle=90, 21 | colors=["lightskyblue", "gold", "lightgreen", "white"], 22 | # colors=["lightskyblue", "orange", "cyan", "lightcoral", "gold", "lightgreen", "white"], 23 | wedgeprops={'linewidth': 0.5, 'edgecolor': "black"}, 24 | # explode=[0, 0, 0, 0, 0, 0, 0.06], 25 | explode=[0, 0, 0, 0.06], 26 | shadow=True, 27 | labeldistance=10, 28 | radius=1, 29 | autopct='%3.1f %%', 30 | pctdistance=0.8, 31 | textprops={ 32 | 'size': 16 33 | } 34 | ) 35 | 36 | font1 = {'family': 'Times New Roman', 37 | 'weight': 'normal', 38 | 'size': 11.5, 39 | } 40 | plt.legend(loc="upper right", prop=font1) 41 | plt.savefig("../images/time_mr_yelp.pdf", format="pdf") 42 | -------------------------------------------------------------------------------- /config/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "GenePT": { 3 | "BERT_MODEL": "pretrain_bert_model/bert-base-uncased/" 4 | }, 5 | "Selective_Masking": { 6 | "Finetune_BERT": { 7 | "DATA_DIR": "data/datasets/ABSA/14data_lap/", 8 | "OUTPUT_DIR": "results/test/origin/CKPT_1M/", 9 | "GPU_LIST": "0," 10 | }, 11 | "Downstream_Mask": { 12 | "BERT_MODEL": "results/test/origin/CKPT_1M/42/best_model/", 13 | "DATA_DIR": "data/datasets/ABSA/14data_lap/", 14 | "OUTPUT_DIR": "data/datasets/test/full_rule_mask/", 15 | "GPU_LIST": "(0 1)", 16 | "TASK_NAME": "absa_term" 17 | }, 18 | "Train_NN": { 19 | "DATA_DIR": "data/datasets/test/full_rule_mask/merged/", 20 | "OUTPUT_DIR": "results/test/full_mask_generator/", 21 | "GPU_LIST": "0," 22 | }, 23 | "In_domain_Mask": { 24 | "BERT_MODEL": "results/test/full_mask_generator/best_model/", 25 | "DATA_DIR": "data/datasets/YELP-AMAZON/amazon_review_full_csv", 26 | "OUTPUT_DIR": "SelectiveMasking/data/datasets/test/full_amazon/", 27 | "GPU_LIST": "(0 1)", 28 | "TASK_NAME": "amazon" 29 | } 30 | }, 31 | "TaskPT": { 32 | "DATA_DIR": "data/datasets/test/full_amazon/model/merged/", 33 | "OUTPUT_DIR": "results/test/full_amazon/model/", 34 | "GPU_LIST": "0,1,2,3" 35 | }, 36 | "Fine_tuning": { 37 | "CKPT": "results/test/full_amazon/model/checkpoints/best_ckpt.pt", 38 | "DATA_DIR": "data/datasets/ABSA/14data_lap/", 39 | "OUTPUT_DIR": "results/test/full_amazon/model/", 40 | "GPU_LIST": "0," 41 | } 42 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Selective Masking 2 | 3 | Source code for "Train No Evil: Selective Masking for Task-Guided Pre-Training" 4 | 5 | ## Download Data 6 | 7 | The datasets can be downloaded from this [link](https://drive.google.com/file/d/1dnDQO6kCNOe2iCpDl-xJ4XXRKLXq-5yw/view?usp=sharing). The datasets need to be put in `data/datasets`. 8 | 9 | ## Run the Whole Pipeline 10 | 11 | 1. Modify `config/test.json` for input path, output path, BERT model path, GPU usage etc. 12 | 13 | 2. run `bash scripts/run_all_pipeline.sh` . 14 | 15 | ## Run each step 16 | 17 | The meaning of each step can be found in the appendix of our paper. The input/output paths are also set in `config/test.json`. Run `python3 convert_config.py config/test.json` to convert the .json file to a .sh file. 18 | 19 | ### 1 GenePT 20 | 21 | We use the training scripts from for general pre-training. 22 | 23 | ### 2 Selective Masking 24 | 25 | #### 2.1 Finetune BERT 26 | 27 | ```[bash] 28 | bash scripts/finetune_origin.sh 29 | ``` 30 | 31 | #### 2.2 Downstream Mask 32 | 33 | ```[bash] 34 | bash data/create_data_rule/run.sh. 35 | ``` 36 | 37 | #### 2.3 Train NN 38 | 39 | ```[bash] 40 | bash scripts/run_mask_model.sh 41 | ``` 42 | 43 | #### 2.4 In-domain Mask 44 | 45 | ```[bash] 46 | bash data/create_data_model/run.sh 47 | ``` 48 | 49 | ### 3 TaskPT 50 | 51 | ```[bash] 52 | bash scripts/run_pretraining.sh 53 | ``` 54 | 55 | ### 4 Fine-tune 56 | 57 | ```[bash] 58 | bash scripts/finetune_ckpt_all_seed.sh 59 | python3 gather_results.py $PATH_TO_THE_FINETUNE_OUTPUT 60 | ``` 61 | 62 | ## Cite 63 | 64 | If you use the code, please cite this paper: 65 | 66 | ```[] 67 | @inproceedings{gu2020train, 68 | title={Train No Evil: Selective Masking for Task-Guided Pre-Training}, 69 | author={Yuxian Gu and Zhengyan Zhang and Xiaozhi Wang and Zhiyuan Liu and Maosong Sun}, 70 | year={2020}, 71 | booktitle={Proceedings of EMNLP 2020}, 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | #Data 12 | data/*/*/ 13 | data/*/*.zip 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json 118 | 119 | # Pyre type checker 120 | .pyre/ 121 | 122 | # vscode 123 | .vscode 124 | 125 | # TF code 126 | tensorflow_code 127 | 128 | # Models 129 | models 130 | 131 | *.un~ 132 | upload.sh 133 | -------------------------------------------------------------------------------- /data/rand_mask_gen.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import spacy 6 | import sys 7 | import collections 8 | from tqdm import tqdm 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.nn.functional import softmax 11 | 12 | sys.path.append("../") 13 | from model.tokenization import BertTokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | MaskedTokenInstance = collections.namedtuple("MaskedTokenInstance", ["tokens", "info"]) 17 | MaskedItemInfo = collections.namedtuple("MaskedItemInfo", ["current_pos", "sen_doc_pos", "sen_right_id", "doc_ground_truth"]) 18 | 19 | 20 | class InputFeatures(object): 21 | def __init__(self, input_ids, input_mask, segment_ids): 22 | self.input_ids = input_ids 23 | self.input_mask = input_mask 24 | self.segment_ids = segment_ids 25 | 26 | 27 | class RandMask(nn.Module): 28 | def __init__(self, mask_rate, bert_model, do_lower_case, max_seq_length): 29 | super(RandMask, self).__init__() 30 | self.mask_rate = mask_rate 31 | self.max_seq_length = max_seq_length 32 | self.tokenizer = BertTokenizer.from_pretrained( 33 | bert_model, do_lower_case=do_lower_case) 34 | self.vocab = list(self.tokenizer.vocab.keys()) 35 | 36 | def forward(self, data, all_labels, dupe_factor, rng): 37 | # data: not tokenized 38 | all_documents = [] 39 | for _ in range(dupe_factor): 40 | for line in tqdm(data): 41 | all_documents.append([]) 42 | tokens = self.tokenizer.tokenize(line) 43 | cand_indexes = [i for i in range(len(tokens))] 44 | rng.shuffle(cand_indexes) 45 | masked_info = [{} for token in tokens] 46 | masked_token = None 47 | masked_lms_len = 0 48 | num_to_predict = max(1, int(round(len(tokens) * self.mask_rate))) 49 | for index in cand_indexes: 50 | if masked_lms_len >= num_to_predict: 51 | break 52 | if rng.random() < 0.8: 53 | masked_token = "[MASK]" 54 | else: 55 | if rng.random() < 0.5: 56 | masked_token = tokens[index] 57 | else: 58 | masked_token = self.vocab[rng.randint(0, len(self.vocab) - 1)] 59 | 60 | masked_info[index]["mask"] = masked_token 61 | masked_info[index]["label"] = tokens[index] 62 | masked_lms_len += 1 63 | all_documents[-1].append(MaskedTokenInstance(tokens=tokens, info=masked_info)) 64 | return all_documents 65 | -------------------------------------------------------------------------------- /plot/plot_mr_yelp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | font1 = {'family': 'Times New Roman', 6 | 'weight': 'normal', 7 | 'size': 25, 8 | } 9 | 10 | font2 = {'family': 'Times New Roman', 11 | 'weight': 'normal', 12 | 'size': 35, 13 | } 14 | 15 | fig = plt.figure(figsize=(10, 10)) 16 | sub = fig.add_subplot(111) 17 | 18 | shiftY = 0.4 19 | shiftX = 35 20 | ms = 11 21 | 22 | fontsize1 = 28 23 | fontsize2 = 35 24 | 25 | model_steps_30 = [300, 464] 26 | model_score_30 = [81.34, 85.43] 27 | l2, = sub.plot(model_steps_30, model_score_30, 'b^-', ms=ms) 28 | sub.text(model_steps_30[-1], model_score_30[-1] + shiftY, "300k\nTask", ha='center', va='center', fontsize=fontsize1) 29 | 30 | rand_steps_30 = [300, 480] 31 | rand_score_30 = [81.34, 84.91] 32 | l3, = sub.plot(rand_steps_30, rand_score_30, 'gs-', ms=ms) 33 | sub.text(rand_steps_30[-1] + shiftX, rand_score_30[-1], "300k\nRand.", ha='center', va='center', fontsize=fontsize1) 34 | 35 | model_steps_20 = [200, 368] 36 | model_score_20 = [80.99, 84.55] 37 | sub.plot(model_steps_20, model_score_20, 'b^-', ms=ms) 38 | sub.text(model_steps_20[-1], model_score_20[-1] + shiftY, "200k\nTask", ha='center', va='center', fontsize=fontsize1) 39 | 40 | rand_steps_20 = [200, 368] 41 | rand_score_20 = [80.99, 83.37] 42 | sub.plot(rand_steps_20, rand_score_20, 'gs-', ms=ms) 43 | sub.text(rand_steps_20[-1], rand_score_20[-1] + shiftY, "200k\nRand.", ha='center', va='center', fontsize=fontsize1) 44 | 45 | model_steps_10 = [100, 256] 46 | model_score_10 = [79.7, 83.50] 47 | sub.plot(model_steps_10, model_score_10, 'b^-', ms=ms) 48 | sub.text(model_steps_10[-1], model_score_10[-1] + shiftY, "100k\nTask", ha='center', va='center', fontsize=fontsize1) 49 | 50 | rand_steps_10 = [100, 252] 51 | rand_score_10 = [79.7, 82.3] 52 | sub.plot(rand_steps_10, rand_score_10, 'gs-', ms=ms) 53 | sub.text(rand_steps_10[-1], rand_score_10[-1] + shiftY, "100k\nRand.", ha='center', va='center', fontsize=fontsize1) 54 | 55 | main_steps = [100, 200, 300] 56 | main_score = [79.7, 80.99, 81.34] 57 | l1, = sub.plot(main_steps, main_score, 'ro-', ms=ms) 58 | 59 | for step, score in zip(main_steps, main_score): 60 | sub.text(step + shiftX *4 / 5, score - shiftY / 4, str(step) + "k", ha='center', va='center', fontsize=fontsize1) 61 | sub.hlines(87.37, 100, 500, colors="gray", linestyles="dashed", linewidth=3) 62 | sub.text(210, 87, "Fully-trained (1M steps)", ha='center', va='center', fontsize=27) 63 | sub.text(48, 87.37, "87.4-", ha='center', va='center', fontsize=fontsize2) 64 | 65 | plt.grid() 66 | plt.tick_params(labelsize=fontsize2) 67 | 68 | plt.xlabel("k Steps", font2) 69 | plt.ylabel("Acc.(%)", font2) 70 | 71 | plt.legend(handles=[l1, l2, l3], labels=['General Pre-train', 'Selective Mask', 'Random Mask'], loc='lower right', prop=font1) 72 | plt.savefig("../images/mr_yelp.pdf", format="pdf") 73 | -------------------------------------------------------------------------------- /data/merge_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import sys 3 | import os 4 | import numpy as np 5 | import collections 6 | from tqdm import tqdm 7 | 8 | origin_dir = sys.argv[1] 9 | num_files = int(sys.argv[2]) 10 | 11 | features_trn, features_dev = collections.OrderedDict(), collections.OrderedDict() 12 | num_instances, max_seq_length, max_predictions_per_seq = 0, 0, 0 13 | data = [] 14 | 15 | dev_rate = 0.1 16 | 17 | for i in range(num_files): 18 | filename = os.path.join(origin_dir, "{}.hdf5".format(i)) 19 | with h5py.File(filename, "r") as f: 20 | num_inst = f["input_ids"].shape[0] 21 | max_seq_length = f["input_ids"].shape[1] 22 | max_predictions_per_seq = f["masked_lm_positions"].shape[1] 23 | num_instances += num_inst 24 | for k in tqdm(range(num_inst), desc="loading {}".format(filename)): 25 | data.append({ 26 | "input_ids": f["input_ids"][i, :], 27 | "input_mask": f["input_mask"][i, :], 28 | "segment_ids": f["segment_ids"][i, :], 29 | "masked_lm_positions": f["masked_lm_positions"][i, :], 30 | "masked_lm_ids": f["masked_lm_ids"][i, :], 31 | "next_sentence_labels": f["next_sentence_labels"][i] 32 | }) 33 | 34 | train_data = data[:int((1 - dev_rate) * num_instances)] 35 | dev_data = data[int((1 - dev_rate) * num_instances):] 36 | 37 | for feat, d, name in [(features_trn, train_data, "train.hdf5"), (features_dev, dev_data, "dev/dev.hdf5")]: 38 | feat["input_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32") 39 | feat["input_mask"] = np.zeros([num_instances, max_seq_length], dtype="int32") 40 | feat["segment_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32") 41 | feat["masked_lm_positions"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32") 42 | feat["masked_lm_ids"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32") 43 | feat["next_sentence_labels"] = np.zeros(num_instances, dtype="int32") 44 | 45 | for i, inst in enumerate(tqdm(d, desc="train data")): 46 | for key in feat: 47 | feat[key][i] = inst[key] 48 | 49 | f = h5py.File(os.path.join(origin_dir, "merged", name), 'w') 50 | f.create_dataset("input_ids", data=feat["input_ids"], dtype='i4', compression='gzip') 51 | f.create_dataset("input_mask", data=feat["input_mask"], dtype='i1', compression='gzip') 52 | f.create_dataset("segment_ids", data=feat["segment_ids"], dtype='i1', compression='gzip') 53 | f.create_dataset("masked_lm_positions", data=feat["masked_lm_positions"], dtype='i4', compression='gzip') 54 | f.create_dataset("masked_lm_ids", data=feat["masked_lm_ids"], dtype='i4', compression='gzip') 55 | f.create_dataset("next_sentence_labels", data=feat["next_sentence_labels"], dtype='i1', compression='gzip') 56 | f.flush() 57 | f.close() 58 | -------------------------------------------------------------------------------- /plot/plot_mr_amazon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | font1 = {'family': 'Times New Roman', 6 | 'weight': 'normal', 7 | 'size': 25, 8 | } 9 | 10 | font2 = {'family': 'Times New Roman', 11 | 'weight': 'normal', 12 | 'size': 35, 13 | } 14 | 15 | fig = plt.figure(figsize=(10, 10)) 16 | sub = fig.add_subplot(111) 17 | 18 | shiftY = 0.4 19 | shiftX = 35 20 | ms = 11 21 | 22 | fontsize1 = 28 23 | fontsize2 = 35 24 | 25 | model_steps_30 = [300, 480] 26 | model_score_30 = [81.34, 87.19] 27 | l2, = sub.plot(model_steps_30, model_score_30, 'b^-', ms=ms) 28 | sub.text(model_steps_30[-1] + shiftX, model_score_30[-1], "300k\nTask", ha='center', va='center', fontsize=fontsize1) 29 | 30 | rand_steps_30 = [300, 488] 31 | rand_score_30 = [81.34, 85.85] 32 | l3, = sub.plot(rand_steps_30, rand_score_30, 'gs-', ms=ms) 33 | sub.text(rand_steps_30[-1], rand_score_30[-1] + shiftY, "300k\nRand.", ha='center', va='center', fontsize=fontsize1) 34 | 35 | model_steps_20 = [200, 380] 36 | model_score_20 = [80.99, 86.43] 37 | sub.plot(model_steps_20, model_score_20, 'b^-', ms=ms) 38 | sub.text(model_steps_20[-1] + shiftX, model_score_20[-1], "200k\nTask", ha='center', va='center', fontsize=fontsize1) 39 | 40 | rand_steps_20 = [200, 384] 41 | rand_score_20 = [80.99, 85.22] 42 | sub.plot(rand_steps_20, rand_score_20, 'gs-', ms=ms) 43 | sub.text(rand_steps_20[-1], rand_score_20[-1] + shiftY, "200k\nRand.", ha='center', va='center', fontsize=fontsize1) 44 | 45 | model_steps_10 = [100, 264] 46 | model_score_10 = [79.7, 85.32] 47 | sub.plot(model_steps_10, model_score_10, 'b^-', ms=ms) 48 | sub.text(model_steps_10[-1], model_score_10[-1] + shiftY, "100k\nTask", ha='center', va='center', fontsize=fontsize1) 49 | 50 | rand_steps_10 = [100, 272] 51 | rand_score_10 = [79.7, 84.57] 52 | sub.plot(rand_steps_10, rand_score_10, 'gs-', ms=ms) 53 | sub.text(rand_steps_10[-1] + shiftX, rand_score_10[-1] + shiftY, "100k\nRand.", ha='center', va='center', fontsize=fontsize1) 54 | 55 | main_steps = [100, 200, 300] 56 | main_score = [79.7, 80.99, 81.34] 57 | l1, = sub.plot(main_steps, main_score, 'ro-', ms=ms) 58 | 59 | for step, score in zip(main_steps, main_score): 60 | sub.text(step + shiftX / 5 * 4, score - shiftY / 4, str(step) + "k", ha='center', va='center', fontsize=fontsize1) 61 | sub.hlines(87.37, 100, 500, colors="gray", linestyles="dashed", linewidth=3) 62 | sub.text(210, 87, "Fully-trained (1M steps)", ha='center', va='center', fontsize=27) 63 | sub.text(48, 87.37, "87.4-", ha='center', va='center', fontsize=fontsize2) 64 | 65 | plt.grid() 66 | plt.tick_params(labelsize=fontsize2) 67 | 68 | plt.xlabel("k Steps", font2) 69 | plt.ylabel("Acc.(%)", font2) 70 | 71 | plt.legend(handles=[l1, l2, l3], labels=['General Pre-train', 'Selective Mask', 'Random Mask'], loc='lower right', prop=font1) 72 | plt.savefig("../images/mr_amazon.pdf", foramt="pdf") 73 | -------------------------------------------------------------------------------- /plot/plot_absa_amazon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | font1 = {'family': 'Times New Roman', 6 | 'weight': 'normal', 7 | 'size': 25, 8 | } 9 | 10 | font2 = {'family': 'Times New Roman', 11 | 'weight': 'normal', 12 | 'size': 35, 13 | } 14 | 15 | fig = plt.figure(figsize=(10, 10)) 16 | sub = fig.add_subplot(111) 17 | 18 | shiftY = 0.4 19 | shiftX = 35 20 | ms = 11 21 | 22 | fontsize1 = 28 23 | fontsize2 = 35 24 | 25 | model_steps_30 = [300, 480] 26 | model_score_30 = [84.42, 88.98] 27 | l2, = sub.plot(model_steps_30, model_score_30, 'b^-', ms=ms) 28 | sub.text(model_steps_30[-1] + shiftX, model_score_30[-1], "300k\nTask", ha='center', va='center', fontsize=fontsize1) 29 | 30 | rand_steps_30 = [300, 480] 31 | rand_score_30 = [84.42, 87.76] 32 | l3, = sub.plot(rand_steps_30, rand_score_30, 'gs-', ms=ms) 33 | sub.text(rand_steps_30[-1], rand_score_30[-1] + shiftY, "300k\nRand.", ha='center', va='center', fontsize=fontsize1) 34 | 35 | model_steps_20 = [200, 366] 36 | model_score_20 = [82.68, 88.06] 37 | sub.plot(model_steps_20, model_score_20, 'b^-', ms=ms) 38 | sub.text(model_steps_20[-1], model_score_20[-1] + shiftY, "200k\nTask", ha='center', va='center', fontsize=fontsize1) 39 | 40 | rand_steps_20 = [200, 384] 41 | rand_score_20 = [82.68, 87.22] 42 | sub.plot(rand_steps_20, rand_score_20, 'gs-', ms=ms) 43 | sub.text(rand_steps_20[-1] + shiftX * 4 / 3, rand_score_20[-1], "200k\nRand.", ha='center', va='center', fontsize=fontsize1) 44 | 45 | model_steps_10 = [100, 252] 46 | model_score_10 = [81.95, 87.84] 47 | sub.plot(model_steps_10, model_score_10, 'b^-', ms=ms) 48 | sub.text(model_steps_10[-1] - shiftX, model_score_10[-1], "100k\nTask", ha='center', va='center', fontsize=fontsize1) 49 | 50 | rand_steps_10 = [100, 244] 51 | rand_score_10 = [81.95, 87.08] 52 | sub.plot(rand_steps_10, rand_score_10, 'gs-', ms=ms) 53 | sub.text(rand_steps_10[-1] + shiftX, rand_score_10[-1] - shiftY, "100k\nRand.", ha='center', va='center', fontsize=fontsize1) 54 | 55 | main_steps = [100, 200, 300] 56 | main_score = [81.95, 82.68, 84.42] 57 | l1, = sub.plot(main_steps, main_score, 'ro-', ms=ms) 58 | 59 | for step, score in zip(main_steps, main_score): 60 | sub.text(step + shiftX, score - shiftY / 2, str(step) + "k", ha='center', va='center', fontsize=fontsize1) 61 | sub.hlines(88.60, 100, 500, colors="gray", linestyles="dashed", linewidth=3) 62 | sub.text(210, 88.9, "Fully-trained (1M steps)", ha='center', va='center', fontsize=27) 63 | sub.text(48, 88.60, "88.6-", ha='center', va='center', fontsize=fontsize2) 64 | 65 | plt.grid() 66 | plt.tick_params(labelsize=fontsize2) 67 | 68 | plt.xlabel("k Steps", font2) 69 | plt.ylabel("Acc.(%)", font2) 70 | 71 | plt.legend(handles=[l1, l2, l3], labels=['General Pre-train', 'Selective Mask', 'Random Mask'], loc='lower right', prop=font1) 72 | plt.savefig("../images/absa_amazon.pdf", foramt="pdf") 73 | -------------------------------------------------------------------------------- /plot/plot_lap_amazon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | font1 = {'family': 'Times New Roman', 6 | 'weight': 'normal', 7 | 'size': 25, 8 | } 9 | 10 | font2 = {'family': 'Times New Roman', 11 | 'weight': 'normal', 12 | 'size': 35, 13 | } 14 | 15 | fig = plt.figure(figsize=(10, 10)) 16 | sub = fig.add_subplot(111) 17 | 18 | shiftY = 0.4 19 | shiftX = 35 20 | ms = 11 21 | 22 | fontsize1 = 28 23 | fontsize2 = 35 24 | 25 | model_steps_30 = [300, 480] 26 | model_score_30 = [69.89, 73.24] 27 | l2, = sub.plot(model_steps_30, model_score_30, 'b^-', ms=ms) 28 | sub.text(model_steps_30[-1] + shiftX, model_score_30[-1], "300k\nTask", ha='center', va='center', fontsize=fontsize1) 29 | 30 | rand_steps_30 = [300, 480] 31 | rand_score_30 = [69.89, 72.20] 32 | l3, = sub.plot(rand_steps_30, rand_score_30, 'gs-', ms=ms) 33 | sub.text(rand_steps_30[-1], rand_score_30[-1] + shiftY, "300k\nRand.", ha='center', va='center', fontsize=fontsize1) 34 | 35 | model_steps_20 = [200, 366] 36 | model_score_20 = [69.14, 73.02] 37 | sub.plot(model_steps_20, model_score_20, 'b^-', ms=ms) 38 | sub.text(model_steps_20[-1], model_score_20[-1] + shiftY, "200k\nTask", ha='center', va='center', fontsize=fontsize1) 39 | 40 | rand_steps_20 = [200, 384] 41 | rand_score_20 = [69.14, 72.04] 42 | sub.plot(rand_steps_20, rand_score_20, 'gs-', ms=ms) 43 | sub.text(rand_steps_20[-1] + shiftX * 4 / 3, rand_score_20[-1], "200k\nRand.", ha='center', va='center', fontsize=fontsize1) 44 | 45 | model_steps_10 = [100, 252] 46 | model_score_10 = [68.78, 72.98] 47 | sub.plot(model_steps_10, model_score_10, 'b^-', ms=ms) 48 | sub.text(model_steps_10[-1], model_score_10[-1] + shiftY, "100k\nTask", ha='center', va='center', fontsize=fontsize1) 49 | 50 | rand_steps_10 = [100, 244] 51 | rand_score_10 = [68.78, 71.78] 52 | sub.plot(rand_steps_10, rand_score_10, 'gs-', ms=ms) 53 | sub.text(rand_steps_10[-1] + shiftX, rand_score_10[-1] - shiftY, "100k\nRand.", ha='center', va='center', fontsize=fontsize1) 54 | 55 | main_steps = [100, 200, 300] 56 | main_score = [68.78, 69.14, 69.89] 57 | l1, = sub.plot(main_steps, main_score, 'ro-', ms=ms) 58 | 59 | for step, score in zip(main_steps, main_score): 60 | sub.text(step + shiftX, score - shiftY / 2, str(step) + "k", ha='center', va='center', fontsize=fontsize1) 61 | sub.hlines(72.57, 100, 500, colors="gray", linestyles="dashed", linewidth=3) 62 | sub.text(210, 72.70, "Fully-trained (1M steps)", ha='center', va='center', fontsize=27) 63 | sub.text(48, 72.6, "72.6-", ha='center', va='center', fontsize=fontsize2) 64 | 65 | plt.grid() 66 | plt.tick_params(labelsize=fontsize2) 67 | 68 | plt.xlabel("k Steps", font2) 69 | plt.ylabel("Acc.(%)", font2) 70 | 71 | plt.legend(handles=[l1, l2, l3], labels=['General Pre-train', 'Selective Mask', 'Random Mask'], loc='lower right', prop=font1) 72 | plt.savefig("../images/lap_amazon.pdf", foramt="pdf") 73 | -------------------------------------------------------------------------------- /plot/plot_lap_yelp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.backends.backend_pdf import PdfPages 4 | 5 | 6 | font1 = {'family': 'Times New Roman', 7 | 'weight': 'normal', 8 | 'size': 25, 9 | } 10 | 11 | font2 = {'family': 'Times New Roman', 12 | 'weight': 'normal', 13 | 'size': 35, 14 | } 15 | 16 | fig = plt.figure(figsize=(10, 10)) 17 | sub = fig.add_subplot(111) 18 | 19 | shiftY = 0.4 20 | shiftX = 35 21 | ms = 11 22 | 23 | fontsize1 = 28 24 | fontsize2 = 35 25 | 26 | model_steps_30 = [300, 488] 27 | model_score_30 = [69.89, 73.10] 28 | l2, = sub.plot(model_steps_30, model_score_30, 'b^-', ms=ms) 29 | sub.text(model_steps_30[-1] + shiftX, model_score_30[-1], "300k\nTask", ha='center', va='center', fontsize=fontsize1) 30 | 31 | rand_steps_30 = [300, 480] 32 | rand_score_30 = [69.89, 72.06] 33 | l3, = sub.plot(rand_steps_30, rand_score_30, 'gs-', ms=ms) 34 | sub.text(rand_steps_30[-1], rand_score_30[-1] + shiftY, "300k\nRand.", ha='center', va='center', fontsize=fontsize1) 35 | 36 | model_steps_20 = [200, 380] 37 | model_score_20 = [69.14, 72.82] 38 | sub.plot(model_steps_20, model_score_20, 'b^-', ms=ms) 39 | sub.text(model_steps_20[-1], model_score_20[-1] + shiftY, "200k\nTask", ha='center', va='center', fontsize=fontsize1) 40 | 41 | rand_steps_20 = [200, 380] 42 | rand_score_20 = [69.14, 71.20] 43 | sub.plot(rand_steps_20, rand_score_20, 'gs-', ms=ms) 44 | sub.text(rand_steps_20[-1], rand_score_20[-1] + shiftY, "200k\nRand.", ha='center', va='center', fontsize=fontsize1) 45 | 46 | model_steps_10 = [100, 244] 47 | model_score_10 = [68.78, 71.33] 48 | sub.plot(model_steps_10, model_score_10, 'b^-', ms=ms) 49 | sub.text(model_steps_10[-1], model_score_10[-1] + shiftY, "100k\nTask", ha='center', va='center', fontsize=fontsize1) 50 | 51 | rand_steps_10 = [100, 252] 52 | rand_score_10 = [68.78, 70.69] 53 | sub.plot(rand_steps_10, rand_score_10, 'gs-', ms=ms) 54 | sub.text(rand_steps_10[-1] + shiftX, rand_score_10[-1] + shiftY, "100k\nRand.", ha='center', va='center', fontsize=fontsize1) 55 | 56 | main_steps = [100, 200, 300] 57 | main_score = [68.78, 69.14, 69.89] 58 | l1, = sub.plot(main_steps, main_score, 'ro-', ms=ms) 59 | 60 | for step, score in zip(main_steps, main_score): 61 | sub.text(step + shiftX / 3 * 2, score - shiftY / 3 * 2, str(step) + "k", ha='center', va='center', fontsize=fontsize1) 62 | sub.hlines(72.57, 100, 500, colors="gray", linestyles="dashed", linewidth=3) 63 | sub.text(210, 72.70, "Fully-trained (1M steps)", ha='center', va='center', fontsize=27) 64 | sub.text(48, 72.57, "72.6-", ha='center', va='center', fontsize=fontsize2) 65 | 66 | plt.grid() 67 | plt.tick_params(labelsize=fontsize2) 68 | 69 | plt.xlabel("k Steps", font2) 70 | plt.ylabel("Acc.(%)", font2) 71 | 72 | plt.legend(handles=[l1, l2, l3], labels=['General Pre-train', 'Selective Mask', 'Random Mask'], loc='lower right', prop=font1) 73 | plt.savefig("../images/lap_yelp.pdf", format="pdf") 74 | -------------------------------------------------------------------------------- /plot/plot_absa_yelp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.backends.backend_pdf import PdfPages 4 | 5 | 6 | font1 = {'family': 'Times New Roman', 7 | 'weight': 'normal', 8 | 'size': 25, 9 | } 10 | 11 | font2 = {'family': 'Times New Roman', 12 | 'weight': 'normal', 13 | 'size': 35, 14 | } 15 | 16 | fig = plt.figure(figsize=(10, 10)) 17 | sub = fig.add_subplot(111) 18 | 19 | shiftY = 0.4 20 | shiftX = 35 21 | ms = 11 22 | 23 | fontsize1 = 28 24 | fontsize2 = 35 25 | 26 | model_steps_30 = [300, 488] 27 | model_score_30 = [84.42, 89.49] 28 | l2, = sub.plot(model_steps_30, model_score_30, 'b^-', ms=ms) 29 | sub.text(model_steps_30[-1] + shiftX, model_score_30[-1], "300k\nTask", ha='center', va='center', fontsize=fontsize1) 30 | 31 | rand_steps_30 = [300, 480] 32 | rand_score_30 = [84.42, 87.47] 33 | l3, = sub.plot(rand_steps_30, rand_score_30, 'gs-', ms=ms) 34 | sub.text(rand_steps_30[-1], rand_score_30[-1] + shiftY, "300k\nRand.", ha='center', va='center', fontsize=fontsize1) 35 | 36 | model_steps_20 = [200, 380] 37 | model_score_20 = [82.68, 88.47] 38 | sub.plot(model_steps_20, model_score_20, 'b^-', ms=ms) 39 | sub.text(model_steps_20[-1], model_score_20[-1] + shiftY, "200k\nTask", ha='center', va='center', fontsize=fontsize1) 40 | 41 | rand_steps_20 = [200, 380] 42 | rand_score_20 = [82.68, 87.4] 43 | sub.plot(rand_steps_20, rand_score_20, 'gs-', ms=ms) 44 | sub.text(rand_steps_20[-1], rand_score_20[-1] + shiftY, "200k\nRand.", ha='center', va='center', fontsize=fontsize1) 45 | 46 | model_steps_10 = [100, 244] 47 | model_score_10 = [81.95, 88.31] 48 | sub.plot(model_steps_10, model_score_10, 'b^-', ms=ms) 49 | sub.text(model_steps_10[-1] + shiftX, model_score_10[-1] - shiftY * 2 / 3, "100k\nTask", ha='center', va='center', fontsize=fontsize1) 50 | 51 | rand_steps_10 = [100, 252] 52 | rand_score_10 = [81.95, 86.54] 53 | sub.plot(rand_steps_10, rand_score_10, 'gs-', ms=ms) 54 | sub.text(rand_steps_10[-1], rand_score_10[-1] + shiftY, "100k\nRand.", ha='center', va='center', fontsize=fontsize1) 55 | 56 | main_steps = [100, 200, 300] 57 | main_score = [81.95, 82.68, 84.42] 58 | l1, = sub.plot(main_steps, main_score, 'ro-', ms=ms) 59 | 60 | for step, score in zip(main_steps, main_score): 61 | sub.text(step + shiftX / 3 * 2, score - shiftY / 3 * 2, str(step) + "k", ha='center', va='center', fontsize=fontsize1) 62 | sub.hlines(88.60, 100, 500, colors="gray", linestyles="dashed", linewidth=3) 63 | sub.text(210, 88.9, "Fully-trained (1M steps)", ha='center', va='center', fontsize=27) 64 | sub.text(48, 88.60, "88.6-", ha='center', va='center', fontsize=fontsize2) 65 | 66 | plt.grid() 67 | plt.tick_params(labelsize=fontsize2) 68 | 69 | plt.xlabel("k Steps", font2) 70 | plt.ylabel("Acc.(%)", font2) 71 | 72 | plt.legend(handles=[l1, l2, l3], labels=['General Pre-train', 'Selective Mask', 'Random Mask'], loc='lower right', prop=font1) 73 | plt.savefig("../images/absa_yelp.pdf", format="pdf") 74 | -------------------------------------------------------------------------------- /model/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | # from apex.optimizers import FP16_Optimizer 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class LRScheduler(_LRScheduler): 9 | def __init__(self, optimizer, last_epoch=-1): 10 | # Check if using mixed precision training 11 | self.mixed_training = False 12 | base_optimizer = optimizer 13 | # if isinstance(optimizer, FP16_Optimizer): 14 | # self.mixed_training = True 15 | # self.fp16_optimizer = optimizer 16 | # base_optimizer = optimizer.optimizer 17 | # Check that optimizer param is valid 18 | # elif not isinstance(optimizer, Optimizer): 19 | # raise TypeError('{} is not an Optimizer'.format( 20 | # type(optimizer).__name__)) 21 | 22 | super(LRScheduler, self).__init__(base_optimizer, last_epoch) 23 | 24 | def step(self, epoch=None): 25 | # Set the current training step 26 | # ('epoch' is used to be consistent with _LRScheduler) 27 | if self.mixed_training: 28 | # The assumption is that the step will be constant 29 | state_dict = self.optimizer.state[self.optimizer.param_groups[0]['params'][0]] 30 | if 'step' in state_dict: 31 | self.last_epoch = state_dict['step'] + 1 32 | else: 33 | self.last_epoch = 1 34 | else: 35 | self.last_epoch = epoch if epoch is not None else self.last_epoch + 1 36 | 37 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 38 | param_group['lr'] = lr 39 | 40 | 41 | class CosineWarmupScheduler(LRScheduler): 42 | """ 43 | Applies a warm up period to the learning rate. 44 | """ 45 | 46 | def __init__(self, optimizer, warmup, total_steps, last_epoch=-1): 47 | self.warmup = warmup 48 | self.total_steps = total_steps 49 | super(CosineWarmUpScheduler, self).__init__(optimizer, last_epoch) 50 | 51 | def get_lr(self): 52 | progress = self.last_epoch / self.total_steps 53 | if progress < self.warmup: 54 | return [base_lr * progress / self.warmup for base_lr in self.base_lrs] 55 | else: 56 | return [base_lr * (0.5 * (1.0 + torch.cos(math.pi + progress))) for base_lr in self.base_lrs] 57 | 58 | 59 | class ConstantWarmupScheduler(LRScheduler): 60 | """ 61 | Applies a warm up period to the learning rate. 62 | """ 63 | 64 | def __init__(self, optimizer, warmup, total_steps, last_epoch=-1): 65 | self.warmup = warmup 66 | self.total_steps = total_steps 67 | super(CosineWarmUpScheduler, self).__init__(optimizer, last_epoch) 68 | 69 | def get_lr(self): 70 | progress = self.last_epoch / self.total_steps 71 | if progress < self.warmup: 72 | return [base_lr * progress / self.warmup for base_lr in self.base_lrs] 73 | else: 74 | return self.base_lrs 75 | 76 | 77 | class LinearWarmUpScheduler(LRScheduler): 78 | """ 79 | Applies a warm up period to the learning rate. 80 | """ 81 | 82 | def __init__(self, optimizer, warmup, total_steps, last_epoch=-1): 83 | self.warmup = warmup 84 | self.total_steps = total_steps 85 | super(LinearWarmUpScheduler, self).__init__(optimizer, last_epoch) 86 | 87 | def get_lr(self): 88 | progress = self.last_epoch / self.total_steps 89 | if progress < self.warmup: 90 | return [base_lr * progress / self.warmup for base_lr in self.base_lrs] 91 | else: 92 | return [base_lr * max(( progress - 1.0)/(self.warmup - 1.0), 0.) for base_lr in self.base_lrs] 93 | -------------------------------------------------------------------------------- /model/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except AttributeError: 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | response = requests.head(url, allow_redirects=True) 192 | if response.status_code != 200: 193 | raise IOError("HEAD request failed for url {} with status code {}" 194 | .format(url, response.status_code)) 195 | etag = response.headers.get("ETag") 196 | 197 | filename = url_to_filename(url, etag) 198 | 199 | # get cache path to put the file 200 | cache_path = os.path.join(cache_dir, filename) 201 | 202 | if not os.path.exists(cache_path): 203 | # Download to temporary file, then copy to cache dir once finished. 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. 205 | with tempfile.NamedTemporaryFile() as temp_file: 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 207 | 208 | # GET file object 209 | if url.startswith("s3://"): 210 | s3_get(url, temp_file) 211 | else: 212 | http_get(url, temp_file) 213 | 214 | # we are copying the file before closing it, so flush to avoid truncation 215 | temp_file.flush() 216 | # shutil.copyfileobj() starts at the current position, so go to the start 217 | temp_file.seek(0) 218 | 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 220 | with open(cache_path, 'wb') as cache_file: 221 | shutil.copyfileobj(temp_file, cache_file) 222 | 223 | logger.info("creating metadata file for %s", cache_path) 224 | meta = {'url': url, 'etag': etag} 225 | meta_path = cache_path + '.json' 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 227 | json.dump(meta, meta_file) 228 | 229 | logger.info("removing temp file %s", temp_file.name) 230 | 231 | return cache_path 232 | 233 | 234 | def read_set_from_file(filename): 235 | ''' 236 | Extract a de-duped collection (set) of text from a file. 237 | Expected file format is one item per line. 238 | ''' 239 | collection = set() 240 | with open(filename, 'r', encoding='utf-8') as file_: 241 | for line in file_: 242 | collection.add(line.rstrip()) 243 | return collection 244 | 245 | 246 | def get_file_extension(path, dot=True, lower=True): 247 | ext = os.path.splitext(path)[1] 248 | ext = ext if dot else ext[1:] 249 | return ext.lower() if lower else ext 250 | -------------------------------------------------------------------------------- /model/fused_adam_local.py: -------------------------------------------------------------------------------- 1 | import types 2 | import importlib 3 | 4 | import math 5 | import torch 6 | 7 | def warmup_cosine(x, warmup=0.002): 8 | if x < warmup: 9 | return x/warmup 10 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 11 | 12 | def warmup_constant(x, warmup=0.002): 13 | if x < warmup: 14 | return x/warmup 15 | return 1.0 16 | 17 | def warmup_linear(x, warmup=0.002): 18 | if x < warmup: 19 | return x/warmup 20 | return 1.0 - x 21 | 22 | SCHEDULES = { 23 | 'warmup_cosine':warmup_cosine, 24 | 'warmup_constant':warmup_constant, 25 | 'warmup_linear':warmup_linear, 26 | } 27 | 28 | class FusedAdamBert(torch.optim.Optimizer): 29 | 30 | """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via 31 | ``python setup.py install --cuda_ext --cpp_ext``. 32 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 33 | Arguments: 34 | params (iterable): iterable of parameters to optimize or dicts defining 35 | parameter groups. 36 | lr (float, optional): learning rate. (default: 1e-3) 37 | betas (Tuple[float, float], optional): coefficients used for computing 38 | running averages of gradient and its square. (default: (0.9, 0.999)) 39 | eps (float, optional): term added to the denominator to improve 40 | numerical stability. (default: 1e-8) 41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 42 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 43 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 44 | (default: False) NOT SUPPORTED in FusedAdam! 45 | eps_inside_sqrt (boolean, optional): in the 'update parameters' step, 46 | adds eps to the bias-corrected second moment estimate before 47 | evaluating square root instead of adding it to the square root of 48 | second moment estimate as in the original paper. (default: False) 49 | .. _Adam\: A Method for Stochastic Optimization: 50 | https://arxiv.org/abs/1412.6980 51 | .. _On the Convergence of Adam and Beyond: 52 | https://openreview.net/forum?id=ryQu7f-RZ 53 | """ 54 | 55 | # def __init__(self, params, 56 | # lr=1e-3, bias_correction = True, 57 | # betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, 58 | # weight_decay=0., max_grad_norm=0., amsgrad=False): 59 | 60 | def __init__(self, params, lr=1e-3, warmup=-1, t_total=-1, bias_correction=False, betas=(0.9, 0.999), schedule='warmup_linear', 61 | eps=1e-6, eps_inside_sqrt = False, weight_decay=0., max_grad_norm=1.0, amsgrad=False): 62 | 63 | 64 | global fused_adam_cuda 65 | fused_adam_cuda = importlib.import_module("fused_adam_cuda") 66 | 67 | if amsgrad: 68 | raise RuntimeError('FusedAdam does not support the AMSGrad variant.') 69 | defaults = dict(lr=lr, bias_correction=bias_correction, 70 | betas=betas, eps=eps, weight_decay=weight_decay, 71 | max_grad_norm=max_grad_norm) 72 | super(FusedAdamBert, self).__init__(params, defaults) 73 | print("LOCAL FUSED ADAM") 74 | self.eps_mode = 0 if eps_inside_sqrt else 1 75 | self.schedule = schedule 76 | self.t_total = t_total 77 | self.warmup = warmup 78 | 79 | def get_lr(self): 80 | lr = [] 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | state = self.state[p] 84 | if len(state) == 0: 85 | return [0] 86 | if group['t_total'] != -1: 87 | schedule_fct = SCHEDULES[group['schedule']] 88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 89 | else: 90 | lr_scheduled = group['lr'] 91 | lr.append(lr_scheduled) 92 | print("LR {}".format(lr_scheduled)) 93 | return lr 94 | 95 | def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): 96 | """Performs a single optimization step. 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | grads (list of tensors, optional): weight gradient to use for the 101 | optimizer update. If gradients have type torch.half, parameters 102 | are expected to be in type torch.float. (default: None) 103 | output params (list of tensors, optional): A reduced precision copy 104 | of the updated weights written out in addition to the regular 105 | updated weights. Have to be of same type as gradients. (default: None) 106 | scale (float, optional): factor to divide gradient tensor values 107 | by before applying to weights. (default: 1) 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | if grads is None: 114 | grads_group = [None]*len(self.param_groups) 115 | # backward compatibility 116 | # assuming a list/generator of parameter means single group 117 | elif isinstance(grads, types.GeneratorType): 118 | grads_group = [grads] 119 | elif type(grads[0])!=list: 120 | grads_group = [grads] 121 | else: 122 | grads_group = grads 123 | 124 | if output_params is None: 125 | output_params_group = [None]*len(self.param_groups) 126 | elif isinstance(output_params, types.GeneratorType): 127 | output_params_group = [output_params] 128 | elif type(output_params[0])!=list: 129 | output_params_group = [output_params] 130 | else: 131 | output_params_group = output_params 132 | 133 | if grad_norms is None: 134 | grad_norms = [None]*len(self.param_groups) 135 | 136 | #Compute global norm 137 | global_norm = 0.0 138 | for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, 139 | output_params_group, grad_norms): 140 | global_norm = (global_norm ** 2 + grad_norm ** 2) ** 0.5 141 | 142 | for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms): 143 | if grads_this_group is None: 144 | grads_this_group = [None]*len(group['params']) 145 | if output_params_this_group is None: 146 | output_params_this_group = [None]*len(group['params']) 147 | 148 | # compute combined scale factor for this group 149 | combined_scale = scale 150 | if group['max_grad_norm'] > 0: 151 | # norm is in fact norm*scale 152 | clip = ((global_norm / scale) + 1e-6) / group['max_grad_norm'] 153 | if clip > 1: 154 | combined_scale = clip * scale 155 | 156 | bias_correction = 1 if group['bias_correction'] else 0 157 | 158 | for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): 159 | #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients 160 | if p.grad is None and grad is None: 161 | continue 162 | if grad is None: 163 | grad = p.grad.data 164 | if grad.is_sparse: 165 | raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead') 166 | 167 | state = self.state[p] 168 | 169 | # State initialization 170 | if len(state) == 0: 171 | state['step'] = 0 172 | # Exponential moving average of gradient values 173 | state['exp_avg'] = torch.zeros_like(p.data) 174 | # Exponential moving average of squared gradient values 175 | state['exp_avg_sq'] = torch.zeros_like(p.data) 176 | 177 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 178 | beta1, beta2 = group['betas'] 179 | 180 | state['step'] += 1 181 | 182 | out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param 183 | #Changes sharath 184 | 185 | schedule_fct = SCHEDULES[self.schedule] 186 | #schedule_fct(state['step']/self.t_total, self.warmup) 187 | #step_lr = group['lr'] * schedule_fct(state['step']/self.t_total, self.warmup) 188 | #step_lr = group['lr'] * scale#schedule_fct(state['step']/self.t_total, self.warmup)# schedule_fct(state['step']/group['t_total'], group['warmup']) 189 | #print(scale, step_lr) 190 | #print(group['lr']) 191 | fused_adam_cuda.adam(p.data, 192 | out_p, 193 | exp_avg, 194 | exp_avg_sq, 195 | grad, 196 | group['lr'], #step_lr,#group['lr'], 197 | beta1, 198 | beta2, 199 | group['eps'], 200 | combined_scale, 201 | state['step'], 202 | self.eps_mode, 203 | bias_correction, 204 | group['weight_decay']) 205 | return loss 206 | -------------------------------------------------------------------------------- /model/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | # from fused_adam_local import FusedAdam 23 | from apex.optimizers import FusedAdam 24 | 25 | def warmup_cosine(x, warmup=0.002): 26 | if x < warmup: 27 | return x/warmup 28 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 29 | 30 | def warmup_constant(x, warmup=0.002): 31 | if x < warmup: 32 | return x/warmup 33 | return 1.0 34 | 35 | def warmup_linear(x, warmup=0.002): 36 | if x < warmup: 37 | return x/warmup 38 | # return (1.0 - x) 39 | 40 | return max((x - 1.)/ (warmup - 1.), 0.) 41 | 42 | SCHEDULES = { 43 | 'warmup_cosine':warmup_cosine, 44 | 'warmup_constant':warmup_constant, 45 | 'warmup_linear':warmup_linear, 46 | } 47 | 48 | 49 | class BertAdam(Optimizer): 50 | """Implements BERT version of Adam algorithm with weight decay fix. 51 | Params: 52 | lr: learning rate 53 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 54 | t_total: total number of training steps for the learning 55 | rate schedule, -1 means constant learning rate. Default: -1 56 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 57 | b1: Adams b1. Default: 0.9 58 | b2: Adams b2. Default: 0.999 59 | e: Adams epsilon. Default: 1e-6 60 | weight_decay: Weight decay. Default: 0.01 61 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 62 | """ 63 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 64 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 65 | max_grad_norm=1.0): 66 | if lr is not required and lr < 0.0: 67 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 68 | if schedule not in SCHEDULES: 69 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 70 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 71 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 72 | if not 0.0 <= b1 < 1.0: 73 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 74 | if not 0.0 <= b2 < 1.0: 75 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 76 | if not e >= 0.0: 77 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 78 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 79 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 80 | max_grad_norm=max_grad_norm) 81 | super(BertAdam, self).__init__(params, defaults) 82 | 83 | def get_lr(self): 84 | lr = [] 85 | for group in self.param_groups: 86 | for p in group['params']: 87 | state = self.state[p] 88 | if len(state) == 0: 89 | return [0] 90 | if group['t_total'] != -1: 91 | schedule_fct = SCHEDULES[group['schedule']] 92 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 93 | else: 94 | lr_scheduled = group['lr'] 95 | lr.append(lr_scheduled) 96 | return lr 97 | 98 | def step(self, closure=None): 99 | """Performs a single optimization step. 100 | 101 | Arguments: 102 | closure (callable, optional): A closure that reevaluates the model 103 | and returns the loss. 104 | """ 105 | loss = None 106 | if closure is not None: 107 | loss = closure() 108 | 109 | for group in self.param_groups: 110 | for p in group['params']: 111 | if p.grad is None: 112 | continue 113 | grad = p.grad.data 114 | if grad.is_sparse: 115 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 116 | 117 | state = self.state[p] 118 | 119 | # State initialization 120 | if len(state) == 0: 121 | state['step'] = 0 122 | # Exponential moving average of gradient values 123 | state['next_m'] = torch.zeros_like(p.data) 124 | # Exponential moving average of squared gradient values 125 | state['next_v'] = torch.zeros_like(p.data) 126 | 127 | next_m, next_v = state['next_m'], state['next_v'] 128 | beta1, beta2 = group['b1'], group['b2'] 129 | 130 | # Add grad clipping 131 | if group['max_grad_norm'] > 0: 132 | clip_grad_norm_(p, group['max_grad_norm']) 133 | 134 | # Decay the first and second moment running average coefficient 135 | # In-place operations to update the averages at the same time 136 | next_m.mul_(beta1).add_(1 - beta1, grad) 137 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 138 | update = next_m / (next_v.sqrt() + group['e']) 139 | 140 | # Just adding the square of the weights to the loss function is *not* 141 | # the correct way of using L2 regularization/weight decay with Adam, 142 | # since that will interact with the m and v parameters in strange ways. 143 | # 144 | # Instead we want to decay the weights in a manner that doesn't interact 145 | # with the m/v parameters. This is equivalent to adding the square 146 | # of the weights to the loss with plain (non-momentum) SGD. 147 | if group['weight_decay'] > 0.0: 148 | update += group['weight_decay'] * p.data 149 | 150 | if group['t_total'] != -1: 151 | schedule_fct = SCHEDULES[group['schedule']] 152 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 153 | else: 154 | lr_scheduled = group['lr'] 155 | 156 | update_with_lr = lr_scheduled * update 157 | p.data.add_(-update_with_lr) 158 | 159 | state['step'] += 1 160 | 161 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 162 | # No bias correction 163 | # bias_correction1 = 1 - beta1 ** state['step'] 164 | # bias_correction2 = 1 - beta2 ** state['step'] 165 | 166 | return loss 167 | 168 | # ======================================================================= 169 | class BertAdam_FP16(FusedAdam): 170 | """Implements BERT version of Adam algorithm with weight decay fix. 171 | Params: 172 | lr: learning rate 173 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 174 | t_total: total number of training steps for the learning 175 | rate schedule, -1 means constant learning rate. Default: -1 176 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 177 | b1: Adams b1. Default: 0.9 178 | b2: Adams b2. Default: 0.999 179 | e: Adams epsilon. Default: 1e-6 180 | weight_decay: Weight decay. Default: 0.01 181 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 182 | """ 183 | def __init__(self, params, lr, warmup=-1, t_total=-1, bias_correction=False, schedule='warmup_linear', 184 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 185 | max_grad_norm=1.0): 186 | if not lr >= 0.0: 187 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 188 | if schedule not in SCHEDULES: 189 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 190 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 191 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 192 | if not 0.0 <= b1 < 1.0: 193 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 194 | if not 0.0 <= b2 < 1.0: 195 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 196 | if not e >= 0.0: 197 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 198 | # defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 199 | # b1=b1, b2=b2, e=e, weight_decay=weight_decay, 200 | # max_grad_norm=max_grad_norm) 201 | super(BertAdam_FP16, self).__init__(params, lr=lr, bias_correction=bias_correction, betas=(b1, b2), eps=e, weight_decay=weight_decay, max_grad_norm=max_grad_norm)#defaults) 202 | 203 | def get_lr(self): 204 | lr = [] 205 | for group in self.param_groups: 206 | for p in group['params']: 207 | state = self.state[p] 208 | if len(state) == 0: 209 | print("returning", state) 210 | return [0] 211 | if group['t_total'] != -1: 212 | schedule_fct = SCHEDULES[group['schedule']] 213 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 214 | else: 215 | lr_scheduled = group['lr'] 216 | lr.append(lr_scheduled) 217 | print("LR {}".format(lr_scheduled)) 218 | return lr 219 | -------------------------------------------------------------------------------- /model/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | import six 24 | from io import open 25 | 26 | from file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | home = os.environ.get('HOME') 30 | 31 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 32 | # 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 33 | # 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 34 | # 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 35 | # 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 36 | # 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 37 | # 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 38 | # 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 39 | 40 | 'bert-base-uncased': home + "/SelectiveMasking/vocab/uncased_L-12_H-768_A-12/vocab.txt", 41 | 'bert-large-uncased': home + "/SelectiveMasking/vocab/uncased_L-24_H-1024_A-16/vocab.txt", 42 | 'bert-base-cased': home + "/SelectiveMasking/vocab/cased_L-12_H-768_A-12/vocab.txt", 43 | 'bert-large-cased': home + "/SelectiveMasking/vocab/cased_L-24_H-1024_A-16/vocab.txt", 44 | 'bert-base-multilingual-uncased': home + "/SelectiveMasking/vocab/multilingual_L-12_H-768_A-12/vocab.txt", 45 | 'bert-base-multilingual-cased': home + "/SelectiveMasking/vocab/multi_cased_L-12_H-768_A-12/vocab.txt", 46 | 'bert-base-chinese': home + "/SelectiveMasking/vocab/chinese_L-12_H-768_A-12/vocab.txt", 47 | } 48 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 49 | 'bert-base-uncased': 512, 50 | 'bert-large-uncased': 512, 51 | 'bert-base-cased': 512, 52 | 'bert-large-cased': 512, 53 | 'bert-base-multilingual-uncased': 512, 54 | 'bert-base-multilingual-cased': 512, 55 | 'bert-base-chinese': 512, 56 | } 57 | VOCAB_NAME = 'vocab.txt' 58 | 59 | def load_vocab(vocab_file): 60 | """Loads a vocabulary file into a dictionary.""" 61 | vocab = collections.OrderedDict() 62 | index = 0 63 | with open(vocab_file, "r", encoding="utf-8") as reader: 64 | while True: 65 | token = reader.readline() 66 | if not token: 67 | break 68 | token = token.strip() 69 | vocab[token] = index 70 | index += 1 71 | return vocab 72 | 73 | 74 | def whitespace_tokenize(text): 75 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 76 | text = text.strip() 77 | if not text: 78 | return [] 79 | tokens = text.split() 80 | return tokens 81 | 82 | 83 | class BertTokenizer(object): 84 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 85 | 86 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, 87 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 88 | if not os.path.isfile(vocab_file): 89 | raise ValueError( 90 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 91 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 92 | self.vocab = load_vocab(vocab_file) 93 | self.ids_to_tokens = collections.OrderedDict( 94 | [(ids, tok) for tok, ids in self.vocab.items()]) 95 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 96 | never_split=never_split) 97 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 98 | self.max_len = max_len if max_len is not None else int(1e12) 99 | 100 | def tokenize(self, text): 101 | split_tokens = [] 102 | for token in self.basic_tokenizer.tokenize(text): 103 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 104 | split_tokens.append(sub_token) 105 | return split_tokens 106 | 107 | def convert_tokens_to_ids(self, tokens): 108 | """Converts a sequence of tokens into ids using the vocab.""" 109 | ids = [] 110 | for token in tokens: 111 | ids.append(self.vocab[token]) 112 | if len(ids) > self.max_len: 113 | raise ValueError( 114 | "Token indices sequence length is longer than the specified maximum " 115 | " sequence length for this BERT model ({} > {}). Running this" 116 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 117 | ) 118 | return ids 119 | 120 | def convert_ids_to_tokens(self, ids): 121 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 122 | tokens = [] 123 | for i in ids: 124 | tokens.append(self.ids_to_tokens[i]) 125 | return tokens 126 | 127 | def save_vocab(self, vocab_path): 128 | index = 0 129 | if os.path.isdir(vocab_path): 130 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 131 | else: 132 | vocab_file = vocab_path 133 | with open(vocab_file, "w", encoding="utf-8") as writer: 134 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 135 | if index != token_index: 136 | logger.warning( 137 | "Saving vocabulary to {}: vocabulary indices are not consecutive." 138 | " Please check that the vocabulary is not corrupted!".format(vocab_file) 139 | ) 140 | index = token_index 141 | writer.write(token + "\n") 142 | index += 1 143 | return (vocab_file,) 144 | 145 | @classmethod 146 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 147 | """ 148 | Instantiate a PreTrainedBertModel from a pre-trained model file. 149 | Download and cache the pre-trained model file if needed. 150 | """ 151 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 152 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 153 | else: 154 | vocab_file = pretrained_model_name_or_path 155 | if os.path.isdir(vocab_file): 156 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 157 | # redirect to the cache, if necessary 158 | try: 159 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 160 | except EnvironmentError: 161 | logger.error( 162 | "Model name '{}' was not found in model name list ({}). " 163 | "We assumed '{}' was a path or url but couldn't find any file " 164 | "associated to this path or url.".format( 165 | pretrained_model_name_or_path, 166 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 167 | vocab_file)) 168 | return None 169 | if resolved_vocab_file == vocab_file: 170 | logger.info("loading vocabulary file {}".format(vocab_file)) 171 | else: 172 | logger.info("loading vocabulary file {} from cache at {}".format( 173 | vocab_file, resolved_vocab_file)) 174 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 175 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 176 | # than the number of positional embeddings 177 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 178 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 179 | # Instantiate tokenizer. 180 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 181 | return tokenizer 182 | 183 | 184 | class BasicTokenizer(object): 185 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 186 | 187 | def __init__(self, 188 | do_lower_case=True, 189 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 190 | """Constructs a BasicTokenizer. 191 | 192 | Args: 193 | do_lower_case: Whether to lower case the input. 194 | """ 195 | self.do_lower_case = do_lower_case 196 | self.never_split = never_split 197 | 198 | def tokenize(self, text): 199 | """Tokenizes a piece of text.""" 200 | text = self._clean_text(text) 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | orig_tokens = whitespace_tokenize(text) 209 | split_tokens = [] 210 | for token in orig_tokens: 211 | if self.do_lower_case and token not in self.never_split: 212 | token = token.lower() 213 | token = self._run_strip_accents(token) 214 | split_tokens.extend(self._run_split_on_punc(token)) 215 | 216 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 217 | return output_tokens 218 | 219 | def _run_strip_accents(self, text): 220 | """Strips accents from a piece of text.""" 221 | text = unicodedata.normalize("NFD", text) 222 | output = [] 223 | for char in text: 224 | cat = unicodedata.category(char) 225 | if cat == "Mn": 226 | continue 227 | output.append(char) 228 | return "".join(output) 229 | 230 | def _run_split_on_punc(self, text): 231 | """Splits punctuation on a piece of text.""" 232 | if text in self.never_split: 233 | return [text] 234 | chars = list(text) 235 | i = 0 236 | start_new_word = True 237 | output = [] 238 | while i < len(chars): 239 | char = chars[i] 240 | if _is_punctuation(char): 241 | output.append([char]) 242 | start_new_word = True 243 | else: 244 | if start_new_word: 245 | output.append([]) 246 | start_new_word = False 247 | output[-1].append(char) 248 | i += 1 249 | 250 | return ["".join(x) for x in output] 251 | 252 | def _tokenize_chinese_chars(self, text): 253 | """Adds whitespace around any CJK character.""" 254 | output = [] 255 | for char in text: 256 | cp = ord(char) 257 | if self._is_chinese_char(cp): 258 | output.append(" ") 259 | output.append(char) 260 | output.append(" ") 261 | else: 262 | output.append(char) 263 | return "".join(output) 264 | 265 | def _is_chinese_char(self, cp): 266 | """Checks whether CP is the codepoint of a CJK character.""" 267 | # This defines a "chinese character" as anything in the CJK Unicode block: 268 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 269 | # 270 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 271 | # despite its name. The modern Korean Hangul alphabet is a different block, 272 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 273 | # space-separated words, so they are not treated specially and handled 274 | # like the all of the other languages. 275 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 276 | (cp >= 0x3400 and cp <= 0x4DBF) or # 277 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 278 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 279 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 280 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 281 | (cp >= 0xF900 and cp <= 0xFAFF) or # 282 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 283 | return True 284 | 285 | return False 286 | 287 | def _clean_text(self, text): 288 | """Performs invalid character removal and whitespace cleanup on text.""" 289 | output = [] 290 | for char in text: 291 | cp = ord(char) 292 | if cp == 0 or cp == 0xfffd or _is_control(char): 293 | continue 294 | if _is_whitespace(char): 295 | output.append(" ") 296 | else: 297 | output.append(char) 298 | return "".join(output) 299 | 300 | 301 | class WordpieceTokenizer(object): 302 | """Runs WordPiece tokenization.""" 303 | 304 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 305 | self.vocab = vocab 306 | self.unk_token = unk_token 307 | self.max_input_chars_per_word = max_input_chars_per_word 308 | 309 | def tokenize(self, text): 310 | """Tokenizes a piece of text into its word pieces. 311 | 312 | This uses a greedy longest-match-first algorithm to perform tokenization 313 | using the given vocabulary. 314 | 315 | For example: 316 | input = "unaffable" 317 | output = ["un", "##aff", "##able"] 318 | 319 | Args: 320 | text: A single token or whitespace separated tokens. This should have 321 | already been passed through `BasicTokenizer`. 322 | 323 | Returns: 324 | A list of wordpiece tokens. 325 | """ 326 | 327 | output_tokens = [] 328 | for token in whitespace_tokenize(text): 329 | chars = list(token) 330 | if len(chars) > self.max_input_chars_per_word: 331 | output_tokens.append(self.unk_token) 332 | continue 333 | 334 | is_bad = False 335 | start = 0 336 | sub_tokens = [] 337 | while start < len(chars): 338 | end = len(chars) 339 | cur_substr = None 340 | while start < end: 341 | substr = "".join(chars[start:end]) 342 | if start > 0: 343 | substr = "##" + substr 344 | if substr in self.vocab: 345 | cur_substr = substr 346 | break 347 | end -= 1 348 | if cur_substr is None: 349 | is_bad = True 350 | break 351 | sub_tokens.append(cur_substr) 352 | start = end 353 | 354 | if is_bad: 355 | output_tokens.append(self.unk_token) 356 | else: 357 | output_tokens.extend(sub_tokens) 358 | return output_tokens 359 | 360 | 361 | def _is_whitespace(char): 362 | """Checks whether `chars` is a whitespace character.""" 363 | # \t, \n, and \r are technically contorl characters but we treat them 364 | # as whitespace since they are generally considered as such. 365 | if char == " " or char == "\t" or char == "\n" or char == "\r": 366 | return True 367 | cat = unicodedata.category(char) 368 | if cat == "Zs": 369 | return True 370 | return False 371 | 372 | 373 | def _is_control(char): 374 | """Checks whether `chars` is a control character.""" 375 | # These are technically control characters but we count them as whitespace 376 | # characters. 377 | if char == "\t" or char == "\n" or char == "\r": 378 | return False 379 | cat = unicodedata.category(char) 380 | if cat.startswith("C"): 381 | return True 382 | return False 383 | 384 | 385 | def _is_punctuation(char): 386 | """Checks whether `chars` is a punctuation character.""" 387 | cp = ord(char) 388 | # We treat all non-letter/number ASCII as punctuation. 389 | # Characters such as "^", "$", and "`" are not in the Unicode 390 | # Punctuation class but we treat them as punctuation anyways, for 391 | # consistency. 392 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 393 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 394 | return True 395 | cat = unicodedata.category(char) 396 | if cat.startswith("P"): 397 | return True 398 | return False 399 | -------------------------------------------------------------------------------- /data/create_data.py: -------------------------------------------------------------------------------- 1 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 2 | from __future__ import absolute_import, division, print_function, unicode_literals 3 | 4 | import torch 5 | import argparse 6 | import logging 7 | import os 8 | import random 9 | import h5py 10 | import numpy as np 11 | import collections 12 | import json 13 | import pickle 14 | import sys 15 | from tqdm import tqdm, trange 16 | 17 | sys.path.append("../") 18 | 19 | import model.tokenization as tokenization 20 | from tokenization import BertTokenizer 21 | from data.data_utils import processors 22 | from data.sc_mask_gen import SC, ModelGen, ASC 23 | from data.rand_mask_gen import RandMask 24 | 25 | class TrainingInstance(object): 26 | """A single training instance (sentence pair).""" 27 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next): 28 | self.tokens = tokens 29 | self.segment_ids = segment_ids 30 | self.is_random_next = is_random_next 31 | self.masked_lm_positions = masked_lm_positions 32 | self.masked_lm_labels = masked_lm_labels 33 | 34 | def __str__(self): 35 | s = "" 36 | s += "tokens: %s\n" % (" ".join( 37 | [tokenization.printable_text(x) for x in self.tokens])) 38 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 39 | s += "is_random_next: %s\n" % self.is_random_next 40 | s += "masked_lm_positions: %s\n" % (" ".join( 41 | [str(x) for x in self.masked_lm_positions])) 42 | s += "masked_lm_labels: %s\n" % (" ".join( 43 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 44 | s += "\n" 45 | return s 46 | 47 | def __repr__(self): 48 | return self.__str__() 49 | 50 | 51 | def write_instance_to_example_file(instances, tokenizer, max_seq_length, 52 | max_predictions_per_seq, output_file): 53 | """Create TF example files from `TrainingInstance`s.""" 54 | print(output_file) 55 | total_written = 0 56 | features = collections.OrderedDict() 57 | 58 | num_instances = len(instances) 59 | features["input_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32") 60 | features["input_mask"] = np.zeros([num_instances, max_seq_length], dtype="int32") 61 | features["segment_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32") 62 | features["masked_lm_positions"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32") 63 | features["masked_lm_ids"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32") 64 | features["next_sentence_labels"] = np.zeros(num_instances, dtype="int32") 65 | 66 | 67 | for inst_index, instance in enumerate(tqdm(instances, desc="Writing Instances")): 68 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 69 | input_mask = [1] * len(input_ids) 70 | segment_ids = list(instance.segment_ids) 71 | assert len(input_ids) <= max_seq_length 72 | 73 | while len(input_ids) < max_seq_length: 74 | input_ids.append(0) 75 | input_mask.append(0) 76 | segment_ids.append(0) 77 | 78 | assert len(input_ids) == max_seq_length 79 | assert len(input_mask) == max_seq_length 80 | assert len(segment_ids) == max_seq_length 81 | 82 | masked_lm_positions = list(instance.masked_lm_positions) 83 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 84 | masked_lm_weights = [1.0] * len(masked_lm_ids) 85 | 86 | while len(masked_lm_positions) < max_predictions_per_seq: 87 | masked_lm_positions.append(0) 88 | masked_lm_ids.append(0) 89 | masked_lm_weights.append(0.0) 90 | 91 | next_sentence_label = 1 if instance.is_random_next else 0 92 | 93 | features["input_ids"][inst_index] = input_ids 94 | features["input_mask"][inst_index] = input_mask 95 | features["segment_ids"][inst_index] = segment_ids 96 | features["masked_lm_positions"][inst_index] = masked_lm_positions 97 | features["masked_lm_ids"][inst_index] = masked_lm_ids 98 | features["next_sentence_labels"][inst_index] = next_sentence_label 99 | 100 | total_written += 1 101 | 102 | print("saving data") 103 | f= h5py.File(output_file, 'w') 104 | f.create_dataset("input_ids", data=features["input_ids"], dtype='i4', compression='gzip') 105 | f.create_dataset("input_mask", data=features["input_mask"], dtype='i1', compression='gzip') 106 | f.create_dataset("segment_ids", data=features["segment_ids"], dtype='i1', compression='gzip') 107 | f.create_dataset("masked_lm_positions", data=features["masked_lm_positions"], dtype='i4', compression='gzip') 108 | f.create_dataset("masked_lm_ids", data=features["masked_lm_ids"], dtype='i4', compression='gzip') 109 | f.create_dataset("next_sentence_labels", data=features["next_sentence_labels"], dtype='i1', compression='gzip') 110 | f.flush() 111 | f.close() 112 | 113 | def write_labeled_data(labeled_data, output_file): 114 | with open(output_file, "wb") as f: 115 | pickle.dump(labeled_data, f) 116 | 117 | def create_training_instances(data, all_labels, task_name, generator, max_seq_length, dupe_factor, short_seq_prob, masked_lm_prob, max_predictions_per_seq, rng, with_rand=False): 118 | """Create `TrainingInstance`s from raw text.""" 119 | 120 | # Remove empty documents 121 | if with_rand: 122 | all_documents, rand_all_documents = generator(data, all_labels, dupe_factor, rng) 123 | print(len(all_documents), len(rand_all_documents)) 124 | else: 125 | all_documents = generator(data, all_labels, dupe_factor, rng) 126 | print(len(all_documents)) 127 | 128 | instances = [] 129 | all_documents = [x for x in all_documents if x] 130 | rng.shuffle(all_documents) 131 | for document_index in range(len(all_documents)): 132 | instances.extend(create_instances_from_document(all_documents, document_index, max_seq_length, short_seq_prob, 133 | masked_lm_prob, max_predictions_per_seq, rng)) 134 | 135 | rng.shuffle(instances) 136 | 137 | labeled_data = [] 138 | for document in all_documents: 139 | for sentence in document: 140 | labeled_data.append((sentence.tokens, [1 if x else 0 for x in sentence.info])) 141 | 142 | if with_rand: 143 | rand_instances = [] 144 | rand_all_documents = [x for x in rand_all_documents if x] 145 | rng.shuffle(rand_all_documents) 146 | for document_index in range(len(rand_all_documents)): 147 | rand_instances.extend(create_instances_from_document(rand_all_documents, document_index, max_seq_length, short_seq_prob, 148 | masked_lm_prob, max_predictions_per_seq, rng)) 149 | 150 | rng.shuffle(rand_instances) 151 | 152 | return instances, rand_instances, labeled_data 153 | else: 154 | return instances, labeled_data 155 | 156 | 157 | def create_instances_from_document( 158 | all_documents, document_index, max_seq_length, short_seq_prob, 159 | masked_lm_prob, max_predictions_per_seq, rng): 160 | """Creates `TrainingInstance`s for a single document.""" 161 | 162 | # document: MaskedTokenInstance: (tokens, info) 163 | document = all_documents[document_index] 164 | 165 | # Account for [CLS], [SEP] 166 | max_num_tokens = max_seq_length - 2 167 | 168 | target_seq_length = max_num_tokens 169 | if rng.random() < short_seq_prob: 170 | target_seq_length = rng.randint(2, max_num_tokens) 171 | 172 | instances = [] 173 | current_chunk = [] 174 | current_length = 0 175 | i = 0 176 | while i < len(document): 177 | segment = document[i] # segment: MaskedTokenInstance (tokens, info) 178 | current_chunk.append(segment) 179 | current_length += len(segment.tokens) 180 | if i == len(document) - 1 or current_length >= target_seq_length: 181 | if current_chunk: 182 | tokens_a = [] 183 | m_info_a = [] 184 | for j in range(len(current_chunk)): 185 | tokens_a.extend(current_chunk[j].tokens) 186 | m_info_a.extend(current_chunk[j].info) 187 | truncate_seq_pair(tokens_a, m_info_a, [], [], max_num_tokens, rng) 188 | 189 | assert len(tokens_a) >= 1 190 | 191 | tokens = [] 192 | m_info = [] 193 | segment_ids = [] 194 | tokens.append("[CLS]") 195 | m_info.append({}) 196 | segment_ids.append(0) 197 | for token, info in zip(tokens_a, m_info_a): 198 | tokens.append(token) 199 | m_info.append(info) 200 | segment_ids.append(0) 201 | 202 | tokens.append("[SEP]") 203 | m_info.append({}) 204 | segment_ids.append(0) 205 | 206 | masked_lm_positions = [index for index in range(len(m_info)) if m_info[index]] 207 | if len(masked_lm_positions) > max_predictions_per_seq: 208 | rng.shuffle(masked_lm_positions) 209 | masked_lm_positions = masked_lm_positions[0:max_predictions_per_seq] 210 | masked_lm_positions.sort() 211 | masked_lm_labels = [m_info[pos]["label"] for pos in masked_lm_positions] 212 | 213 | for pos in masked_lm_positions: 214 | tokens[pos] = m_info[pos]["mask"] 215 | 216 | is_random_next = False 217 | instance = TrainingInstance( 218 | tokens=tokens, 219 | segment_ids=segment_ids, 220 | is_random_next=is_random_next, 221 | masked_lm_positions=masked_lm_positions, 222 | masked_lm_labels=masked_lm_labels) 223 | instances.append(instance) 224 | current_chunk = [] 225 | current_length = 0 226 | i += 1 227 | return instances 228 | 229 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) 230 | MaskedTokenInstance = collections.namedtuple("MaskedTokenInstance", ["tokens", "info"]) 231 | 232 | def truncate_seq_pair(tokens_a, m_info_a, tokens_b, m_info_b, max_num_tokens, rng): 233 | """Truncates a pair of sequences to a maximum sequence length.""" 234 | while True: 235 | total_length = len(tokens_a) + len(tokens_b) 236 | if total_length <= max_num_tokens: 237 | break 238 | 239 | (trunc_tokens, trunc_info) = (tokens_a, m_info_a) if len(tokens_a) > len(tokens_b) else (tokens_b, m_info_b) 240 | assert len(trunc_tokens) >= 1 241 | 242 | # We want to sometimes truncate from the front and sometimes from the 243 | # back to add more randomness and avoid biases. 244 | if rng.random() < 0.5: 245 | del trunc_tokens[0] 246 | del trunc_info[0] 247 | else: 248 | trunc_tokens.pop() 249 | trunc_info.pop() 250 | 251 | 252 | def main(): 253 | print(torch.cuda.is_available()) 254 | parser = argparse.ArgumentParser() 255 | ## Required parameters 256 | parser.add_argument("--input_dir", 257 | default=None, 258 | type=str, 259 | required=True, 260 | help="The input train corpus. can be directory with .txt files or a path to a single file") 261 | parser.add_argument("--output_dir", 262 | default=None, 263 | type=str, 264 | required=True, 265 | help="The output file where the model checkpoints will be written.") 266 | 267 | ## Other parameters 268 | 269 | # bool 270 | parser.add_argument("--mode", 271 | type=str, 272 | ) 273 | 274 | # str 275 | parser.add_argument("--bert_model", 276 | default="bert-large-uncased", 277 | type=str, 278 | required=False, 279 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 280 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 281 | parser.add_argument("--task_name", 282 | default="", 283 | type=str, 284 | required=False) 285 | parser.add_argument("--local_rank", 286 | default=0, 287 | type=int) 288 | 289 | # int 290 | parser.add_argument("--max_seq_length", 291 | default=128, 292 | type=int, 293 | help="The maximum total input sequence length after WordPiece tokenization. \n" 294 | "Sequences longer than this will be truncated, and sequences shorter \n" 295 | "than this will be padded.") 296 | parser.add_argument("--dupe_factor", 297 | default=1, 298 | type=int, 299 | help="Number of times to duplicate the input data (with different masks).") 300 | parser.add_argument("--max_predictions_per_seq", 301 | default=20, 302 | type=int, 303 | help="Maximum sequence length.") 304 | parser.add_argument("--sentence_batch_size", 305 | default=32, 306 | type=int) 307 | parser.add_argument("--top_sen_rate", 308 | default=0.8, 309 | type=float) 310 | parser.add_argument("--threshold", 311 | default=0.2, 312 | type=float) 313 | 314 | 315 | # floats 316 | 317 | parser.add_argument("--masked_lm_prob", 318 | default=0.15, 319 | type=float, 320 | help="Masked LM probability.") 321 | 322 | parser.add_argument("--short_seq_prob", 323 | default=0.1, 324 | type=float, 325 | help="Probability to create a sequence shorter than maximum sequence length") 326 | 327 | parser.add_argument("--do_lower_case", 328 | action='store_true', 329 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 330 | parser.add_argument('--random_seed', 331 | type=int, 332 | default=12345, 333 | help="random seed for initialization") 334 | parser.add_argument('--part', 335 | type=int, 336 | default=0) 337 | parser.add_argument('--max_proc', 338 | type=int, 339 | default=1) 340 | parser.add_argument('--with_rand', 341 | action='store_true' 342 | ) 343 | parser.add_argument('--split_part', 344 | type=int 345 | ) 346 | 347 | args = parser.parse_args() 348 | print(args) 349 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 350 | logger = logging.getLogger(__name__) 351 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 352 | datefmt='%m/%d/%Y %H:%M:%S', 353 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 354 | rng = random.Random(args.random_seed) 355 | np.random.seed(args.random_seed) 356 | torch.manual_seed(args.random_seed) 357 | 358 | print("creating instance from {}".format(args.input_dir)) 359 | processor = processors[args.task_name]() 360 | eval_examples = processor.get_pretrain_examples(args.input_dir, args.part, args.max_proc) 361 | if args.task_name == "absa" or args.task_name == "absa_term": 362 | data = eval_examples 363 | all_labels = None 364 | else: 365 | data = [example.text_a for example in eval_examples] 366 | all_labels = [example.label for example in eval_examples] 367 | 368 | del eval_examples 369 | 370 | label_list = processor.get_labels() 371 | logger.info("Bert Model: {}".format(args.bert_model)) 372 | 373 | if args.mode == "rand": 374 | print("Mode: rand") 375 | generator = RandMask(args.masked_lm_prob, args.bert_model, args.do_lower_case, args.max_seq_length) 376 | elif args.mode == "rule": 377 | print("Mode: rule") 378 | if args.task_name == "absa" or args.task_name == "absa_term": 379 | generator = ASC(args.masked_lm_prob, args.top_sen_rate, args.threshold, args.bert_model, args.do_lower_case, args.max_seq_length, label_list, args.sentence_batch_size) 380 | else: 381 | generator = SC(args.masked_lm_prob, args.top_sen_rate, args.threshold, args.bert_model, args.do_lower_case, args.max_seq_length, label_list, args.sentence_batch_size) 382 | else: 383 | print("Mode: model") 384 | generator = ModelGen(args.masked_lm_prob, args.bert_model, args.do_lower_case, args.max_seq_length, args.sentence_batch_size, with_rand=args.with_rand) 385 | 386 | if args.with_rand: 387 | instances, rand_instances, labeled_data = create_training_instances( 388 | data, all_labels, args.task_name, generator, args.max_seq_length, args.dupe_factor, 389 | args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq, 390 | rng, with_rand=args.with_rand) 391 | else: 392 | instances, labeled_data = create_training_instances( 393 | data, all_labels, args.task_name, generator, args.max_seq_length, args.dupe_factor, 394 | args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq, 395 | rng, with_rand=args.with_rand) 396 | 397 | if args.part >= 0: 398 | output_file = os.path.join(args.output_dir, "model", "{}.hdf5".format(args.part)) 399 | if args.with_rand: 400 | rand_output_file = os.path.join(args.output_dir, "rand", "{}.hdf5".format(args.part)) 401 | labeled_output_file = os.path.join(args.output_dir, "{}.pkl".format(args.part)) 402 | else: 403 | output_file = os.path.join(args.output_dir, "model", "0.hdf5") 404 | if args.with_rand: 405 | rand_output_file = os.path.join(args.output_dir, "rand", "0.hdf5") 406 | labeled_output_file = os.path.join(args.output_dir, "0.pkl") 407 | 408 | if args.mode == "rule": 409 | print("Writing labeled data(.pkl) for rule mode") 410 | write_labeled_data(labeled_data, labeled_output_file) 411 | else: 412 | print("Writing masked data(.hdf5) for model mode") 413 | if args.with_rand: 414 | print("Num instances: {}. Num rand instance: {}".format(len(instances), len(rand_instances))) 415 | write_instance_to_example_file(instances, tokenizer, args.max_seq_length, args.max_predictions_per_seq, output_file) 416 | write_instance_to_example_file(rand_instances, tokenizer, args.max_seq_length, args.max_predictions_per_seq, rand_output_file) 417 | else: 418 | print("Num instances: {}.".format(len(instances))) 419 | write_instance_to_example_file(instances, tokenizer, args.max_seq_length, args.max_predictions_per_seq, labeled_output_file) 420 | 421 | if __name__ == "__main__": 422 | main() 423 | -------------------------------------------------------------------------------- /run_pretraining.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import csv 7 | import os 8 | import logging 9 | import argparse 10 | import random 11 | import h5py 12 | from tqdm import tqdm, trange 13 | import os 14 | import numpy as np 15 | import torch 16 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset 17 | from torch.utils.data.distributed import DistributedSampler 18 | import math 19 | from apex import amp 20 | import json 21 | 22 | from model.tokenization import BertTokenizer 23 | from model.modeling import BertForMaskedLM, BertConfig 24 | from model.optimization import BertAdam, BertAdam_FP16 25 | from model.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 26 | from model.schedulers import LinearWarmUpScheduler 27 | 28 | from apex.optimizers import FusedAdam 29 | from apex.parallel import DistributedDataParallel as DDP 30 | 31 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 32 | datefmt = '%m/%d/%Y %H:%M:%S', 33 | level = logging.INFO) 34 | logger = logging.getLogger(__name__) 35 | 36 | class pretraining_dataset(Dataset): 37 | 38 | def __init__(self, input_file, max_pred_length): 39 | self.input_file = input_file 40 | self.max_pred_length = max_pred_length 41 | f = h5py.File(input_file, "r") 42 | self.input_ids = np.asarray(f["input_ids"][:]).astype(np.int64)#[num_instances x max_seq_length]) 43 | self.input_masks = np.asarray(f["input_mask"][:]).astype(np.int64) #[num_instances x max_seq_length] 44 | self.segment_ids = np.asarray(f["segment_ids"][:]).astype(np.int64) #[num_instances x max_seq_length] 45 | self.masked_lm_positions = np.asarray(f["masked_lm_positions"][:]).astype(np.int64) #[num_instances x max_pred_length] 46 | self.masked_lm_ids= np.asarray(f["masked_lm_ids"][:]).astype(np.int64) #[num_instances x max_pred_length] 47 | self.next_sentence_labels = np.asarray(f["next_sentence_labels"][:]).astype(np.int64) # [num_instances] 48 | f.close() 49 | 50 | def __len__(self): 51 | 'Denotes the total number of samples' 52 | return len(self.input_ids) 53 | 54 | def __getitem__(self, index): 55 | 56 | input_ids= torch.from_numpy(self.input_ids[index]) # [max_seq_length] 57 | input_mask = torch.from_numpy(self.input_masks[index]) #[max_seq_length] 58 | segment_ids = torch.from_numpy(self.segment_ids[index])# [max_seq_length] 59 | masked_lm_positions = torch.from_numpy(self.masked_lm_positions[index]) #[max_pred_length] 60 | masked_lm_ids = torch.from_numpy(self.masked_lm_ids[index]) #[max_pred_length] 61 | next_sentence_labels = torch.from_numpy(np.asarray(self.next_sentence_labels[index])) #[1] 62 | 63 | masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -1 64 | index = self.max_pred_length 65 | # store number of masked tokens in index 66 | if len((masked_lm_positions == 0).nonzero()) != 0: 67 | index = (masked_lm_positions == 0).nonzero()[0].item() 68 | masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] 69 | 70 | return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] 71 | 72 | def main(): 73 | 74 | parser = argparse.ArgumentParser() 75 | 76 | ## Required parameters 77 | parser.add_argument("--input_dir", 78 | default=None, 79 | type=str, 80 | required=True, 81 | help="The input data dir. Should contain .hdf5 files for the task.") 82 | 83 | parser.add_argument("--bert_model", default="bert-large-uncased", type=str, 84 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 85 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 86 | 87 | parser.add_argument("--output_dir", 88 | default=None, 89 | type=str, 90 | required=True, 91 | help="The output directory where the model checkpoints will be written.") 92 | 93 | ## Other parameters 94 | parser.add_argument("--config_file", 95 | default=None, 96 | type=str, 97 | help="The BERT model config") 98 | parser.add_argument("--ckpt", 99 | default="", 100 | type=str) 101 | parser.add_argument("--max_seq_length", 102 | default=512, 103 | type=int, 104 | help="The maximum total input sequence length after WordPiece tokenization. \n" 105 | "Sequences longer than this will be truncated, and sequences shorter \n" 106 | "than this will be padded.") 107 | parser.add_argument("--max_predictions_per_seq", 108 | default=80, 109 | type=int, 110 | help="The maximum total of masked tokens in input sequence") 111 | parser.add_argument("--train_batch_size", 112 | default=32, 113 | type=int, 114 | help="Total batch size for training.") 115 | parser.add_argument("--learning_rate", 116 | default=5e-5, 117 | type=float, 118 | help="The initial learning rate for Adam.") 119 | parser.add_argument("--num_train_epochs", 120 | default=3.0, 121 | type=float, 122 | help="Total number of training epochs to perform.") 123 | parser.add_argument("--max_steps", 124 | default=1000, 125 | type=float, 126 | help="Total number of training steps to perform.") 127 | parser.add_argument("--warmup_proportion", 128 | default=0.01, 129 | type=float, 130 | help="Proportion of training to perform linear learning rate warmup for. " 131 | "E.g., 0.1 = 10%% of training.") 132 | parser.add_argument("--local_rank", 133 | type=int, 134 | default=-1, 135 | help="local_rank for distributed training on gpus") 136 | parser.add_argument('--seed', 137 | type=int, 138 | default=42, 139 | help="random seed for initialization") 140 | parser.add_argument('--gradient_accumulation_steps', 141 | type=int, 142 | default=1, 143 | help="Number of updates steps to accumualte before performing a backward/update pass.") 144 | parser.add_argument('--fp16', 145 | default=False, 146 | action='store_true', 147 | help="Whether to use 16-bit float precision instead of 32-bit") 148 | parser.add_argument('--loss_scale', 149 | type=float, default=0.0, 150 | help='Loss scaling, positive power of 2 values can improve fp16 convergence.') 151 | parser.add_argument('--log_freq', 152 | type=float, default=500, 153 | help='frequency of logging loss.') 154 | parser.add_argument('--checkpoint_activations', 155 | default=False, 156 | action='store_true', 157 | help="Whether to use gradient checkpointing") 158 | parser.add_argument("--resume_from_checkpoint", 159 | default=False, 160 | action='store_true', 161 | help="Whether to resume training from checkpoint.") 162 | parser.add_argument('--resume_step', 163 | type=int, 164 | default=-1, 165 | help="Step to resume training from.") 166 | parser.add_argument('--num_steps_per_checkpoint', 167 | type=int, 168 | default=2000, 169 | help="Number of update steps until a model checkpoint is saved to disk.") 170 | parser.add_argument('--dev_data_file', 171 | type=str, 172 | default="dev/dev.hdf5") 173 | parser.add_argument('--dev_batch_size', 174 | type=int, 175 | default=16) 176 | parser.add_argument("--save_total_limit", type=int, default=10) 177 | 178 | args = parser.parse_args() 179 | 180 | random.seed(args.seed) 181 | np.random.seed(args.seed) 182 | torch.manual_seed(args.seed) 183 | 184 | min_dev_loss = 1000000 185 | best_step = 0 186 | 187 | assert(torch.cuda.is_available()) 188 | print(args.local_rank) 189 | if args.local_rank == -1: 190 | device = torch.device("cuda") 191 | n_gpu = torch.cuda.device_count() 192 | else: 193 | torch.cuda.set_device(args.local_rank) 194 | device = torch.device("cuda", args.local_rank) 195 | n_gpu = 1 196 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 197 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 198 | 199 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 200 | 201 | if args.gradient_accumulation_steps < 1: 202 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 203 | args.gradient_accumulation_steps)) 204 | if args.train_batch_size % args.gradient_accumulation_steps != 0: 205 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( 206 | args.gradient_accumulation_steps, args.train_batch_size)) 207 | 208 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 209 | 210 | 211 | 212 | if not args.resume_from_checkpoint and os.path.exists(args.output_dir) and (os.listdir(args.output_dir) and os.listdir(args.output_dir)!=['logfile.txt']): 213 | logger.warning("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 214 | # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 215 | 216 | if not args.resume_from_checkpoint: 217 | os.makedirs(args.output_dir, exist_ok=True) 218 | 219 | # Prepare model 220 | if args.config_file: 221 | config = BertConfig.from_json_file(args.config_file) 222 | 223 | if args.bert_model: 224 | model = BertForMaskedLM.from_pretrained(args.bert_model) 225 | else: 226 | model = BertForMaskedLM(config) 227 | 228 | print(args.ckpt) 229 | if args.ckpt: 230 | print("load from", args.ckpt) 231 | ckpt = torch.load(args.ckpt, map_location='cpu') 232 | if model in ckpt: 233 | ckpt = ckpt['model'] 234 | model.load_state_dict(ckpt, strict=False) 235 | 236 | pretrained_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 237 | torch.save(model.state_dict(), pretrained_model_file) 238 | 239 | if not args.resume_from_checkpoint: 240 | global_step = 0 241 | else: 242 | if args.resume_step == -1: 243 | model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")] 244 | args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names]) 245 | 246 | global_step = args.resume_step 247 | 248 | checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu") 249 | model.load_state_dict(checkpoint['model'], strict=False) 250 | 251 | print("resume step from ", args.resume_step) 252 | 253 | model.to(device) 254 | 255 | # Prepare optimizer 256 | param_optimizer = list(model.named_parameters()) 257 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 258 | optimizer_grouped_parameters = [ 259 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 260 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 261 | ] 262 | 263 | if args.fp16: 264 | optimizer = FusedAdam(optimizer_grouped_parameters, 265 | lr=args.learning_rate, 266 | bias_correction=False, 267 | weight_decay=0.01) 268 | 269 | if args.loss_scale == 0: 270 | model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale="dynamic") 271 | else: 272 | model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale=args.loss_scale) 273 | 274 | scheduler = LinearWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) 275 | 276 | else: 277 | optimizer = BertAdam(optimizer_grouped_parameters, 278 | lr=args.learning_rate, 279 | warmup=args.warmup_proportion, 280 | t_total=args.max_steps) 281 | 282 | if args.resume_from_checkpoint: 283 | optimizer.load_state_dict(checkpoint['optimizer']) 284 | 285 | if args.local_rank != -1: 286 | model = DDP(model) 287 | elif n_gpu > 1: 288 | model = torch.nn.DataParallel(model) 289 | 290 | files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f))] 291 | files.sort() 292 | 293 | num_files = len(files) 294 | 295 | logger.info("***** Loading Dev Data *****") 296 | dev_data = pretraining_dataset(input_file=os.path.join(args.input_dir, args.dev_data_file), max_pred_length=args.max_predictions_per_seq) 297 | if args.local_rank == -1: 298 | dev_sampler = RandomSampler(dev_data) 299 | dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.dev_batch_size * n_gpu, num_workers=4, pin_memory=True) 300 | else: 301 | dev_sampler = DistributedSampler(dev_data) 302 | dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.dev_batch_size, num_workers=4, pin_memory=True) 303 | 304 | logger.info("***** Running training *****") 305 | logger.info(" Batch size = {}".format(args.train_batch_size)) 306 | logger.info(" LR = {}".format(args.learning_rate)) 307 | 308 | model.train() 309 | logger.info(" Training. . .") 310 | 311 | most_recent_ckpts_paths = [] 312 | 313 | tr_loss = 0.0 # total added training loss 314 | average_loss = 0.0 # averaged loss every args.log_freq steps 315 | epoch = 0 316 | training_steps = 0 317 | while True: 318 | if not args.resume_from_checkpoint: 319 | random.shuffle(files) 320 | f_start_id = 0 321 | else: 322 | f_start_id = checkpoint['files'][0] 323 | files = checkpoint['files'][1:] 324 | args.resume_from_checkpoint = False 325 | for f_id in range(f_start_id, len(files)): 326 | data_file = files[f_id] 327 | logger.info("file no {} file {}".format(f_id, data_file)) 328 | train_data = pretraining_dataset(input_file=data_file, max_pred_length=args.max_predictions_per_seq) 329 | 330 | if args.local_rank == -1: 331 | train_sampler = RandomSampler(train_data) 332 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size * n_gpu, num_workers=4, pin_memory=True) 333 | else: 334 | train_sampler = DistributedSampler(train_data) 335 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=4, pin_memory=True) 336 | 337 | for step, batch in enumerate(tqdm(train_dataloader, desc="File Iteration")): 338 | model.train() 339 | training_steps += 1 340 | batch = [t.to(device) for t in batch] 341 | input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch#\ 342 | loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, masked_lm_labels=masked_lm_labels,checkpoint_activations=args.checkpoint_activations) 343 | if n_gpu > 1: 344 | loss = loss.mean() # mean() to average on multi-gpu. 345 | 346 | if args.gradient_accumulation_steps > 1: 347 | loss = loss / args.gradient_accumulation_steps 348 | 349 | if args.fp16: 350 | with amp.scale_loss(loss, optimizer) as scaled_loss: 351 | scaled_loss.backward() 352 | else: 353 | loss.backward() 354 | tr_loss += loss.item() 355 | average_loss += loss.item() 356 | 357 | if training_steps % args.gradient_accumulation_steps == 0: 358 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 359 | scheduler.step() 360 | optimizer.step() 361 | optimizer.zero_grad() 362 | global_step += 1 363 | 364 | if training_steps == 1 * args.gradient_accumulation_steps: 365 | logger.info("Global Step:{} Average Loss = {} Step Loss = {} LR {}".format(global_step, average_loss, 366 | loss.item(), optimizer.param_groups[0]['lr'])) 367 | 368 | if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0: 369 | logger.info("Global Step:{} Average Loss = {} Step Loss = {} LR {}".format(global_step, average_loss / args.log_freq, 370 | loss.item(), optimizer.param_groups[0]['lr'])) 371 | average_loss = 0 372 | 373 | if training_steps % (args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0: 374 | logger.info("Begin Eval") 375 | model.eval() 376 | with torch.no_grad(): 377 | dev_global_step = 0 378 | dev_final_loss = 0.0 379 | for dev_step, dev_batch in enumerate(tqdm(dev_dataloader, desc="Evaluating")): 380 | batch = [t.to(device) for t in batch] 381 | dev_input_ids, dev_segment_ids, dev_input_mask, dev_masked_lm_labels, dev_next_sentence_labels = batch 382 | loss = model(input_ids=dev_input_ids, token_type_ids=dev_segment_ids, attention_mask=dev_input_mask, masked_lm_labels=dev_masked_lm_labels) 383 | dev_final_loss += loss 384 | dev_global_step += 1 385 | dev_final_loss /= dev_global_step 386 | if (torch.distributed.is_initialized()): 387 | dev_final_loss /= torch.distributed.get_world_size() 388 | torch.distributed.all_reduce(dev_final_loss) 389 | logger.info("Dev Loss: {}".format(dev_final_loss.item())) 390 | if dev_final_loss < min_dev_loss: 391 | best_step = global_step 392 | min_dev_loss = dev_final_loss 393 | if (not torch.distributed.is_initialized() or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)): 394 | logger.info("** ** * Saving best dev loss model ** ** * at step {}".format(best_step)) 395 | dev_model_to_save = model.module if hasattr(model, 'module') else model 396 | output_save_file = os.path.join(args.output_dir, "best_ckpt.pt") 397 | torch.save({'model' : dev_model_to_save.state_dict(), 398 | 'optimizer' : optimizer.state_dict(), 399 | 'files' : [f_id] + files}, output_save_file) 400 | 401 | if (not torch.distributed.is_initialized() or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)): 402 | # Save a trained model 403 | logger.info("** ** * Saving fine - tuned model ** ** * ") 404 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 405 | output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)) 406 | 407 | torch.save({'model' : model_to_save.state_dict(), 408 | 'optimizer' : optimizer.state_dict(), 409 | 'files' : [f_id] + files}, output_save_file) 410 | 411 | most_recent_ckpts_paths.append(output_save_file) 412 | if len(most_recent_ckpts_paths) > args.save_total_limit: 413 | ckpt_to_be_removed = most_recent_ckpts_paths.pop(0) 414 | os.remove(ckpt_to_be_removed) 415 | 416 | if global_step >= args.max_steps: 417 | tr_loss = tr_loss * args.gradient_accumulation_steps / training_steps 418 | if (torch.distributed.is_initialized()): 419 | tr_loss /= torch.distributed.get_world_size() 420 | print(tr_loss) 421 | torch.distributed.all_reduce(torch.tensor(tr_loss).cuda()) 422 | logger.info("Total Steps:{} Final Loss = {}".format(training_steps, tr_loss)) 423 | 424 | with open(os.path.join(args.output_dir, "valid_results.txt"), "w") as f: 425 | f.write("Min dev loss: {}\nBest step: {}\n".format(min_dev_loss, best_step)) 426 | 427 | return 428 | del train_dataloader 429 | del train_sampler 430 | del train_data 431 | 432 | torch.cuda.empty_cache() 433 | epoch += 1 434 | 435 | if __name__ == "__main__": 436 | main() 437 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import sys 7 | import random 8 | from tqdm import tqdm, trange 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 14 | TensorDataset, Dataset) 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torch.nn import CrossEntropyLoss, MSELoss 17 | 18 | from model.modeling_classification import BertForSequenceClassification, WEIGHTS_NAME, CONFIG_NAME, VOCAB_NAME 19 | from model.tokenization import BertTokenizer 20 | from model.optimization import BertAdam, warmup_linear 21 | 22 | from data.data_utils import processors, output_modes, convert_examples_to_features, compute_metrics 23 | 24 | if sys.version_info[0] == 2: 25 | import cPickle as pickle 26 | else: 27 | import pickle 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | class InputDataset(Dataset): 33 | def __init__(self, input_ids, attn_masks, segment_ids, labels): 34 | self.input_ids = input_ids 35 | self.attn_masks = attn_masks 36 | self.segment_ids = segment_ids 37 | self.labels = labels 38 | 39 | self.pads = { 40 | "input_ids": 0, 41 | "attention_mask": 0, 42 | "token_type_ids": 0 43 | } 44 | 45 | def __len__(self): 46 | return len(self.input_ids) 47 | 48 | def __getitem__(self, item): 49 | return { 50 | "input_ids": self.input_ids[item], 51 | "attention_mask": self.attn_masks[item], 52 | "token_type_ids": self.segment_ids[item], 53 | }, { 54 | "labels": self.labels[item] 55 | } 56 | 57 | def collate(self, example): 58 | seq_insts = [e[0] for e in example] 59 | int_insts = [e[1] for e in example] 60 | max_length = max([len(x["input_ids"]) for x in seq_insts]) 61 | 62 | inputs = {} 63 | labels = {} 64 | 65 | for key in seq_insts[0].keys(): 66 | seq = [inst[key] + [self.pads[key]] * (max_length - len(inst[key])) for inst in seq_insts] 67 | inputs[key] = torch.tensor(seq, dtype=torch.long) 68 | for key in int_insts[0].keys(): 69 | labels[key] = torch.tensor([inst[key] for inst in int_insts]) 70 | 71 | return inputs, labels 72 | 73 | def main(): 74 | parser = argparse.ArgumentParser() 75 | 76 | ## Required parameters 77 | parser.add_argument("--data_dir", 78 | default=None, 79 | type=str, 80 | required=True, 81 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 82 | parser.add_argument("--bert_model", default=None, type=str, required=True, 83 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 84 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 85 | "bert-base-multilingual-cased, bert-base-chinese.") 86 | parser.add_argument("--task_name", 87 | default=None, 88 | type=str, 89 | required=True, 90 | help="The name of the task to train.") 91 | parser.add_argument("--output_dir", 92 | default=None, 93 | type=str, 94 | required=True, 95 | help="The output directory where the model predictions and checkpoints will be written.") 96 | 97 | ## Other parameters 98 | parser.add_argument("--vocab_file", 99 | default="", 100 | type=str) 101 | parser.add_argument("--cache_dir", 102 | default="", 103 | type=str, 104 | help="Where do you want to store the pre-trained models downloaded from s3") 105 | parser.add_argument("--max_seq_length", 106 | default=128, 107 | type=int, 108 | help="The maximum total input sequence length after WordPiece tokenization. \n" 109 | "Sequences longer than this will be truncated, and sequences shorter \n" 110 | "than this will be padded.") 111 | parser.add_argument("--do_train", 112 | action='store_true', 113 | help="Whether to run training.") 114 | parser.add_argument("--do_eval", 115 | action='store_true', 116 | help="Whether to run eval on the dev set.") 117 | parser.add_argument("--do_lower_case", 118 | action='store_true', 119 | help="Set this flag if you are using an uncased model.") 120 | parser.add_argument("--train_batch_size", 121 | default=32, 122 | type=int, 123 | help="Total batch size for training.") 124 | parser.add_argument("--eval_batch_size", 125 | default=8, 126 | type=int, 127 | help="Total batch size for eval.") 128 | parser.add_argument("--learning_rate", 129 | default=5e-5, 130 | type=float, 131 | help="The initial learning rate for Adam.") 132 | parser.add_argument("--num_train_epochs", 133 | default=3.0, 134 | type=float, 135 | help="Total number of training epochs to perform.") 136 | parser.add_argument("--warmup_proportion", 137 | default=0.1, 138 | type=float, 139 | help="Proportion of training to perform linear learning rate warmup for. " 140 | "E.g., 0.1 = 10%% of training.") 141 | parser.add_argument("--no_cuda", 142 | action='store_true', 143 | help="Whether not to use CUDA when available") 144 | parser.add_argument('--overwrite_output_dir', 145 | action='store_true', 146 | help="Overwrite the content of the output directory") 147 | parser.add_argument("--local_rank", 148 | type=int, 149 | default=-1, 150 | help="local_rank for distributed training on gpus") 151 | parser.add_argument('--seed', 152 | type=int, 153 | default=42, 154 | help="random seed for initialization") 155 | parser.add_argument('--gradient_accumulation_steps', 156 | type=int, 157 | default=1, 158 | help="Number of updates steps to accumulate before performing a backward/update pass.") 159 | parser.add_argument('--fp16', 160 | action='store_true', 161 | help="Whether to use 16-bit float precision instead of 32-bit") 162 | parser.add_argument("--fp16_opt_level", 163 | type=str, 164 | default="O1", 165 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 166 | "See details at https://nvidia.github.io/apex/amp.html") 167 | parser.add_argument('--loss_scale', 168 | type=float, default=0, 169 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 170 | "0 (default value): dynamic loss scaling.\n" 171 | "Positive power of 2: static loss scaling value.\n") 172 | parser.add_argument("--ckpt", type=str, help="ckpt position") 173 | parser.add_argument("--save_all", action="store_true") 174 | parser.add_argument("--output_dev_detail", action="store_true") 175 | args = parser.parse_args() 176 | 177 | if args.local_rank == -1 or args.no_cuda: 178 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 179 | n_gpu = torch.cuda.device_count() 180 | else: 181 | torch.cuda.set_device(args.local_rank) 182 | device = torch.device("cuda", args.local_rank) 183 | n_gpu = 1 184 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 185 | torch.distributed.init_process_group(backend='nccl') 186 | args.device = device 187 | 188 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 189 | datefmt = '%m/%d/%Y %H:%M:%S', 190 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 191 | 192 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 193 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 194 | 195 | if args.gradient_accumulation_steps < 1: 196 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 197 | args.gradient_accumulation_steps)) 198 | 199 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 200 | 201 | random.seed(args.seed) 202 | np.random.seed(args.seed) 203 | torch.manual_seed(args.seed) 204 | if n_gpu > 0: 205 | torch.cuda.manual_seed_all(args.seed) 206 | 207 | if not args.do_train and not args.do_eval: 208 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 209 | 210 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 211 | os.makedirs(args.output_dir) 212 | 213 | task_name = args.task_name.lower() 214 | 215 | if task_name not in processors: 216 | raise ValueError("Task not found: %s" % (task_name)) 217 | 218 | processor = processors[task_name]() 219 | output_mode = output_modes[task_name] 220 | 221 | label_list = processor.get_labels() 222 | num_labels = len(label_list) 223 | 224 | if args.local_rank not in [-1, 0]: 225 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 226 | 227 | if args.vocab_file: 228 | tokenizer = BertTokenizer(args.vocab_file, args.do_lower_case) 229 | else: 230 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 231 | 232 | model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) 233 | 234 | if args.ckpt: 235 | print("load from", args.ckpt) 236 | model_dict = model.state_dict() 237 | ckpt = torch.load(args.ckpt) 238 | if "model" in ckpt: 239 | pretrained_dict = ckpt['model'] 240 | else: 241 | pretrained_dict = ckpt 242 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys() and k not in ["classifier.weight", "classifier.bias"]} 243 | model_dict.update(new_dict) 244 | print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict))) 245 | model.load_state_dict(model_dict) 246 | 247 | if args.local_rank == 0: 248 | torch.distributed.barrier() 249 | 250 | model.to(device) 251 | if args.local_rank != -1: 252 | model = torch.nn.parallel.DistributedDataParallel(model, 253 | device_ids=[args.local_rank], 254 | output_device=args.local_rank, 255 | find_unused_parameters=True) 256 | elif n_gpu > 1: 257 | model = torch.nn.DataParallel(model) 258 | 259 | global_step = 0 260 | nb_tr_steps = 0 261 | tr_loss = 0 262 | 263 | if args.do_train: 264 | # Prepare data loader 265 | train_examples = processor.get_train_examples(args.data_dir) 266 | print(len(train_examples)) 267 | cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format( 268 | list(filter(None, args.bert_model.split('/'))).pop(), 269 | str(args.max_seq_length), 270 | str(task_name))) 271 | try: 272 | with open(cached_train_features_file, "rb") as reader: 273 | train_features = pickle.load(reader) 274 | except: 275 | train_features = convert_examples_to_features( 276 | train_examples, label_list, args.max_seq_length, tokenizer, output_mode) 277 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 278 | logger.info(" Saving train features into cached file %s", cached_train_features_file) 279 | with open(cached_train_features_file, "wb") as writer: 280 | pickle.dump(train_features, writer) 281 | 282 | all_input_ids = [f.input_ids for f in train_features] 283 | all_input_mask = [f.input_mask for f in train_features] 284 | all_segment_ids = [f.segment_ids for f in train_features] 285 | 286 | if output_mode == "classification": 287 | all_label_ids = [f.label_id for f in train_features] 288 | elif output_mode == "regression": 289 | all_label_ids = [f.label_id for f in train_features] 290 | 291 | train_data = InputDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 292 | 293 | if args.local_rank == -1: 294 | train_sampler = RandomSampler(train_data) 295 | else: 296 | train_sampler = DistributedSampler(train_data) 297 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=train_data.collate) 298 | 299 | num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 300 | 301 | # Prepare optimizer 302 | 303 | param_optimizer = list(model.named_parameters()) 304 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 305 | optimizer_grouped_parameters = [ 306 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 307 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 308 | ] 309 | optimizer = BertAdam(optimizer_grouped_parameters, 310 | lr=args.learning_rate, 311 | warmup=args.warmup_proportion, 312 | t_total=num_train_optimization_steps) 313 | if args.fp16: 314 | try: 315 | from apex import amp 316 | except ImportError: 317 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 318 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 319 | 320 | logger.info("***** Running training *****") 321 | logger.info(" Num examples = %d", len(train_examples)) 322 | logger.info(" Batch size = %d", args.train_batch_size) 323 | logger.info(" Num steps = %d", num_train_optimization_steps) 324 | 325 | os.makedirs(os.path.join(args.output_dir, "all_models"), exist_ok=True) 326 | model.train() 327 | for e in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]): 328 | tr_loss = 0 329 | nb_tr_examples, nb_tr_steps = 0, 0 330 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): 331 | inputs, labels = batch 332 | for key in inputs.keys(): 333 | inputs[key] = inputs[key].to(args.device) 334 | for key in labels.keys(): 335 | labels[key] = labels[key].to(args.device) 336 | # define a new function to compute loss values for both output_modes 337 | label_ids = labels["labels"] 338 | logits = model(**inputs) 339 | 340 | if output_mode == "classification": 341 | loss_fct = CrossEntropyLoss() 342 | loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 343 | elif output_mode == "regression": 344 | loss_fct = MSELoss() 345 | loss = loss_fct(logits.view(-1), label_ids.view(-1)) 346 | 347 | if n_gpu > 1: 348 | loss = loss.mean() # mean() to average on multi-gpu. 349 | if args.gradient_accumulation_steps > 1: 350 | loss = loss / args.gradient_accumulation_steps 351 | 352 | if args.fp16: 353 | with amp.scale_loss(loss, optimizer) as scaled_loss: 354 | scaled_loss.backward() 355 | else: 356 | loss.backward() 357 | 358 | tr_loss += loss.item() 359 | nb_tr_steps += 1 360 | if (step + 1) % args.gradient_accumulation_steps == 0: 361 | optimizer.step() 362 | optimizer.zero_grad() 363 | global_step += 1 364 | # save each epoch 365 | model_to_save = model.module if hasattr(model, 'module') else model 366 | output_model_file = os.path.join(args.output_dir, "all_models", "e{}_{}".format(e, WEIGHTS_NAME)) 367 | torch.save(model_to_save.state_dict(), output_model_file) 368 | 369 | ### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 370 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 371 | output_args_file = os.path.join(args.output_dir, 'training_args.bin') 372 | torch.save(args, output_args_file) 373 | else: 374 | model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) 375 | 376 | ### Evaluation 377 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 378 | best_acc = 0 379 | best_epoch = 0 380 | val_res_file = os.path.join(args.output_dir, "valid_results.txt") 381 | val_f = open(val_res_file, "w") 382 | if args.output_dev_detail: 383 | logger.info("***** Dev Eval results *****") 384 | for e in tqdm(range(int(args.num_train_epochs)), desc="Epoch on dev"): 385 | weight_path = os.path.join(args.output_dir, "all_models", "e{}_{}".format(e, WEIGHTS_NAME)) 386 | result = evaluate(args, model, weight_path, processor, device, task_name, "dev", label_list, tokenizer, output_mode, num_labels, show_detail=False) 387 | if result["acc"] > best_acc: 388 | best_acc = result["acc"] 389 | best_epoch = e 390 | 391 | if args.output_dev_detail: 392 | logger.info("Epoch {}".format(e)) 393 | val_f.write("Epoch {}\n".format(e)) 394 | for key in sorted(result.keys()): 395 | if args.output_dev_detail: 396 | logger.info("{} = {}".format(key, str(result[key]))) 397 | val_f.write("{} = {}\n".format(key, str(result[key]))) 398 | val_f.write("\n") 399 | 400 | logger.info("\nBest epoch: {}. Best val acc: {}".format(best_epoch, best_acc)) 401 | val_f.write("Best epoch: {}. Best val acc: {}\n".format(best_epoch, best_acc)) 402 | val_f.close() 403 | 404 | test_weight_path = os.path.join(args.output_dir, "all_models", "e{}_{}".format(best_epoch, WEIGHTS_NAME)) 405 | test_result = evaluate(args, model, test_weight_path, processor, device, task_name, "test", label_list, tokenizer, output_mode, num_labels) 406 | test_res_file = os.path.join(args.output_dir, "test_results.txt") 407 | 408 | logger.info("***** Test Eval results *****") 409 | with open(test_res_file, "w") as test_f: 410 | for key in sorted(test_result.keys()): 411 | logger.info("{} = {}".format(key, str(test_result[key]))) 412 | test_f.write("{} = {}\n".format(key, str(test_result[key]))) 413 | 414 | best_model_dir = os.path.join(args.output_dir, "best_model") 415 | os.makedirs(best_model_dir, exist_ok=True) 416 | os.system("cp {} {}/{}".format(test_weight_path, best_model_dir, WEIGHTS_NAME)) 417 | with open(os.path.join(best_model_dir, CONFIG_NAME), 'w') as f: 418 | f.write(model_to_save.config.to_json_string()) 419 | tokenizer.save_vocab(os.path.join(best_model_dir, VOCAB_NAME)) 420 | 421 | if not args.save_all: 422 | os.system("rm -r {}".format(os.path.join(args.output_dir, "all_models"))) 423 | 424 | 425 | def evaluate(args, model, weight_path, processor, device, task_name, mode, label_list, tokenizer, output_mode, num_labels, show_detail=True): 426 | model.load_state_dict(torch.load(weight_path)) 427 | model.to(device) 428 | 429 | if show_detail: 430 | print("Loading From: ", weight_path) 431 | 432 | if mode == "test": 433 | eval_examples = processor.get_test_examples(args.data_dir) 434 | cached_eval_features_file = os.path.join(args.data_dir, 'test_{0}_{1}_{2}'.format( 435 | list(filter(None, args.bert_model.split('/'))).pop(), 436 | str(args.max_seq_length), 437 | str(task_name))) 438 | else: 439 | eval_examples = processor.get_dev_examples(args.data_dir) 440 | cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format( 441 | list(filter(None, args.bert_model.split('/'))).pop(), 442 | str(args.max_seq_length), 443 | str(task_name))) 444 | 445 | try: 446 | with open(cached_eval_features_file, "rb") as reader: 447 | eval_features = pickle.load(reader) 448 | except: 449 | eval_features = convert_examples_to_features( 450 | eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) 451 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 452 | logger.info( 453 | " Saving eval features into cached file %s", cached_eval_features_file) 454 | with open(cached_eval_features_file, "wb") as writer: 455 | pickle.dump(eval_features, writer) 456 | 457 | if show_detail: 458 | logger.info("***** Running evaluation *****") 459 | logger.info(" Num examples = %d", len(eval_examples)) 460 | logger.info(" Batch size = %d", args.eval_batch_size) 461 | all_input_ids = [f.input_ids for f in eval_features] 462 | all_input_mask = [f.input_mask for f in eval_features] 463 | all_segment_ids = [f.segment_ids for f in eval_features] 464 | 465 | if output_mode == "classification": 466 | all_label_ids = [f.label_id for f in eval_features] 467 | elif output_mode == "regression": 468 | all_label_ids = [f.label_id for f in eval_features] 469 | 470 | eval_data = InputDataset( 471 | all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 472 | # Run prediction for full data 473 | if args.local_rank == -1: 474 | eval_sampler = SequentialSampler(eval_data) 475 | else: 476 | # Note that this sampler samples randomly 477 | eval_sampler = DistributedSampler(eval_data) 478 | eval_dataloader = DataLoader( 479 | eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=eval_data.collate) 480 | 481 | model.eval() 482 | eval_loss = 0 483 | nb_eval_steps = 0 484 | preds = [] 485 | out_label_ids = None 486 | 487 | for batch in tqdm(eval_dataloader, desc="Evaluating", disable=(not show_detail)): 488 | inputs, labels = batch 489 | for key in inputs.keys(): 490 | inputs[key] = inputs[key].to(args.device) 491 | for key in labels.keys(): 492 | labels[key] = labels[key].to(args.device) 493 | # define a new function to compute loss values for both output_modes 494 | label_ids = labels["labels"] 495 | 496 | with torch.no_grad(): 497 | logits = model(**inputs) 498 | 499 | # create eval loss and other metric required by the task 500 | if output_mode == "classification": 501 | loss_fct = CrossEntropyLoss() 502 | tmp_eval_loss = loss_fct( 503 | logits.view(-1, num_labels), label_ids.view(-1)) 504 | elif output_mode == "regression": 505 | loss_fct = MSELoss() 506 | tmp_eval_loss = loss_fct( 507 | logits.view(-1), label_ids.view(-1)) 508 | 509 | eval_loss += tmp_eval_loss.mean().item() 510 | nb_eval_steps += 1 511 | if len(preds) == 0: 512 | preds.append(logits.detach().cpu().numpy()) 513 | out_label_ids = label_ids.detach().cpu().numpy() 514 | else: 515 | preds[0] = np.append( 516 | preds[0], logits.detach().cpu().numpy(), axis=0) 517 | out_label_ids = np.append( 518 | out_label_ids, label_ids.detach().cpu().numpy(), axis=0) 519 | 520 | eval_loss = eval_loss / nb_eval_steps 521 | preds = preds[0] 522 | if output_mode == "classification": 523 | preds = np.argmax(preds, axis=1) 524 | elif output_mode == "regression": 525 | preds = np.squeeze(preds) 526 | result = compute_metrics(task_name, preds, out_label_ids) 527 | 528 | return result 529 | 530 | 531 | if __name__ == "__main__": 532 | main() 533 | -------------------------------------------------------------------------------- /data/sc_mask_gen.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import spacy 6 | import sys 7 | import collections 8 | import multiprocessing 9 | from spacy.lang.en import English 10 | from tqdm import tqdm 11 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.functional import softmax 14 | 15 | sys.path.append("../") 16 | from model.modeling_classification import BertForSequenceClassification, BertForTokenClassification 17 | from model.tokenization import BertTokenizer 18 | 19 | logger = logging.getLogger(__name__) 20 | MaskedTokenInstance = collections.namedtuple("MaskedTokenInstance", ["tokens", "info"]) 21 | MaskedItemInfo = collections.namedtuple("MaskedItemInfo", ["current_pos", "sen_doc_pos", "sen_right_id", "doc_ground_truth"]) 22 | nlp = English() 23 | sentencizer = nlp.create_pipe("sentencizer") 24 | nlp.add_pipe(sentencizer) 25 | 26 | class InputFeatures(object): 27 | def __init__(self, input_ids, input_mask, segment_ids=None): 28 | self.input_ids = input_ids 29 | self.input_mask = input_mask 30 | self.segment_ids = segment_ids 31 | 32 | class SC(nn.Module): 33 | def __init__(self, mask_rate, top_sen_rate, threshold, bert_model, do_lower_case, max_seq_length, label_list, sen_batch_size, use_gpu=True): 34 | super(SC, self).__init__() 35 | self.mask_rate = mask_rate 36 | self.top_sen_rate = top_sen_rate 37 | self.threshold = threshold 38 | self.label_list = label_list 39 | self.num_labels = len(self.label_list) 40 | self.max_seq_length = max_seq_length 41 | self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case) 42 | self.model = BertForSequenceClassification.from_pretrained(bert_model, num_labels=self.num_labels) 43 | self.device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu") 44 | self.model.to(self.device) 45 | self.n_gpu = torch.cuda.device_count() 46 | self.sen_batch_size = sen_batch_size 47 | self.vocab = list(self.tokenizer.vocab.keys()) 48 | if self.n_gpu > 1: 49 | self.model = torch.nn.DataParallel(self.model) 50 | 51 | def convert_examples_to_features(self, data): 52 | features = [] 53 | for (ex_index, tokens_a) in enumerate(data): 54 | if ex_index % 10000 == 0: 55 | logger.info("Writing example %d of %d" % (ex_index, len(data))) 56 | if len(tokens_a) > self.max_seq_length - 2: 57 | tokens_a = tokens_a[:(self.max_seq_length - 2)] 58 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 59 | 60 | segment_ids = [0] * len(tokens) 61 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 62 | input_mask = [1] * len(input_ids) 63 | 64 | padding = [0] * (self.max_seq_length - len(input_ids)) 65 | input_ids += padding 66 | input_mask += padding 67 | segment_ids += padding 68 | 69 | assert len(input_ids) == self.max_seq_length 70 | assert len(input_mask) == self.max_seq_length 71 | assert len(segment_ids) == self.max_seq_length 72 | 73 | features.append(InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids)) 74 | 75 | return features 76 | 77 | def evaluate(self, data, batch_size): 78 | eval_features = self.convert_examples_to_features(data) 79 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 80 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 81 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 82 | 83 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) 84 | eval_sampler = SequentialSampler(eval_data) 85 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size) 86 | 87 | self.model.eval() 88 | preds = [] 89 | for input_ids, input_mask, segment_ids, in tqdm(eval_dataloader, desc="Evaluating"): 90 | input_ids = input_ids.to(self.device) 91 | input_mask = input_mask.to(self.device) 92 | segment_ids = segment_ids.to(self.device) 93 | with torch.no_grad(): 94 | logits = self.model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) 95 | logits = softmax(logits, dim=1) 96 | if len(preds) == 0: 97 | preds.append(logits.detach().cpu().numpy()) 98 | else: 99 | preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0) 100 | 101 | preds_arg = np.argmax(preds[0], axis=1) 102 | return preds_arg, preds[0] 103 | 104 | def create_mask(self, mask_poses, sen, rng): 105 | masked_info = [{} for token in sen] 106 | for pos in mask_poses: 107 | lexeme = nlp.vocab[sen[pos]] 108 | if lexeme.is_stop: 109 | continue 110 | if rng.random() < 0.8: 111 | mask_token = "[MASK]" 112 | else: 113 | if rng.random() < 0.5: 114 | mask_token = sen[pos] 115 | else: 116 | mask_token = self.vocab[rng.randint(0, len(self.vocab) - 1)] 117 | masked_info[pos]["mask"] = mask_token 118 | masked_info[pos]["label"] = sen[pos] 119 | return masked_info 120 | 121 | def create_reverse_mask(self, mask_poses, sen, rng): 122 | reverse_mask_poses = [i for i in range(len(sen)) if i not in mask_poses] 123 | rng.shuffle(reverse_mask_poses) 124 | cand_indexes = reverse_mask_poses[0:max(1, int(self.mask_rate * len(sen)))] 125 | masked_info = [{} for token in sen] 126 | for cand_index in cand_indexes: 127 | if rng.random() < 0.8: 128 | mask_token = "[MASK]" 129 | else: 130 | if rng.random() < 0.5: 131 | mask_token = sen[cand_index] 132 | else: 133 | mask_token = self.vocab[rng.randint(0, len(self.vocab) - 1)] 134 | masked_info[cand_index]["mask"] = mask_token 135 | masked_info[cand_index]["label"] = sen[cand_index] 136 | return masked_info 137 | 138 | def forward(self, data, all_labels, dupe_factor, rng): 139 | # convert label to ids 140 | doc_num = len(data) 141 | label_map = {label : i for i, label in enumerate(self.label_list)} 142 | all_label_ids = [label_map[label] for label in all_labels] 143 | 144 | # convert data, segment data to sentences 145 | sentences = [] 146 | sen_doc_ids = [] # [0, 0, ..., 0, 1, 1, ..., 1, ...] 147 | for (doc_id, doc) in enumerate(data): 148 | doc = nlp(doc) 149 | tL = [self.tokenizer.tokenize(sen.text) for sen in doc.sents] 150 | sentences.extend(tL) 151 | sen_doc_ids.extend([doc_id] * len(tL)) 152 | 153 | logger.info("Begin eval for all sentence") 154 | sens_preds, sens_pred_scores = self.evaluate(sentences, self.sen_batch_size) 155 | 156 | right_sens = [] 157 | right_preds = [] 158 | right_scores = [] 159 | right_sen_doc_ids = [] 160 | right_sen_doc_poses = [] 161 | i = 0 162 | for doc_id in range(doc_num): 163 | ds = [] 164 | while i < len(sen_doc_ids) and sen_doc_ids[i] == doc_id: 165 | sen_pred = sens_preds[i] 166 | doc_ground_truth = all_label_ids[doc_id] 167 | # compare with ground truth 168 | if doc_ground_truth == sen_pred: 169 | ds.append((sentences[i], doc_id, i, sen_pred, sens_pred_scores[i][doc_ground_truth])) 170 | i += 1 171 | if len(ds) == 0: 172 | continue 173 | ds = sorted(ds, key=lambda x : x[-1], reverse=True) 174 | t_sen, t_sen_doc_id, t_sen_doc_pos, t_pred, t_score = zip(*ds[0:max(int(self.top_sen_rate * len(ds)), 1)]) # select top sentences 175 | right_sens.extend(t_sen) 176 | right_preds.extend(t_pred) 177 | right_scores.extend(t_score) 178 | right_sen_doc_ids.extend(t_sen_doc_id) 179 | right_sen_doc_poses.extend(t_sen_doc_pos) 180 | 181 | right_sens_num = len(right_sens) 182 | # convert right sentence to reverse 183 | 184 | masked_sens = [] 185 | masked_item_infos = [] 186 | 187 | # init 188 | for sen_right_id, (sen_doc_pos, sen) in enumerate(zip(right_sen_doc_poses, right_sens)): 189 | masked_sens.append(sen[0:1]) 190 | masked_item_infos.append({"sen_doc_pos": sen_doc_pos, "sen_right_id": sen_right_id, "doc_ground_truth": all_label_ids[sen_doc_ids[sen_doc_pos]]}) 191 | 192 | mask_poses_d = {} 193 | mask_pos = 0 194 | 195 | while len(masked_sens) != 0: 196 | _, mask_sens_scores = self.evaluate(masked_sens, self.sen_batch_size) 197 | masked_sens_num = len(masked_sens) 198 | temp_masked_sens = [] 199 | temp_masked_item_infos = [] 200 | for masked_sen, masked_item_info, mask_sens_score in zip(masked_sens, masked_item_infos, mask_sens_scores): 201 | sen_doc_pos = masked_item_info["sen_doc_pos"] 202 | doc_ground_truth = masked_item_info["doc_ground_truth"] 203 | sen_right_id = masked_item_info["sen_right_id"] 204 | origin_score = right_scores[sen_right_id] 205 | if origin_score - mask_sens_score[doc_ground_truth] < self.threshold: 206 | # choose as mask 207 | if sen_doc_pos in mask_poses_d: 208 | mask_poses_d[sen_doc_pos].append(mask_pos) 209 | else: 210 | mask_poses_d[sen_doc_pos] = [mask_pos] 211 | masked_sen.pop() 212 | 213 | # add next token 214 | if mask_pos + 1 < len(right_sens[sen_right_id]): 215 | masked_sen.append(right_sens[sen_right_id][mask_pos + 1]) 216 | temp_masked_sens.append(masked_sen) 217 | temp_masked_item_infos.append(masked_item_info) 218 | 219 | masked_sens = temp_masked_sens 220 | masked_item_infos = temp_masked_item_infos 221 | mask_pos += 1 222 | 223 | all_documents = [] 224 | 225 | for _ in range(dupe_factor): 226 | i = 0 227 | for doc_id in tqdm(range(doc_num), desc="Generating All Documents"): 228 | all_documents.append([]) 229 | while i < len(sen_doc_ids) and doc_id == sen_doc_ids[i]: 230 | mask_poses = [] 231 | if i in mask_poses_d: 232 | mask_poses = mask_poses_d[i] 233 | m_info = self.create_mask(mask_poses, sentences[i], rng) 234 | all_documents[-1].append(MaskedTokenInstance(tokens=sentences[i], info=m_info)) 235 | i += 1 236 | return all_documents 237 | 238 | class ASC(nn.Module): 239 | def __init__(self, mask_rate, top_sen_rate, threshold, bert_model, do_lower_case, max_seq_length, label_list, sen_batch_size, use_gpu=True): 240 | super(ASC, self).__init__() 241 | self.mask_rate = mask_rate 242 | self.top_sen_rate = top_sen_rate 243 | self.threshold = threshold 244 | self.label_list = label_list 245 | self.num_labels = len(self.label_list) 246 | self.max_seq_length = max_seq_length 247 | self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case) 248 | self.model = BertForSequenceClassification.from_pretrained(bert_model, num_labels=self.num_labels) 249 | self.device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu") 250 | print(self.device) 251 | self.model.to(self.device) 252 | self.n_gpu = torch.cuda.device_count() 253 | self.sen_batch_size = sen_batch_size 254 | self.vocab = list(self.tokenizer.vocab.keys()) 255 | if self.n_gpu > 1: 256 | self.model = torch.nn.DataParallel(self.model) 257 | 258 | def evaluate(self, data, batch_size): 259 | eval_features = self.convert_examples_to_features(data) 260 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 261 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 262 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 263 | 264 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) 265 | eval_sampler = SequentialSampler(eval_data) 266 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size) 267 | 268 | self.model.eval() 269 | preds = [] 270 | for input_ids, input_mask, segment_ids, in tqdm(eval_dataloader, desc="Evaluating"): 271 | input_ids = input_ids.to(self.device) 272 | input_mask = input_mask.to(self.device) 273 | segment_ids = segment_ids.to(self.device) 274 | with torch.no_grad(): 275 | logits = self.model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) 276 | logits = softmax(logits, dim=1) 277 | if len(preds) == 0: 278 | preds.append(logits.detach().cpu().numpy()) 279 | else: 280 | preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0) 281 | 282 | preds_arg = np.argmax(preds[0], axis=1) 283 | return preds_arg, preds[0] 284 | 285 | def create_mask(self, mask_poses, sen, rng): 286 | masked_info = [{} for token in sen] 287 | for pos in mask_poses: 288 | lexeme = nlp.vocab[sen[pos]] 289 | if lexeme.is_stop: 290 | # print("stop words: ", sen[pos]) 291 | continue 292 | if rng.random() < 0.8: 293 | mask_token = "[MASK]" 294 | else: 295 | if rng.random() < 0.5: 296 | mask_token = sen[pos] 297 | else: 298 | mask_token = self.vocab[rng.randint( 299 | 0, len(self.vocab) - 1)] 300 | masked_info[pos]["mask"] = mask_token 301 | masked_info[pos]["label"] = sen[pos] 302 | return masked_info 303 | 304 | def convert_examples_to_features(self, data): 305 | features = [] 306 | for (ex_index, item) in enumerate(data): 307 | tokens_b = item["text"] 308 | if len(tokens_b) > self.max_seq_length - 2: 309 | tokens_b = tokens_b[:(self.max_seq_length - 2)] 310 | tokens_a = item["aspect"] 311 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] 312 | 313 | segment_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1) 314 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 315 | input_mask = [1] * len(input_ids) 316 | 317 | padding = [0] * (self.max_seq_length - len(input_ids)) 318 | input_ids += padding 319 | input_mask += padding 320 | segment_ids += padding 321 | assert len(input_ids) == self.max_seq_length 322 | assert len(input_mask) == self.max_seq_length 323 | assert len(segment_ids) == self.max_seq_length 324 | features.append(InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids)) 325 | 326 | return features 327 | 328 | def forward(self, data, all_labels, dupe_factor, rng): 329 | # data[i]: {"text": ... , "facts": ["aspect1": label1, "aspect2": label2, ...]} 330 | doc_num = len(data) 331 | label_map = {label : i for i, label in enumerate(self.label_list)} 332 | 333 | sen_doc_ids = [] 334 | sentences = [] 335 | texts = [] 336 | for (doc_id, doc) in enumerate(data): 337 | text = self.tokenizer.tokenize(doc["text"]) 338 | texts.append(text) 339 | for fact in doc["facts"]: 340 | prefix = fact["category"] if "category" in fact else fact["term"] 341 | if fact["polarity"] != "conflict": 342 | sentences.append({"text": text, "aspect": self.tokenizer.tokenize(prefix), "label": label_map[fact["polarity"]]}) 343 | sen_doc_ids.append(doc_id) 344 | 345 | logger.info("Begin eval for all sentence") 346 | sens_preds, sens_pred_scores = self.evaluate(sentences, self.sen_batch_size) 347 | 348 | right_sens = [] 349 | right_scores = [] 350 | right_sen_doc_ids = [] 351 | i = 0 352 | for sen_id in range(len(sentences)): 353 | if sens_preds[sen_id] == sentences[sen_id]["label"]: 354 | right_sens.append(sentences[sen_id]) 355 | right_sen_doc_ids.append(sen_doc_ids[sen_id]) 356 | right_scores.append(sens_pred_scores[sen_id][sentences[sen_id]["label"]]) 357 | 358 | masked_sens = [] 359 | masked_item_infos = [] 360 | 361 | # init 362 | for sen_right_id, sen in enumerate(right_sens): 363 | masked_sens.append({"text": sen["text"][0:1], "aspect": sen["aspect"], "label": sen["label"], "sen_right_id": sen_right_id}) 364 | 365 | mask_poses_L = [set() for i in range(doc_num)] 366 | mask_pos = 0 367 | while len(masked_sens) != 0: 368 | _, mask_sens_scores = self.evaluate(masked_sens, self.sen_batch_size) 369 | masked_sens_num = len(masked_sens) 370 | temp_masked_sens = [] 371 | temp_masked_item_infos = [] 372 | for masked_sen, mask_sens_score in zip(masked_sens, mask_sens_scores): 373 | doc_ground_truth = masked_sen["label"] 374 | sen_right_id = masked_sen["sen_right_id"] 375 | origin_score = right_scores[sen_right_id] 376 | right_sen_doc_id = right_sen_doc_ids[sen_right_id] 377 | 378 | if origin_score - mask_sens_score[doc_ground_truth] < self.threshold: 379 | mask_poses_L[right_sen_doc_id].add(mask_pos) 380 | masked_sen["text"].pop() 381 | 382 | if mask_pos + 1 < len(right_sens[sen_right_id]["text"]): 383 | masked_sen["text"].append(right_sens[sen_right_id]["text"][mask_pos + 1]) 384 | temp_masked_sens.append(masked_sen) 385 | 386 | masked_sens = temp_masked_sens 387 | mask_pos += 1 388 | 389 | 390 | all_documents = [] 391 | for doc_id in range(doc_num): 392 | mask_poses = mask_poses_L[doc_id] 393 | 394 | for _ in range(dupe_factor): 395 | for doc_id in tqdm(range(doc_num), desc="Generating All Documents"): 396 | mask_poses = mask_poses_L[doc_id] 397 | m_info = self.create_mask(mask_poses, texts[doc_id], rng) 398 | all_documents.append([MaskedTokenInstance(tokens=texts[doc_id], info=m_info)]) 399 | 400 | return all_documents 401 | 402 | class ModelGen(nn.Module): 403 | def __init__(self, mask_rate, bert_model, do_lower_case, max_seq_length, sen_batch_size, with_rand=False, use_gpu=True): 404 | super(ModelGen, self).__init__() 405 | self.mask_rate = mask_rate 406 | self.max_seq_length = max_seq_length 407 | self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case) 408 | self.model = BertForTokenClassification.from_pretrained(bert_model, num_labels=2) 409 | self.device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu") 410 | self.model.to(self.device) 411 | self.n_gpu = torch.cuda.device_count() 412 | self.sen_batch_size = sen_batch_size 413 | self.vocab = list(self.tokenizer.vocab.keys()) 414 | self.with_rand = with_rand 415 | if self.n_gpu > 1: 416 | self.model = torch.nn.DataParallel(self.model) 417 | 418 | def create_mask(self, mask_poses, sen, rng): 419 | masked_info = [{} for token in sen] 420 | for pos in mask_poses: 421 | if rng.random() < 0.8: 422 | mask_token = "[MASK]" 423 | else: 424 | if rng.random() < 0.5: 425 | mask_token = sen[pos] 426 | else: 427 | mask_token = self.vocab[rng.randint(0, len(self.vocab) - 1)] 428 | masked_info[pos]["mask"] = mask_token 429 | masked_info[pos]["label"] = sen[pos] 430 | return masked_info 431 | 432 | def convert_examples_to_features(self, data): 433 | features = [] 434 | for tokens in tqdm(data, desc="converting to features"): 435 | if len(tokens) >= self.max_seq_length - 1: 436 | tokens = tokens[0:(self.max_seq_length - 2)] 437 | ntokens = [] 438 | ntokens.append("[CLS]") 439 | for token in tokens: 440 | ntokens.append(token) 441 | ntokens.append("[SEP]") 442 | input_ids = self.tokenizer.convert_tokens_to_ids(ntokens) 443 | input_mask = [1] * len(input_ids) 444 | while len(input_ids) < self.max_seq_length: 445 | input_ids.append(0) 446 | input_mask.append(0) 447 | 448 | features.append(InputFeatures(input_ids=input_ids, input_mask=input_mask)) 449 | 450 | return features 451 | 452 | def evaluate(self, data, batch_size): 453 | eval_features = self.convert_examples_to_features(data) 454 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 455 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 456 | del eval_features 457 | eval_data = TensorDataset(all_input_ids, all_input_mask) 458 | eval_sampler = SequentialSampler(eval_data) 459 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size) 460 | 461 | self.model.eval() 462 | preds = [] 463 | all_res = [] 464 | all_logits = [] 465 | for input_ids, input_mask in tqdm(eval_dataloader, desc="Evaluating"): 466 | input_ids = input_ids.to(self.device) 467 | input_mask = input_mask.to(self.device) 468 | with torch.no_grad(): 469 | logits = self.model(input_ids, attention_mask=input_mask) 470 | 471 | res = torch.argmax(logits, dim=2).detach().cpu().numpy() 472 | logits = logits.detach().cpu().numpy() 473 | all_res.extend(res) 474 | all_logits.extend(logits) 475 | 476 | N = len(all_res) 477 | for i in tqdm(range(0, N), desc="Begin CPU"): 478 | r, m, l = all_res[i], all_input_mask[i], all_logits[i] 479 | K = len(m) 480 | t = [] 481 | for j in range(1, K): 482 | mm, rr, ll = m[j], r[j], l[j] 483 | if mm == 1: 484 | t.append((rr, ll[rr])) 485 | t.pop() # pop out [SEP] 486 | preds.append(t) 487 | return preds 488 | 489 | 490 | def forward(self, data, all_labels, dupe_factor, rng): 491 | # data: not tokenized 492 | doc_num = len(data) 493 | # convert data, segment data to sentences 494 | sentences = [] 495 | sen_doc_ids = [] # [0, 0, ..., 0, 1, 1, ..., 1, ...] 496 | for (doc_id, doc) in enumerate(tqdm(data)): 497 | doc = nlp(doc) 498 | tL = [self.tokenizer.tokenize(sen.text) for sen in doc.sents] 499 | sentences.extend(tL) 500 | sen_doc_ids.extend([doc_id] * len(tL)) 501 | del tL 502 | 503 | preds = self.evaluate(sentences, self.sen_batch_size) 504 | 505 | all_documents = [] 506 | rand_all_documents = [] 507 | for _ in range(dupe_factor): 508 | i = 0 509 | for doc_id in tqdm(range(doc_num), desc="Generating All Documents"): 510 | all_documents.append([]) 511 | if self.with_rand: 512 | rand_all_documents.append([]) 513 | while i < len(sen_doc_ids) and doc_id == sen_doc_ids[i]: 514 | mask_poses = [(pos, pred[1]) for (pos, pred) in enumerate(preds[i]) if pred[0] == 1] 515 | mask_poses = sorted(mask_poses, key=lambda x: x[1], reverse=True) 516 | max_mask_num = int(max(1, self.mask_rate * len(sentences[i]))) 517 | mask_poses = [pos for pos, _ in mask_poses[0:max_mask_num]] 518 | m_info = self.create_mask(mask_poses, sentences[i], rng) 519 | all_documents[-1].append(MaskedTokenInstance(tokens=sentences[i], info=m_info)) 520 | if self.with_rand: 521 | cand_indexes = [i for i in range(len(sentences[i]))] 522 | rng.shuffle(cand_indexes) 523 | rand_mask_poses = cand_indexes[0:len(mask_poses)] 524 | rand_m_info = self.create_mask(rand_mask_poses, sentences[i], rng) 525 | rand_all_documents[-1].append(MaskedTokenInstance(tokens=sentences[i], info=rand_m_info)) 526 | i += 1 527 | if self.with_rand: 528 | print("with rand") 529 | return all_documents, rand_all_documents 530 | else: 531 | return all_documents 532 | -------------------------------------------------------------------------------- /mask_model_pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import random 7 | import pickle 8 | import numpy as np 9 | from tqdm import tqdm, trange 10 | from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn 15 | from torch.nn import CrossEntropyLoss 16 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) 17 | from torch.utils.data.distributed import DistributedSampler 18 | 19 | from model.modeling_classification import (CONFIG_NAME, WEIGHTS_NAME, VOCAB_NAME, BertConfig, BertForTokenClassification) 20 | from model.optimization import BertAdam 21 | from model.tokenization import BertTokenizer 22 | 23 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 24 | datefmt='%m/%d/%Y %H:%M:%S', 25 | level=logging.INFO) 26 | logger = logging.getLogger(__name__) 27 | 28 | class InputExample(object): 29 | """A single training/test example for simple sequence classification.""" 30 | 31 | def __init__(self, guid, text_a, text_b=None, label=None): 32 | """Constructs a InputExample. 33 | 34 | Args: 35 | guid: Unique id for the example. 36 | text_a: string. The untokenized text of the first sequence. For single 37 | sequence tasks, only this sequence must be specified. 38 | text_b: (Optional) string. The untokenized text of the second sequence. 39 | Only must be specified for sequence pair tasks. 40 | label: (Optional) string. The label of the example. This should be 41 | specified for train and dev examples, but not for test examples. 42 | """ 43 | self.guid = guid 44 | self.text_a = text_a 45 | self.text_b = text_b 46 | self.label = label 47 | 48 | 49 | class InputFeatures(object): 50 | """A single set of features of data.""" 51 | 52 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 53 | self.input_ids = input_ids 54 | self.input_mask = input_mask 55 | self.segment_ids = segment_ids 56 | self.label_id = label_id 57 | 58 | 59 | def readfile(filename): 60 | ''' 61 | read file 62 | return format [(['I', 'like', 'Marvel'], [0, 1, 0]), (), ...] 63 | ''' 64 | f = open(filename, "rb") 65 | return pickle.load(f) 66 | 67 | 68 | class DataProcessor(object): 69 | """Base class for data converters for sequence classification data sets.""" 70 | 71 | def get_train_examples(self, data_dir): 72 | """Gets a collection of `InputExample`s for the train set.""" 73 | raise NotImplementedError() 74 | 75 | def get_dev_examples(self, data_dir): 76 | """Gets a collection of `InputExample`s for the dev set.""" 77 | raise NotImplementedError() 78 | 79 | def get_labels(self): 80 | """Gets the list of labels for this data set.""" 81 | raise NotImplementedError() 82 | 83 | @classmethod 84 | def _read_tsv(cls, input_file, quotechar=None): 85 | """Reads a tab separated value file.""" 86 | return readfile(input_file) 87 | 88 | 89 | class MaskGenProcessor(DataProcessor): 90 | """Processor for the CoNLL-2003 data set.""" 91 | 92 | def get_train_examples(self, data_dir): 93 | """See base class.""" 94 | return self._create_examples( 95 | self._read_tsv(os.path.join(data_dir, "train.pkl")), "train") 96 | 97 | def get_dev_examples(self, data_dir): 98 | """See base class.""" 99 | return self._create_examples( 100 | self._read_tsv(os.path.join(data_dir, "valid.pkl")), "dev") 101 | 102 | def get_test_examples(self, data_dir): 103 | """See base class.""" 104 | return self._create_examples( 105 | self._read_tsv(os.path.join(data_dir, "test.pkl")), "test") 106 | 107 | def get_labels(self): 108 | return [0, 1] 109 | 110 | def _create_examples(self, lines, set_type): 111 | examples = [] 112 | for i, (sentence, label) in enumerate(lines): 113 | guid = "%s-%s" % (set_type, i) 114 | text_a = sentence 115 | text_b = None 116 | label = label 117 | examples.append(InputExample( 118 | guid=guid, text_a=text_a, text_b=text_b, label=label)) 119 | return examples 120 | 121 | 122 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 123 | """Loads a data file into a list of `InputBatch`s.""" 124 | 125 | # label_map = {label: i for i, label in enumerate(label_list, 1)} 126 | 127 | features = [] 128 | for (ex_index, example) in enumerate(tqdm(examples, desc="processing")): 129 | tokens = example.text_a 130 | labels = example.label 131 | if len(tokens) >= max_seq_length - 1: 132 | tokens = tokens[0:(max_seq_length - 2)] 133 | labels = labels[0:(max_seq_length - 2)] 134 | ntokens = [] 135 | segment_ids = [] 136 | label_ids = [] 137 | ntokens.append("[CLS]") 138 | segment_ids.append(0) 139 | label_ids.append(0) # label 0 for CLS 140 | for i, token in enumerate(tokens): 141 | ntokens.append(token) 142 | segment_ids.append(0) 143 | label_ids.append(labels[i]) 144 | ntokens.append("[SEP]") 145 | segment_ids.append(0) 146 | label_ids.append(0) # label 0 for SEP 147 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 148 | input_mask = [1] * len(input_ids) 149 | while len(input_ids) < max_seq_length: 150 | input_ids.append(0) 151 | input_mask.append(0) 152 | segment_ids.append(0) 153 | label_ids.append(0) 154 | while len(label_ids) < max_seq_length: 155 | label_ids.append(0) 156 | assert len(input_ids) == max_seq_length 157 | assert len(input_mask) == max_seq_length 158 | assert len(segment_ids) == max_seq_length 159 | assert len(label_ids) == max_seq_length 160 | 161 | features.append(InputFeatures(input_ids=input_ids, 162 | input_mask=input_mask, 163 | segment_ids=segment_ids, 164 | label_id=label_ids)) 165 | return features 166 | 167 | def main(): 168 | parser = argparse.ArgumentParser() 169 | 170 | # Required parameters 171 | parser.add_argument("--data_dir", 172 | default=None, 173 | type=str, 174 | required=True, 175 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 176 | parser.add_argument("--bert_model", default=None, type=str, required=True, 177 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 178 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 179 | "bert-base-multilingual-cased, bert-base-chinese.") 180 | parser.add_argument("--task_name", 181 | default=None, 182 | type=str, 183 | required=True, 184 | help="The name of the task to train.") 185 | parser.add_argument("--output_dir", 186 | default=None, 187 | type=str, 188 | required=True, 189 | help="The output directory where the model predictions and checkpoints will be written.") 190 | 191 | ## Other parameters 192 | parser.add_argument("--ckpt", 193 | default="", 194 | type=str) 195 | parser.add_argument("--vocab_file", 196 | default="", 197 | type=str,) 198 | parser.add_argument("--cache_dir", 199 | default="", 200 | type=str, 201 | help="Where do you want to store the pre-trained models downloaded from s3") 202 | parser.add_argument("--max_seq_length", 203 | default=128, 204 | type=int, 205 | help="The maximum total input sequence length after WordPiece tokenization. \n" 206 | "Sequences longer than this will be truncated, and sequences shorter \n" 207 | "than this will be padded.") 208 | parser.add_argument("--do_train", 209 | action='store_true', 210 | help="Whether to run training.") 211 | parser.add_argument("--do_eval", 212 | action='store_true', 213 | help="Whether to run eval on the dev set.") 214 | parser.add_argument("--do_lower_case", 215 | action='store_true', 216 | help="Set this flag if you are using an uncased model.") 217 | parser.add_argument("--train_batch_size", 218 | default=32, 219 | type=int, 220 | help="Total batch size for training.") 221 | parser.add_argument("--eval_batch_size", 222 | default=8, 223 | type=int, 224 | help="Total batch size for eval.") 225 | parser.add_argument("--learning_rate", 226 | default=5e-5, 227 | type=float, 228 | help="The initial learning rate for Adam.") 229 | parser.add_argument("--num_train_epochs", 230 | default=3.0, 231 | type=float, 232 | help="Total number of training epochs to perform.") 233 | parser.add_argument("--warmup_proportion", 234 | default=0.1, 235 | type=float, 236 | help="Proportion of training to perform linear learning rate warmup for. " 237 | "E.g., 0.1 = 10%% of training.") 238 | parser.add_argument("--no_cuda", 239 | action='store_true', 240 | help="Whether not to use CUDA when available") 241 | parser.add_argument("--local_rank", 242 | type=int, 243 | default=-1, 244 | help="local_rank for distributed training on gpus") 245 | parser.add_argument('--seed', 246 | type=int, 247 | default=42, 248 | help="random seed for initialization") 249 | parser.add_argument('--gradient_accumulation_steps', 250 | type=int, 251 | default=1, 252 | help="Number of updates steps to accumulate before performing a backward/update pass.") 253 | parser.add_argument('--fp16', 254 | action='store_true', 255 | help="Whether to use 16-bit float precision instead of 32-bit") 256 | parser.add_argument("--fp16_opt_level", 257 | type=str, 258 | default="O1", 259 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 260 | "See details at https://nvidia.github.io/apex/amp.html") 261 | parser.add_argument('--loss_scale', 262 | type=float, default=0, 263 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 264 | "0 (default value): dynamic loss scaling.\n" 265 | "Positive power of 2: static loss scaling value.\n") 266 | parser.add_argument('--sample_weight', type=float, default=1) 267 | parser.add_argument("--save_all", action="store_true") 268 | args = parser.parse_args() 269 | 270 | processors = {"maskgen": MaskGenProcessor} 271 | 272 | if args.local_rank == -1 or args.no_cuda: 273 | device = torch.device( 274 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 275 | n_gpu = torch.cuda.device_count() 276 | else: 277 | torch.cuda.set_device(args.local_rank) 278 | device = torch.device("cuda", args.local_rank) 279 | n_gpu = 1 280 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 281 | torch.distributed.init_process_group(backend='nccl') 282 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 283 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 284 | 285 | if args.gradient_accumulation_steps < 1: 286 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 287 | args.gradient_accumulation_steps)) 288 | 289 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 290 | 291 | random.seed(args.seed) 292 | np.random.seed(args.seed) 293 | torch.manual_seed(args.seed) 294 | if not args.do_train and not args.do_eval: 295 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 296 | 297 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 298 | # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 299 | logger.warning("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 300 | 301 | if not os.path.exists(args.output_dir): 302 | os.makedirs(args.output_dir) 303 | 304 | task_name = args.task_name.lower() 305 | 306 | if task_name not in processors: 307 | raise ValueError("Task not found: %s" % (task_name)) 308 | 309 | processor = processors[task_name]() 310 | 311 | label_list = processor.get_labels() 312 | num_labels = len(label_list) 313 | 314 | if args.local_rank not in [-1, 0]: 315 | # Make sure only the first process in distributed training will download model & vocab 316 | torch.distributed.barrier() 317 | 318 | if args.vocab_file: 319 | tokenizer = BertTokenizer(args.vocab_file, args.do_lower_case) 320 | else: 321 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 322 | 323 | # Prepare model 324 | model = BertForTokenClassification.from_pretrained(args.bert_model, num_labels=num_labels) 325 | 326 | if args.ckpt: 327 | print("load from", args.ckpt) 328 | model_dict = model.state_dict() 329 | ckpt = torch.load(args.ckpt) 330 | pretrained_dict = ckpt['model'] 331 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 332 | model_dict.update(new_dict) 333 | print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict))) 334 | model.load_state_dict(model_dict) 335 | 336 | if args.local_rank == 0: 337 | torch.distributed.barrier() 338 | 339 | model.to(device) 340 | if args.local_rank != -1: 341 | try: 342 | from apex.parallel import DistributedDataParallel as DDP 343 | except ImportError: 344 | raise ImportError( 345 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 346 | 347 | model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 348 | elif n_gpu > 1: 349 | model = torch.nn.DataParallel(model) 350 | 351 | global_step = 0 352 | nb_tr_steps = 0 353 | tr_loss = 0 354 | 355 | train_examples = None 356 | num_train_optimization_steps = None 357 | if args.do_train: 358 | train_examples = processor.get_train_examples(args.data_dir) 359 | 360 | if args.fp16: 361 | sample_weight = torch.HalfTensor([1.0, args.sample_weight]).cuda() 362 | else: 363 | sample_weight = torch.FloatTensor([1.0, args.sample_weight]).cuda() 364 | 365 | cached_train_features_file = os.path.join(args.data_dir, 'train_{}_{}_{}'.format(list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(task_name))) 366 | try: 367 | with open(cached_train_features_file, "rb") as reader: 368 | logger.info("Load from cache dir: {}".format(cached_train_features_file)) 369 | train_features = pickle.load(reader) 370 | except: 371 | train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer) 372 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 373 | logger.info("Saving train features into cached file {}".format(cached_train_features_file)) 374 | with open(cached_train_features_file, "wb") as writer: 375 | pickle.dump(train_features, writer) 376 | 377 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 378 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 379 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 380 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 381 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 382 | 383 | if args.local_rank == -1: 384 | train_sampler = RandomSampler(train_data) 385 | else: 386 | train_sampler = DistributedSampler(train_data) 387 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 388 | 389 | num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 390 | 391 | param_optimizer = list(model.named_parameters()) 392 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 393 | optimizer_grouped_parameters = [ 394 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 395 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 396 | ] 397 | optimizer = BertAdam(optimizer_grouped_parameters, 398 | lr=args.learning_rate, 399 | warmup=args.warmup_proportion, 400 | t_total=num_train_optimization_steps) 401 | if args.fp16: 402 | try: 403 | from apex import amp 404 | except ImportError: 405 | raise ImportError( 406 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 407 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 408 | 409 | label_map = {i: label for i, label in enumerate(label_list, 1)} 410 | 411 | logger.info("***** Running training *****") 412 | logger.info(" Num examples = %d", len(train_examples)) 413 | logger.info(" Batch size = %d", args.train_batch_size) 414 | logger.info(" Num steps = %d", num_train_optimization_steps) 415 | 416 | os.makedirs(os.path.join(args.output_dir, "all_models"), exist_ok=True) 417 | model.train() 418 | for e in trange(int(args.num_train_epochs), desc="Epoch"): 419 | tr_loss = 0 420 | nb_tr_steps = 0 421 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 422 | batch = tuple(t.to(device) for t in batch) 423 | input_ids, input_mask, segment_ids, label_ids = batch 424 | loss = model(input_ids, segment_ids, input_mask, label_ids, weight=sample_weight) 425 | if n_gpu > 1: 426 | loss = loss.mean() # mean() to average on multi-gpu. 427 | if args.gradient_accumulation_steps > 1: 428 | loss = loss / args.gradient_accumulation_steps 429 | 430 | if args.fp16: 431 | with amp.scale_loss(loss, optimizer) as scaled_loss: 432 | scaled_loss.backward() 433 | else: 434 | loss.backward() 435 | 436 | tr_loss += loss.item() 437 | nb_tr_steps += 1 438 | if (step + 1) % args.gradient_accumulation_steps == 0: 439 | optimizer.step() 440 | optimizer.zero_grad() 441 | global_step += 1 442 | # save each epoch 443 | model_to_save = model.module if hasattr(model, 'module') else model 444 | output_model_file = os.path.join(args.output_dir, "all_models", "e{}_{}".format(e, WEIGHTS_NAME)) 445 | torch.save(model_to_save.state_dict(), output_model_file) 446 | 447 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 448 | output_args_file = os.path.join(args.output_dir, 'training_args.bin') 449 | torch.save(args, output_args_file) 450 | else: 451 | model = BertForTokenClassification.from_pretrained(args.bert_model, num_labels=num_labels) 452 | 453 | ### Evaluation 454 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 455 | best_f1 = 0 456 | best_epoch = 0 457 | val_res_file = os.path.join(args.output_dir, "valid_results.txt") 458 | val_f = open(val_res_file, "w") 459 | logger.info("***** Dev Eval results *****") 460 | for e in range(int(args.num_train_epochs)): 461 | weight_path = os.path.join(args.output_dir, "all_models", "e{}_{}".format(e, WEIGHTS_NAME)) 462 | model.load_state_dict(torch.load(weight_path)) 463 | model.to(device) 464 | eval_examples = processor.get_dev_examples(args.data_dir) 465 | 466 | cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format( 467 | list(filter(None, args.bert_model.split('/'))).pop(), 468 | str(args.max_seq_length), 469 | str(task_name))) 470 | try: 471 | with open(cached_eval_features_file, "rb") as reader: 472 | eval_features = pickle.load(reader) 473 | except: 474 | eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer) 475 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 476 | logger.info(" Saving eval features into cached file %s", cached_eval_features_file) 477 | with open(cached_eval_features_file, "wb") as writer: 478 | pickle.dump(eval_features, writer) 479 | 480 | logger.info("***** Running evaluation *****") 481 | logger.info(" Num examples = %d", len(eval_examples)) 482 | logger.info(" Batch size = %d", args.eval_batch_size) 483 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 484 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 485 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 486 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 487 | 488 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 489 | # Run prediction for full data 490 | if args.local_rank == -1: 491 | eval_sampler = SequentialSampler(eval_data) 492 | else: 493 | eval_sampler = DistributedSampler(eval_data) # Note that this sampler samples randomly 494 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 495 | 496 | model.eval() 497 | y_true_L = [] 498 | y_pred_L = [] 499 | 500 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 501 | input_ids = input_ids.to(device) 502 | input_mask = input_mask.to(device) 503 | segment_ids = segment_ids.to(device) 504 | label_ids = label_ids.to(device) 505 | 506 | with torch.no_grad(): 507 | logits = model(input_ids, segment_ids, input_mask) 508 | 509 | logits = torch.argmax(F.log_softmax(logits, dim=2), dim=2) 510 | logits = logits.detach().cpu().numpy() 511 | label_ids = label_ids.to('cpu').numpy() 512 | input_mask = input_mask.to('cpu').numpy() 513 | 514 | y_true = [[str(x) for x in L] for L in label_ids] 515 | y_pred = [[str(x) for x in L] for L in logits] 516 | 517 | for (m, t, p) in zip(input_mask, y_true, y_pred): 518 | for mm, tt, pp in zip(m, t, p): 519 | if mm == 1: 520 | y_true_L.append(int(tt)) 521 | y_pred_L.append(int(pp)) 522 | 523 | acc = accuracy_score(y_true_L, y_pred_L) 524 | f1 = f1_score(y_true_L, y_pred_L) 525 | recall = recall_score(y_true_L, y_pred_L) 526 | prec = precision_score(y_true_L, y_pred_L) 527 | 528 | if f1 > best_f1: 529 | best_f1 = f1 530 | best_epoch = e 531 | 532 | result = { 533 | "acc": acc, 534 | "f1": f1, 535 | "recall": recall, 536 | "prec": prec 537 | } 538 | 539 | logger.info("Epoch {}".format(e)) 540 | val_f.write("Epoch {}\n".format(e)) 541 | for key in sorted(result.keys()): 542 | logger.info("{} = {}".format(key, str(result[key]))) 543 | val_f.write("{} = {}\n".format(key, str(result[key]))) 544 | val_f.write("\n") 545 | 546 | logger.info("\nBest epoch: {}. Best val f1: {}".format(best_epoch, best_f1)) 547 | val_f.write("Best epoch: {}. Best val f1: {}\n".format(best_epoch, best_f1)) 548 | val_f.close() 549 | 550 | best_weight_path = os.path.join(args.output_dir, "all_models", "e{}_{}".format(best_epoch, WEIGHTS_NAME)) 551 | best_model_dir = os.path.join(args.output_dir, "best_model") 552 | os.makedirs(best_model_dir, exist_ok=True) 553 | os.system("cp {} {}/{}".format(best_weight_path, best_model_dir, WEIGHTS_NAME)) 554 | with open(os.path.join(best_model_dir, CONFIG_NAME), 'w') as f: 555 | f.write(model_to_save.config.to_json_string()) 556 | tokenizer.save_vocab(os.path.join(best_model_dir, VOCAB_NAME)) 557 | 558 | if not args.save_all: 559 | os.system("rm -r {}".format(os.path.join(args.output_dir, "all_models"))) 560 | 561 | if __name__ == "__main__": 562 | main() 563 | --------------------------------------------------------------------------------