├── requirements.txt ├── tests ├── bitwise_and.py ├── gather.py ├── model_fn.py ├── construct_label.py ├── tpu_operation.py ├── tile_repeat.py └── cumsum.py ├── bert ├── __init__.py ├── optimization.py └── tokenization.py ├── scripts ├── data │ ├── generate_tfrecord_dataset.sh │ ├── transform_ckpt_pytorch_to_tf.sh │ ├── download_pretrained_mlm.sh │ └── preprocess_ontonotes_annfiles.sh └── models │ ├── quoref_tpu.sh │ ├── squad_tpu.sh │ ├── mention_gpu.sh │ ├── corefqa_gpu.sh │ ├── mention_tpu.sh │ └── corefqa_tpu.sh ├── data_utils ├── config_utils.py └── conll.py ├── conll-2012 └── scorer │ └── v8.01 │ ├── scorer.pl │ ├── scorer.bat │ └── README.txt ├── .gitignore ├── func_builders ├── input_fn_builder.py └── model_fn_builder.py ├── utils ├── load_pytorch_to_tf.py ├── util.py ├── metrics.py └── radam.py ├── run ├── transform_spanbert_pytorch_to_tf.py ├── run_mention_proposal.py ├── run_corefqa.py └── build_dataset_to_tfrecord.py ├── README.md ├── models └── mention_proposal.py └── logs └── corefqa_log.txt /requirements.txt: -------------------------------------------------------------------------------- 1 | pyhocon 2 | tensorboard==1.15.0 3 | tensorflow-estimator==1.15.1 4 | tensorflow-gpu==1.15.0 5 | pyyaml==5.2 6 | scikit-learn==0.19.1 7 | scipy 8 | torch==1.2.0 9 | -------------------------------------------------------------------------------- /tests/bitwise_and.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # 9 | 10 | 11 | import tensorflow as tf 12 | 13 | 14 | if __name__ == "__main__": 15 | sess = tf.compat.v1.InteractiveSession() 16 | lhs = tf.constant([0, 5, 3, 14], dtype=tf.int32) 17 | rhs = tf.constant([5, 0, 7, 11], dtype=tf.int32) 18 | 19 | res = tf.bitwise.bitwise_and(lhs, rhs) 20 | res.eval() 21 | # array([ 0, 0, 3, 10], dtype=int32) 22 | sess.close() 23 | 24 | 25 | -------------------------------------------------------------------------------- /tests/gather.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | 7 | # author: xiaoy li 8 | 9 | 10 | import tensorflow as tf 11 | 12 | 13 | 14 | if __name__ == "__main__": 15 | sess = tf.compat.v1.InteractiveSession() 16 | lhs = tf.zeros((4, 3)) 17 | 18 | slice_lhs = tf.gather(lhs, 1) 19 | # slice_lhs_nd = tf.gather_nd(lhs, 1) 20 | 21 | slice_lhs = tf.gather(lhs, [1, 2]) 22 | 23 | slice_lhs.eval() 24 | # slice_lhs_nd.eval() 25 | # array([ 0, 0, 3, 10], dtype=int32) 26 | sess.close() -------------------------------------------------------------------------------- /tests/model_fn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # test config in model fn builder 9 | 10 | 11 | 12 | 13 | def model_fn(config): 14 | 15 | def mention_proposal_fn(): 16 | print("the number of document is : ") 17 | print(config.document_number) 18 | print(config.number_window_size) 19 | 20 | return mention_proposal_fn 21 | 22 | 23 | class Config: 24 | number_window_size = 2 25 | document_number = 3 26 | 27 | 28 | 29 | 30 | if __name__ == "__main__": 31 | config = Config() 32 | get_model_fn = model_fn(config) 33 | get_model_fn() -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /scripts/data/generate_tfrecord_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # generate train/dev/test tfrecord files for training the model. 9 | # example: 10 | # bash generate_tfrecord_dataset.sh /path-to-conll-coreference-resolution-dataset /path-to-save-tfrecord-for-training /cased_L-12_H-768_A-12/vocab.txt 11 | 12 | 13 | 14 | REPO_PATH=/home/lixiaoya/coref-tf 15 | export PYTHONPATH=$REPO_PATH 16 | 17 | source_dir=$1 18 | target_dir=$2 19 | vocab_file=$3 20 | 21 | mkdir -p ${target_dir} 22 | 23 | 24 | python3 ${REPO_PATH}/run/build_dataset_to_tfrecord.py \ 25 | --source_files_dir $source_dir \ 26 | --target_output_dir $target_dir \ 27 | --num_window 2 \ 28 | --window_size 64 \ 29 | --max_num_mention 50 \ 30 | --max_num_cluster 40 \ 31 | --vocab_file $vocab_file \ 32 | --language english \ 33 | --max_doc_length 600 -------------------------------------------------------------------------------- /data_utils/config_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # config utils for the mention proposal and corefqa 9 | 10 | 11 | 12 | import os 13 | import json 14 | import tensorflow as tf 15 | 16 | 17 | class ModelConfig(object): 18 | def __init__(self, tf_flags, output_dir, model_sign="model"): 19 | key_value_pairs = tf_flags.flag_values_dict() 20 | 21 | for item_key, item_value in key_value_pairs.items(): 22 | self.__dict__[item_key] = item_value 23 | 24 | self.output_dir = output_dir 25 | config_path = os.path.join(self.output_dir, "{}_config.json".format(model_sign)) 26 | 27 | def logging_configs(self): 28 | tf.logging.info("$*$"*30) 29 | tf.logging.info("****** print model configs : ******") 30 | tf.logging.info("$*$"*30) 31 | 32 | for item_key, item_value in self.__dict__.items(): 33 | tf.logging.info("{} : {}".format(str(item_key), str(item_value))) 34 | -------------------------------------------------------------------------------- /tests/construct_label.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # desc: 7 | # construct labels 8 | 9 | 10 | import tensorflow as tf 11 | 12 | 13 | 14 | if __name__ == "__main__": 15 | sess = tf.compat.v1.InteractiveSession() 16 | gold_starts = tf.constant([1, 2, 3, 4]) 17 | gold_ends = tf.constant([2, 3, 4, 5]) 18 | num_word = 10 19 | gold_mention_sparse_label = tf.stack([gold_starts, gold_ends], axis=1) 20 | gold_mention_sparse_label.eval() 21 | gold_span_value = tf.reshape(tf.ones_like(gold_starts, tf.int32), [-1]) 22 | gold_span_shape = tf.constant([num_word, num_word]) 23 | gold_span_label = tf.cast(tf.scatter_nd(gold_mention_sparse_label, gold_span_value, gold_span_shape), tf.int32) 24 | gold_span_label.eval() 25 | 26 | candidate_start = tf.constant([1, 4, 5]) 27 | candidate_end = tf.constant([2, 5, 5]) 28 | candidate_span = tf.stack([candidate_start, candidate_end], axis=1) 29 | 30 | gold_span_label = tf.gather_nd(gold_span_label, tf.expand_dims(candidate_span, 1)) 31 | gold_span_label.eval() 32 | -------------------------------------------------------------------------------- /scripts/models/quoref_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # finetune the spanbert model on squad 2.0 for data augment. 9 | 10 | 11 | 12 | REPO_PATH=/home/shannon/coref-tf 13 | export TPU_NAME=tf-tpu 14 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 15 | QUOREF_DIR=gs://qa_tasks/quoref 16 | BERT_DIR=gs://corefqa_output_squad/panbert_large_squad2_2e-5 17 | OUTPUT_DIR=gs://corefqa_output_quoref/spanbert_large_squad2_best_quoref_3e-5 18 | 19 | 20 | python3 ${REPO_PATH}/run_quoref.py \ 21 | --vocab_file=$BERT_DIR/vocab.txt \ 22 | --bert_config_file=$BERT_DIR/bert_config.json \ 23 | --init_checkpoint=$BERT_DIR/best_bert_model.ckpt \ 24 | --do_train=True \ 25 | --train_file=$QUOREF_DIR/quoref-train-v0.1.json \ 26 | --do_predict=True \ 27 | --predict_file=$QUOREF_DIR/quoref-dev-v0.1.json \ 28 | --train_batch_size=8 \ 29 | --learning_rate=3e-5 \ 30 | --num_train_epochs=5 \ 31 | --max_seq_length=384 \ 32 | --do_lower_case=False \ 33 | --doc_stride=128 \ 34 | --output_dir=${OUTPUT_DIR} \ 35 | --use_tpu=True \ 36 | --tpu_name=$TPU_NAME -------------------------------------------------------------------------------- /scripts/models/squad_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # finetune the spanbert model on squad 2.0 for data augment. 9 | 10 | 11 | 12 | REPO_PATH=/home/shannon/coref-tf 13 | export TPU_NAME=tf-tpu 14 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 15 | SQUAD_DIR=gs://qa_tasks/squad2 16 | BERT_DIR=gs://pretrained_mlm_checkpoint/spanbert_large_tf 17 | OUTPUT_DIR=gs://corefqa_output_squad/spanbert_large_squad2_2e-5 18 | 19 | 20 | python3 ${REPO_PATH}/run/run_squad.py \ 21 | --vocab_file=$BERT_DIR/vocab.txt \ 22 | --bert_config_file=$BERT_DIR/bert_config.json \ 23 | --init_checkpoint=$BERT_DIR/bert_model.ckpt \ 24 | --do_train=True \ 25 | --train_file=$SQUAD_DIR/train-v2.0.json \ 26 | --do_predict=True \ 27 | --predict_file=$SQUAD_DIR/dev-v2.0.json \ 28 | --train_batch_size=8 \ 29 | --learning_rate=2e-5 \ 30 | --num_train_epochs=4.0 \ 31 | --max_seq_length=384 \ 32 | --do_lower_case=False \ 33 | --doc_stride=128 \ 34 | --output_dir=${OUTPUT_DIR} \ 35 | --use_tpu=True \ 36 | --tpu_name=$TPU_NAME \ 37 | --version_2_with_negative=True -------------------------------------------------------------------------------- /scripts/models/mention_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # clean code and add comments 9 | 10 | 11 | 12 | REPO_PATH=/home/lixiaoya/xiaoy_tf 13 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 14 | 15 | output_dir=/xiaoya/mention_proposal_output 16 | bert_dir=/xiaoya/pretrain_ckpt/uncased_L-2_H-128_A-2 17 | data_dir=/xiaoya/corefqa_data/final_overlap_64_2 18 | 19 | 20 | rm -rf ${output_dir} 21 | mkdir -p ${output_dir} 22 | 23 | 24 | 25 | CUDA_VISIBLE_DEVICES=3 python3 ${REPO_PATH}/run/run_mention_proposal.py \ 26 | --output_dir=${output_dir} \ 27 | --bert_config_file=${bert_dir}/bert_config_nodropout.json \ 28 | --init_checkpoint=${bert_dir}/bert_model.ckpt \ 29 | --vocab_file=${bert_dir}/vocab.txt \ 30 | --logfile_path=${output_dir}/train.log \ 31 | --num_epochs=20 \ 32 | --keep_checkpoint_max=50 \ 33 | --save_checkpoints_steps=500 \ 34 | --train_file=${data_dir}/train.64.english.tfrecord \ 35 | --dev_file=${data_dir}/dev.64.english.tfrecord \ 36 | --test_file=${data_dir}/test.64.english.tfrecord \ 37 | --do_train=True \ 38 | --do_eval=False \ 39 | --do_predict=False \ 40 | --learning_rate=1e-5 \ 41 | --dropout_rate=0.0 \ 42 | --mention_threshold=0.5 \ 43 | --hidden_size=128 \ 44 | --num_docs=5604 \ 45 | --window_size=64 \ 46 | --num_window=2 \ 47 | --max_num_mention=20 \ 48 | --start_end_share=False \ 49 | --loss_start_ratio=0.3 \ 50 | --loss_end_ratio=0.3 \ 51 | --loss_span_ratio=0.3 \ 52 | --use_tpu=False \ 53 | --seed=2333 54 | -------------------------------------------------------------------------------- /conll-2012/scorer/v8.01/scorer.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | 3 | BEGIN { 4 | $d = $0; 5 | $d =~ s/\/[^\/][^\/]*$//g; 6 | 7 | if ($d eq $0) { 8 | unshift(@INC, "lib"); 9 | } 10 | else { 11 | unshift(@INC, $d . "/lib"); 12 | } 13 | } 14 | 15 | use strict; 16 | use CorScorer; 17 | 18 | if (@ARGV < 3) { 19 | print q| 20 | use: scorer.pl [name] 21 | 22 | metric: the metric desired to score the results: 23 | muc: MUCScorer (Vilain et al, 1995) 24 | bcub: B-Cubed (Bagga and Baldwin, 1998) 25 | ceafm: CEAF (Luo et al, 2005) using mention-based similarity 26 | ceafe: CEAF (Luo et al, 2005) using entity-based similarity 27 | blanc: BLANC 28 | all: uses all the metrics to score 29 | 30 | keys_file: file with expected coreference chains in SemEval format 31 | 32 | response_file: file with output of coreference system (SemEval format) 33 | 34 | name: [optional] the name of the document to score. If name is not 35 | given, all the documents in the dataset will be scored. If given 36 | name is "none" then all the documents are scored but only total 37 | results are shown. 38 | 39 | |; 40 | exit; 41 | } 42 | 43 | my $metric = shift(@ARGV); 44 | if ($metric !~ /^(muc|bcub|ceafm|ceafe|blanc|all)/i) { 45 | print "Invalid metric\n"; 46 | exit; 47 | } 48 | 49 | if ($metric eq 'all') { 50 | foreach my $m ('muc', 'bcub', 'ceafm', 'ceafe', 'blanc') { 51 | print "\nMETRIC $m:\n"; 52 | &CorScorer::Score($m, @ARGV); 53 | } 54 | } 55 | else { 56 | &CorScorer::Score($metric, @ARGV); 57 | } 58 | 59 | -------------------------------------------------------------------------------- /scripts/data/transform_ckpt_pytorch_to_tf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # transform trained spanbert language model from pytorch(.bin) to tensorflow(.ckpt). 9 | # PLEASE NOTICE: the same scale(Base/Large) BERT(TF) Models are also necessary. 10 | 11 | 12 | 13 | REPO_PATH=/home/lixiaoya/coref-tf 14 | export PYTHONPATH=${REPO_PATH} 15 | 16 | 17 | MODEL_NAME=$1 18 | PATH_TO_SPANBERT_PYTORCH_DIR=$2 19 | PATH_TO_SAME_SCALE_BERT_TF_DIR=$3 20 | PATH_TO_SAVE_SPANBERT_TF_DIR=$4 21 | 22 | 23 | if [[ $MODEL_NAME == "spanbert_base" ]]; then 24 | # spanbert large 25 | echo "Transform SpanBERT Cased Base from Pytorch To TF" 26 | python3 ${REPO_PATH}/run/transform_spanbert_pytorch_to_tf.py \ 27 | --spanbert_config_path $PATH_TO_SPANBERT_PYTORCH_DIR/config.json \ 28 | --bert_tf_ckpt_path $PATH_TO_SAME_SCALE_BERT_TF_DIR/bert_model.ckpt \ 29 | --spanbert_pytorch_bin_path $PATH_TO_SPANBERT_PYTORCH_DIR/pytorch_model.bin \ 30 | --output_spanbert_tf_dir $PATH_TO_SAVE_SPANBERT_TF_DIR 31 | elif [[ $MODEL_NAME == "spanbert_large" ]]; then 32 | # spanbert base 33 | echo "Transform SpanBERT Cased Large from Pytorch To TF" 34 | python3 ${REPO_PATH}/run/transform_spanbert_pytorch_to_tf.py \ 35 | --spanbert_config_path $PATH_TO_SPANBERT_PYTORCH_DIR/config.json \ 36 | --bert_tf_ckpt_path $PATH_TO_SAME_SCALE_BERT_TF_DIR/bert_model.ckpt \ 37 | --spanbert_pytorch_bin_path $PATH_TO_SPANBERT_PYTORCH_DIR/pytorch_model.bin \ 38 | --output_spanbert_tf_dir $PATH_TO_SAVE_SPANBERT_TF_DIR 39 | else 40 | echo 'Unknown argment 1 (Model Sign)' 41 | fi -------------------------------------------------------------------------------- /scripts/models/corefqa_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # train and evaluate the middle checkpoints on dev and test sets. 9 | 10 | 11 | 12 | REPO_PATH=/home/lixiaoya/xiaoy_tf 13 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 14 | 15 | output_dir=/xiaoya/mention_proposal_output 16 | bert_dir=/xiaoya/pretrain_ckpt/uncased_L-2_H-128_A-2 17 | data_dir=/xiaoya/corefqa_data/final_overlap_64_2 18 | 19 | 20 | rm -rf ${output_dir} 21 | mkdir -p ${output_dir} 22 | 23 | 24 | 25 | CUDA_VISIBLE_DEVICES=3 python3 ${REPO_PATH}/run/run_corefqa.py \ 26 | --output_dir=${output_dir} \ 27 | --bert_config_file=${bert_dir}/bert_config_nodropout.json \ 28 | --init_checkpoint=${bert_dir}/bert_model.ckpt \ 29 | --vocab_file=${bert_dir}/vocab.txt \ 30 | --logfile_path=${output_dir}/train.log \ 31 | --num_epochs=20 \ 32 | --keep_checkpoint_max=50 \ 33 | --save_checkpoints_steps=500 \ 34 | --train_file=${data_dir}/train.64.english.tfrecord \ 35 | --dev_file=${data_dir}/dev.64.english.tfrecord \ 36 | --test_file=${data_dir}/test.64.english.tfrecord \ 37 | --do_train=True \ 38 | --do_eval=False \ 39 | --do_predict=False \ 40 | --learning_rate=1e-5 \ 41 | --dropout_rate=0.0 \ 42 | --mention_threshold=0.5 \ 43 | --hidden_size=128 \ 44 | --num_docs=5604 \ 45 | --window_size=64 \ 46 | --num_window=2 \ 47 | --max_num_mention=20 \ 48 | --start_end_share=False \ 49 | --max_span_width=20 \ 50 | --max_candidate_mentions=50 \ 51 | --top_span_ratio=0.2 \ 52 | --max_top_antecedents=30 \ 53 | --max_query_len=150 \ 54 | --max_context_len=150 \ 55 | --sec_qa_mention_score=False \ 56 | --use_tpu=False \ 57 | --seed=2333 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /scripts/models/mention_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # clean code and add comments 9 | 10 | 11 | 12 | REPO_PATH=/home/xiaoyli1110/xiaoya/Coref-tf 13 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 14 | export TPU_NAME=tensorflow-tpu 15 | export TPU_ZONE=europe-west4-a 16 | export GCP_PROJECT=xiaoyli-20-04-274510 17 | 18 | output_dir=gs://europe_mention_proposal/output_bertlarge 19 | bert_dir=gs://europe_pretrain_mlm/uncased_L-2_H-128_A-2 20 | data_dir=gs://europe_corefqa_data/final_overlap_64_2 21 | 22 | 23 | 24 | python3 ${REPO_PATH}/run/run_mention_proposal.py \ 25 | --output_dir=${output_dir} \ 26 | --bert_config_file=${bert_dir}/bert_config_nodropout.json \ 27 | --init_checkpoint=${bert_dir}/bert_model.ckpt \ 28 | --vocab_file=${bert_dir}/vocab.txt \ 29 | --logfile_path=${output_dir}/train.log \ 30 | --num_epochs=20 \ 31 | --keep_checkpoint_max=50 \ 32 | --save_checkpoints_steps=500 \ 33 | --train_file=${data_dir}/train.64.english.tfrecord \ 34 | --dev_file=${data_dir}/dev.64.english.tfrecord \ 35 | --test_file=${data_dir}/test.64.english.tfrecord \ 36 | --do_train=True \ 37 | --do_eval=False \ 38 | --do_predict=False \ 39 | --learning_rate=1e-5 \ 40 | --dropout_rate=0.0 \ 41 | --mention_threshold=0.5 \ 42 | --hidden_size=128 \ 43 | --num_docs=5604 \ 44 | --window_size=64 \ 45 | --num_window=2 \ 46 | --max_num_mention=20 \ 47 | --start_end_share=False \ 48 | --loss_start_ratio=0.3 \ 49 | --loss_end_ratio=0.3 \ 50 | --loss_span_ratio=0.3 \ 51 | --use_tpu=True \ 52 | --tpu_name=$TPU_NAME \ 53 | --tpu_zone=$TPU_ZONE \ 54 | --gcp_project=$GCP_PROJECT \ 55 | --num_tpu_cores=1 \ 56 | --seed=2333 57 | -------------------------------------------------------------------------------- /tests/tpu_operation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # descripiton: 8 | # test math operations in tpu 9 | 10 | 11 | import tensorflow as tf 12 | from tensorflow.contrib import tpu 13 | from tensorflow.contrib.cluster_resolver import TPUClusterResolver 14 | 15 | 16 | 17 | TPU_NAME = "tensorflow-tpu" 18 | TPU_ZONE = "us-central1-f" 19 | GCP_PROJECT = "xiaoyli-20-04-274510" 20 | 21 | 22 | 23 | if __name__ == "__main__": 24 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(TPU_NAME, zone=TPU_ZONE, project=GCP_PROJECT) 25 | # tpu_cluster_resolver = TPUClusterResolver(tpu=['tensorflow-tpu']).get_master() 26 | tf.config.experimental_connect_to_cluster(tpu_cluster_resolver) 27 | tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver) 28 | 29 | scores = tf.constant([1.0, 2.3, 3.2, 4.3, 1.5, 1.8, 98, 2.9]) 30 | k = 2 31 | 32 | def test_top_k(): 33 | top_scores, top_index = tf.nn.top_k(scores, k) 34 | return top_scores, top_index 35 | 36 | test_op = test_top_k 37 | 38 | # with tf.compat.v1.InteractiveSession(tpu_cluster_resolver) as sess: 39 | with tf.compat.v1.Session(tpu_cluster_resolver) as sess: 40 | sess.run(tpu.initialize_system()) 41 | 42 | scores = tf.constant([1.0, 2.3, 3.2, 4.3, 1.5, 1.8, 98, 2.9]) 43 | k = 2 44 | print("ALL Devices: ", tf.config.experimental_list_devices()) 45 | 46 | top_scores, top_index = tf.nn.top_k(scores, k) 47 | 48 | print(top_scores.eval()) 49 | print(top_index.eval()) 50 | 51 | sess.run(tpu.shutdown_system()) 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /scripts/models/corefqa_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # clean code and add comments 9 | 10 | 11 | 12 | REPO_PATH=/home/xiaoyli1110/xiaoya/Coref-tf 13 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 14 | export TPU_NAME=tensorflow-tpu 15 | export TPU_ZONE=europe-west4-a 16 | export GCP_PROJECT=xiaoyli-20-04-274510 17 | 18 | output_dir=gs://europe_mention_proposal/output_bertlarge 19 | bert_dir=gs://europe_pretrain_mlm/uncased_L-2_H-128_A-2 20 | data_dir=gs://europe_corefqa_data/final_overlap_64_2 21 | 22 | 23 | 24 | python3 ${REPO_PATH}/run/run_corefqa.py \ 25 | --output_dir=${output_dir} \ 26 | --bert_config_file=${bert_dir}/bert_config_nodropout.json \ 27 | --init_checkpoint=${bert_dir}/bert_model.ckpt \ 28 | --vocab_file=${bert_dir}/vocab.txt \ 29 | --logfile_path=${output_dir}/train.log \ 30 | --num_epochs=20 \ 31 | --keep_checkpoint_max=50 \ 32 | --save_checkpoints_steps=500 \ 33 | --train_file=${data_dir}/train.64.english.tfrecord \ 34 | --dev_file=${data_dir}/dev.64.english.tfrecord \ 35 | --test_file=${data_dir}/test.64.english.tfrecord \ 36 | --do_train=True \ 37 | --do_eval=False \ 38 | --do_predict=False \ 39 | --learning_rate=1e-5 \ 40 | --dropout_rate=0.0 \ 41 | --mention_threshold=0.5 \ 42 | --hidden_size=128 \ 43 | --num_docs=5604 \ 44 | --window_size=64 \ 45 | --num_window=2 \ 46 | --max_num_mention=20 \ 47 | --start_end_share=False \ 48 | --max_span_width=20 \ 49 | --max_candidate_mentions=50 \ 50 | --top_span_ratio=0.2 \ 51 | --max_top_antecedents=30 \ 52 | --max_query_len=150 \ 53 | --max_context_len=150 \ 54 | --sec_qa_mention_score=False \ 55 | --use_tpu=True \ 56 | --tpu_name=$TPU_NAME \ 57 | --tpu_zone=$TPU_ZONE \ 58 | --gcp_project=$GCP_PROJECT \ 59 | --num_tpu_cores=1 \ 60 | --seed=2333 61 | -------------------------------------------------------------------------------- /conll-2012/scorer/v8.01/scorer.bat: -------------------------------------------------------------------------------- 1 | @rem = '--*-Perl-*-- 2 | @echo off 3 | if "%OS%" == "Windows_NT" goto WinNT 4 | perl -x -S "%0" %1 %2 %3 %4 %5 %6 %7 %8 %9 5 | goto endofperl 6 | :WinNT 7 | perl -x -S %0 %* 8 | if NOT "%COMSPEC%" == "%SystemRoot%\system32\cmd.exe" goto endofperl 9 | if %errorlevel% == 9009 echo You do not have Perl in your PATH. 10 | if errorlevel 1 goto script_failed_so_exit_with_non_zero_val 2>nul 11 | goto endofperl 12 | @rem '; 13 | #!perl 14 | #line 15 15 | 16 | BEGIN { 17 | $d = $0; 18 | $d =~ s/\/[^\/][^\/]*$//g; 19 | push(@INC, $d."/lib"); 20 | } 21 | 22 | use strict; 23 | use CorScorer; 24 | 25 | if (@ARGV < 3) { 26 | print q| 27 | use: scorer.bat [name] 28 | 29 | metric: the metric desired to score the results: 30 | muc: MUCScorer (Vilain et al, 1995) 31 | bcub: B-Cubed (Bagga and Baldwin, 1998) 32 | ceafm: CEAF (Luo et al, 2005) using mention-based similarity 33 | ceafe: CEAF (Luo et al, 2005) using entity-based similarity 34 | all: uses all the metrics to score 35 | 36 | keys_file: file with expected coreference chains in SemEval format 37 | 38 | response_file: file with output of coreference system (SemEval format) 39 | 40 | name: [optional] the name of the document to score. If name is not 41 | given, all the documents in the dataset will be scored. If given 42 | name is "none" then all the documents are scored but only total 43 | results are shown. 44 | 45 | |; 46 | exit; 47 | } 48 | 49 | my $metric = shift (@ARGV); 50 | if ($metric !~ /^(muc|bcub|ceafm|ceafe|all)/i) { 51 | print "Invalid metric\n"; 52 | exit; 53 | } 54 | 55 | 56 | if ($metric eq 'all') { 57 | foreach my $m ('muc', 'bcub', 'ceafm', 'ceafe') { 58 | print "\nMETRIC $m:\n"; 59 | &CorScorer::Score( $m, @ARGV ); 60 | } 61 | } 62 | else { 63 | &CorScorer::Score( $metric, @ARGV ); 64 | } 65 | 66 | __END__ 67 | :endofperl 68 | -------------------------------------------------------------------------------- /scripts/data/download_pretrained_mlm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # Author: xiaoy li 7 | # description: 8 | # download pretrained model ckpt 9 | 10 | 11 | 12 | BERT_PRETRAIN_CKPT=$1 13 | MODEL_NAME=$2 14 | 15 | 16 | if [[ $MODEL_NAME == "bert_base" ]]; then 17 | mkdir -p $BERT_PRETRAIN_CKPT 18 | echo "DownLoad BERT Cased Base" 19 | wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip -P $BERT_PRETRAIN_CKPT 20 | unzip $BERT_PRETRAIN_CKPT/cased_L-12_H-768_A-12.zip -d $BERT_PRETRAIN_CKPT 21 | rm $BERT_PRETRAIN_CKPT/cased_L-12_H-768_A-12.zip 22 | elif [[ $MODEL_NAME == "bert_large" ]]; then 23 | echo "DownLoad BERT Cased Large" 24 | wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip -P $BERT_PRETRAIN_CKPT 25 | unzip $BERT_PRETRAIN_CKPT/cased_L-24_H-1024_A-16.zip -d $BERT_PRETRAIN_CKPT 26 | rm $BERT_PRETRAIN_CKPT/cased_L-24_H-1024_A-16.zip 27 | elif [[ $MODEL_NAME == "spanbert_base" ]]; then 28 | echo "DownLoad Span-BERT Cased Base" 29 | wget https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf_base.tar.gz -P $BERT_PRETRAIN_CKPT 30 | tar -zxvf $BERT_PRETRAIN_CKPT/spanbert_hf_base.tar.gz -C $BERT_PRETRAIN_CKPT 31 | rm $BERT_PRETRAIN_CKPT/spanbert_hf_base.tar.gz 32 | elif [[ $MODEL_NAME == "spanbert_large" ]]; then 33 | echo "DownLoad Span-BERT Cased Large" 34 | wget https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf.tar.gz -P $BERT_PRETRAIN_CKPT 35 | tar -zxvf $BERT_PRETRAIN_CKPT/spanbert_hf.tar.gz -C $BERT_PRETRAIN_CKPT 36 | rm $BERT_PRETRAIN_CKPT/spanbert_hf.tar.gz 37 | elif [[ $MODEL_NAME == "bert_tiny" ]]; then 38 | each "DownLoad BERT Uncased Tiny; Helps to debug on GPU." 39 | wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip -P $BERT_PRETRAIN_CKPT 40 | tar -zxvf $BERT_PRETRAIN_CKPT/uncased_L-2_H-128_A-2.zip -C $BERT_PRETRAIN_CKPT 41 | rm $BERT_PRETRAIN_CKPT/uncased_L-2_H-128_A-2.zip 42 | else 43 | echo 'Unknown argment 2 (Model Sign)' 44 | fi -------------------------------------------------------------------------------- /scripts/data/preprocess_ontonotes_annfiles.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # author: xiaoy li 5 | # description: 6 | # generate annotated CONLL-2012 coreference resolution datasets from the official released OntoNotes 5.0 dataset. 7 | # 8 | ###################################### 9 | # NOTICE: 10 | ###################################### 11 | # the scripts only work with python 2. 12 | # if you want to run with python 3, please refer to https://github.com/huggingface/neuralcoref/blob/master/neuralcoref/train/training.md#get-the-data 13 | # Thanks to their amazing job ! 14 | # 15 | # Reference: 16 | # https://github.com/huggingface/neuralcoref/blob/master/neuralcoref/train/training.md#get-the-data 17 | # https://github.com/mandarjoshi90/coref 18 | # 19 | 20 | 21 | path_to_ontonotes5.0_directory=$1 22 | path_to_save_processed_data_directory=$2 23 | language=$3 24 | 25 | 26 | dlx() { 27 | wget -P $path_to_save_processed_data_directory $1/$2 28 | tar -xvzf $path_to_save_processed_data_directory/$2 -C $path_to_save_processed_data_directory 29 | rm $path_to_save_processed_data_directory/$2 30 | } 31 | 32 | 33 | conll_url=http://conll.cemantix.org/2012/download 34 | dlx $conll_url conll-2012-train.v4.tar.gz 35 | dlx $conll_url conll-2012-development.v4.tar.gz 36 | dlx $conll_url/test conll-2012-test-key.tar.gz 37 | dlx $conll_url/test conll-2012-test-official.v9.tar.gz 38 | 39 | dlx $conll_url conll-2012-scripts.v3.tar.gz 40 | dlx http://conll.cemantix.org/download reference-coreference-scorers.v8.01.tar.gz 41 | 42 | bash $path_to_save_processed_data_directory/conll-2012/v3/scripts/skeleton2conll.sh -D $path_to_ontonotes5.0_directory/data/files/data $path_to_save_processed_data_directory/conll-2012 43 | 44 | function compile_partition() { 45 | rm -f $2.$5.$3$4 46 | cat $path_to_save_processed_data_directory/conll-2012/$3/data/$1/data/$5/annotations/*/*/*/*.$3$4 >> $path_to_save_processed_data_directory/$2.$5.$3$4 47 | } 48 | 49 | function compile_language() { 50 | compile_partition development dev v4 _gold_conll $1 51 | compile_partition train train v4 _gold_conll $1 52 | compile_partition test test v4 _gold_conll $1 53 | } 54 | 55 | compile_language $language 56 | 57 | 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | sftp-config.json 7 | 8 | .DS_Store 9 | 10 | *.py.swp 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ -------------------------------------------------------------------------------- /func_builders/input_fn_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # 9 | # 10 | 11 | 12 | import tensorflow as tf 13 | 14 | 15 | def file_based_input_fn_builder(input_file, num_window=None, window_size=None, max_num_mention=None, is_training=False, drop_remainder=True): 16 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 17 | name_to_features = { 18 | 'sentence_map': tf.FixedLenFeature([num_window * window_size], tf.int64), 19 | 'text_len': tf.FixedLenFeature([num_window], tf.int64), 20 | 'subtoken_map': tf.FixedLenFeature([num_window * window_size], tf.int64), 21 | 'speaker_ids': tf.FixedLenFeature([num_window * window_size], tf.int64), 22 | 'flattened_input_ids': tf.FixedLenFeature([num_window * window_size], tf.int64), 23 | 'flattened_input_mask': tf.FixedLenFeature([num_window * window_size], tf.int64), 24 | 'span_starts': tf.FixedLenFeature([max_num_mention], tf.int64), 25 | 'span_ends': tf.FixedLenFeature([max_num_mention], tf.int64), 26 | 'cluster_ids': tf.FixedLenFeature([max_num_mention], tf.int64), 27 | } 28 | 29 | 30 | def _decode_record(record, name_to_features): 31 | """Decodes a record to a TensorFlow example.""" 32 | example = tf.io.parse_single_example(record, name_to_features) 33 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 34 | # So cast all int64 to int32. 35 | for name in list(example.keys()): 36 | t = example[name] 37 | if t.dtype == tf.int64: 38 | t = tf.to_int32(t) 39 | example[name] = t 40 | return example 41 | 42 | 43 | def input_fn_from_tfrecord(params): 44 | """The actual input function.""" 45 | batch_size = params["batch_size"] 46 | 47 | # For training, we want a lot of parallel reading and shuffling. 48 | # For eval, we want no shuffling and parallel reading doesn't matter. 49 | d = tf.data.TFRecordDataset(input_file) 50 | if is_training: 51 | d = d.repeat() 52 | d = d.shuffle(buffer_size=100) 53 | 54 | d = d.apply( 55 | tf.contrib.data.map_and_batch( 56 | lambda record: _decode_record(record, name_to_features), 57 | batch_size=batch_size, 58 | drop_remainder=drop_remainder)) 59 | 60 | return d 61 | 62 | return input_fn_from_tfrecord -------------------------------------------------------------------------------- /tests/tile_repeat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | a_np = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4], [5, 6, 7], [8, 9, 10]]) 8 | 9 | print(a_np.shape) 10 | # exit() 11 | 12 | def shape(x, dim): 13 | return x.get_shape()[dim].value or tf.shape(x)[dim] 14 | 15 | 16 | if __name__ == "__main__": 17 | 18 | original_array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4], [5, 6, 7], [8, 9, 10]]) 19 | sess = tf.compat.v1.InteractiveSession() 20 | start_scores = tf.convert_to_tensor(tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4], [5, 6, 7], [8, 9, 10]])) 21 | print(tf.shape(start_scores)) 22 | # exit() 23 | expand_scores = tf.tile(tf.expand_dims(start_scores, 2), [1, 1, 3]) 24 | # (6, 3, 3) 25 | print(expand_scores.eval()) 26 | print(shape(expand_scores, 0)) 27 | print(shape(expand_scores, 1)) 28 | print(shape(expand_scores, 2)) 29 | print("=*="*20) 30 | # tf.convert_to_tensor(data_np, np.float32) 31 | # ndarray_scores = tf.make_ndarray(expand_scores) 32 | # ndarray_scores = tf.convert_to_tensor(expand_scores, np.int32) 33 | # print(ndarray_scores) 34 | # exit() 35 | ndarray_scores = np.array([[[1, 1, 1], [ 2 , 2 , 2], [ 3 , 3 , 3]], 36 | [[ 4 , 4 , 4],[ 5 , 5 , 5],[ 6 , 6 , 6]], 37 | [[ 7 , 7 , 7],[ 8 , 8 , 8],[ 9 , 9 , 9]], 38 | [[ 2 , 2 , 2], [ 3 , 3 , 3], [ 4 , 4 , 4]], 39 | [[ 5 , 5 , 5], [ 6 , 6 , 6], [ 7 , 7 , 7]], 40 | [[ 8 , 8 , 8], [ 9 , 9 , 9], [10 , 10 ,10]]]) 41 | print("$="*20) 42 | print("test_a is : {}".format(str(ndarray_scores[2, 2, 2]))) 43 | print("test_b is : {}".format(str(original_array[2, 2]))) 44 | print("^-"*20) 45 | print("test_a is : {}".format(str(ndarray_scores[2, 1, 1]))) 46 | print("test_b is : {}".format(str(original_array[2, 1]))) 47 | print("^-"*20) 48 | print("test_a is : {}".format(str(ndarray_scores[2, 0, 0]))) 49 | print("test_b is : {}".format(str(original_array[2, 0]))) 50 | sess.close() 51 | # span_scores[k][i][j] = start_scores[k][i] + end_scores[k][j] 52 | # start_scores[k][i][j] = start_scores[k][i] 53 | # end_scores[k][i][j] = end_scores[k][j] 54 | 55 | # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4], [5, 6, 7], [8, 9, 10]] 56 | 57 | """ 58 | [[[1, 1, 1], [ 2 , 2 , 2], [ 3 , 3 , 3]], 59 | [[ 4 , 4 , 4],[ 5 , 5 , 5],[ 6 , 6 , 6]], 60 | [[ 7 , 7 , 7],[ 8 , 8 , 8],[ 9 , 9 , 9]], 61 | [[ 2 , 2 , 2], [ 3 , 3 , 3], [ 4 , 4 , 4]], 62 | [[ 5 , 5 , 5], [ 6 , 6 , 6], [ 7 , 7 , 7]], 63 | [[ 8 , 8 , 8], [ 9 , 9 , 9], [10 , 10 ,10]]] 64 | """ 65 | 66 | -------------------------------------------------------------------------------- /tests/cumsum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | input_a: 6 | array([[0, 1, 0, 1, 1], 7 | [0, 1, 1, 1, 1], 8 | [0, 1, 1, 0, 1], 9 | [1, 1, 1, 1, 1], 10 | [0, 1, 1, 1, 0]], dtype=int32) 11 | cum_input_b: 12 | array([[1, 2, 3, 4, 5], 13 | [1, 2, 3, 4, 5], 14 | [1, 2, 3, 4, 5], 15 | [1, 2, 3, 4, 5], 16 | [1, 2, 3, 4, 5]], dtype=int32) 17 | input_c: 18 | array([[0, 2, 0, 4, 5], 19 | [0, 2, 3, 4, 5], 20 | [0, 2, 3, 0, 5], 21 | [1, 2, 3, 4, 5], 22 | [0, 2, 3, 4, 0]], dtype=int32) 23 | input_c: 24 | array([[ 1, 2, 3, 4, 5], 25 | [129, 130, 131, 132, 133], 26 | [257, 258, 259, 260, 261], 27 | [385, 386, 387, 388, 389], 28 | [513, 514, 515, 516, 517]], dtype=int32) 29 | input_d: 30 | array([[ 0, 2, 0, 4, 5], 31 | [ 0, 130, 131, 132, 133], 32 | [ 0, 258, 259, 0, 261], 33 | [385, 386, 387, 388, 389], 34 | [ 0, 514, 515, 516, 0]], dtype=int32) 35 | flat_input_d: 36 | array([ 0, 2, 0, 4, 5, 0, 130, 131, 132, 133, 0, 258, 259, 37 | 0, 261, 385, 386, 387, 388, 389, 0, 514, 515, 516, 0], dtype=int32) 38 | boolean_mask: 39 | array([False, True, False, True, True, False, True, True, True, 40 | True, False, True, True, False, True, True, True, True, 41 | True, True, False, True, True, True, False]) 42 | input_f: 43 | array([ 2, 4, 5, 130, 131, 132, 133, 258, 259, 261, 385, 386, 387, 44 | 388, 389, 514, 515, 516], dtype=int32) 45 | """ 46 | 47 | 48 | 49 | 50 | import tensorflow as tf 51 | 52 | 53 | if __name__ == "__main__": 54 | sess = tf.compat.v1.InteractiveSession() 55 | input_a = tf.constant([ 56 | [0, 1, 0, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 0, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 0]]) 57 | ones_input_b = tf.ones_like(input_a, tf.int32) 58 | cum_input_b = tf.math.cumsum(ones_input_b, axis=1) 59 | cum_input_b.eval() 60 | # input_c = tf.math.multiply(cum_input_b, input_a) 61 | # input_c.eval() 62 | seq_len = 128 63 | offset = tf.tile(tf.expand_dims(tf.range(5) * 128, 1), [1, 5]) 64 | offset.eval() 65 | input_e = offset + cum_input_b 66 | input_e.eval() 67 | 68 | input_d = tf.math.multiply(input_e, input_a) 69 | input_d.eval() 70 | flat_input_d = tf.reshape(input_d, [-1]) 71 | flat_input_d.eval() 72 | 73 | boolean_mask = tf.math.greater(flat_input_d, tf.zeros_like(flat_input_d, tf.int32)) 74 | boolean_mask.eval() 75 | 76 | input_f = tf.boolean_mask(flat_input_d, boolean_mask) 77 | input_f.eval() 78 | 79 | sess.close() 80 | 81 | 82 | -------------------------------------------------------------------------------- /utils/load_pytorch_to_tf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # transform pretrained model checkpoints from [pytorch *.bin] to [tensorflow *.ckpt] 9 | # Reference: 10 | # https://github.com/mandarjoshi90/coref/blob/master/pytorch_to_tf.py 11 | 12 | 13 | import numpy as np 14 | import torch 15 | import tensorflow as tf 16 | from tensorflow.python.framework import ops 17 | from tensorflow.python.ops import variable_scope as vs 18 | 19 | 20 | tensors_to_transpose = ( 21 | "dense/kernel", 22 | "attention/self/query", 23 | "attention/self/key", 24 | "attention/self/value" 25 | ) 26 | 27 | var_map = ( 28 | ('layer.', 'layer_'), 29 | ('word_embeddings.weight', 'word_embeddings'), 30 | ('position_embeddings.weight', 'position_embeddings'), 31 | ('token_type_embeddings.weight', 'token_type_embeddings'), 32 | ('.', '/'), 33 | ('LayerNorm/weight', 'LayerNorm/gamma'), 34 | ('LayerNorm/bias', 'LayerNorm/beta'), 35 | ('weight', 'kernel') 36 | ) 37 | 38 | 39 | def to_tf_var_name(name: str): 40 | for patt, repl in iter(var_map): 41 | name = name.replace(patt, repl) 42 | return '{}'.format(name) 43 | 44 | 45 | def my_convert_keys(model): 46 | converted = {} 47 | for k_pt, v in model.items(): 48 | k_tf = to_tf_var_name(k_pt) 49 | converted[k_tf] = v 50 | return converted 51 | 52 | 53 | def load_from_pytorch_checkpoint(checkpoint, assignment_map): 54 | pytorch_model = torch.load(checkpoint, map_location='cpu') 55 | pt_model_with_tf_keys = my_convert_keys(pytorch_model) 56 | for _, name in assignment_map.items(): 57 | store_vars = vs._get_default_variable_store()._vars 58 | var = store_vars.get(name, None) 59 | assert var is not None 60 | if name not in pt_model_with_tf_keys: 61 | print('WARNING:', name, 'not found in original model.') 62 | continue 63 | array = pt_model_with_tf_keys[name].cpu().numpy() 64 | if any([x in name for x in tensors_to_transpose]): 65 | array = array.transpose() 66 | assert tuple(var.get_shape().as_list()) == tuple(array.shape) 67 | init_value = ops.convert_to_tensor(array, dtype=np.float32) 68 | var._initial_value = init_value 69 | var._initializer_op = var.assign(init_value) 70 | 71 | 72 | def print_vars(pytorch_ckpt, tf_ckpt): 73 | tf_vars = tf.train.list_variables(tf_ckpt) 74 | tf_vars = {k: v for (k, v) in tf_vars} 75 | pytorch_model = torch.load(pytorch_ckpt) 76 | pt_model_with_tf_keys = my_convert_keys(pytorch_model) 77 | only_pytorch, only_tf, common = [], [], [] 78 | tf_only = set(tf_vars.keys()) 79 | for k, v in pt_model_with_tf_keys.items(): 80 | if k in tf_vars: 81 | common.append(k) 82 | tf_only.remove(k) 83 | else: 84 | only_pytorch.append(k) 85 | print('-------------------') 86 | print('Common', len(common)) 87 | for k in common: 88 | array = pt_model_with_tf_keys[k].cpu().numpy() 89 | if any([x in k for x in tensors_to_transpose]): 90 | array = array.transpose() 91 | tf_shape = tuple(tf_vars[k]) 92 | pt_shape = tuple(array.shape) 93 | if tf_shape != pt_shape: 94 | print(k, tf_shape, pt_shape) 95 | print('-------------------') 96 | print('Pytorch only', len(only_pytorch)) 97 | for k in only_pytorch: 98 | print(k, pt_model_with_tf_keys[k].size()) 99 | print('-------------------') 100 | print('TF only', len(tf_only)) 101 | for k in tf_only: 102 | print(k, tf_vars[k]) 103 | 104 | 105 | -------------------------------------------------------------------------------- /conll-2012/scorer/v8.01/README.txt: -------------------------------------------------------------------------------- 1 | NAME 2 | CorScorer: Perl package for scoring coreference resolution systems 3 | using different metrics. 4 | 5 | 6 | VERSION 7 | v8.01 -- reference implementations of MUC, B-cubed, CEAF and BLANC metrics. 8 | 9 | 10 | CHANGES SINCE v8.0 11 | - fixed a bug that crashed the BLANC scorer when a duplicate singleton 12 | mention was present in the response. 13 | 14 | INSTALLATION 15 | Requirements: 16 | 1. Perl: downloadable from http://perl.org 17 | 2. Algorithm-Munkres: included in this package and downloadable 18 | from CPAN http://search.cpan.org/~tpederse/Algorithm-Munkres-0.08 19 | 20 | USE 21 | This package is distributed with two scripts to execute the scorer from 22 | the command line. 23 | 24 | Windows (tm): scorer.bat 25 | Linux: scorer.pl 26 | 27 | 28 | SYNOPSIS 29 | use CorScorer; 30 | 31 | $metric = 'ceafm'; 32 | 33 | # Scores the whole dataset 34 | &CorScorer::Score($metric, $keys_file, $response_file); 35 | 36 | # Scores one file 37 | &CorScorer::Score($metric, $keys_file, $response_file, $name); 38 | 39 | 40 | INPUT 41 | metric: the metric desired to score the results: 42 | muc: MUCScorer (Vilain et al, 1995) 43 | bcub: B-Cubed (Bagga and Baldwin, 1998) 44 | ceafm: CEAF (Luo et al., 2005) using mention-based similarity 45 | ceafe: CEAF (Luo et al., 2005) using entity-based similarity 46 | blanc: BLANC (Luo et al., 2014) BLANC metric for gold and predicted mentions 47 | all: uses all the metrics to score 48 | 49 | keys_file: file with expected coreference chains in CoNLL-2011/2012 format 50 | 51 | response_file: file with output of coreference system (CoNLL-2011/2012 format) 52 | 53 | name: [optional] the name of the document to score. If name is not 54 | given, all the documents in the dataset will be scored. If given 55 | name is "none" then all the documents are scored but only total 56 | results are shown. 57 | 58 | 59 | OUTPUT 60 | The score subroutine returns an array with four values in this order: 61 | 1) Recall numerator 62 | 2) Recall denominator 63 | 3) Precision numerator 64 | 4) Precision denominator 65 | 66 | Also recall, precision and F1 are printed in the standard output when variable 67 | $VERBOSE is not null. 68 | 69 | Final scores: 70 | Recall = recall_numerator / recall_denominator 71 | Precision = precision_numerator / precision_denominator 72 | F1 = 2 * Recall * Precision / (Recall + Precision) 73 | 74 | Identification of mentions 75 | An scorer for identification of mentions (recall, precision and F1) is also included. 76 | Mentions from system response are compared with key mentions. This version performs 77 | strict mention matching as was used in the CoNLL-2011 and 2012 shared tasks. 78 | 79 | AUTHORS 80 | Emili Sapena, Universitat Politècnica de Catalunya, http://www.lsi.upc.edu/~esapena, esapena lsi.upc.edu 81 | Sameer Pradhan, sameer.pradhan childrens.harvard.edu 82 | Sebastian Martschat, sebastian.martschat h-its.org 83 | Xiaoqiang Luo, xql google.com 84 | 85 | COPYRIGHT AND LICENSE 86 | Copyright (C) 2009-2011, Emili Sapena esapena lsi.upc.edu 87 | 2011-2014, Sameer Pradhan sameer.pradhan childrens.harvard.edu 88 | 89 | This program is free software; you can redistribute it and/or modify it 90 | under the terms of the GNU General Public License as published by the 91 | Free Software Foundation; either version 2 of the License, or (at your 92 | option) any later version. This program is distributed in the hope that 93 | it will be useful, but WITHOUT ANY WARRANTY; without even the implied 94 | warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 95 | GNU General Public License for more details. 96 | 97 | You should have received a copy of the GNU General Public License along 98 | with this program; if not, write to the Free Software Foundation, Inc., 99 | 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 100 | 101 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | import codecs 7 | import collections 8 | import errno 9 | import os 10 | import shutil 11 | import pyhocon 12 | import tensorflow as tf 13 | from models import corefqa 14 | from models import mention_proposal 15 | 16 | 17 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-2]) 18 | 19 | 20 | def get_model(config, model_sign="corefqa"): 21 | if model_sign == "corefqa": 22 | return corefqa.CorefQAModel(config) 23 | else: 24 | return mention_proposal.MentionProposalModel(config) 25 | 26 | 27 | def initialize_from_env(eval_test=False, config_params="train_spanbert_base", config_file="experiments_tinybert.conf", use_tpu=False, print_info=False): 28 | if not use_tpu: 29 | print("loading experiments.conf ... ") 30 | config = pyhocon.ConfigFactory.parse_file(os.path.join(repo_path, config_file)) 31 | else: 32 | print("loading experiments_tpu.conf ... ") 33 | config = pyhocon.ConfigFactory.parse_file(os.path.join(repo_path, config_file)) 34 | 35 | config = config[config_params] 36 | 37 | if print_info: 38 | tf.logging.info("%*%"*20) 39 | tf.logging.info("%*%"*20) 40 | tf.logging.info("%%%%%%%% Configs are showed as follows : %%%%%%%%") 41 | for tmp_key, tmp_value in config.items(): 42 | tf.logging.info(str(tmp_key) + " : " + str(tmp_value)) 43 | 44 | tf.logging.info("%*%"*20) 45 | tf.logging.info("%*%"*20) 46 | 47 | config["log_dir"] = mkdirs(os.path.join(config["log_root"], config_params)) 48 | 49 | if print_info: 50 | tf.logging.info(pyhocon.HOCONConverter.convert(config, "hocon")) 51 | return config 52 | 53 | 54 | def copy_checkpoint(source, target): 55 | for ext in (".index", ".data-00000-of-00001"): 56 | shutil.copyfile(source + ext, target + ext) 57 | 58 | 59 | def make_summary(value_dict): 60 | return tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=v) for k, v in value_dict.items()]) 61 | 62 | 63 | def flatten(l): 64 | return [item for sublist in l for item in sublist] 65 | 66 | 67 | def set_gpus(*gpus): 68 | # pass 69 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpus) 70 | print("Setting CUDA_VISIBLE_DEVICES to: {}".format(os.environ["CUDA_VISIBLE_DEVICES"])) 71 | 72 | 73 | def mkdirs(path): 74 | try: 75 | os.makedirs(path) 76 | except OSError as exception: 77 | if exception.errno != errno.EEXIST: 78 | raise 79 | return path 80 | 81 | 82 | def load_char_dict(char_vocab_path): 83 | vocab = [u""] 84 | with codecs.open(char_vocab_path, encoding="utf-8") as f: 85 | vocab.extend(l.strip() for l in f.readlines()) 86 | char_dict = collections.defaultdict(int) 87 | char_dict.update({c: i for i, c in enumerate(vocab)}) 88 | return char_dict 89 | 90 | 91 | def maybe_divide(x, y): 92 | return 0 if y == 0 else x / float(y) 93 | 94 | 95 | 96 | def shape(x, dim): 97 | return x.get_shape()[dim].value or tf.shape(x)[dim] 98 | 99 | 100 | def ffnn(inputs, num_hidden_layers, hidden_size, output_size, dropout, 101 | output_weights_initializer=tf.truncated_normal_initializer(stddev=0.02), 102 | hidden_initializer=tf.truncated_normal_initializer(stddev=0.02)): 103 | if len(inputs.get_shape()) > 3: 104 | raise ValueError("FFNN with rank {} not supported".format(len(inputs.get_shape()))) 105 | current_inputs = inputs 106 | hidden_weights = tf.get_variable("hidden_weights", [hidden_size, output_size], 107 | initializer=hidden_initializer) 108 | hidden_bias = tf.get_variable("hidden_bias", [output_size], initializer=tf.zeros_initializer()) 109 | current_outputs = tf.nn.relu(tf.nn.xw_plus_b(current_inputs, hidden_weights, hidden_bias)) 110 | 111 | return current_outputs 112 | 113 | 114 | def batch_gather(emb, indices): 115 | batch_size = shape(emb, 0) 116 | seqlen = shape(emb, 1) 117 | if len(emb.get_shape()) > 2: 118 | emb_size = shape(emb, 2) 119 | else: 120 | emb_size = 1 121 | flattened_emb = tf.reshape(emb, [batch_size * seqlen, emb_size]) # [batch_size * seqlen, emb] 122 | offset = tf.expand_dims(tf.range(batch_size) * seqlen, 1) # [batch_size, 1] 123 | gathered = tf.gather(flattened_emb, indices + offset) # [batch_size, num_indices, emb] 124 | if len(emb.get_shape()) == 2: 125 | gathered = tf.squeeze(gathered, 2) # [batch_size, num_indices] 126 | return gathered 127 | 128 | -------------------------------------------------------------------------------- /data_utils/conll.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # use the offical conll-2012 evaluation scripts to evaluate the trained model and save the evaluation results into files. 9 | 10 | 11 | import os 12 | import re 13 | import collections 14 | import operator 15 | import subprocess 16 | import tempfile 17 | 18 | 19 | 20 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-2]) 21 | BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)") 22 | COREF_RESULTS_REGEX = re.compile( 23 | r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", 24 | re.DOTALL) 25 | 26 | 27 | 28 | def get_doc_key(doc_id, part): 29 | return "{}_{}".format(doc_id, int(part)) 30 | 31 | 32 | def output_conll(input_file, output_file, predictions, subtoken_map): 33 | prediction_map = {} 34 | for doc_key, clusters in predictions.items(): 35 | start_map = collections.defaultdict(list) 36 | end_map = collections.defaultdict(list) 37 | word_map = collections.defaultdict(list) 38 | for cluster_id, mentions in enumerate(clusters): 39 | for start, end in mentions: 40 | start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end] 41 | if start == end: 42 | word_map[start].append(cluster_id) 43 | else: 44 | start_map[start].append((cluster_id, end)) 45 | end_map[end].append((cluster_id, start)) 46 | for k, v in start_map.items(): 47 | start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)] 48 | for k, v in end_map.items(): 49 | end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)] 50 | prediction_map[doc_key] = (start_map, end_map, word_map) 51 | 52 | word_index = 0 53 | for line in input_file.readlines(): 54 | row = line.split() 55 | if len(row) == 0: 56 | output_file.write("\n") 57 | elif row[0].startswith("#"): 58 | begin_match = re.match(BEGIN_DOCUMENT_REGEX, line) 59 | if begin_match: 60 | doc_key = get_doc_key(begin_match.group(1), begin_match.group(2)) 61 | start_map, end_map, word_map = prediction_map[doc_key] 62 | word_index = 0 63 | output_file.write(line) 64 | output_file.write("\n") 65 | else: 66 | assert get_doc_key(row[0], row[1]) == doc_key 67 | coref_list = [] 68 | if word_index in end_map: 69 | for cluster_id in end_map[word_index]: 70 | coref_list.append("{})".format(cluster_id)) 71 | if word_index in word_map: 72 | for cluster_id in word_map[word_index]: 73 | coref_list.append("({})".format(cluster_id)) 74 | if word_index in start_map: 75 | for cluster_id in start_map[word_index]: 76 | coref_list.append("({}".format(cluster_id)) 77 | 78 | if len(coref_list) == 0: 79 | row[-1] = "-" 80 | else: 81 | row[-1] = "|".join(coref_list) 82 | 83 | output_file.write(" ".join(row)) 84 | output_file.write("\n") 85 | word_index += 1 86 | 87 | 88 | def official_conll_eval(gold_path, predicted_path, metric, official_stdout=False): 89 | cmd = ["perl", os.path.join(REPO_PATH, "conll-2012/scorer/v8.01/scorer.pl"), metric, gold_path, predicted_path, "none"] 90 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 91 | stdout, stderr = process.communicate() 92 | process.wait() 93 | 94 | stdout = stdout.decode("utf-8") 95 | if stderr is not None: 96 | print(stderr) 97 | 98 | if official_stdout: 99 | print("Official result for {}".format(metric)) 100 | print(stdout) 101 | 102 | coref_results_match = re.match(COREF_RESULTS_REGEX, stdout) 103 | recall = float(coref_results_match.group(1)) 104 | precision = float(coref_results_match.group(2)) 105 | f1 = float(coref_results_match.group(3)) 106 | return {"r": recall, "p": precision, "f": f1} 107 | 108 | 109 | def evaluate_conll(gold_path, predictions, subtoken_maps, official_stdout=False): 110 | with tempfile.NamedTemporaryFile(delete=False, mode="w") as prediction_file: 111 | with open(gold_path, "r") as gold_file: 112 | output_conll(gold_file, prediction_file, predictions, subtoken_maps) 113 | print("Predicted conll file: {}".format(prediction_file.name)) 114 | return {m: official_conll_eval(gold_file.name, prediction_file.name, m, official_stdout) for m in 115 | ("muc", "bcub", "ceafe")} 116 | -------------------------------------------------------------------------------- /run/transform_spanbert_pytorch_to_tf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # transform pytorch .bin models to tensorflow ckpt 9 | 10 | 11 | import os 12 | import sys 13 | import shutil 14 | import torch 15 | import argparse 16 | import random 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-2]) 21 | 22 | if REPO_PATH not in sys.path: 23 | sys.path.insert(0, REPO_PATH) 24 | 25 | 26 | from bert import modeling 27 | from utils.load_pytorch_to_tf import load_from_pytorch_checkpoint 28 | 29 | 30 | def load_models(bert_config_path, ): 31 | bert_config = modeling.BertConfig.from_json_file(bert_config_path) 32 | input_ids = tf.ones((8, 128), tf.int32) 33 | 34 | model = modeling.BertModel( 35 | config=bert_config, 36 | is_training=False, 37 | input_ids=input_ids, 38 | use_one_hot_embeddings=False, 39 | scope="bert") 40 | 41 | return model, bert_config 42 | 43 | 44 | def copy_checkpoint(source, target): 45 | for ext in (".index", ".data-00000-of-00001"): 46 | shutil.copyfile(source + ext, target + ext) 47 | 48 | 49 | def main(bert_config_path, bert_ckpt_path, pytorch_init_checkpoint, output_tf_dir): 50 | 51 | with tf.Session() as session: 52 | model, bert_config = load_models(bert_config_path) 53 | tvars = tf.trainable_variables() 54 | assignment_map, initialized_variable_names = modeling.get_assignment_map_from_checkpoint(tvars, bert_ckpt_path) 55 | session.run(tf.global_variables_initializer()) 56 | init_from_checkpoint = load_from_pytorch_checkpoint 57 | init_from_checkpoint(pytorch_init_checkpoint, assignment_map) 58 | 59 | for var in tvars: 60 | init_string = "" 61 | if var.name in initialized_variable_names: 62 | init_string = ", *INIT_FROM_CKPT*" 63 | print("name = %s, shape = %s%s" % (var.name, var.shape, init_string)) 64 | 65 | saver = tf.train.Saver() 66 | saver.save(session, os.path.join(output_tf_dir, "model"), global_step=100) 67 | copy_checkpoint(os.path.join(output_tf_dir, "model-{}".format(str(100))), os.path.join(output_tf_dir, "bert_model.ckpt")) 68 | print("=*="*30) 69 | print("save models : {}".format(output_tf_dir)) 70 | print("=*="*30) 71 | 72 | 73 | def parse_args(): 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("--spanbert_config_path", default="/home/lixiaoya/spanbert_base_cased/config.json", type=str) 76 | parser.add_argument("--bert_tf_ckpt_path", default="/home/lixiaoya/cased_L-12_H-768_A-12/bert_model.ckpt", type=str) 77 | parser.add_argument("--spanbert_pytorch_bin_path", default="/home/lixiaoya/spanbert_base_cased/pytorch_model.bin", type=str) 78 | parser.add_argument("--output_spanbert_tf_dir", default="/home/lixiaoya/tf_spanbert_base_case", type=str) 79 | parser.add_argument("--seed", default=2333, type=int) 80 | 81 | 82 | args = parser.parse_args() 83 | 84 | random.seed(args.seed) 85 | np.random.seed(args.seed) 86 | torch.manual_seed(args.seed) 87 | tf.set_random_seed(args.seed) 88 | torch.cuda.manual_seed_all(args.seed) 89 | 90 | os.makedirs(args.output_spanbert_tf_dir, exist_ok=True) 91 | 92 | try: 93 | shutil(args.spanbert_config_path, args.output_spanbert_tf_dir) 94 | except: 95 | print("#=#"*30) 96 | print("copy spanbert_config from {} to {}".format(args.spanbert_config_path, args.output_spanbert_tf_dir)) 97 | 98 | return args 99 | 100 | 101 | if __name__ == "__main__": 102 | 103 | args_config = parse_args() 104 | 105 | main(args_config.spanbert_config_path, args_config.bert_tf_ckpt_path, args_config.spanbert_pytorch_bin_path, args_config.output_spanbert_tf_dir) 106 | 107 | # 108 | # Please refer to scripts/data/transform_ckpt_pytorch_to_tf.sh 109 | # 110 | 111 | # for spanbert large 112 | # 113 | # python3 transform_spanbert_pytorch_to_tf.py \ 114 | # --spanbert_config_path /xiaoya/pretrain_ckpt/spanbert_large_cased/config.json \ 115 | # --bert_tf_ckpt_path /xiaoya/pretrain_ckpt/cased_L-24_H-1024_A-16/bert_model.ckpt \ 116 | # --spanbert_pytorch_bin_path /xiaoya/pretrain_ckpt/spanbert_large_cased/pytorch_model.bin \ 117 | # --output_spanbert_tf_dir /xiaoya/pretrain_ckpt/tf_spanbert_large_cased 118 | 119 | 120 | # for spanbert base 121 | # 122 | # python3 transform_spanbert_pytorch_to_tf.py \ 123 | # --spanbert_config_path /xiaoya/pretrain_ckpt/spanbert_base_cased/config.json \ 124 | # --bert_tf_ckpt_path /xiaoya/pretrain_ckpt/cased_L-12_H-768_A-12/bert_model.ckpt \ 125 | # --spanbert_pytorch_bin_path /xiaoya/pretrain_ckpt/spanbert_base_cased/pytorch_model.bin \ 126 | # --output_spanbert_tf_dir /xiaoya/pretrain_ckpt/tf_spanbert_base_cased 127 | 128 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | import numpy as np 7 | from collections import Counter 8 | from scipy.optimize import linear_sum_assignment 9 | 10 | 11 | def mention_proposal_prediction(config, current_doc_result, concat_only=True): 12 | """ 13 | current_doc_result: 14 | "total_loss": total_loss, 15 | "start_scores": start_scores, 16 | "start_gold": gold_starts, 17 | "end_gold": gold_ends, 18 | "end_scores": end_scores, 19 | "span_scores": span_scores, 20 | "span_gold": span_mention 21 | 22 | """ 23 | 24 | span_scores = current_doc_result["span_scores"] 25 | span_gold = current_doc_result["span_gold"] 26 | 27 | if concat_only: 28 | scores = span_scores 29 | else: 30 | start_scores = current_doc_result["start_scores"], 31 | end_scores = current_doc_result["end_scores"] 32 | # start_scores = tf.tile(tf.expand_dims(start_scores, 2), [1, 1, config["max_segment_len"]]) 33 | start_scores = np.tile(np.expand_dims(start_scores, axis=2), (1, 1, config["max_segment_len"])) 34 | end_scores = np.tile(np.expand_dims(end_scores, axis=2), (1, 1, config["max_segment_len"])) 35 | start_scores = np.reshape(start_scores, [-1, config["max_segment_len"], config["max_segment_len"]]) 36 | end_scores = np.reshape(end_scores, [-1, config["max_segment_len"], config["max_segment_len"]]) 37 | 38 | # end_scores -> max_training_sent, max_segment_len 39 | scores = (start_scores + end_scores + span_scores)/3 40 | 41 | pred_span_label = scores >= 0.5 42 | pred_span_label = np.reshape(pred_span_label, [-1]) 43 | gold_span_label = np.reshape(span_gold, [-1]) 44 | 45 | return pred_span_label, gold_span_label 46 | 47 | 48 | def f1(p_num, p_den, r_num, r_den, beta=1): 49 | p = 0 if p_den == 0 else p_num / float(p_den) 50 | r = 0 if r_den == 0 else r_num / float(r_den) 51 | return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) 52 | 53 | 54 | class CorefEvaluator(object): 55 | def __init__(self): 56 | self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] 57 | 58 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 59 | for e in self.evaluators: 60 | e.update(predicted, gold, mention_to_predicted, mention_to_gold) 61 | 62 | def get_f1(self): 63 | return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) 64 | 65 | def get_recall(self): 66 | return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) 67 | 68 | def get_precision(self): 69 | return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) 70 | 71 | def get_prf(self): 72 | return self.get_precision(), self.get_recall(), self.get_f1() 73 | 74 | 75 | class Evaluator(object): 76 | def __init__(self, metric, beta=1): 77 | self.p_num = 0 78 | self.p_den = 0 79 | self.r_num = 0 80 | self.r_den = 0 81 | self.metric = metric 82 | self.beta = beta 83 | 84 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 85 | if self.metric == ceafe: 86 | pn, pd, rn, rd = self.metric(predicted, gold) 87 | else: 88 | pn, pd = self.metric(predicted, mention_to_gold) 89 | rn, rd = self.metric(gold, mention_to_predicted) 90 | self.p_num += pn 91 | self.p_den += pd 92 | self.r_num += rn 93 | self.r_den += rd 94 | 95 | def get_f1(self): 96 | return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) 97 | 98 | def get_recall(self): 99 | return 0 if self.r_num == 0 else self.r_num / float(self.r_den) 100 | 101 | def get_precision(self): 102 | return 0 if self.p_num == 0 else self.p_num / float(self.p_den) 103 | 104 | def get_prf(self): 105 | return self.get_precision(), self.get_recall(), self.get_f1() 106 | 107 | def get_counts(self): 108 | return self.p_num, self.p_den, self.r_num, self.r_den 109 | 110 | 111 | def evaluate_documents(documents, metric, beta=1): 112 | evaluator = Evaluator(metric, beta=beta) 113 | for document in documents: 114 | evaluator.update(document) 115 | return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1() 116 | 117 | 118 | def b_cubed(clusters, mention_to_gold): 119 | num, dem = 0, 0 120 | 121 | for c in clusters: 122 | if len(c) == 1: 123 | continue 124 | 125 | gold_counts = Counter() 126 | correct = 0 127 | for m in c: 128 | if m in mention_to_gold: 129 | gold_counts[tuple(mention_to_gold[m])] += 1 130 | for c2, count in gold_counts.items(): 131 | if len(c2) != 1: 132 | correct += count * count 133 | 134 | num += correct / float(len(c)) 135 | dem += len(c) 136 | 137 | return num, dem 138 | 139 | 140 | def muc(clusters, mention_to_gold): 141 | tp, p = 0, 0 142 | for c in clusters: 143 | p += len(c) - 1 144 | tp += len(c) 145 | linked = set() 146 | for m in c: 147 | if m in mention_to_gold: 148 | linked.add(mention_to_gold[m]) 149 | else: 150 | tp -= 1 151 | tp -= len(linked) 152 | return tp, p 153 | 154 | 155 | def phi4(c1, c2): 156 | return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) 157 | 158 | 159 | def ceafe(clusters, gold_clusters): 160 | clusters = [c for c in clusters if len(c) != 1] 161 | scores = np.zeros((len(gold_clusters), len(clusters))) 162 | for i in range(len(gold_clusters)): 163 | for j in range(len(clusters)): 164 | scores[i, j] = phi4(gold_clusters[i], clusters[j]) 165 | row_ind, col_ind = linear_sum_assignment(-scores) 166 | similarity = sum(scores[row_ind, col_ind]) 167 | return similarity, len(clusters), similarity, len(gold_clusters) 168 | 169 | 170 | def lea(clusters, mention_to_gold): 171 | num, dem = 0, 0 172 | 173 | for c in clusters: 174 | if len(c) == 1: 175 | continue 176 | 177 | common_links = 0 178 | all_links = len(c) * (len(c) - 1) / 2.0 179 | for i, m in enumerate(c): 180 | if m in mention_to_gold: 181 | for m2 in c[i + 1:]: 182 | if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: 183 | common_links += 1 184 | 185 | num += len(c) * common_links / float(all_links) 186 | dem += len(c) 187 | 188 | return num, dem 189 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /utils/radam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | import tensorflow as tf 7 | from tensorflow.python import ( 8 | ops, math_ops, state_ops, control_flow_ops, resource_variable_ops) 9 | from tensorflow.python.training.optimizer import Optimizer 10 | 11 | __all__ = ['RAdam'] 12 | 13 | 14 | class RAdam(Optimizer): 15 | """Rectified Adam (RAdam) optimizer. 16 | According to the paper 17 | [On The Variance Of The Adaptive Learning Rate And Beyond](https://arxiv.org/pdf/1908.03265v1.pdf). 18 | """ 19 | 20 | def __init__(self, 21 | learning_rate=0.001, 22 | beta1=0.9, 23 | beta2=0.999, 24 | epsilon=1e-8, 25 | amsgrad=False, 26 | use_locking=False, 27 | name='RAdam'): 28 | r"""Construct a new Rectified Adam optimizer. 29 | Args: 30 | learning_rate: A Tensor or a floating point value. The learning rate. 31 | beta1: A float value or a constant float tensor. The exponential decay 32 | rate for the 1st moment estimates. 33 | beta2: A float value or a constant float tensor. The exponential decay 34 | rate for the 2nd moment estimates. 35 | epsilon: A small constant for numerical stability. This epsilon is 36 | "epsilon hat" in the Kingma and Ba paper (in the formula just before 37 | Section 2.1), not the epsilon in Algorithm 1 of the paper. 38 | amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from 39 | the paper "On the Convergence of Adam and beyond". 40 | use_locking: If `True` use locks for update operations. 41 | name: Optional name for the operations created when applying gradients. 42 | Defaults to "Adam". @compatibility(eager) When eager execution is 43 | enabled, `learning_rate`, `beta1`, `beta2`, and `epsilon` can each be 44 | a callable that takes no arguments and returns the actual value to use. 45 | This can be useful for changing these values across different 46 | invocations of optimizer functions. @end_compatibility 47 | """ 48 | 49 | super(RAdam, self).__init__(use_locking, name) 50 | self._lr = learning_rate 51 | self._beta1 = beta1 52 | self._beta2 = beta2 53 | self._epsilon = epsilon 54 | self._amsgrad = amsgrad 55 | 56 | def _get_beta_accumulators(self): 57 | with ops.init_scope(): 58 | graph = ops.get_default_graph() 59 | return (self._get_non_slot_variable("beta1_power", graph=graph), 60 | self._get_non_slot_variable("beta2_power", graph=graph), 61 | ) 62 | 63 | def _get_niter(self): 64 | with ops.init_scope(): 65 | graph = ops.get_default_graph() 66 | return self._get_non_slot_variable("niter", graph=graph) 67 | 68 | def _create_slots(self, var_list): 69 | first_var = min(var_list, key=lambda x: x.name) 70 | self._create_non_slot_variable( 71 | initial_value=self._beta1, name="beta1_power", colocate_with=first_var) 72 | self._create_non_slot_variable( 73 | initial_value=self._beta2, name="beta2_power", colocate_with=first_var) 74 | self._create_non_slot_variable( 75 | initial_value=1, name="niter", colocate_with=first_var) 76 | for var in var_list: 77 | self._zeros_slot(var, 'm', self._name) 78 | self._zeros_slot(var, 'v', self._name) 79 | if self._amsgrad: 80 | for var in var_list: 81 | self._zeros_slot(var, 'vhat', self._name) 82 | 83 | def _prepare(self): 84 | learning_rate = self._call_if_callable(self._lr) 85 | beta1 = self._call_if_callable(self._beta1) 86 | beta2 = self._call_if_callable(self._beta2) 87 | epsilon = self._call_if_callable(self._epsilon) 88 | 89 | self._lr_t = ops.convert_to_tensor(learning_rate, name="learning_rate") 90 | self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") 91 | self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") 92 | self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") 93 | 94 | def _apply_dense_shared(self, grad, var): 95 | var_dtype = var.dtype.base_dtype 96 | beta1_power, beta2_power = self._get_beta_accumulators() 97 | beta1_power = math_ops.cast(beta1_power, var_dtype) 98 | beta2_power = math_ops.cast(beta2_power, var_dtype) 99 | niter = self._get_niter() 100 | niter = math_ops.cast(niter, var_dtype) 101 | lr_t = math_ops.cast(self._lr_t, var_dtype) 102 | beta1_t = math_ops.cast(self._beta1_t, var_dtype) 103 | beta2_t = math_ops.cast(self._beta2_t, var_dtype) 104 | epsilon_t = math_ops.cast(self._epsilon_t, var_dtype) 105 | 106 | sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 107 | sma_t = sma_inf - 2.0 * niter * beta2_power / (1.0 - beta2_power) 108 | 109 | m = self.get_slot(var, 'm') 110 | m_t = state_ops.assign(m, 111 | beta1_t * m + (1.0 - beta1_t) * grad, 112 | use_locking=self._use_locking) 113 | m_corr_t = m_t / (1.0 - beta1_power) 114 | 115 | v = self.get_slot(var, 'v') 116 | v_t = state_ops.assign(v, 117 | beta2_t * v + (1.0 - beta2_t) * math_ops.square(grad), 118 | use_locking=self._use_locking) 119 | 120 | if self._amsgrad: 121 | vhat = self.get_slot(var, 'vhat') 122 | vhat_t = state_ops.assign(vhat, 123 | math_ops.maximum(vhat, v_t), 124 | use_locking=self._use_locking) 125 | v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power) + epsilon_t) 126 | else: 127 | v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t) 128 | 129 | r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * 130 | (sma_t - 2.0) / (sma_inf - 2.0) * 131 | sma_inf / sma_t) 132 | 133 | var_t = tf.where(sma_t > 5.0, r_t * m_corr_t / v_corr_t, m_corr_t) 134 | 135 | var_update = state_ops.assign_sub(var, 136 | lr_t * var_t, 137 | use_locking=self._use_locking) 138 | 139 | updates = [var_update, m_t, v_t] 140 | if self._amsgrad: 141 | updates.append(vhat_t) 142 | return control_flow_ops.group(*updates) 143 | 144 | def _apply_dense(self, grad, var): 145 | return self._apply_dense_shared(grad, var) 146 | 147 | def _resource_apply_dense(self, grad, var): 148 | return self._apply_dense_shared(grad, var.handle) 149 | 150 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 151 | var_dtype = var.dtype.base_dtype 152 | beta1_power, beta2_power = self._get_beta_accumulators() 153 | beta1_power = math_ops.cast(beta1_power, var_dtype) 154 | beta2_power = math_ops.cast(beta2_power, var_dtype) 155 | niter = self._get_niter() 156 | niter = math_ops.cast(niter, var_dtype) 157 | lr_t = math_ops.cast(self._lr_t, var_dtype) 158 | beta1_t = math_ops.cast(self._beta1_t, var_dtype) 159 | beta2_t = math_ops.cast(self._beta2_t, var_dtype) 160 | epsilon_t = math_ops.cast(self._epsilon_t, var_dtype) 161 | 162 | sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 163 | sma_t = sma_inf - 2.0 * niter * beta2_power / (1.0 - beta2_power) 164 | 165 | m = self.get_slot(var, 'm') 166 | m_t = state_ops.assign(m, beta1_t * m, use_locking=self._use_locking) 167 | with ops.control_dependencies([m_t]): 168 | m_t = scatter_add(m, indices, grad * (1 - beta1_t)) 169 | m_corr_t = m_t / (1.0 - beta1_power) 170 | 171 | v = self.get_slot(var, 'v') 172 | v_t = state_ops.assign(v, beta2_t * v, use_locking=self._use_locking) 173 | with ops.control_dependencies([v_t]): 174 | v_t = scatter_add(v, indices, (1.0 - beta2_t) * math_ops.square(grad)) 175 | 176 | if self._amsgrad: 177 | vhat = self.get_slot(var, 'vhat') 178 | vhat_t = state_ops.assign(vhat, 179 | math_ops.maximum(vhat, v_t), 180 | use_locking=self._use_locking) 181 | v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power) + epsilon_t) 182 | else: 183 | v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t) 184 | 185 | r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * 186 | (sma_t - 2.0) / (sma_inf - 2.0) * 187 | sma_inf / sma_t) 188 | 189 | var_t = tf.where(sma_t > 5.0, r_t * m_corr_t / v_corr_t, m_corr_t) 190 | 191 | var_update = state_ops.assign_sub(var, 192 | lr_t * var_t, 193 | use_locking=self._use_locking) 194 | 195 | updates = [var_update, m_t, v_t] 196 | if self._amsgrad: 197 | updates.append(vhat_t) 198 | return control_flow_ops.group(*updates) 199 | 200 | def _apply_sparse(self, grad, var): 201 | return self._apply_sparse_shared( 202 | grad.values, 203 | var, 204 | grad.indices, 205 | lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda 206 | x, 207 | i, 208 | v, 209 | use_locking=self._use_locking)) 210 | 211 | def _resource_apply_sparse(self, grad, var, indices): 212 | return self._apply_sparse_shared(grad, var, indices, 213 | self._resource_scatter_add) 214 | 215 | def _resource_scatter_add(self, x, i, v): 216 | with ops.control_dependencies( 217 | [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 218 | return x.value() 219 | 220 | def _finish(self, update_ops, name_scope): 221 | # Update the power accumulators. 222 | with ops.control_dependencies(update_ops): 223 | beta1_power, beta2_power = self._get_beta_accumulators() 224 | niter = self._get_niter() 225 | with ops.colocate_with(beta1_power): 226 | update_beta1 = beta1_power.assign( 227 | beta1_power * self._beta1_t, use_locking=self._use_locking) 228 | update_beta2 = beta2_power.assign( 229 | beta2_power * self._beta2_t, use_locking=self._use_locking) 230 | update_niter = niter.assign( 231 | niter + 1, use_locking=self._use_locking) 232 | return control_flow_ops.group( 233 | *update_ops + [update_beta1, update_beta2, update_niter], name=name_scope) 234 | -------------------------------------------------------------------------------- /func_builders/model_fn_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # 9 | # 10 | 11 | 12 | 13 | import tensorflow as tf 14 | from utils import util 15 | from utils.radam import RAdam 16 | 17 | 18 | def model_fn_builder(config, model_sign="mention_proposal"): 19 | 20 | def mention_proposal_model_fn(features, labels, mode, params): 21 | """The `model_fn` for TPUEstimator.""" 22 | input_ids = features["flattened_input_ids"] 23 | input_mask = features["flattened_input_mask"] 24 | text_len = features["text_len"] 25 | speaker_ids = features["speaker_ids"] 26 | gold_starts = features["span_starts"] 27 | gold_ends = features["span_ends"] 28 | cluster_ids = features["cluster_ids"] 29 | sentence_map = features["sentence_map"] 30 | 31 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 32 | 33 | model = util.get_model(config, model_sign="mention_proposal") 34 | 35 | if config.use_tpu: 36 | def tpu_scaffold(): 37 | return tf.train.Scaffold() 38 | scaffold_fn = tpu_scaffold 39 | else: 40 | scaffold_fn = None 41 | 42 | if mode == tf.estimator.ModeKeys.TRAIN: 43 | tf.logging.info("****************************** tf.estimator.ModeKeys.TRAIN ******************************") 44 | tf.logging.info("********* Features *********") 45 | for name in sorted(features.keys()): 46 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 47 | 48 | instance = (input_ids, input_mask, sentence_map, text_len, speaker_ids, gold_starts, gold_ends, cluster_ids) 49 | total_loss, start_scores, end_scores, span_scores = model.get_mention_proposal_and_loss(instance, is_training) 50 | gold_start_sequence_labels, gold_end_sequence_labels, gold_span_sequence_labels = model.get_gold_mention_sequence_labels_from_pad_index(gold_starts, gold_ends, text_len) 51 | 52 | if config.use_tpu: 53 | optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08) 54 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 55 | train_op = optimizer.minimize(total_loss, tf.train.get_global_step()) 56 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 57 | mode=mode, 58 | loss=total_loss, 59 | train_op=train_op, 60 | scaffold_fn=scaffold_fn) 61 | else: 62 | optimizer = RAdam(learning_rate=config.learning_rate, epsilon=1e-8, beta1=0.9, beta2=0.999) 63 | train_op = optimizer.minimize(total_loss, tf.train.get_global_step()) 64 | 65 | train_logging_hook = tf.train.LoggingTensorHook({"loss": total_loss}, every_n_iter=1) 66 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 67 | mode=mode, 68 | loss=total_loss, 69 | train_op=train_op, 70 | scaffold_fn=scaffold_fn, 71 | training_hooks=[train_logging_hook]) 72 | 73 | elif mode == tf.estimator.ModeKeys.EVAL: 74 | tf.logging.info("****************************** tf.estimator.ModeKeys.EVAL ******************************") 75 | 76 | instance = (input_ids, input_mask, sentence_map, text_len, speaker_ids, gold_starts, gold_ends, cluster_ids) 77 | total_loss, start_scores, end_scores, span_scores = model.get_mention_proposal_and_loss(instance, is_training) 78 | total_loss, start_scores, end_scores, span_scores = model.get_mention_proposal_and_loss(instance, is_training) 79 | gold_start_sequence_labels, gold_end_sequence_labels, gold_span_sequence_labels = model.get_gold_mention_sequence_labels_from_pad_index(gold_starts, gold_ends, text_len) 80 | 81 | def metric_fn(start_scores, end_scores, span_scores, gold_span_label): 82 | start_scores = tf.reshape(start_scores, [-1, config.window_size]) 83 | end_scores = tf.reshape(end_scores, [-1, config.window_size]) 84 | start_scores = tf.tile(tf.expand_dims(start_scores, 2), [1, 1, config.window_size]) 85 | end_scores = tf.tile(tf.expand_dims(end_scores, 2), [1, 1, config.window_size]) 86 | sce_span_scores = (start_scores + end_scores + span_scores)/ 3 87 | pred_span_label = tf.cast(tf.reshape(tf.math.greater_equal(sce_span_scores, config.mention_threshold), [-1]), tf.bool) 88 | 89 | gold_span_label = tf.cast(tf.reshape(gold_span_sequence_labels, [-1]), tf.bool) 90 | 91 | return {"precision": tf.compat.v1.metrics.precision(gold_span_label, pred_span_label), 92 | "recall": tf.compat.v1.metrics.recall(gold_span_label, pred_span_label)} 93 | 94 | eval_metrics = (metric_fn, [start_scores, end_scores, span_scores]) 95 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 96 | mode=tf.estimator.ModeKeys.EVAL, 97 | loss=total_loss, 98 | eval_metrics=eval_metrics, 99 | scaffold_fn=scaffold_fn) 100 | 101 | elif mode == tf.estimator.ModeKeys.PREDICT: 102 | tf.logging.info("****************************** tf.estimator.ModeKeys.PREDICT ******************************") 103 | 104 | instance = (input_ids, input_mask, sentence_map, text_len, speaker_ids, gold_starts, gold_ends, cluster_ids) 105 | total_loss, start_scores, end_scores, span_scores = model.get_mention_proposal_and_loss(instance, is_training) 106 | gold_start_sequence_labels, gold_end_sequence_labels, gold_span_sequence_labels = model.get_gold_mention_sequence_labels_from_pad_index(gold_starts, gold_ends, text_len) 107 | predictions = { 108 | "total_loss": total_loss, 109 | "start_scores": start_scores, 110 | "start_gold": gold_starts, 111 | "end_gold": gold_ends, 112 | "end_scores": end_scores, 113 | "span_scores": span_scores 114 | } 115 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 116 | mode=tf.estimator.ModeKeys.PREDICT, 117 | predictions=predictions, 118 | scaffold_fn=scaffold_fn) 119 | else: 120 | raise ValueError("Please check the the mode ! ") 121 | 122 | return output_spec 123 | 124 | 125 | def corefqa_model_fn(features, labels, mode, params): 126 | 127 | """The `model_fn` for TPUEstimator.""" 128 | input_ids = features["flattened_input_ids"] 129 | input_mask = features["flattened_input_mask"] 130 | text_len = features["text_len"] 131 | speaker_ids = features["speaker_ids"] 132 | gold_starts = features["span_starts"] 133 | gold_ends = features["span_ends"] 134 | cluster_ids = features["cluster_ids"] 135 | sentence_map = features["sentence_map"] 136 | 137 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 138 | 139 | model = util.get_model(config, model_sign="corefqa") 140 | 141 | if config.use_tpu: 142 | tf.logging.info("****************************** Training on TPU ******************************") 143 | def tpu_scaffold(): 144 | return tf.train.Scaffold() 145 | scaffold_fn = tpu_scaffold 146 | else: 147 | scaffold_fn = None 148 | 149 | 150 | if mode == tf.estimator.ModeKeys.TRAIN: 151 | tf.logging.info("****************************** tf.estimator.ModeKeys.TRAIN ******************************") 152 | tf.logging.info("********* Features *********") 153 | for name in sorted(features.keys()): 154 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 155 | 156 | instance = (input_ids, input_mask, sentence_map, text_len, speaker_ids, gold_starts, gold_ends, cluster_ids) 157 | total_loss, (topk_mention_start_indices, topk_mention_end_indices), (forward_topc_mention_start_indices, forward_topc_mention_end_indices), top_mention_span_linking_scores = model.get_coreference_resolution_and_loss(instance, is_training, use_tpu=config.use_tpu) 158 | 159 | if config.use_tpu: 160 | optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08) 161 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 162 | train_op = optimizer.minimize(total_loss, tf.train.get_global_step()) 163 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 164 | mode=tf.estimator.ModeKeys.TRAIN, 165 | loss=total_loss, 166 | train_op=train_op, 167 | scaffold_fn=scaffold_fn) 168 | else: 169 | optimizer = RAdam(learning_rate=config.learning_rate, epsilon=1e-8, beta1=0.9, beta2=0.999) 170 | train_op = optimizer.minimize(total_loss, tf.train.get_global_step()) 171 | 172 | training_logging_hook = tf.train.LoggingTensorHook({"loss": total_loss}, every_n_iter=1) 173 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 174 | mode=tf.estimator.ModeKeys.TRAIN, 175 | loss=total_loss, 176 | train_op=train_op, 177 | scaffold_fn=scaffold_fn, 178 | training_hooks=[training_logging_hook]) 179 | 180 | 181 | elif mode == tf.estimator.ModeKeys.EVAL: 182 | tf.logging.info("****************************** tf.estimator.ModeKeys.EVAL ******************************") 183 | tf.logging.info("@@@@@ MERELY support tf.estimator.ModeKeys.PREDICT ! @@@@@") 184 | tf.logging.info("@@@@@ YOU can EVAL your checkpoints after the training process. @@@@@") 185 | tf.logging.info("****************************** tf.estimator.ModeKeys.EVAL ******************************") 186 | 187 | elif mode == tf.estimator.ModeKeys.PREDICT : 188 | tf.logging.info("****************************** tf.estimator.ModeKeys.PREDICT ******************************") 189 | 190 | instance = (input_ids, input_mask, sentence_map, text_len, speaker_ids, gold_starts, gold_ends, cluster_ids) 191 | total_loss, (topk_mention_start_indices, topk_mention_end_indices), (forward_topc_mention_start_indices, forward_topc_mention_end_indices), top_mention_span_linking_scores = model.get_coreference_resolution_and_loss(instance, True, use_tpu=config.use_tpu) 192 | 193 | top_antecedent = tf.math.argmax(top_mention_span_linking_scores, axis=-1) 194 | predictions = { 195 | "total_loss": total_loss, 196 | "topk_span_starts": topk_mention_start_indices, 197 | "topk_span_ends": topk_mention_end_indices, 198 | "top_antecedent_scores": top_mention_span_linking_scores, 199 | "top_antecedent": top_antecedent, 200 | "cluster_ids" : cluster_ids, 201 | "gold_starts": gold_starts, 202 | "gold_ends": gold_ends} 203 | 204 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.PREDICT, 205 | predictions=predictions, 206 | scaffold_fn=scaffold_fn) 207 | else: 208 | raise ValueError("Please check the the mode ! ") 209 | return output_spec 210 | 211 | 212 | if model_sign == "mention_proposal": 213 | return mention_proposal_model_fn 214 | elif model_sign == "corefqa": 215 | return corefqa_model_fn 216 | else: 217 | raise ValueError("Please check the model sign! Only support [mention_proposal] and [corefqa] .") 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /run/run_mention_proposal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | """ 6 | this file contains pre-training and testing the mention proposal model 7 | """ 8 | 9 | import os 10 | import math 11 | import random 12 | import logging 13 | import numpy as np 14 | import tensorflow as tf 15 | from data_utils.config_utils import ModelConfig 16 | from func_builders.model_fn_builder import model_fn_builder 17 | from func_builders.input_fn_builder import file_based_input_fn_builder 18 | from utils.metrics import mention_proposal_prediction 19 | 20 | tf.app.flags.DEFINE_string('f', '', 'kernel') 21 | flags = tf.app.flags 22 | 23 | flags.DEFINE_string("output_dir", "data", "The output directory of the model training.") 24 | flags.DEFINE_string("bert_config_file", "/home/uncased_L-2_H-128_A-2/config.json", "The config json file corresponding to the pre-trained BERT model.") 25 | flags.DEFINE_string("init_checkpoint", "/home/uncased_L-2_H-128_A-2/bert_model.ckpt", "Initial checkpoint (usually from a pre-trained BERT model).") 26 | flags.DEFINE_string("vocab_file", "/home/uncased_L-2_H-128_A-2/vocab.txt", "The vocabulary file that the BERT model was trained on.") 27 | flags.DEFINE_string("logfile_path", "/home/lixiaoya/spanbert_large_mention_proposal.log", "the path to the exported log file.") 28 | flags.DEFINE_integer("num_epochs", 20, "Total number of training epochs to perform.") 29 | flags.DEFINE_integer("keep_checkpoint_max", 30, "How many checkpoint models keep at most.") 30 | flags.DEFINE_integer("save_checkpoints_steps", 500, "Save checkpoint every X updates steps.") 31 | 32 | 33 | flags.DEFINE_string("train_file", "/home/lixiaoya/train.english.tfrecord", "TFRecord file for training. E.g., train.english.tfrecord") 34 | flags.DEFINE_string("dev_file", "/home/lixiaoya/dev.english.tfrecord", "TFRecord file for validating. E.g., dev.english.tfrecord") 35 | flags.DEFINE_string("test_file", "/home/lixiaoya/test.english.tfrecord", "TFRecord file for testing. E.g., test.english.tfrecord") 36 | 37 | 38 | flags.DEFINE_bool("do_train", True, "Whether to train a model.") 39 | flags.DEFINE_bool("do_eval", False, "whether to do evaluation: evaluation is done on a set of trained checkpoints, the checkpoint with the best score on the dev set will be selected.") 40 | flags.DEFINE_bool("do_predict", False, "Whether to test (only) one trained model.") 41 | flags.DEFINE_string("eval_checkpoint", "/home/lixiaoya/mention_proposal_output_dir/bert_model.ckpt", "[Optional] The saved checkpoint for evaluation (usually after the training process).") 42 | flags.DEFINE_integer("iterations_per_loop", 1000, "How many steps to make in each estimator call.") 43 | 44 | 45 | flags.DEFINE_float("learning_rate", 3e-5, "The initial learning rate for Adam.") 46 | flags.DEFINE_float("dropout_rate", 0.3, "Dropout rate for the training process.") 47 | flags.DEFINE_float("mention_threshold", 0.5, "The threshold for determining whether the span is a mention.") 48 | flags.DEFINE_integer("hidden_size", 128, "The size of hidden layers for the pre-trained model.") 49 | flags.DEFINE_integer("num_docs", 5604, "[Optional] The number of documents in the training files. Only need to change when conduct experiments on the small test sets.") 50 | flags.DEFINE_integer("window_size", 384, "The number of sliding window size. Each document is split into a set of subdocuments with length set to window_size.") 51 | flags.DEFINE_integer("num_window", 5, "The max number of windows for one document. This is used for fitting a document into fix shape for TF computation. \ 52 | If a document is longer than num_window*window_size, the exceeding part will be abandoned. This only affects training and does not affect test, since the all \ 53 | docs in the test set is shorter than num_window*window_size") 54 | flags.DEFINE_integer("max_num_mention", 30, "The max number of mentions in one document.") 55 | flags.DEFINE_bool("start_end_share", False, "Whether only to use [start, end] embedding to calculate the start/end scores.") 56 | flags.DEFINE_float("loss_start_ratio", 0.3, "As described in the paper, the loss for a span being a mention is -loss_start_ratio* log p(the start of the given span is a start).") 57 | flags.DEFINE_float("loss_end_ratio", 0.3, "As described in the paper, the loss for a span being a mention is -loss_end_ratio* log p(the end of the given span is a end).") 58 | flags.DEFINE_float("loss_span_ratio", 0.4, "As described in the paper, the loss for a span being a mention is -loss_span_ratio* log p(the start and the end forms a span).") 59 | 60 | 61 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 62 | flags.DEFINE_string("tpu_name", None, "The Cloud TPU to use for training. This should be either the name used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.") 63 | flags.DEFINE_string("tpu_zone", None, "[Optional] GCE zone where the Cloud TPU is located in. If not specified, we will attempt to automatically detect the GCE project from metadata.") 64 | flags.DEFINE_string("gcp_project", None, "[Optional] Project name for the Cloud TPU-enabled project. If not specified, we will attempt to automatically detect the GCE project from metadata.") 65 | flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 66 | flags.DEFINE_integer("num_tpu_cores", 1, "[Optional] Only used if `use_tpu` is True. Total number of TPU cores to use.") 67 | flags.DEFINE_integer("seed", 2333, "[Optional] Random seed for initialization.") 68 | FLAGS = tf.flags.FLAGS 69 | 70 | 71 | 72 | format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 73 | logging.basicConfig(format=format, filename=FLAGS.logfile_path, level=logging.INFO) 74 | logger = logging.getLogger(__name__) 75 | logger.setLevel(logging.INFO) 76 | 77 | 78 | 79 | def main(_): 80 | 81 | tf.logging.set_verbosity(tf.logging.INFO) 82 | num_train_steps = FLAGS.num_docs * FLAGS.num_epochs 83 | # num_train_steps = 100 84 | keep_chceckpoint_max = max(math.ceil(num_train_steps / FLAGS.save_checkpoints_steps), FLAGS.keep_checkpoint_max) 85 | 86 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 87 | raise ValueError("At least one of `do_train`, `do_eval` or `do_predict' must be True.") 88 | 89 | tf.gfile.MakeDirs(FLAGS.output_dir) 90 | tpu_cluster_resolver = None 91 | if FLAGS.use_tpu and FLAGS.tpu_name: 92 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 93 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 94 | tf.config.experimental_connect_to_cluster(tpu_cluster_resolver) 95 | tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver) 96 | 97 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 98 | run_config = tf.contrib.tpu.RunConfig( 99 | cluster=tpu_cluster_resolver, 100 | master=FLAGS.master, 101 | # evaluation_master=FLAGS.master, 102 | model_dir=FLAGS.output_dir, 103 | keep_checkpoint_max = keep_chceckpoint_max, 104 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 105 | # session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True), 106 | tpu_config=tf.contrib.tpu.TPUConfig( 107 | iterations_per_loop=FLAGS.iterations_per_loop, 108 | num_shards=FLAGS.num_tpu_cores, 109 | per_host_input_for_training=is_per_host)) 110 | 111 | 112 | model_config = ModelConfig(FLAGS, FLAGS.output_dir) 113 | model_config.logging_configs() 114 | 115 | model_fn = model_fn_builder(model_config, model_sign="mention_proposal") 116 | estimator = tf.contrib.tpu.TPUEstimator( 117 | use_tpu=FLAGS.use_tpu, 118 | # eval_on_tpu=FLAGS.use_tpu, 119 | warm_start_from=tf.estimator.WarmStartSettings(FLAGS.init_checkpoint, 120 | vars_to_warm_start="bert*"), 121 | model_fn=model_fn, 122 | config=run_config, 123 | train_batch_size=1, 124 | predict_batch_size=1) 125 | 126 | 127 | if FLAGS.do_train: 128 | estimator.train(input_fn=file_based_input_fn_builder(model_config.train_file, num_window=model_config.num_window, 129 | window_size=model_config.window_size, max_num_mention=model_config.max_num_mention, is_training=True, drop_remainder=True), max_steps=num_train_steps) 130 | 131 | 132 | if FLAGS.do_eval: 133 | # doing evaluation on a set of trained checkpoints, the checkpoint with the best score on the dev set will be selected. 134 | best_dev_f1, best_dev_prec, best_dev_rec, test_f1_when_dev_best, test_prec_when_dev_best, test_rec_when_dev_best = 0, 0, 0, 0, 0, 0 135 | best_ckpt_path = "" 136 | checkpoints_iterator = [os.path.join(FLAGS.eval_dir, "model.ckpt-{}".format(str(int(ckpt_idx)))) for ckpt_idx in range(0, num_train_steps, FLAGS.save_checkpoints_steps)] 137 | for checkpoint_path in checkpoints_iterator[1:]: 138 | eval_dev_result = estimator.evaluate(input_fn=file_based_input_fn_builder(FLAGS.dev_file, num_window=FLAGS.num_window, 139 | window_size=FLAGS.window_size, max_num_mention=FLAGS.max_num_mention, is_training=False, drop_remainder=False), 140 | steps=698, checkpoint_path=checkpoint_path) 141 | dev_f1 = 2*eval_dev_result["precision"] * eval_dev_result["recall"] / (eval_dev_result["precision"] + eval_dev_result["recall"]+1e-10) 142 | tf.logging.info("***** Current ckpt path is ***** : {}".format(checkpoint_path)) 143 | tf.logging.info("***** EVAL ON DEV SET *****") 144 | tf.logging.info("***** [DEV EVAL] ***** : precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(eval_dev_result["precision"], eval_dev_result["recall"], dev_f1)) 145 | if dev_f1 > best_dev_f1: 146 | best_dev_f1, best_dev_prec, best_dev_rec = dev_f1, eval_dev_result["precision"], eval_dev_result["recall"] 147 | best_ckpt_path = checkpoint_path 148 | eval_test_result = estimator.evaluate(input_fn=file_based_input_fn_builder(FLAGS.test_file, 149 | num_window=FLAGS.num_window, window_size=FLAGS.window_size, max_num_mention=FLAGS.max_num_mention, 150 | is_training=False, drop_remainder=False),steps=698, checkpoint_path=checkpoint_path) 151 | test_f1 = 2*eval_test_result["precision"] * eval_test_result["recall"] / (eval_test_result["precision"] + eval_test_result["recall"]+1e-10) 152 | test_f1_when_dev_best, test_prec_when_dev_best, test_rec_when_dev_best = test_f1, eval_test_result["precision"], eval_test_result["recall"] 153 | tf.logging.info("***** EVAL ON TEST SET *****") 154 | tf.logging.info("***** [TEST EVAL] ***** : precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(eval_test_result["precision"], eval_test_result["recall"], test_f1)) 155 | tf.logging.info("*"*20) 156 | tf.logging.info("- @@@@@ the path to the BEST DEV result is : {}".format(best_ckpt_path)) 157 | tf.logging.info("- @@@@@ BEST DEV F1 : {:.4f}, Precision : {:.4f}, Recall : {:.4f},".format(best_dev_f1, best_dev_prec, best_dev_rec)) 158 | tf.logging.info("- @@@@@ TEST when DEV best F1 : {:.4f}, Precision : {:.4f}, Recall : {:.4f},".format(test_f1_when_dev_best, test_prec_when_dev_best, test_rec_when_dev_best)) 159 | tf.logging.info("- @@@@@ mention_proposal_only_concate {}".format(FLAGS.mention_proposal_only_concate)) 160 | 161 | 162 | if FLAGS.do_predict: 163 | tp, fp, fn = 0, 0, 0 164 | epsilon = 1e-10 165 | for doc_output in estimator.predict(file_based_input_fn_builder(FLAGS.test_file, 166 | num_window=FLAGS.num_window, window_size=FLAGS.window_size, max_num_mention=FLAGS.max_num_mention, 167 | is_training=False, drop_remainder=False), checkpoint_path=FLAGS.eval_checkpoint, yield_single_examples=False): 168 | # iterate over each doc for evaluation 169 | pred_span_label, gold_span_label = mention_proposal_prediction(FLAGS, doc_output) 170 | 171 | tem_tp = np.logical_and(pred_span_label, gold_span_label).sum() 172 | tem_fp = np.logical_and(pred_span_label, np.logical_not(gold_span_label)).sum() 173 | tem_fn = np.logical_and(np.logical_not(pred_span_label), gold_span_label).sum() 174 | 175 | tp += tem_tp 176 | fp += tem_fp 177 | fn += tem_fn 178 | 179 | p = tp / (tp+fp+epsilon) 180 | r = tp / (tp+fn+epsilon) 181 | f = 2*p*r/(p+r+epsilon) 182 | tf.logging.info("Average precision: {:.4f}, Average recall: {:.4f}, Average F1 {:.4f}".format(p, r, f)) 183 | 184 | 185 | 186 | if __name__ == '__main__': 187 | # set the random seed. 188 | random.seed(FLAGS.seed) 189 | np.random.seed(FLAGS.seed) 190 | tf.set_random_seed(FLAGS.seed) 191 | # start train/evaluate the model. 192 | tf.app.run() 193 | 194 | 195 | 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /run/run_corefqa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | this file contains training and testing the CorefQA model. 6 | """ 7 | 8 | import os 9 | import math 10 | import logging 11 | import random 12 | import numpy as np 13 | import tensorflow as tf 14 | from utils import util 15 | from utils import metrics 16 | from data_utils.config_utils import ModelConfig 17 | from func_builders.model_fn_builder import model_fn_builder 18 | from func_builders.input_fn_builder import file_based_input_fn_builder 19 | 20 | 21 | tf.app.flags.DEFINE_string('f', '', 'kernel') 22 | flags = tf.app.flags 23 | 24 | flags.DEFINE_string("output_dir", "data", "The output directory.") 25 | flags.DEFINE_string("bert_config_file", "/home/uncased_L-2_H-128_A-2/config.json", "The config json file corresponding to the pre-trained BERT model.") 26 | flags.DEFINE_string("init_checkpoint", "/home/uncased_L-2_H-128_A-2/bert_model.ckpt", "Initial checkpoint (usually from a pre-trained BERT model).") 27 | flags.DEFINE_string("vocab_file", "/home/uncased_L-2_H-128_A-2/vocab.txt", "The vocabulary file that the BERT model was trained on.") 28 | flags.DEFINE_string("logfile_path", "/home/lixiaoya/spanbert_large_mention_proposal.log", "the path to the exported log file.") 29 | flags.DEFINE_integer("num_epochs", 20, "Total number of training epochs to perform.") 30 | flags.DEFINE_integer("keep_checkpoint_max", 30, "How many checkpoint models keep at most.") 31 | flags.DEFINE_integer("save_checkpoints_steps", 500, "Save checkpoint every X updates steps.") 32 | 33 | 34 | flags.DEFINE_string("train_file", "/home/lixiaoya/train.english.tfrecord", "TFRecord file for training. E.g., train.english.tfrecord") 35 | flags.DEFINE_string("dev_file", "/home/lixiaoya/dev.english.tfrecord", "TFRecord file for validating. E.g., dev.english.tfrecord") 36 | flags.DEFINE_string("test_file", "/home/lixiaoya/test.english.tfrecord", "TFRecord file for testing. E.g., test.english.tfrecord") 37 | 38 | 39 | flags.DEFINE_bool("do_train", True, "Whether to train a model.") 40 | flags.DEFINE_bool("do_eval", False, "Whether to do evaluation: evaluation is done on a set of trained checkpoints, the model will select the best one on the dev set, and report result on the test set") 41 | flags.DEFINE_bool("do_predict", False, "Whether to test (only) one trained model.") 42 | flags.DEFINE_string("eval_checkpoint", "/home/lixiaoya/mention_proposal_output_dir/bert_model.ckpt", "[Optional] The saved checkpoint for evaluation (usually after the training process).") 43 | flags.DEFINE_integer("iterations_per_loop", 1000, "How many steps to make in each estimator call.") 44 | 45 | 46 | flags.DEFINE_float("learning_rate", 3e-5, "The initial learning rate for Adam.") 47 | flags.DEFINE_float("dropout_rate", 0.3, "Dropout rate for the training process.") 48 | flags.DEFINE_float("mention_threshold", 0.5, "The threshold for determining whether the span is a mention.") 49 | flags.DEFINE_integer("hidden_size", 128, "The size of hidden layers for the pre-trained model.") 50 | flags.DEFINE_integer("num_docs", 5604, "[Optional] The number of documents in the training files. Only need to change when conduct experiments on the small test sets.") 51 | flags.DEFINE_integer("window_size", 384, "The number of sliding window size. Each document is split into a set of subdocuments with length set to window_size.") 52 | flags.DEFINE_integer("num_window", 5, "The max number of windows for one document. This is used for fitting a document into fix shape for TF computation. \ 53 | If a document is longer than num_window*window_size, the exceeding part will be abandoned. This only affects training and does not affect test, since the all \ 54 | docs in the test set is shorter than num_window*window_size") 55 | flags.DEFINE_integer("max_num_mention", 30, "The max number of mentions in one document.") 56 | flags.DEFINE_bool("start_end_share", False, "Whether only to use [start, end] embedding to calculate the start/end scores.") 57 | flags.DEFINE_integer("max_span_width", 5, "The max length of a mention.") 58 | flags.DEFINE_integer("max_candidate_mentions", 30, "The number of candidate mentions.") 59 | flags.DEFINE_float("top_span_ratio", 0.2, "The ratio of.") 60 | flags.DEFINE_integer("max_top_antecedents", 30, "The number of top_antecedents candidate mentions.") 61 | flags.DEFINE_integer("max_query_len", 150, ".") 62 | flags.DEFINE_integer("max_context_len", 150, ".") 63 | flags.DEFINE_bool("sec_qa_mention_score", False, "Whether to use TPU or GPU/CPU.") 64 | 65 | 66 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 67 | flags.DEFINE_string("tpu_name", None, "The Cloud TPU to use for training. This should be either the name used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.") 68 | flags.DEFINE_string("tpu_zone", None, "[Optional] GCE zone where the Cloud TPU is located in. If not specified, we will attempt to automatically detect the GCE project from metadata.") 69 | flags.DEFINE_string("gcp_project", None, "[Optional] Project name for the Cloud TPU-enabled project. If not specified, we will attempt to automatically detect the GCE project from metadata.") 70 | flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 71 | flags.DEFINE_integer("num_tpu_cores", 1, "[Optional] Only used if `use_tpu` is True. Total number of TPU cores to use.") 72 | flags.DEFINE_integer("seed", 2333, "[Optional] Random seed for initialization.") 73 | 74 | 75 | FLAGS = tf.flags.FLAGS 76 | 77 | 78 | format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 79 | logging.basicConfig(format=format, filename=FLAGS.logfile_path, level=logging.INFO) 80 | logger = logging.getLogger(__name__) 81 | logger.setLevel(logging.INFO) 82 | 83 | 84 | 85 | def main(_): 86 | 87 | tf.logging.set_verbosity(tf.logging.INFO) 88 | num_train_steps = FLAGS.num_docs * FLAGS.num_epochs 89 | 90 | 91 | keep_chceckpoint_max = max(math.ceil(num_train_steps / FLAGS.save_checkpoints_steps), FLAGS.keep_checkpoint_max) 92 | 93 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 94 | raise ValueError("At least one of `do_train`, `do_eval` or `do_predict' must be True.") 95 | 96 | tf.gfile.MakeDirs(FLAGS.output_dir) 97 | tpu_cluster_resolver = None 98 | if FLAGS.use_tpu and FLAGS.tpu_name: 99 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 100 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 101 | tf.config.experimental_connect_to_cluster(tpu_cluster_resolver) 102 | tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver) 103 | 104 | 105 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 106 | run_config = tf.contrib.tpu.RunConfig( 107 | cluster=tpu_cluster_resolver, 108 | master=FLAGS.master, 109 | model_dir=FLAGS.output_dir, 110 | evaluation_master=FLAGS.master, 111 | keep_checkpoint_max = keep_chceckpoint_max, 112 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 113 | session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True), 114 | tpu_config=tf.contrib.tpu.TPUConfig( 115 | iterations_per_loop=FLAGS.iterations_per_loop, 116 | num_shards=FLAGS.num_tpu_cores, 117 | per_host_input_for_training=is_per_host)) 118 | 119 | 120 | model_config = ModelConfig(FLAGS, FLAGS.output_dir) 121 | model_config.logging_configs() 122 | 123 | 124 | model_fn = model_fn_builder(model_config, model_sign="corefqa") 125 | estimator = tf.contrib.tpu.TPUEstimator( 126 | use_tpu=FLAGS.use_tpu, 127 | eval_on_tpu=FLAGS.use_tpu, 128 | warm_start_from=tf.estimator.WarmStartSettings(FLAGS.init_checkpoint, 129 | vars_to_warm_start="bert*"), 130 | model_fn=model_fn, 131 | config=run_config, 132 | train_batch_size=1, 133 | eval_batch_size=1, 134 | predict_batch_size=1) 135 | 136 | 137 | if FLAGS.do_train: 138 | estimator.train(input_fn=file_based_input_fn_builder(FLAGS.train_file, num_window=FLAGS.num_window, 139 | window_size=FLAGS.window_size, max_num_mention=FLAGS.max_num_mention, is_training=True, drop_remainder=True), 140 | max_steps=num_train_steps) 141 | 142 | 143 | if FLAGS.do_eval: 144 | best_dev_f1, best_dev_prec, best_dev_rec, test_f1_when_dev_best, test_prec_when_dev_best, test_rec_when_dev_best = 0, 0, 0, 0, 0, 0 145 | best_ckpt_path = "" 146 | checkpoints_iterator = [os.path.join(FLAGS.eval_dir, "model.ckpt-{}".format(str(int(ckpt_idx)))) for ckpt_idx in range(0, num_train_steps+1, FLAGS.save_checkpoints_steps)] 147 | model = util.get_model(model_config, model_sign="corefqa") 148 | for checkpoint_path in checkpoints_iterator[1:]: 149 | dev_coref_evaluator = metrics.CorefEvaluator() 150 | for result in estimator.predict(file_based_input_fn_builder(FLAGS.dev_file, num_window=FLAGS.num_window, 151 | window_size=FLAGS.window_size, max_num_mention=FLAGS.max_num_mention, is_training=False, drop_remainder=False), 152 | steps=698, checkpoint_path=checkpoint_path, yield_single_examples=False): 153 | 154 | predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold = model.evaluate(result["topk_span_starts"], result["topk_span_ends"], result["top_antecedent"], 155 | result["cluster_ids"], result["gold_starts"], result["gold_ends"]) 156 | dev_coref_evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) 157 | dev_prec, dev_rec, dev_f1 = dev_coref_evaluator.get_prf() 158 | tf.logging.info("***** Current ckpt path is ***** : {}".format(checkpoint_path)) 159 | tf.logging.info("***** EVAL ON DEV SET *****") 160 | tf.logging.info("***** [DEV EVAL] ***** : precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(dev_prec, dev_rec, dev_f1)) 161 | if dev_f1 > best_dev_f1: 162 | best_ckpt_path = checkpoint_path 163 | best_dev_f1 = dev_f1 164 | best_dev_prec = dev_prec 165 | best_dev_rec = dev_rec 166 | test_coref_evaluator = metrics.CorefEvaluator() 167 | for result in estimator.predict(file_based_input_fn_builder(FLAGS.test_file, 168 | num_window=FLAGS.num_window, window_size=FLAGS.window_size, max_num_mention=FLAGS.max_num_mention, 169 | is_training=False, drop_remainder=False),steps=698, checkpoint_path=checkpoint_path, yield_single_examples=False): 170 | predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold = model.evaluate(result["topk_span_starts"], result["topk_span_ends"], result["top_antecedent"], 171 | result["cluster_ids"], result["gold_starts"], result["gold_ends"]) 172 | test_coref_evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) 173 | 174 | test_pre, test_rec, test_f1 = test_coref_evaluator.get_prf() 175 | test_f1_when_dev_best, test_prec_when_dev_best, test_rec_when_dev_best = test_f1, test_pre, test_rec 176 | tf.logging.info("***** EVAL ON TEST SET *****") 177 | tf.logging.info("***** [TEST EVAL] ***** : precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(test_pre, test_rec, test_f1)) 178 | 179 | tf.logging.info("*"*20) 180 | tf.logging.info("- @@@@@ the path to the BEST DEV result is : {}".format(best_ckpt_path)) 181 | tf.logging.info("- @@@@@ BEST DEV F1 : {:.4f}, Precision : {:.4f}, Recall : {:.4f},".format(best_dev_f1, best_dev_prec, best_dev_rec)) 182 | tf.logging.info("- @@@@@ TEST when DEV best F1 : {:.4f}, Precision : {:.4f}, Recall : {:.4f},".format(test_f1_when_dev_best, test_prec_when_dev_best, test_rec_when_dev_best)) 183 | 184 | 185 | if FLAGS.do_predict: 186 | coref_evaluator = metrics.CorefEvaluator() 187 | model = util.get_model(model_config, model_sign="corefqa") 188 | for result in estimator.predict(file_based_input_fn_builder(FLAGS.test_file, 189 | num_window=FLAGS.num_window, window_size=FLAGS.window_size, max_num_mention=FLAGS.max_num_mention, 190 | is_training=False, drop_remainder=False),steps=698, checkpoint_path=checkpoint_path, yield_single_examples=False): 191 | 192 | predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold = model.evaluate(result["topk_span_starts"], result["topk_span_ends"], 193 | result["top_antecedent"], result["cluster_ids"], result["gold_starts"], result["gold_ends"]) 194 | coref_evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) 195 | 196 | p, r, f = coref_evaluator.get_prf() 197 | tf.logging.info("Average precision: {:.4f}, Average recall: {:.4f}, Average F1 {:.4f}".format(p, r, f)) 198 | 199 | 200 | 201 | if __name__ == '__main__': 202 | # set the random seed. 203 | random.seed(FLAGS.seed) 204 | np.random.seed(FLAGS.seed) 205 | tf.set_random_seed(FLAGS.seed) 206 | # start train/evaluate the model. 207 | tf.app.run() 208 | 209 | 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CorefQA: Coreference Resolution as Query-based Span Prediction 2 | 3 | The repository contains the code of the recent research advances in [Shannon.AI](http://www.shannonai.com). Please post github issues or email xiaoya_li@shannonai.com for relevant questions. 4 | 5 | 6 | 7 | **CorefQA: Coreference Resolution as Query-based Span Prediction**
8 | Wei Wu, Fei Wang, Arianna Yuan, Fei Wu and Jiwei Li
9 | In ACL 2020. [paper](https://arxiv.org/abs/1911.01746)
10 | If you find this repo helpful, please cite the following: 11 | ```latex 12 | @article{wu2019coreference, 13 | title={Coreference Resolution as Query-based Span Prediction}, 14 | author={Wu, Wei and Wang, Fei and Yuan, Arianna and Wu, Fei and Li, Jiwei}, 15 | journal={arXiv preprint arXiv:1911.01746}, 16 | year={2019} 17 | } 18 | ``` 19 | 20 | 21 | ## Contents 22 | - [Overview](#overview) 23 | - [Hardware Requirements](#hardware-requirements) 24 | - [Install Package Dependencies](#install-package-dependencies) 25 | - [Data Preprocess](#data-preprocess) 26 | - [Download Pretrained MLM](#download-pretrained-mlm) 27 | - [Training](#training) 28 | - [Finetune the SpanBERT Model on the Combination of Squad and Quoref Datasets](#finetune-the-spanbert-model-on-the-combination-of-squad-and-quoref-datasets) 29 | - [Train the CorefQA Model on the CoNLL-2012 Coreference Resolution Task](#train-the-corefqa-model-on-the-conll-2012-coreference-resolution-task) 30 | - [Evaluation and Prediction](#evaluation-and-prediction) 31 | - [Download the Final CorefQA Model](#download-the-final-corefqa-model) 32 | - [Descriptions of Directories](#descriptions-of-directories) 33 | - [Acknowledgement](#acknowledgement) 34 | - [Useful Materials](#useful-materials) 35 | - [Contact](#contact) 36 | 37 | 38 | ## Overview 39 | The model introduces +3.5 (83.1) F1 performance boost over previous SOTA coreference models on the CoNLL benchmark. The current codebase is written in Tensorflow. We plan to release the PyTorch version soon. The current code version only supports training on TPUs and testing on GPUs (due to the annoying features of TF and TPUs). You thus have to bear the trouble of transferring all saved checkpoints from TPUs to GPUs for evaluation (we will fix this soon). Please follow the parameter setting in the log directionary to reproduce the performance. 40 | 41 | 42 | | Model | F1 (%) | 43 | | -------------- |:------:| 44 | | Previous SOTA (Joshi et al., 2019a) | 79.6 | 45 | | CorefQA + SpanBERT-large | 83.1 | 46 | 47 | 48 | ## Hardware Requirements 49 | TPU for training: Cloud TPU v3-8 device (128G memory) with Tensorflow 1.15 Python 3.5 50 | 51 | GPU for evaluation: with CUDA 10.0 Tensorflow 1.15 Python 3.5 52 | 53 | ## Install Package Dependencies 54 | 55 | ```shell 56 | $ python3 -m pip install --user virtualenv 57 | $ virtualenv --python=python3.5 ~/corefqa_venv 58 | $ source ~/corefqa_venv/bin/activate 59 | $ cd CorefQA 60 | $ pip install -r requirements.txt 61 | # If you are using TPU, please run the following commands: 62 | $ pip install --upgrade google-api-python-client 63 | $ pip install --upgrade oauth2client 64 | ``` 65 | 66 | ## Data Preprocess 67 | 68 | 1) Download the offical released [Ontonotes 5.0 (LDC2013T19)](https://catalog.ldc.upenn.edu/LDC2013T19).
69 | 2) Preprocess Ontonotes5 annotations files for the CoNLL-2012 coreference resolution task.
70 | Run the command with **Python 2** 71 | `bash ./scripts/data/preprocess_ontonotes_annfiles.sh `
72 | and it will create `{train/dev/test}.{language}.v4_gold_conll` files in the directory ``.
73 | `` can be `english`, `arabic` or `chinese`. In this paper, we set `` to `english`.
74 | If you want to use **Python 3**, please refer to the 75 | [guideline](https://github.com/huggingface/neuralcoref/blob/master/neuralcoref/train/training.md#get-the-data)
76 | 3) Generate TFRecord files for experiments.
77 | Run the command with **Python 3** `bash ./scripts/data/generate_tfrecord_dataset.sh ` 78 | and it will create `{train/dev/test}.overlap.corefqa.{language}.tfrecord` files in the directory ``.
79 | 80 | ## Download Pretrained MLM 81 | In our experiments, we used pretrained mask language models to initialize the mention_proposal and corefqa models. 82 | 83 | 1) Download the pretrained models.
84 | Run `bash ./scripts/data/download_pretrained_mlm.sh ` to download and unzip the pretrained mlm models.
85 | `` shoule take the value of `[bert_base, bert_large, spanbert_base, spanbert_large, bert_tiny]`. 86 | 87 | - `bert_base, bert_large, spanbert_base, spanbert_large` are trained with a cased(upppercase and lowercase tokens) vocabulary. Should use the cased train/dev/test coreference datasets. 88 | - `bert_tiny` is trained with a uncased(lowercase tokens) vocabulary. We use the tinyBERT model for fast debugging. Should use the uncased train/dev/test coreference datasets.
89 | 90 | 2) Transform SpanBERT from `Pytorch` to `Tensorflow`.
91 | 92 | After downloading `bert_` to `_tf_dir>` and `spanbert_` to `_pytorch_dir>`, you can start transforming the SpanBERT model to Tensorflow and the model is saved to the directory ``. `` should take the value of `[base, large]`.
93 | 94 | We need to tranform the SpanBERT checkpoints from Pytorch to TF because the offical relased models were trained with Pytorch. 95 | Run `bash ./scripts/data/transform_ckpt_pytorch_to_tf.sh _pytorch_dir> _tf_dir> ` 96 | and the `` in TF will be saved in ``. 97 | 98 | - `` should take the value of `[spanbert_base, spanbert_large]`. 99 | - `` indicates that the `bert_model.ckpt` in the `_tf_dir>` should have the same scale(base, large) to the `bert_model.bin` in `_pytorch_dir>`. 100 | 101 | 102 | ## Training 103 | 104 | Follow the pipeline described in the paper, you need to:
105 | 1) load a pretrained SpanBERT model.
106 | 2) finetune the SpanBERT model on the combination of Squad and Quoref datasets.
107 | 3) pretrain the mention proposal model on the coref dataset.
108 | 4) jointly train the mention proposal model and the mention linking model.
109 | 110 | **Notice:** We provide the options of both pretraining these models yourself and loading the our pretrained models for 2) and 3).
111 | 112 | ### Finetune the SpanBERT Model on the Combination of Squad and Quoref Datasets 113 | We finetune the SpanBERT model on the [SQuAD 2.0](https://rajpurkar.github.io/SQuAD-explorer/) and [Quoref](https://allennlp.org/quoref) QA tasks for data augmentation before the coreference resolution task. 114 | 115 | 1. You can directly download the pretrained model on the datasets. 116 | Download Data Augmentation Models on Squad and Quoref [link](https://www.dropbox.com/s/lqjc6kfe0w34jt0/finetune_spanbert_large_squad2.tar.gz?dl=0)
117 | Run `./scripts/data/download_squad2_finetune_model.sh ` to download finetuned SpanBERT on SQuAD2.0.
118 | The `` should take the value of `[base, large]`.
119 | The `` is the path to save finetuned spanbert on SQuAD2.0 datasets.
120 | 121 | 122 | 2. Or start to finetune the SpanBERT model on QA tasks yourself. 123 | - Download SQuAD 2.0 [train](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json) and [dev](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json) sets. 124 | - Download Quoref [train and dev](https://quoref-dataset.s3-us-west-2.amazonaws.com/train_and_dev/quoref-train-dev-v0.1.zip) sets. 125 | - Finetune the SpanBERT model on Google Could V3-8 TPU. 126 | 127 | For Squad 2.0, Run the script in [./script/model/squad_tpu.sh](https://github.com/ShannonAI/CorefQA/blob/master/scripts/models/squad_tpu.sh) 128 | ```bash 129 | 130 | REPO_PATH=/home/shannon/coref-tf 131 | export TPU_NAME=tf-tpu 132 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 133 | SQUAD_DIR=gs://qa_tasks/squad2 134 | BERT_DIR=gs://pretrained_mlm_checkpoint/spanbert_large_tf 135 | OUTPUT_DIR=gs://corefqa_output_squad/spanbert_large_squad2_2e-5 136 | 137 | python3 ${REPO_PATH}/run/run_squad.py \ 138 | --vocab_file=$BERT_DIR/vocab.txt \ 139 | --bert_config_file=$BERT_DIR/bert_config.json \ 140 | --init_checkpoint=$BERT_DIR/bert_model.ckpt \ 141 | --do_train=True \ 142 | --train_file=$SQUAD_DIR/train-v2.0.json \ 143 | --do_predict=True \ 144 | --predict_file=$SQUAD_DIR/dev-v2.0.json \ 145 | --train_batch_size=8 \ 146 | --learning_rate=2e-5 \ 147 | --num_train_epochs=4.0 \ 148 | --max_seq_length=384 \ 149 | --do_lower_case=False \ 150 | --doc_stride=128 \ 151 | --output_dir=${OUTPUT_DIR} \ 152 | --use_tpu=True \ 153 | --tpu_name=$TPU_NAME \ 154 | --version_2_with_negative=True 155 | ``` 156 | After getting the best model (choose based on the performance on dev set) on `SQuAD2.0`, you should start finetuning the saved model on `Quoref`.
157 | 158 | Run the script in [./script/model/quoref_tpu.sh](https://github.com/ShannonAI/CorefQA/blob/master/scripts/models/quoref_tpu.sh) 159 | ```bash 160 | 161 | REPO_PATH=/home/shannon/coref-tf 162 | export TPU_NAME=tf-tpu 163 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 164 | QUOREF_DIR=gs://qa_tasks/quoref 165 | BERT_DIR=gs://corefqa_output_squad/panbert_large_squad2_2e-5 166 | OUTPUT_DIR=gs://corefqa_output_quoref/spanbert_large_squad2_best_quoref_3e-5 167 | 168 | python3 ${REPO_PATH}/run_quoref.py \ 169 | --vocab_file=$BERT_DIR/vocab.txt \ 170 | --bert_config_file=$BERT_DIR/bert_config.json \ 171 | --init_checkpoint=$BERT_DIR/best_bert_model.ckpt \ 172 | --do_train=True \ 173 | --train_file=$QUOREF_DIR/quoref-train-v0.1.json \ 174 | --do_predict=True \ 175 | --predict_file=$QUOREF_DIR/quoref-dev-v0.1.json \ 176 | --train_batch_size=8 \ 177 | --learning_rate=3e-5 \ 178 | --num_train_epochs=5 \ 179 | --max_seq_length=384 \ 180 | --do_lower_case=False \ 181 | --doc_stride=128 \ 182 | --output_dir=${OUTPUT_DIR} \ 183 | --use_tpu=True \ 184 | --tpu_name=$TPU_NAME 185 | ``` 186 | We use the best model (choose based on the performance on DEV set) on `Quoref` to initialize the CorefQA Model. 187 | 188 | ### Train the CorefQA Model on the CoNLL-2012 Coreference Resolution Task 189 | 1.1 Your can you can download the pre-trained mention proposal model (including [model](https://storage.googleapis.com/public_model_checkpoints/mention_proposal/model.ckpt-22000.data-00000-of-00001), [meta](https://storage.googleapis.com/public_model_checkpoints/mention_proposal/model.ckpt-22000.meta) and [index](https://storage.googleapis.com/public_model_checkpoints/mention_proposal/model.ckpt-22000.index)). 190 | 191 | 1.2. Or train the mention proposal model yourself. 192 | 193 | The script can be found in [./script/model/mention_tpu.sh](https://github.com/ShannonAI/CorefQA/blob/master/scripts/models/mention_tpu.sh). 194 | 195 | ```bash 196 | 197 | REPO_PATH=/home/shannon/coref-tf 198 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 199 | export TPU_NAME=tf-tpu 200 | export TPU_ZONE=europe-west4-a 201 | export GCP_PROJECT=xiaoyli-20-01-4820 202 | 203 | BERT_DIR=gs://corefqa_output_quoref/spanbert_large_squad2_best_quoref_1e-5 204 | DATA_DIR=gs://corefqa_data/final_overlap_384_6 205 | OUTPUT_DIR=gs://corefqa_output_mention_proposal/squad_quoref_large_384_6_1e5_8_0.2 206 | 207 | python3 ${REPO_PATH}/run/run_mention_proposal.py \ 208 | --output_dir=$OUTPUT_DIR \ 209 | --bert_config_file=$BERT_DIR/bert_config.json \ 210 | --init_checkpoint=$BERT_DIR/bert_model.ckpt \ 211 | --vocab_file=$BERT_DIR/vocab.txt \ 212 | --logfile_path=$OUTPUT_DIR/train.log \ 213 | --num_epochs=8 \ 214 | --keep_checkpoint_max=50 \ 215 | --save_checkpoints_steps=500 \ 216 | --train_file=$DATA_DIR/train.corefqa.english.tfrecord \ 217 | --dev_file=$DATA_DIR/dev.corefqa.english.tfrecord \ 218 | --test_file=$DATA_DIR/test.corefqa.english.tfrecord \ 219 | --do_train=True \ 220 | --do_eval=False \ 221 | --do_predict=False \ 222 | --learning_rate=1e-5 \ 223 | --dropout_rate=0.2 \ 224 | --mention_threshold=0.5 \ 225 | --hidden_size=1024 \ 226 | --num_docs=5604 \ 227 | --window_size=384 \ 228 | --num_window=6 \ 229 | --max_num_mention=60 \ 230 | --start_end_share=False \ 231 | --loss_start_ratio=0.3 \ 232 | --loss_end_ratio=0.3 \ 233 | --loss_span_ratio=0.3 \ 234 | --use_tpu=True \ 235 | --tpu_name=$TPU_NAME \ 236 | --tpu_zone=$TPU_ZONE \ 237 | --gcp_project=$GCP_PROJECT \ 238 | --num_tpu_cores=1 \ 239 | --seed=2333 240 | ``` 241 | 242 | 2. Jointly train the mention proposal model and linking model on CoNLL-12.
243 | 244 | After getting the best mention proposal model on the dev set, start jointly training the mention proposal and linking tasks. 245 | 246 | Run and the script can be found in [./script/model/corefqa_tpu.sh](https://github.com/ShannonAI/CorefQA/blob/master/scripts/models/corefqa_tpu.sh) 247 | 248 | ```bash 249 | 250 | REPO_PATH=/home/shannon/coref-tf 251 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 252 | export TPU_NAME=tf-tpu 253 | export TPU_ZONE=europe-west4-a 254 | export GCP_PROJECT=xiaoyli-20-01-4820 255 | 256 | BERT_DIR=gs://corefqa_output_mention_proposal/output_bertlarge 257 | DATA_DIR=gs://corefqa_data/final_overlap_384_6 258 | OUTPUT_DIR=gs://corefqa_output_corefqa/squad_quoref_mention_large_384_6_8e4_8_0.2 259 | 260 | python3 ${REPO_PATH}/run/run_corefqa.py \ 261 | --output_dir=$OUTPUT_DIR \ 262 | --bert_config_file=$BERT_DIR/bert_config.json \ 263 | --init_checkpoint=$BERT_DIR/best_bert_model.ckpt \ 264 | --vocab_file=$BERT_DIR/vocab.txt \ 265 | --logfile_path=$OUTPUT_DIR/train.log \ 266 | --num_epochs=8 \ 267 | --keep_checkpoint_max=50 \ 268 | --save_checkpoints_steps=500 \ 269 | --train_file=$DATA_DIR/train.corefqa.english.tfrecord \ 270 | --dev_file=$DATA_DIR/dev.corefqa.english.tfrecord \ 271 | --test_file=$DATA_DIR/test.corefqa.english.tfrecord \ 272 | --do_train=True \ 273 | --do_eval=False \ 274 | --do_predict=False \ 275 | --learning_rate=8e-4 \ 276 | --dropout_rate=0.2 \ 277 | --mention_threshold=0.5 \ 278 | --hidden_size=1024 \ 279 | --num_docs=5604 \ 280 | --window_size=384 \ 281 | --num_window=6 \ 282 | --max_num_mention=50 \ 283 | --start_end_share=False \ 284 | --max_span_width=10 \ 285 | --max_candiate_mentions=100 \ 286 | --top_span_ratio=0.2 \ 287 | --max_top_antecedents=30 \ 288 | --max_query_len=150 \ 289 | --max_context_len=150 \ 290 | --sec_qa_mention_score=False \ 291 | --use_tpu=True \ 292 | --tpu_name=$TPU_NAME \ 293 | --tpu_zone=$TPU_ZONE \ 294 | --gcp_project=$GCP_PROJECT \ 295 | --num_tpu_cores=1 \ 296 | --seed=2333 297 | ``` 298 | 299 | ## Evaluation and Prediction 300 | 301 | Currently, the evaluation is conducted on a set of saved checkpoints after the training process, and DO NOT support evaluation during training. Please transfer all checkpoints (the output directory is set `--output_dir=` when running the `run_.py`) from TPUs to GPUs for evaluation. 302 | This can be achieved by downloading the output directory from the Google Cloud Storage.
303 | 304 | 305 | The performance on the test set is obtained by using the model achieving the highest F1-score on the dev set.
306 | Set `--do_eval=True`、 `--do_train=False` and `--do_predict=False` to `run_.py` and start the evaluation process on a set of saved checkpoints. And other parameters should be the same with the training process. 307 | `` should take the value of `[mention_proposal, corefqa]`.
308 | 309 | The codebase also provides the option of evaluating a single model/checkpoint. Please set `--do_eval=False`、 `--do_train=False` and `--do_predict=True` to `run_.py` with the checkpoint path `--eval_checkpoint=`. 310 | `` should take the value of `[mention_proposal, corefqa]`. 311 |
312 | 313 | ## Download the Final CorefQA Model 314 | You can download the final CorefQA model at [link](https://drive.google.com/file/d/1RPYsS2dDxYyii7-3NkBNG7VtuA96NBLf/view?usp=sharing) and follow the instructions in the prediciton to obtain the score reported in the paper. 315 | 316 | 317 | ## Descriptions of Directories 318 | 319 | Name | Descriptions 320 | ----------- | ------------- 321 | bert | BERT modules (model,tokenizer,optimization) ref to the `google-research/bert` repository. 322 | conll-2012 | offical evaluation scripts for CoNLL2012 shared task. 323 | data_utils | modules for processing training data. 324 | func_builders | the input dataloader and model constructor for CorefQA. 325 | logs | the log files in our experiments. 326 | models | an implementation of CorefQA/MentionProposal models based on TF. 327 | run | modules for data preparation and training models. 328 | scripts/data | scripts for data preparation and loading pretrained models. 329 | scripts/models | scripts for {train/evaluate} {mention_proposal/corefqa} models on {TPU/GPU}. 330 | utils | modules including metrics、optimizers. 331 | 332 | 333 | 334 | ## Acknowledgement 335 | 336 | Many thanks to `Yuxian Meng` and the previous work `https://github.com/mandarjoshi90/coref`. 337 | 338 | ## Useful Materials 339 | 340 | - TPU Quick Start [link](https://cloud.google.com/tpu/docs/quickstart) 341 | - TPU Available Operations [link](https://cloud.google.com/tpu/docs/tensorflow-ops) 342 | 343 | ## Contact 344 | 345 | Feel free to discuss papers/code with us through issues/emails! 346 | -------------------------------------------------------------------------------- /models/mention_proposal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # the mention proposal model for pre-training the Span-BERT model. 9 | 10 | 11 | import os 12 | import sys 13 | 14 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-2]) 15 | if repo_path not in sys.path: 16 | sys.path.insert(0, repo_path) 17 | 18 | import tensorflow as tf 19 | from bert import modeling 20 | 21 | 22 | 23 | class MentionProposalModel(object): 24 | def __init__(self, config): 25 | self.config = config 26 | self.bert_config = modeling.BertConfig.from_json_file(config.bert_config_file) 27 | self.bert_config.hidden_dropout_prob = config.dropout_rate 28 | 29 | def get_mention_proposal_and_loss(self, instance, is_training, use_tpu=False): 30 | """ 31 | Desc: 32 | forward function for training mention proposal module. 33 | Args: 34 | instance: a tuple of train/dev/test data instance. 35 | e.g., (flat_input_ids, flat_doc_overlap_input_mask, flat_sentence_map, text_len, speaker_ids, gold_starts, gold_ends, cluster_ids) 36 | is_training: True/False is in the training process. 37 | """ 38 | self.use_tpu = use_tpu 39 | self.dropout = self.get_dropout(self.config.dropout_rate, is_training) 40 | 41 | flat_input_ids, flat_doc_overlap_input_mask, flat_sentence_map, text_len, speaker_ids, gold_starts, gold_ends, cluster_ids = instance 42 | # flat_input_ids: (num_window, window_size) 43 | # flat_doc_overlap_input_mask: (num_window, window_size) 44 | # flat_sentence_map: (num_window, window_size) 45 | # text_len: dynamic length and is padded to fix length 46 | # gold_start: (max_num_mention), mention start index in the original (NON-OVERLAP) document. Pad with -1 to the fix length max_num_mention. 47 | # gold_end: (max_num_mention), mention end index in the original (NON-OVERLAP) document. Pad with -1 to the fix length max_num_mention. 48 | # cluster_ids/speaker_ids is not used in the mention proposal model. 49 | 50 | flat_input_ids = tf.math.maximum(flat_input_ids, tf.zeros_like(flat_input_ids, tf.int32)) # (num_window * window_size) 51 | 52 | flat_doc_overlap_input_mask = tf.where(tf.math.greater_equal(flat_doc_overlap_input_mask, 0), 53 | x=tf.ones_like(flat_doc_overlap_input_mask, tf.int32), y=tf.zeros_like(flat_doc_overlap_input_mask, tf.int32)) # (num_window * window_size) 54 | # flat_doc_overlap_input_mask = tf.math.maximum(flat_doc_overlap_input_mask, tf.zeros_like(flat_doc_overlap_input_mask, tf.int32)) 55 | flat_sentence_map = tf.math.maximum(flat_sentence_map, tf.zeros_like(flat_sentence_map, tf.int32)) # (num_window * window_size) 56 | 57 | gold_start_end_mask = tf.cast(tf.math.greater_equal(gold_starts, tf.zeros_like(gold_starts, tf.int32)), tf.bool) # (max_num_mention) 58 | gold_start_index_labels = self.boolean_mask_1d(gold_starts, gold_start_end_mask, name_scope="gold_starts", use_tpu=self.use_tpu) # (num_of_mention) 59 | gold_end_index_labels = self.boolean_mask_1d(gold_ends, gold_start_end_mask, name_scope="gold_ends", use_tpu=self.use_tpu) # (num_of_mention) 60 | 61 | text_len = tf.math.maximum(text_len, tf.zeros_like(text_len, tf.int32)) # (num_of_non_empty_window) 62 | num_subtoken_in_doc = tf.math.reduce_sum(text_len) # the value should be num_subtoken_in_doc 63 | 64 | input_ids = tf.reshape(flat_input_ids, [-1, self.config.window_size]) # (num_window, window_size) 65 | input_mask = tf.ones_like(input_ids, tf.int32) # (num_window, window_size) 66 | 67 | model = modeling.BertModel(config=self.bert_config, is_training=is_training, 68 | input_ids=input_ids, input_mask=input_mask, 69 | use_one_hot_embeddings=False, scope='bert') 70 | 71 | doc_overlap_window_embs = model.get_sequence_output() # (num_window, window_size, hidden_size) 72 | doc_overlap_input_mask = tf.reshape(flat_doc_overlap_input_mask, [self.config.num_window, self.config.window_size]) # (num_window, window_size) 73 | 74 | doc_flat_embs = self.transform_overlap_windows_to_original_doc(doc_overlap_window_embs, doc_overlap_input_mask) 75 | doc_flat_embs = tf.reshape(doc_flat_embs, [-1, self.config.hidden_size]) # (num_subtoken_in_doc, hidden_size) 76 | 77 | expand_start_embs = tf.tile(tf.expand_dims(doc_flat_embs, 1), [1, num_subtoken_in_doc, 1]) # (num_subtoken_in_doc, num_subtoken_in_doc, hidden_size) 78 | expand_end_embs = tf.tile(tf.expand_dims(doc_flat_embs, 0), [num_subtoken_in_doc, 1, 1]) # (num_subtoken_in_doc, num_subtoken_in_doc, hidden_size) 79 | expand_mention_span_embs = tf.concat([expand_start_embs, expand_end_embs], axis=-1) # (num_subtoken_in_doc, num_subtoken_in_doc, 2*hidden_size) 80 | expand_mention_span_embs = tf.reshape(expand_mention_span_embs, [-1, self.config.hidden_size*2]) 81 | span_sequence_logits = self.ffnn(expand_mention_span_embs, self.config.hidden_size*2, 1, dropout=self.dropout, name_scope="mention_span") # (num_subtoken_in_doc * num_subtoken_in_doc) 82 | 83 | if self.config.start_end_share: 84 | start_end_sequence_logits = self.ffnn(doc_flat_embs, self.config.hidden_size, 2, dropout=self.dropout, name_scope="mention_start_end") # (num_subtoken_in_doc, 2) 85 | start_sequence_logits, end_sequence_logits = tf.split(start_end_sequence_logits, axis=1) 86 | # start_sequence_logits -> (num_subtoken_in_doc, 1) 87 | # end_sequence_logits -> (num_subtoken_in_doc, 1) 88 | else: 89 | start_sequence_logits = self.ffnn(doc_flat_embs, self.config.hidden_size, 1, dropout=self.dropout, name_scope="mention_start") # (num_subtoken_in_doc) 90 | end_sequence_logits = self.ffnn(doc_flat_embs, self.config.hidden_size, 1, dropout=self.dropout, name_scope="mention_end") # (num_subtoken_in_doc) 91 | 92 | gold_start_sequence_labels = self.scatter_gold_index_to_label_sequence(gold_start_index_labels, num_subtoken_in_doc) # (num_subtoken_in_doc) 93 | gold_end_sequence_labels = self.scatter_gold_index_to_label_sequence(gold_end_index_labels, num_subtoken_in_doc) # (num_subtoken_in_doc) 94 | 95 | start_loss, start_sequence_probabilities = self.compute_score_and_loss(start_sequence_logits, gold_start_sequence_labels) 96 | end_loss, end_sequence_probabilities = self.compute_score_and_loss(end_sequence_logits, gold_end_sequence_labels) 97 | # *_loss -> a scalar 98 | # *_sequence_scores -> (num_subtoken_in_doc) 99 | 100 | gold_span_sequence_labels = self.scatter_span_sequence_labels(gold_start_index_labels, gold_end_index_labels, num_subtoken_in_doc) # (num_subtoken_in_doc * num_subtoken_in_doc) 101 | span_loss, span_sequence_probabilities = self.compute_score_and_loss(span_sequence_logits, gold_span_sequence_labels) 102 | # span_loss -> a scalar 103 | # span_sequence_probabilities -> (num_subtoken_in_doc * num_subtoken_in_doc) 104 | 105 | total_loss = self.config.loss_start_ratio * start_loss + self.config.loss_end_ratio * end_loss + self.config.loss_span_ratio * span_loss 106 | return total_loss, start_sequence_probabilities, end_sequence_probabilities, span_sequence_probabilities 107 | 108 | 109 | def get_gold_mention_sequence_labels_from_pad_index(self, pad_gold_start_index_labels, pad_gold_end_index_labels, pad_text_len): 110 | """ 111 | Desc: 112 | the original gold labels is padded to the fixed length and only contains the position index of gold mentions. 113 | return the gold sequence of labels for evaluation. 114 | Args: 115 | pad_gold_start_index_labels: a tf.int32 tensor with a fixed length (self.config.max_num_mention). 116 | every element in the tensor is the start position index for the mentions. 117 | pad_gold_end_index_labels: a tf.int32 tensor with a fixed length (self.config.max_num_mention). 118 | every element in the tensor is the end position index of the mentions. 119 | pad_text_len: a tf.int32 tensor with a fixed length (self.config.num_window). 120 | every positive element in the tensor indicates that the number of subtokens in the window. 121 | Returns: 122 | gold_start_sequence_labels: a tf.int32 tensor with the shape of (num_subtoken_in_doc). 123 | if the element in the tensor equals to 0, this subtoken is not a start for a mention. 124 | if the elemtn in the tensor equals to 1, this subtoken is a start for a mention. 125 | gold_end_sequence_labels: a tf.int32 tensor with the shape of (num_subtoken_in_doc). 126 | if the element in the tensor equals to 0, this subtoken is not a end for a mention. 127 | if the elemtn in the tensor equals to 1, this subtoken is a end for a mention. 128 | gold_span_sequence_labels: a tf.int32 tensor with the shape of (num_subtoken_in_doc, num_subtoken_in_doc)/ 129 | if the element[i][j] equals to 0, this subtokens from $i$ to $j$ is not a mention. 130 | if the element[i][j] equals to 1, this subtokens from $i$ to $j$ is a mention. 131 | """ 132 | text_len = tf.math.maximum(pad_text_len, tf.zeros_like(pad_text_len, tf.int32)) # (num_of_non_empty_window) 133 | num_subtoken_in_doc = tf.math.reduce_sum(text_len) # the value should be num_subtoken_in_doc 134 | 135 | gold_start_end_mask = tf.cast(tf.math.greater_equal(pad_gold_start_index_labels, tf.zeros_like(pad_gold_start_index_labels, tf.int32)), tf.bool) # (max_num_mention) 136 | gold_start_index_labels = self.boolean_mask_1d(pad_gold_start_index_labels, gold_start_end_mask, name_scope="gold_starts", use_tpu=self.use_tpu) # (num_of_mention) 137 | gold_end_index_labels = self.boolean_mask_1d(pad_gold_end_index_labels, gold_start_end_mask, name_scope="gold_ends", use_tpu=self.use_tpu) # (num_of_mention) 138 | 139 | gold_start_sequence_labels = self.scatter_gold_index_to_label_sequence(gold_start_index_labels, num_subtoken_in_doc) # (num_subtoken_in_doc) 140 | gold_end_sequence_labels = self.scatter_gold_index_to_label_sequence(gold_end_index_labels, num_subtoken_in_doc) # (num_subtoken_in_doc) 141 | gold_span_sequence_labels = self.scatter_span_sequence_labels(gold_start_index_labels, gold_end_index_labels, num_subtoken_in_doc) # (num_subtoken_in_doc, num_subtoken_in_doc) 142 | 143 | return gold_start_sequence_labels, gold_end_sequence_labels, gold_span_sequence_labels 144 | 145 | 146 | def scatter_gold_index_to_label_sequence(self, gold_index_labels, expect_length_of_labels): 147 | """ 148 | Desc: 149 | transform the mention start/end position index tf.int32 Tensor to a tf.int32 Tensor with 1/0 labels for the input subtoken sequences. 150 | 1 denotes this subtoken is the start/end for a mention. 151 | Args: 152 | gold_index_labels: a tf.int32 Tensor with mention start/end position index in the original document. 153 | expect_length_of_labels: the number of subtokens in the original document. 154 | """ 155 | gold_labels_pos = tf.reshape(gold_index_labels, [-1, 1]) # (num_of_mention, 1) 156 | gold_value = tf.reshape(tf.ones_like(gold_index_labels), [-1]) # (num_of_mention) 157 | label_shape = tf.Variable(expect_length_of_labels) 158 | label_shape = tf.reshape(label_shape, [1]) # [1] 159 | gold_sequence_labels = tf.cast(tf.scatter_nd(gold_labels_pos, gold_value, label_shape), tf.int32) # (num_subtoken_in_doc) 160 | return gold_sequence_labels 161 | 162 | 163 | def scatter_span_sequence_labels(self, gold_start_index_labels, gold_end_index_labels, expect_length_of_labels): 164 | """ 165 | Desc: 166 | transform the mention (start, end) position pairs to a span matrix gold_span_sequence_labels. 167 | matrix[i][j]: whether the subtokens between the position $i$ to $j$ can be a mention. 168 | if matrix[i][j] == 0: from $i$ to $j$ is not a mention. 169 | if matrix[i][j] == 1: from $i$ to $j$ is a mention. 170 | Args: 171 | gold_start_index_labels: a tf.int32 Tensor with mention start position index in the original document. 172 | gold_end_index_labels: a tf.int32 Tensor with mention end position index in the original document. 173 | expect_length_of_labels: a scalar, should be the same with num_subtoken_in_doc 174 | """ 175 | gold_span_index_labels = tf.stack([gold_start_index_labels, gold_end_index_labels], axis=1) # (num_of_mention, 2) 176 | gold_span_value = tf.reshape(tf.ones_like(gold_start_index_labels, tf.int32), [-1]) # (num_of_mention) 177 | gold_span_label_shape = tf.Variable([expect_length_of_labels, expect_length_of_labels]) 178 | gold_span_label_shape = tf.reshape(gold_span_label_shape, [-1]) 179 | 180 | gold_span_sequence_labels = tf.cast(tf.scatter_nd(gold_span_index_labels, gold_span_value, gold_span_label_shape), tf.int32) # (num_subtoken_in_doc, num_subtoken_in_doc) 181 | return gold_span_sequence_labels 182 | 183 | 184 | def compute_score_and_loss(self, pred_sequence_logits, gold_sequence_labels, loss_mask=None): 185 | """ 186 | Desc: 187 | compute the unifrom start/end loss and probabilities. 188 | Args: 189 | pred_sequence_logits: (input_shape, 1) 190 | gold_sequence_labels: (input_shape, 1) 191 | loss_mask: [optional] if is not None, it should be (input_shape). should be tf.int32 0/1 tensor. 192 | FOR start/end score and loss, input_shape should be num_subtoken_in_doc. 193 | FOR span score and loss, input_shape should be num_subtoken_in_doc * num_subtoken_in_doc. 194 | """ 195 | pred_sequence_probabilities = tf.cast(tf.reshape(tf.sigmoid(pred_sequence_logits), [-1]),tf.float32) # (input_shape) 196 | expand_pred_sequence_scores = tf.stack([(1 - pred_sequence_probabilities), pred_sequence_probabilities], axis=-1) # (input_shape, 2) 197 | expand_gold_sequence_labels = tf.cast(tf.one_hot(tf.reshape(gold_sequence_labels, [-1]), 2, axis=-1), tf.float32) # (input_shape, 2) 198 | 199 | loss = tf.keras.losses.binary_crossentropy(expand_gold_sequence_labels, expand_pred_sequence_scores) 200 | # loss -> shape is (input_shape) 201 | 202 | if loss_mask is not None: 203 | loss = tf.multiply(loss, tf.cast(loss_mask, tf.float32)) 204 | 205 | total_loss = tf.reduce_mean(loss) 206 | # total_loss -> a scalar 207 | 208 | return total_loss, pred_sequence_probabilities 209 | 210 | 211 | def transform_overlap_windows_to_original_doc(self, doc_overlap_window_embs, doc_overlap_input_mask): 212 | """ 213 | Desc: 214 | hidden_size should be equal to embeddding_size. 215 | Args: 216 | doc_overlap_window_embs: (num_window, window_size, hidden_size). 217 | the output of (num_window, window_size) input_ids forward into BERT model. 218 | doc_overlap_input_mask: (num_window, window_size). A tf.int32 Tensor contains 0/1. 219 | 0 represents token in this position should be neglected. 1 represents token in this position should be reserved. 220 | """ 221 | ones_input_mask = tf.ones_like(doc_overlap_input_mask, tf.int32) # (num_window, window_size) 222 | cumsum_input_mask = tf.math.cumsum(ones_input_mask, axis=1) # (num_window, window_size) 223 | offset_input_mask = tf.tile(tf.expand_dims(tf.range(self.config.num_window) * self.config.window_size, 1), [1, self.config.window_size]) # (num_window, window_size) 224 | offset_cumsum_input_mask = offset_input_mask + cumsum_input_mask # (num_window, window_size) 225 | global_input_mask = tf.math.multiply(ones_input_mask, offset_cumsum_input_mask) # (num_window, window_size) 226 | global_input_mask = tf.reshape(global_input_mask, [-1]) # (num_window * window_size) 227 | global_input_mask_index = self.boolean_mask_1d(global_input_mask, tf.math.greater(global_input_mask, tf.zeros_like(global_input_mask, tf.int32))) # (num_subtoken_in_doc) 228 | 229 | doc_overlap_window_embs = tf.reshape(doc_overlap_window_embs, [-1, self.config.hidden_size]) # (num_window * window_size, hidden_size) 230 | original_doc_embs = tf.gather(doc_overlap_window_embs, global_input_mask_index) # (num_subtoken_in_doc, hidden_size) 231 | 232 | return original_doc_embs 233 | 234 | 235 | def ffnn(self, inputs, hidden_size, output_size, dropout=None, name_scope="fully-conntected-neural-network", 236 | hidden_initializer=tf.truncated_normal_initializer(stddev=0.02)): 237 | """ 238 | Desc: 239 | fully-connected neural network. 240 | transform non-linearly the [input] tensor with [hidden_size] to a fix [output_size] size. 241 | Args: 242 | hidden_size: should be the size of last dimension of [inputs]. 243 | """ 244 | with tf.variable_scope(name_scope, reuse=tf.AUTO_REUSE): 245 | hidden_weights = tf.get_variable("hidden_weights", [hidden_size, output_size], 246 | initializer=hidden_initializer) 247 | hidden_bias = tf.get_variable("hidden_bias", [output_size], initializer=tf.zeros_initializer()) 248 | outputs = tf.nn.relu(tf.nn.xw_plus_b(inputs, hidden_weights, hidden_bias)) 249 | 250 | if dropout is not None: 251 | outputs = tf.nn.dropout(outputs, dropout) 252 | 253 | return outputs 254 | 255 | 256 | def get_dropout(self, dropout_rate, is_training): 257 | return 1 - (tf.to_float(is_training) * dropout_rate) 258 | 259 | 260 | def get_shape(self, x, dim): 261 | """ 262 | Desc: 263 | return the size of input x in DIM. 264 | """ 265 | return x.get_shape()[dim].value or tf.shape(x)[dim] 266 | 267 | 268 | def boolean_mask_1d(self, itemtensor, boolmask_indicator, name_scope="boolean_mask1d", use_tpu=False): 269 | """ 270 | Desc: 271 | the same functionality of tf.boolean_mask. 272 | The tf.boolean_mask operation is not available on the cloud TPU. 273 | Args: 274 | itemtensor : a Tensor contains [tf.int32, tf.float32] numbers. Should be 1-Rank. 275 | boolmask_indicator : a tf.bool Tensor. Should be 1-Rank. 276 | scope : name scope for the operation. 277 | use_tpu : if False, return tf.boolean_mask. 278 | """ 279 | with tf.name_scope(name_scope): 280 | if not use_tpu: 281 | return tf.boolean_mask(itemtensor, boolmask_indicator) 282 | 283 | boolmask_sum = tf.reduce_sum(tf.cast(boolmask_indicator, tf.int32)) 284 | selected_positions = tf.cast(boolmask_indicator, dtype=tf.float32) 285 | indexed_positions = tf.cast(tf.multiply(tf.cumsum(selected_positions), selected_positions),dtype=tf.int32) 286 | one_hot_selector = tf.one_hot(indexed_positions - 1, boolmask_sum, dtype=tf.float32) 287 | sampled_indices = tf.cast(tf.tensordot(tf.cast(tf.range(tf.shape(boolmask_indicator)[0]), dtype=tf.float32), 288 | one_hot_selector,axes=[0, 0]),dtype=tf.int32) 289 | sampled_indices = tf.reshape(sampled_indices, [-1]) 290 | mask_itemtensor = tf.gather(itemtensor, sampled_indices) 291 | 292 | return mask_itemtensor 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /logs/corefqa_log.txt: -------------------------------------------------------------------------------- 1 | /home/xiaoyli1110/venv/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. 2 | from ._conv import register_converters as _register_converters 3 | /home/xiaoyli1110/xiaoya/Coref-tf 4 | W0713 13:36:28.637916 139854454523648 module_wrapper.py:139] From /home/xiaoyli1110/xiaoya/Coref-tf/run/train_corefqa.py:308: The name tf.app.run is deprecated. Please use tf.compat.v1.app.run instead. 5 | 6 | loading experiments_tpu.conf ... 7 | W0713 13:36:28.752999 139854454523648 module_wrapper.py:139] From /home/xiaoyli1110/xiaoya/Coref-tf/utils/util.py:41: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead. 8 | I0716 13:36:28.753216 139854454523648 util.py:41] %*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*% 9 | I0716 13:36:28.753291 139854454523648 util.py:42] %*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*% 10 | I0716 13:36:28.753346 139854454523648 util.py:43] %%%%%%%% Configs are showed as follows : %%%%%%%% 11 | I0716 13:36:28.753432 139854454523648 util.py:45] max_top_antecedents : 60 12 | I0716 13:36:28.753505 139854454523648 util.py:45] max_training_sentences : 3 13 | I0716 13:36:28.753576 139854454523648 util.py:45] top_span_ratio : 0.3 14 | I0716 13:36:28.753644 139854454523648 util.py:45] max_num_speakers : 20 15 | I0716 13:36:28.753710 139854454523648 util.py:45] max_segment_len : 384 16 | I0716 13:36:28.753773 139854454523648 util.py:45] max_cluster_num : 30 17 | I0716 13:36:28.753848 139854454523648 util.py:45] tpu : True 18 | I0716 13:36:28.753910 139854454523648 util.py:45] max_query_len : 150 19 | I0716 13:36:28.753972 139854454523648 util.py:45] max_context_len : 150 20 | I0716 13:36:28.754034 139854454523648 util.py:45] max_qa_len : 300 21 | I0716 13:36:28.754097 139854454523648 util.py:45] hidden_size : 1024 22 | I0716 13:36:28.754161 139854454523648 util.py:45] max_candidate_mentions : 60 23 | I0716 13:36:28.754229 139854454523648 util.py:45] learning_rate : 8e-06 24 | I0716 13:36:28.754301 139854454523648 util.py:45] num_docs : 5604 25 | I0716 13:36:28.754365 139854454523648 util.py:45] start_ratio : 0.8 26 | I0716 13:36:28.754428 139854454523648 util.py:45] end_ratio : 0.8 27 | I0716 13:36:28.754492 139854454523648 util.py:45] mention_ratio : 1.0 28 | I0716 13:36:28.754556 139854454523648 util.py:45] corefqa_loss_ratio : 0.9 29 | I0716 13:36:28.754620 139854454523648 util.py:45] score_ratio : 0.5 30 | I0716 13:36:28.754682 139854454523648 util.py:45] run : estimator 31 | I0716 13:36:28.754746 139854454523648 util.py:45] threshold : 0.5 32 | I0716 13:36:28.754814 139854454523648 util.py:45] dropout_rate : 0.3 33 | I0716 13:36:28.754945 139854454523648 util.py:45] ffnn_size : 1024 34 | I0716 13:36:28.755008 139854454523648 util.py:45] ffnn_depth : 1 35 | I0716 13:36:28.755071 139854454523648 util.py:45] num_epochs : 8 36 | I0716 13:36:28.755135 139854454523648 util.py:45] max_span_width : 30 37 | I0716 13:36:28.755199 139854454523648 util.py:45] use_segment_distance : True 38 | I0716 13:36:28.755261 139854454523648 util.py:45] model_heads : True 39 | I0716 13:36:28.755324 139854454523648 util.py:45] coref_depth : 2 40 | I0716 13:36:28.755383 139854454523648 util.py:45] corefqa_only_concate : False 41 | I0716 13:36:28.755445 139854454523648 util.py:45] train_path : gs://xiaoy-data-europe/overlap_384_3/train.128.english.tfrecord 42 | I0716 13:36:28.755507 139854454523648 util.py:45] eval_path : test.english.jsonlines 43 | I0716 13:36:28.755571 139854454523648 util.py:45] conll_eval_path : gs://corefqa-europe/spanbert_large_overlap_384_3_out 44 | put_2e-5/test.english.v4_gold_conll 45 | I0716 13:36:28.755634 139854454523648 util.py:45] single_example : False 46 | I0716 13:36:28.755702 139854454523648 util.py:45] genres : ['bc', 'bn', 'mz', 'nw', 'pt', 'tc', 'wb'] 47 | I0716 13:36:28.755765 139854454523648 util.py:45] log_root : gs://corefqa-europe/spanbert_large_overlap_384_3_output_2e-5 48 | I0716 13:36:28.755842 139854454523648 util.py:45] save_checkpoints_steps : 1000 49 | I0716 13:36:28.755906 139854454523648 util.py:45] dev_path : gs://xiaoy-data-europe/overlap_384_3/dev.256.english.tfrecord 50 | I0716 13:36:28.755968 139854454523648 util.py:45] test_path : gs://xiaoy-data-europe/overlap_384_3/test.256.english.tfrecord 51 | I0716 13:36:28.756030 139854454523648 util.py:45] bert_config_file : gs://xiaoy-data-europe/spanbert_large_tf/bert_config.json 52 | I0716 13:36:28.756093 139854454523648 util.py:45] vocab_file : gs://xiaoy-data-europe/spanbert_large_tf/vocab.txt 53 | I0716 13:36:28.756155 139854454523648 util.py:45] tf_checkpoint : gs://xiaoy-data-europe/spanbert_large_tf/bert_model.ckpt 54 | I0716 13:36:28.756217 139854454523648 util.py:45] init_checkpoint : gs://xiaoy-data-europe/spanbert_large_tf/bert_model.ckpt 55 | I0716 13:36:28.756279 139854454523648 util.py:45] eval_checkpoint : gs://corefqa-europe/spanbert_large_overlap_384_3_out 56 | put_1e-5_0.3_8/model.ckpt-20 57 | I0716 13:36:28.756341 139854454523648 util.py:45] output_path : gs://corefqa-europe/spanbert_large_overlap_384_3_output_ 58 | 1e-5_0.3_8 59 | I0716 13:36:28.756391 139854454523648 util.py:47] %*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*%%*% 60 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 61 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-500 62 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 63 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.6219, recall: 0.5093, f1: 0.56 64 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 65 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.6413, recall: 0.5428, f1: 0.588 66 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 67 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-1000 68 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 69 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.6227, recall: 0.5841, f1: 0.6028 70 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 71 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.6744, recall: 0.6126, f1: 0.6149 72 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 73 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-1500 74 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 75 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.6906, recall: 0.5288, f1: 0.599 76 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 77 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.6966, recall: 0.506, f1: 0.5862 78 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 79 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-2000 80 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 81 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.6712, recall: 0.5717, f1: 0.6174 82 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 83 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.6696, recall: 0.541, f1: 0.5985 84 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 85 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-2500 86 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 87 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8091, recall: 0.5967, f1: 0.6868 88 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 89 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.8018, recall: 0.5692, f1: 0.6657 90 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 91 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-3000 92 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 93 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.6675, recall: 0.6382, f1: 0.6525 94 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 95 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-3500 96 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 97 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.7155, recall: 0.7515, f1: 0.7331 98 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 99 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.7239, recall: 0.7711, f1: 0.7468 100 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 101 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-4000 102 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 103 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.7762, recall: 0.5819, f1: 0.6651 104 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 105 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-4500 106 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 107 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.6661, recall: 0.6236, f1: 0.6442 108 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 109 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-5000 110 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 111 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.7814, recall: 0.7246, f1: 0.7519 112 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 113 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.7155, recall: 0.7515, f1: 0.7331 114 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 115 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-5500 116 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 117 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8042, recall: 0.7417, f1: 0.7717 118 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 119 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.8439, recall: 0.6328, f1: 0.7232 120 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 121 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-6000 122 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 123 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.7942, recall: 0.7217, f1: 0.7417 124 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 125 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-6500 126 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 127 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.7816, recall: 0.7831, f1: 0.7823 128 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 129 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.7893, recall: 0.8075, f1: 0.7983 130 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 131 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-7000 132 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 133 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8103, recall: 0.814, f1: 0.8121 134 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 135 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.8022, recall: 0.7838, f1: 0.7929 136 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 137 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-7500 138 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 139 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.7935, recall: 0.8292, f1: 0.8109 140 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 141 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-8000 142 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 143 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8997, recall: 0.7292, f1: 0.8056 144 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 145 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-8500 146 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 147 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8737, recall: 0.7444, f1: 0.8039 148 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 149 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-9000 150 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 151 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8107, recall: 0.814, f1: 0.8124 152 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 153 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-10500 154 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 155 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8439, recall: 0.7952, f1: 0.8188 156 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 157 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.8228, recall: 0.8093, f1: 0.8107 158 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 159 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-11000 160 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 161 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8386, recall: 0.8147, f1: 0.8273 162 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 163 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.8239, recall: 0.8104, f1: 0.8143 164 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 165 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-12000 166 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 167 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8398, recall: 0.8201, f1: 0.8369 168 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 169 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-12500 170 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 171 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8666, recall: 0.7766, f1: 0.8192 172 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 173 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-13000 174 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 175 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8397, recall: 0.7892, f1: 0.8056 176 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 177 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-13500 178 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 179 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8357, recall: 0.8169, f1: 0.8262 180 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON TEST SET ***** 181 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [TEST EVAL] ***** : precision: 0.8327, recall: 0.8288, f1: 0.8322 182 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 183 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-14000 184 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 185 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8269, recall: 0.8108, f1: 0.8201 186 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 187 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-14500 188 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 189 | I0716 14:17:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8344, recall: 0.8104, f1: 0.8215 190 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] ***** Current ckpt path is ***** 191 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-15000 192 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] ***** EVAL ON DEV SET ***** 193 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] ***** [DEV EVAL] ***** : precision: 0.8287, recall: 0.8177, f1: 0.8223 194 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] ************************* 195 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] - @@@@@ the path to the BEST DEV result is : gs://corefqa-europe-europe/spanbert_large_overlap_384_3_output_8e6_0.2_8/model.ckpt-13500 196 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] - @@@@@ BEST DEV F1 : 0.8262, Precision : 0.8357, Recall : 0.8169 197 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] - @@@@@ TEST when DEV best F1 : 0.8322, Precision : 0.8327, Recall : 0.8288 198 | I0716 14:46:19.214575 139854454523648 tpu_estimator.py:2307] - @@@@@ mention_proposal_only_concate False -------------------------------------------------------------------------------- /run/build_dataset_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # author: xiaoy li 7 | # description: 8 | # generate tfrecord for train/dev/test set for the model. 9 | # TODO (xiaoya): need to add help description for args 10 | 11 | 12 | 13 | import os 14 | import sys 15 | import re 16 | import json 17 | import argparse 18 | import numpy as np 19 | import tensorflow as tf 20 | from collections import OrderedDict 21 | 22 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-2]) 23 | if REPO_PATH not in sys.path: 24 | sys.path.insert(0, REPO_PATH) 25 | 26 | from data_utils import conll 27 | from bert.tokenization import FullTokenizer 28 | 29 | SPEAKER_START = '[unused19]' 30 | SPEAKER_END = '[unused73]' 31 | subtoken_maps = {} 32 | gold = {} 33 | 34 | 35 | 36 | """ 37 | Desc: 38 | a single training/test example for the squad dataset. 39 | suppose origin input_tokens are : 40 | ['[unused19]', 'speaker', '#', '1', '[unused73]', '-', '-', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various', 'relevant', 'parties', '.', 41 | 'To', 'express', 'its', 'determination', ',', 'the', 'Chinese', 'securities', 'regulatory', 'department', 'compares', 'this', 'stock', 'reform', 'to', 'a', 'die', 'that', 42 | 'has', 'been', 'cast', '.', 'It', 'takes', 'time', 'to', 'prove', 'whether', 'the', 'stock', 'reform', 'can', 'really', 'meet', 'expectations', ',', 'and', 'whether', 'any', 43 | 'de', '##viation', '##s', 'that', 'arise', 'during', 'the', 'stock', 'reform', 'can', 'be', 'promptly', 'corrected', '.', '[unused19]', 'Xu', '_', 'l', '##i', '[unused73]', 44 | 'Dear', 'viewers', ',', 'the', 'China', 'News', 'program', 'will', 'end', 'here', '.', 'This', 'is', 'Xu', 'Li', '.', 'Thank', 'you', 'everyone', 'for', 'watching', '.', 'Coming', 45 | 'up', 'is', 'the', 'Focus', 'Today', 'program', 'hosted', 'by', 'Wang', 'Shi', '##lin', '.', 'Good', '-', 'bye', ',', 'dear', 'viewers', '.'] 46 | IF sliding window size is 50. 47 | Args: 48 | doc_idx: a string: cctv/bn/0001 49 | sentence_map: 50 | e.g. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7] 51 | subtoken_map: 52 | e.g. [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 53, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 97, 98, 99, 99, 99, 100, 101, 102, 103] 53 | flattened_window_input_ids: [num-window, window-size] 54 | e.g. before bert_tokenizer convert subtokens into ids: 55 | [['[CLS]', '[unused19]', 'speaker', '#', '1', '[unused73]', '-', '-', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various', 'relevant', 'parties', '.', 'To', 'express', 'its', 'determination', ',', 'the', 'Chinese', 'securities', 'regulatory', 'department', 'compares', 'this', 'stock', 'reform', 'to', 'a', 'die', 'that', 'has', 'been', 'cast', '.', 'It', 'takes', 'time', 'to', 'prove', 'whether', '[SEP]'], 56 | ['[CLS]', ',', 'the', 'Chinese', 'securities', 'regulatory', 'department', 'compares', 'this', 'stock', 'reform', 'to', 'a', 'die', 'that', 'has', 'been', 'cast', '.', 'It', 'takes', 'time', 'to', 'prove', 'whether', 'the', 'stock', 'reform', 'can', 'really', 'meet', 'expectations', ',', 'and', 'whether', 'any', 'de', '##viation', '##s', 'that', 'arise', 'during', 'the', 'stock', 'reform', 'can', 'be', 'promptly', 'corrected', '[SEP]'], 57 | ['[CLS]', 'the', 'stock', 'reform', 'can', 'really', 'meet', 'expectations', ',', 'and', 'whether', 'any', 'de', '##viation', '##s', 'that', 'arise', 'during', 'the', 'stock', 'reform', 'can', 'be', 'promptly', 'corrected', '.', '[unused19]', 'Xu', '_', 'l', '##i', '[unused73]', 'Dear', 'viewers', ',', 'the', 'China', 'News', 'program', 'will', 'end', 'here', '.', 'This', 'is', 'Xu', 'Li', '.', 'Thank', '[SEP]'], 58 | ['[CLS]', '.', '[unused19]', 'Xu', '_', 'l', '##i', '[unused73]', 'Dear', 'viewers', ',', 'the', 'China', 'News', 'program', 'will', 'end', 'here', '.', 'This', 'is', 'Xu', 'Li', '.', 'Thank', 'you', 'everyone', 'for', 'watching', '.', 'Coming', 'up', 'is', 'the', 'Focus', 'Today', 'program', 'hosted', 'by', 'Wang', 'Shi', '##lin', '.', 'Good', '-', 'bye', ',', 'dear', 'viewers', '[SEP]'], 59 | ['[CLS]', 'you', 'everyone', 'for', 'watching', '.', 'Coming', 'up', 'is', 'the', 'Focus', 'Today', 'program', 'hosted', 'by', 'Wang', 'Shi', '##lin', '.', 'Good', '-', 'bye', ',', 'dear', 'viewers', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']] 60 | flattened_window_masked_ids: [num-window, window-size] 61 | e.g.: before bert_tokenizer ids: 62 | [[-3, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -3], 63 | [-3, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -3], 64 | [-3, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, -1, -1, -1, -1, -1, -1, 68, 69, 70, 71, 72, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -3], 65 | [-3, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -3], 66 | [-3, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -3, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4]] 67 | span_start: 68 | e.g.: mention start indices in the original document 69 | [17, 20, 26, 43, 60, 85, 86] 70 | span_end: 71 | e.g.: mention end indices in the original document 72 | cluster_ids: 73 | e.g.: cluster ids for the (span_start, span_end) pairs 74 | [1, 1, 2, 2, 2, 3, 3] 75 | check the mention in the subword list: 76 | 1. ['its'] 77 | 1. ['the', 'Chinese', 'securities', 'regulatory', 'department'] 78 | 2. ['this', 'stock', 'reform'] 79 | 2. ['the', 'stock', 'reform'] 80 | 2. ['the', 'stock', 'reform'] 81 | 3. ['you'] 82 | 3. ['everyone'] 83 | """ 84 | 85 | 86 | 87 | def prepare_train_dataset(input_file, output_data_dir, output_filename, window_size, num_window, 88 | tokenizer=None, vocab_file=None, language="english", max_doc_length=None, genres=None, 89 | max_num_mention=10, max_num_cluster=30, demo=False, lowercase=False): 90 | 91 | if vocab_file is None: 92 | if not lowercase: 93 | vocab_file = os.path.join(REPO_PATH, "data_utils", "uppercase_vocab.txt") 94 | else: 95 | vocab_file = os.path.join(REPO_PATH, "data_utils", "lowercase_vocab.txt") 96 | 97 | if tokenizer is None: 98 | tokenizer = FullTokenizer(vocab_file=vocab_file, do_lower_case=lowercase) 99 | 100 | writer = tf.python_io.TFRecordWriter(os.path.join(output_data_dir, "{}.{}.tfrecord".format(output_filename, language))) 101 | doc_map = {} 102 | documents = read_conll_file(input_file) 103 | for doc_idx, document in enumerate(documents): 104 | doc_info = parse_document(document, language) 105 | tokenized_document = tokenize_document(genres, doc_info, tokenizer) 106 | doc_key = tokenized_document['doc_key'] 107 | token_windows, mask_windows, text_len = convert_to_sliding_window(tokenized_document, window_size) 108 | input_id_windows = [tokenizer.convert_tokens_to_ids(tokens) for tokens in token_windows] 109 | span_start, span_end, mention_span, cluster_ids = flatten_clusters(tokenized_document['clusters']) 110 | 111 | tmp_speaker_ids = tokenized_document["speakers"] 112 | tmp_speaker_ids = [[0]*130]* num_window 113 | instance = (input_id_windows, mask_windows, text_len, tmp_speaker_ids, tokenized_document["genre"], span_start, span_end, cluster_ids, tokenized_document['sentence_map']) 114 | write_instance_to_example_file(writer, instance, doc_key, window_size=window_size, num_window=num_window, 115 | max_num_mention=max_num_mention, max_num_cluster=max_num_cluster) 116 | doc_map[doc_idx] = doc_key 117 | if demo and doc_idx > 3: 118 | break 119 | with open(os.path.join(output_data_dir, "{}.{}.map".format(output_filename, language)), 'w') as fo: 120 | json.dump(doc_map, fo, indent=2) 121 | 122 | 123 | 124 | def write_instance_to_example_file(writer, instance, doc_key, window_size=64, num_window=5, max_num_mention=20, 125 | max_num_cluster=30, pad_idx=-1): 126 | 127 | input_ids, input_mask, text_len, speaker_ids, genre, gold_starts, gold_ends, cluster_ids, sentence_map = instance 128 | input_id_windows = input_ids 129 | mask_windows = input_mask 130 | flattened_input_ids = [i for j in input_id_windows for i in j] 131 | flattened_input_mask = [i for j in mask_windows for i in j] 132 | cluster_ids = [int(tmp) for tmp in cluster_ids] 133 | 134 | max_sequence_len = int(num_window) 135 | max_seg_len = int(window_size) 136 | 137 | sentence_map = clip_or_pad(sentence_map, max_sequence_len*max_seg_len, pad_idx=pad_idx) 138 | text_len = clip_or_pad(text_len, max_sequence_len, pad_idx=pad_idx) 139 | tmp_subtoken_maps = clip_or_pad(subtoken_maps[doc_key], max_sequence_len*max_seg_len, pad_idx=pad_idx) 140 | 141 | tmp_speaker_ids = clip_or_pad(speaker_ids[0], max_sequence_len*max_seg_len, pad_idx=pad_idx) 142 | 143 | flattened_input_ids = clip_or_pad(flattened_input_ids, max_sequence_len*max_seg_len, pad_idx=pad_idx) 144 | flattened_input_mask = clip_or_pad(flattened_input_mask, max_sequence_len*max_seg_len, pad_idx=pad_idx) 145 | gold_starts = clip_or_pad(gold_starts, max_num_mention, pad_idx=pad_idx) 146 | gold_ends = clip_or_pad(gold_ends, max_num_mention, pad_idx=pad_idx) 147 | cluster_ids = clip_or_pad(cluster_ids, max_num_cluster, pad_idx=pad_idx) 148 | 149 | features = OrderedDict() 150 | features['sentence_map'] = create_int_feature(sentence_map) 151 | features['text_len'] = create_int_feature(text_len) 152 | features['subtoken_map'] = create_int_feature(tmp_subtoken_maps) 153 | features['speaker_ids'] = create_int_feature(tmp_speaker_ids) 154 | features['flattened_input_ids'] = create_int_feature(flattened_input_ids) 155 | features['flattened_input_mask'] = create_int_feature(flattened_input_mask) 156 | features['span_starts'] = create_int_feature(gold_starts) 157 | features['span_ends'] = create_int_feature(gold_ends) 158 | features['cluster_ids'] = create_int_feature(cluster_ids) 159 | 160 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 161 | writer.write(tf_example.SerializeToString()) 162 | 163 | 164 | def create_int_feature(values): 165 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 166 | return feature 167 | 168 | 169 | def clip_or_pad(var, max_var_len, pad_idx=-1): 170 | 171 | if len(var) >= max_var_len: 172 | return var[:max_var_len] 173 | else: 174 | pad_var = (max_var_len - len(var)) * [pad_idx] 175 | var = list(var) + list(pad_var) 176 | return var 177 | 178 | 179 | def flatten_clusters(clusters): 180 | 181 | span_starts = [] 182 | span_ends = [] 183 | cluster_ids = [] 184 | mention_span = [] 185 | for cluster_id, cluster in enumerate(clusters): 186 | for start, end in cluster: 187 | span_starts.append(start) 188 | span_ends.append(end) 189 | mention_span.append((start, end)) 190 | cluster_ids.append(cluster_id + 1) 191 | return span_starts, span_ends, mention_span, cluster_ids 192 | 193 | 194 | def read_conll_file(conll_file_path): 195 | documents = [] 196 | with open(conll_file_path, "r", encoding="utf-8") as fi: 197 | for line in fi: 198 | begin_document_match = re.match(conll.BEGIN_DOCUMENT_REGEX, line) 199 | if begin_document_match: 200 | doc_key = conll.get_doc_key(begin_document_match.group(1), begin_document_match.group(2)) 201 | documents.append((doc_key, [])) 202 | elif line.startswith("#end document"): 203 | continue 204 | else: 205 | documents[-1][1].append(line.strip()) 206 | return documents 207 | 208 | 209 | def parse_document(document, language): 210 | """ 211 | get basic information from one document annotation. 212 | :param document: 213 | :param language: english, chinese or arabic 214 | :return: 215 | """ 216 | doc_key = document[0] 217 | sentences = [[]] 218 | speakers = [] 219 | coreferences = [] 220 | word_idx = -1 221 | last_speaker = '' 222 | for line_id, line in enumerate(document[1]): 223 | row = line.split() 224 | sentence_end = len(row) == 0 225 | if not sentence_end: 226 | assert len(row) >= 12 227 | word_idx += 1 228 | word = normalize_word(row[3], language) 229 | sentences[-1].append(word) 230 | speaker = row[9] 231 | if speaker != last_speaker: 232 | speakers.append((word_idx, speaker)) 233 | last_speaker = speaker 234 | coreferences.append(row[-1]) 235 | else: 236 | sentences.append([]) 237 | clusters = coreference_annotations_to_clusters(coreferences) 238 | doc_info = {'doc_key': doc_key, 'sentences': sentences[: -1], 'speakers': speakers, 'clusters': clusters} 239 | return doc_info 240 | 241 | 242 | def normalize_word(word, language): 243 | if language == "arabic": 244 | word = word[:word.find("#")] 245 | if word == "/." or word == "/?": 246 | return word[1:] 247 | else: 248 | return word 249 | 250 | 251 | def coreference_annotations_to_clusters(annotations): 252 | """ 253 | convert coreference information to clusters 254 | :param annotations: 255 | :return: 256 | """ 257 | clusters = OrderedDict() 258 | coref_stack = OrderedDict() 259 | for word_idx, annotation in enumerate(annotations): 260 | if annotation == '-': 261 | continue 262 | for ann in annotation.split('|'): 263 | cluster_id = int(ann.replace('(', '').replace(')', '')) 264 | if ann[0] == '(' and ann[-1] == ')': 265 | if cluster_id not in clusters.keys(): 266 | clusters[cluster_id] = [(word_idx, word_idx)] 267 | else: 268 | clusters[cluster_id].append((word_idx, word_idx)) 269 | elif ann[0] == '(': 270 | if cluster_id not in coref_stack.keys(): 271 | coref_stack[cluster_id] = [word_idx] 272 | else: 273 | coref_stack[cluster_id].append(word_idx) 274 | elif ann[-1] == ')': 275 | span_start = coref_stack[cluster_id].pop() 276 | if cluster_id not in clusters.keys(): 277 | clusters[cluster_id] = [(span_start, word_idx)] 278 | else: 279 | clusters[cluster_id].append((span_start, word_idx)) 280 | else: 281 | raise NotImplementedError 282 | assert all([len(starts) == 0 for starts in coref_stack.values()]) 283 | return list(clusters.values()) 284 | 285 | 286 | def checkout_clusters(doc_info): 287 | words = [i for j in doc_info['sentences'] for i in j] 288 | clusters = [[' '.join(words[start: end + 1]) for start, end in cluster] for cluster in doc_info['clusters']] 289 | print(clusters) 290 | 291 | 292 | def tokenize_document(genres, doc_info, tokenizer): 293 | """ 294 | tokenize into sub tokens 295 | :param doc_info: 296 | :param tokenizer: 297 | max_doc_length: pad to max_doc_length 298 | :return: 299 | """ 300 | genres = {g: i for i, g in enumerate(genres)} 301 | sub_tokens = [] # all sub tokens of a document 302 | sentence_map = [] # collected tokenized tokens -> sentence id 303 | subtoken_map = [] # collected tokenized tokens -> original token id 304 | 305 | word_idx = -1 306 | 307 | for sentence_id, sentence in enumerate(doc_info['sentences']): 308 | for token in sentence: 309 | word_idx += 1 310 | word_tokens = tokenizer.tokenize(token) 311 | sub_tokens.extend(word_tokens) 312 | sentence_map.extend([sentence_id] * len(word_tokens)) 313 | subtoken_map.extend([word_idx] * len(word_tokens)) 314 | 315 | 316 | subtoken_maps[doc_info['doc_key']] = subtoken_map 317 | genre = genres.get(doc_info['doc_key'][:2], 0) 318 | speakers = {subtoken_map.index(word_index): tokenizer.tokenize(speaker) 319 | for word_index, speaker in doc_info['speakers']} 320 | clusters = [[(subtoken_map.index(start), len(subtoken_map) - 1 - subtoken_map[::-1].index(end)) 321 | for start, end in cluster] for cluster in doc_info['clusters']] 322 | tokenized_document = {'sub_tokens': sub_tokens, 'sentence_map': sentence_map, 'subtoken_map': subtoken_map, 323 | 'speakers': speakers, 'clusters': clusters, 'doc_key': doc_info['doc_key'], 324 | "genre": genre} 325 | return tokenized_document 326 | 327 | 328 | def convert_to_sliding_window(tokenized_document, sliding_window_size): 329 | """ 330 | construct sliding windows, allocate tokens and masks into each window 331 | :param tokenized_document: 332 | :param sliding_window_size: 333 | :return: 334 | """ 335 | expanded_tokens, expanded_masks = expand_with_speakers(tokenized_document) 336 | sliding_windows = construct_sliding_windows(len(expanded_tokens), sliding_window_size - 2) 337 | token_windows = [] # expanded tokens to sliding window 338 | mask_windows = [] # expanded masks to sliding window 339 | text_len = [] 340 | 341 | for window_start, window_end, window_mask in sliding_windows: 342 | original_tokens = expanded_tokens[window_start: window_end] 343 | original_masks = expanded_masks[window_start: window_end] 344 | window_masks = [-2 if w == 0 else o for w, o in zip(window_mask, original_masks)] 345 | one_window_token = ['[CLS]'] + original_tokens + ['[SEP]'] + ['[PAD]'] * ( 346 | sliding_window_size - 2 - len(original_tokens)) 347 | one_window_mask = [-3] + window_masks + [-3] + [-4] * (sliding_window_size - 2 - len(original_tokens)) 348 | token_calculate = [tmp for tmp in one_window_mask if tmp >= 0] 349 | text_len.append(len(token_calculate)) 350 | assert len(one_window_token) == sliding_window_size 351 | assert len(one_window_mask) == sliding_window_size 352 | token_windows.append(one_window_token) 353 | mask_windows.append(one_window_mask) 354 | assert len(tokenized_document['sentence_map']) == sum([i >= 0 for j in mask_windows for i in j]) 355 | 356 | text_len = np.array(text_len) 357 | return token_windows, mask_windows, text_len 358 | 359 | 360 | def expand_with_speakers(tokenized_document): 361 | """ 362 | add speaker name information 363 | :param tokenized_document: tokenized document information 364 | :return: 365 | """ 366 | expanded_tokens = [] 367 | expanded_masks = [] 368 | for token_idx, token in enumerate(tokenized_document['sub_tokens']): 369 | if token_idx in tokenized_document['speakers']: 370 | speaker = [SPEAKER_START] + tokenized_document['speakers'][token_idx] + [SPEAKER_END] 371 | expanded_tokens.extend(speaker) 372 | expanded_masks.extend([-1] * len(speaker)) 373 | expanded_tokens.append(token) 374 | expanded_masks.append(token_idx) 375 | return expanded_tokens, expanded_masks 376 | 377 | 378 | def construct_sliding_windows(sequence_length, sliding_window_size): 379 | """ 380 | construct sliding windows for BERT processing 381 | :param sequence_length: e.g. 9 382 | :param sliding_window_size: e.g. 4 383 | :return: [(0, 4, [1, 1, 1, 0]), (2, 6, [0, 1, 1, 0]), (4, 8, [0, 1, 1, 0]), (6, 9, [0, 1, 1])] 384 | """ 385 | sliding_windows = [] 386 | stride = int(sliding_window_size / 2) 387 | start_index = 0 388 | end_index = 0 389 | while end_index < sequence_length: 390 | end_index = min(start_index + sliding_window_size, sequence_length) 391 | left_value = 1 if start_index == 0 else 0 392 | right_value = 1 if end_index == sequence_length else 0 393 | mask = [left_value] * int(sliding_window_size / 4) + [1] * int(sliding_window_size / 2) \ 394 | + [right_value] * (sliding_window_size - int(sliding_window_size / 2) - int(sliding_window_size / 4)) 395 | mask = mask[: end_index - start_index] 396 | sliding_windows.append((start_index, end_index, mask)) 397 | start_index += stride 398 | assert sum([sum(window[2]) for window in sliding_windows]) == sequence_length 399 | return sliding_windows 400 | 401 | 402 | 403 | def parse_args(): 404 | parser = argparse.ArgumentParser() 405 | parser.add_argument("--source_files_dir", default="/home/lixiaoya/data", type=str, required=True) 406 | parser.add_argument("--target_output_dir", default="/home/lixiaoya/tfrecord_data", type=str, required=True) 407 | parser.add_argument("--num_window", default=5, type=int, required=True) 408 | parser.add_argument("--window_size", default=64, type=int, required=True) 409 | parser.add_argument("--max_num_mention", default=30, type=int) 410 | parser.add_argument("--max_num_cluster", default=20, type=int) 411 | parser.add_argument("--vocab_file", default="/home/lixiaoya/spanbert_large_cased/vocab.txt", type=str) 412 | parser.add_argument("--language", default="english", type=str) 413 | parser.add_argument("--max_doc_length", default=600, type=int) 414 | parser.add_argument("--lowercase", help="DO or NOT lowercase the datasets.", action="store_true") 415 | parser.add_argument("--demo", help="Wether to generate a small dataset for testing the code.", action="store_true") 416 | parser.add_argument('--genres', default=["bc","bn","mz","nw","pt","tc","wb"]) 417 | parser.add_argument("--seed", default=2333, type=int) 418 | 419 | args = parser.parse_args() 420 | 421 | os.makedirs(args.target_output_dir, exist_ok=True) 422 | np.random.seed(args.seed) 423 | tf.set_random_seed(args.seed) 424 | 425 | return args 426 | 427 | 428 | def main(): 429 | args_config = parse_args() 430 | 431 | print("*"*60) 432 | print("***** ***** show configs ***** ***** ") 433 | print("window_size : {}".format(str(args_config.window_size))) 434 | print("num_window : {}".format(str(args_config.num_window))) 435 | print("*"*60) 436 | 437 | for data_sign in ["train", "dev", "test"]: 438 | source_data_file = os.path.join(args_config.source_files_dir, "{}.{}.v4_gold_conll".format(data_sign, args_config.language)) 439 | output_filename = "{}.overlap.corefqa".format(data_sign) 440 | 441 | if args_config.demo: 442 | if args_config.lowercase: 443 | output_filename="demo.lowercase.{}.overlap.corefqa".format(data_sign) 444 | else: 445 | output_filename="demo.{}.overlap.corefqa".format(data_sign) 446 | 447 | print("$"*60) 448 | print("generate {}/{}".format(args_config.target_output_dir, output_filename)) 449 | prepare_train_dataset(source_data_file, args_config.target_output_dir, output_filename, args_config.window_size, 450 | args_config.num_window, vocab_file=args_config.vocab_file, language=args_config.language, 451 | max_doc_length=args_config.max_doc_length, genres=args_config.genres, max_num_mention=args_config.max_num_mention, 452 | max_num_cluster=args_config.max_num_cluster, demo=args_config.demo, lowercase=args_config.lowercase) 453 | 454 | 455 | 456 | 457 | if __name__ == "__main__": 458 | main() 459 | 460 | # please refer ${REPO_PATH}/scripts/data/generate_tfrecord_dataset.sh 461 | # 462 | # for generate tfrecord datasets 463 | # 464 | # python3 build_dataset_to_tfrecord.py \ 465 | # --source_files_dir /xiaoya/data \ 466 | # --target_output_dir /xiaoya/corefqa_data/overlap_64_2 \ 467 | # --num_window 2 \ 468 | # --window_size 64 \ 469 | # --max_num_mention 50 \ 470 | # --max_num_cluster 30 \ 471 | # --vocab_file /xiaoya/pretrain_ckpt/cased_L-12_H-768_A-12/vocab.txt \ 472 | # --language english \ 473 | # --max_doc_length 600 474 | # 475 | 476 | 477 | 478 | --------------------------------------------------------------------------------