├── requirements.txt ├── .gitignore ├── setup_pretrained.sh ├── extract_bert_features.sh ├── coref_ops.py ├── ps.py ├── train_mgpu.sh ├── evaluate.py ├── setup_all.sh ├── get_char_vocab.py ├── filter_embeddings.py ├── setup_training.sh ├── predict.py ├── demo.py ├── continuous_evaluate.py ├── cache_elmo.py ├── experiments.conf ├── README.md ├── train.py ├── worker.py ├── conll.py ├── metrics.py ├── coref_kernels.cc ├── data.py ├── optimization.py ├── minimize.py ├── LICENSE ├── prepare_bert_data.py ├── extract_features.py ├── tokenization.py ├── util.py └── modeling.py /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.8.0 2 | nltk==3.4 3 | numpy==1.14.5 4 | pyhocon==0.3.48 5 | tensorflow-gpu==1.10.1 6 | termcolor==1.1.0 7 | tqdm==4.31.1 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | *.jsonlines 4 | logs 5 | conll-2012 6 | char_vocab*.txt 7 | glove*.txt 8 | glove*.txt.filtered 9 | *.v*_*_conll 10 | *.hdf5 11 | .python-version 12 | -------------------------------------------------------------------------------- /setup_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | curl -O http://lsz-gpu-01.cs.washington.edu/resources/coref/char_vocab.english.txt 4 | 5 | ckpt_file=c2f_final.tgz 6 | curl -O http://lsz-gpu-01.cs.washington.edu/resources/coref/$ckpt_file 7 | mkdir -p logs 8 | tar -xzvf $ckpt_file -C logs 9 | rm $ckpt_file 10 | -------------------------------------------------------------------------------- /extract_bert_features.sh: -------------------------------------------------------------------------------- 1 | export BERT_MODEL_PATH="PATH TO BERT MODEL cased_L-24_H-1024_A-16" 2 | PYTHONPATH=. python extract_features.py --input_file=./ --output_file=./bert_features.hdf5 --bert_config_file $BERT_MODEL_PATH/bert_config.json --init_checkpoint $BERT_MODEL_PATH/bert_model.ckpt --vocab_file $BERT_MODEL_PATH/vocab.txt --do_lower_case=False --stride 1 --window_size 129 3 | -------------------------------------------------------------------------------- /coref_ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from tensorflow.python import pywrap_tensorflow 7 | 8 | coref_op_library = tf.load_op_library("./coref_kernels.so") 9 | 10 | extract_spans = coref_op_library.extract_spans 11 | tf.NotDifferentiable("ExtractSpans") 12 | -------------------------------------------------------------------------------- /ps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | 5 | import tensorflow as tf 6 | import util 7 | 8 | if __name__ == "__main__": 9 | args = util.get_args() 10 | config = util.initialize_from_env(args.experiment, args.logdir) 11 | report_frequency = config["report_frequency"] 12 | cluster_config = util.get_cluster_config() 13 | util.set_gpus() 14 | cluster = tf.train.ClusterSpec(cluster_config["addresses"]) 15 | server = tf.train.Server(cluster, job_name="ps", task_index=0) 16 | server.join() 17 | -------------------------------------------------------------------------------- /train_mgpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | TRAIN_GPUS=$1 3 | EVAL_GPU=$2 4 | NAME=$3 5 | 6 | IFS=',' read -r -a TRAIN_GPUS_ARRAY <<< "$TRAIN_GPUS" 7 | tmux new-session -s "$NAME" -n "ps" -d "bash" 8 | tmux send-keys -t "$NAME:ps" "GPUS=$TRAIN_GPUS python ps.py $NAME" Enter 9 | I=0 10 | for GPU in ${TRAIN_GPUS_ARRAY[@]} 11 | do 12 | tmux new-window -n "worker $I" "bash" 13 | tmux send-keys -t "$NAME:worker $I" "GPUS=$TRAIN_GPUS TASK=$I python worker.py $NAME" Enter 14 | I=$((I+1)) 15 | done 16 | tmux new-window -n "eval" "bash" 17 | tmux send-keys -t "$NAME:eval" "GPU=$EVAL_GPU python continuous_evaluate.py $NAME" Enter 18 | tmux -2 attach-session -d 19 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | 8 | import util 9 | import coref_model as cm 10 | 11 | if __name__ == "__main__": 12 | args = util.get_args() 13 | config = util.initialize_from_env(args.experiment, args.logdir) 14 | config["eval_path"] = "test.english.jsonlines" 15 | config["conll_eval_path"] = "test.english.v4_gold_conll" 16 | config["context_embeddings"]["path"] = "glove.840B.300d.txt" 17 | 18 | model = cm.CorefModel(config, eval_mode=True) 19 | with tf.Session() as session: 20 | model.restore(session, args.latest_checkpoint) 21 | model.evaluate(session, official_stdout=True, pprint=False, test=True) 22 | -------------------------------------------------------------------------------- /setup_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download pretrained embeddings. 4 | curl -O http://lsz-gpu-01.cs.washington.edu/resources/glove_50_300_2.txt 5 | curl -O http://downloads.cs.stanford.edu/nlp/data/glove.840B.300d.zip 6 | unzip glove.840B.300d.zip 7 | rm glove.840B.300d.zip 8 | 9 | # Build custom kernels. 10 | TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) 11 | TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) 12 | 13 | # Linux (pip) 14 | g++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -D_GLIBCXX_USE_CXX11_ABI=0 15 | 16 | # Linux (build from source) 17 | #g++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 18 | 19 | # Mac 20 | #g++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -I -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -D_GLIBCXX_USE_CXX11_ABI=0 -undefined dynamic_lookup 21 | -------------------------------------------------------------------------------- /get_char_vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import json 7 | 8 | def get_char_vocab(input_filenames, output_filename): 9 | vocab = set() 10 | for filename in input_filenames: 11 | with open(filename) as f: 12 | for line in f.readlines(): 13 | for sentence in json.loads(line)["sentences"]: 14 | for word in sentence: 15 | vocab.update(word) 16 | vocab = sorted(list(vocab)) 17 | with open(output_filename, "w") as f: 18 | for char in vocab: 19 | f.write(u"{}\n".format(char).encode("utf8")) 20 | print("Wrote {} characters to {}".format(len(vocab), output_filename)) 21 | 22 | def get_char_vocab_language(language): 23 | get_char_vocab(["{}.{}.jsonlines".format(partition, language) for partition in ("train", "dev", "test")], "char_vocab.{}.txt".format(language)) 24 | 25 | get_char_vocab_language("english") 26 | get_char_vocab_language("chinese") 27 | get_char_vocab_language("arabic") 28 | -------------------------------------------------------------------------------- /filter_embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import json 7 | 8 | if __name__ == "__main__": 9 | if len(sys.argv) < 3: 10 | sys.exit("Usage: {} ...".format(sys.argv[0])) 11 | 12 | words_to_keep = set() 13 | for json_filename in sys.argv[2:]: 14 | with open(json_filename) as json_file: 15 | for line in json_file.readlines(): 16 | for sentence in json.loads(line)["sentences"]: 17 | words_to_keep.update(sentence) 18 | 19 | print("Found {} words in {} dataset(s).".format(len(words_to_keep), len(sys.argv) - 2)) 20 | 21 | total_lines = 0 22 | kept_lines = 0 23 | out_filename = "{}.filtered".format(sys.argv[1]) 24 | with open(sys.argv[1]) as in_file: 25 | with open(out_filename, "w") as out_file: 26 | for line in in_file.readlines(): 27 | total_lines += 1 28 | word = line.split()[0] 29 | if word in words_to_keep: 30 | kept_lines += 1 31 | out_file.write(line) 32 | 33 | print("Kept {} out of {} lines.".format(kept_lines, total_lines)) 34 | print("Wrote result to {}.".format(out_filename)) 35 | -------------------------------------------------------------------------------- /setup_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dlx() { 4 | wget $1/$2 5 | tar -xvzf $2 6 | rm $2 7 | } 8 | 9 | conll_url=http://conll.cemantix.org/2012/download 10 | dlx $conll_url conll-2012-train.v4.tar.gz 11 | dlx $conll_url conll-2012-development.v4.tar.gz 12 | dlx $conll_url/test conll-2012-test-key.tar.gz 13 | dlx $conll_url/test conll-2012-test-official.v9.tar.gz 14 | 15 | dlx $conll_url conll-2012-scripts.v3.tar.gz 16 | 17 | dlx http://conll.cemantix.org/download reference-coreference-scorers.v8.01.tar.gz 18 | mv reference-coreference-scorers conll-2012/scorer 19 | 20 | ontonotes_path=/projects/WebWare6/ontonotes-release-5.0 21 | bash conll-2012/v3/scripts/skeleton2conll.sh -D $ontonotes_path/data/files/data conll-2012 22 | 23 | function compile_partition() { 24 | rm -f $2.$5.$3$4 25 | cat conll-2012/$3/data/$1/data/$5/annotations/*/*/*/*.$3$4 >> $2.$5.$3$4 26 | } 27 | 28 | function compile_language() { 29 | compile_partition development dev v4 _gold_conll $1 30 | compile_partition train train v4 _gold_conll $1 31 | compile_partition test test v4 _gold_conll $1 32 | } 33 | 34 | compile_language english 35 | compile_language chinese 36 | compile_language arabic 37 | 38 | python minimize.py 39 | python get_char_vocab.py 40 | 41 | python filter_embeddings.py glove.840B.300d.txt train.english.jsonlines dev.english.jsonlines 42 | python cache_elmo.py train.english.jsonlines dev.english.jsonlines 43 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import json 7 | 8 | import tensorflow as tf 9 | import coref_model as cm 10 | import util 11 | 12 | if __name__ == "__main__": 13 | config = util.initialize_from_env() 14 | 15 | # Input file in .jsonlines format. 16 | input_filename = sys.argv[2] 17 | 18 | # Predictions will be written to this file in .jsonlines format. 19 | output_filename = sys.argv[3] 20 | 21 | model = cm.CorefModel(config) 22 | 23 | with tf.Session() as session: 24 | model.restore(session) 25 | 26 | with open(output_filename, "w") as output_file: 27 | with open(input_filename) as input_file: 28 | for example_num, line in enumerate(input_file.readlines()): 29 | example = json.loads(line) 30 | tensorized_example = model.tensorize_example(example, is_training=False) 31 | feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)} 32 | _, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run(model.predictions, feed_dict=feed_dict) 33 | predicted_antecedents = model.get_predicted_antecedents(top_antecedents, top_antecedent_scores) 34 | example["predicted_clusters"], _ = model.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents) 35 | 36 | output_file.write(json.dumps(example)) 37 | output_file.write("\n") 38 | if example_num % 100 == 0: 39 | print("Decoded {} examples.".format(example_num + 1)) 40 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from six.moves import input 6 | import tensorflow as tf 7 | import coref_model as cm 8 | import util 9 | 10 | import nltk 11 | nltk.download("punkt") 12 | from nltk.tokenize import sent_tokenize, word_tokenize 13 | 14 | def create_example(text): 15 | raw_sentences = sent_tokenize(text) 16 | sentences = [word_tokenize(s) for s in raw_sentences] 17 | speakers = [["" for _ in sentence] for sentence in sentences] 18 | return { 19 | "doc_key": "nw", 20 | "clusters": [], 21 | "sentences": sentences, 22 | "speakers": speakers, 23 | } 24 | 25 | def print_predictions(example): 26 | words = util.flatten(example["sentences"]) 27 | for cluster in example["predicted_clusters"]: 28 | print(u"Predicted cluster: {}".format([" ".join(words[m[0]:m[1]+1]) for m in cluster])) 29 | 30 | def make_predictions(text, model): 31 | example = create_example(text) 32 | tensorized_example = model.tensorize_example(example, is_training=False) 33 | feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)} 34 | _, _, _, mention_starts, mention_ends, antecedents, antecedent_scores, head_scores = session.run(model.predictions + [model.head_scores], feed_dict=feed_dict) 35 | 36 | predicted_antecedents = model.get_predicted_antecedents(antecedents, antecedent_scores) 37 | 38 | example["predicted_clusters"], _ = model.get_predicted_clusters(mention_starts, mention_ends, predicted_antecedents) 39 | example["top_spans"] = zip((int(i) for i in mention_starts), (int(i) for i in mention_ends)) 40 | example["head_scores"] = head_scores.tolist() 41 | return example 42 | 43 | if __name__ == "__main__": 44 | args = util.get_args() 45 | config = util.initialize_from_env(args.experiment, args.logdir) 46 | model = cm.CorefModel(config) 47 | with tf.Session() as session: 48 | model.restore(session) 49 | while True: 50 | text = input("Document text: ") 51 | print_predictions(make_predictions(text, model)) 52 | -------------------------------------------------------------------------------- /continuous_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import re 8 | import time 9 | import shutil 10 | 11 | import tensorflow as tf 12 | import coref_model as cm 13 | import util 14 | 15 | def copy_checkpoint(source, target): 16 | for ext in (".index", ".data-00000-of-00001"): 17 | shutil.copyfile(source + ext, target + ext) 18 | 19 | if __name__ == "__main__": 20 | args = util.get_args() 21 | config = util.initialize_from_env(args.experiment, args.logdir) 22 | model = cm.CorefModel(config, eval_mode=True) 23 | 24 | log_dir = config["log_dir"] 25 | writer = tf.summary.FileWriter(log_dir, flush_secs=20) 26 | evaluated_checkpoints = set() 27 | max_f1 = 0 28 | checkpoint_pattern = re.compile(".*model.ckpt-([0-9]*)\Z") 29 | 30 | with tf.Session() as session: 31 | while True: 32 | ckpt = tf.train.get_checkpoint_state(log_dir) 33 | if ckpt and ckpt.model_checkpoint_path and ckpt.model_checkpoint_path not in evaluated_checkpoints: 34 | print("Evaluating {}".format(ckpt.model_checkpoint_path)) 35 | 36 | # Move it to a temporary location to avoid being deleted by the training supervisor. 37 | tmp_checkpoint_path = os.path.join(log_dir, "model.tmp.ckpt") 38 | copy_checkpoint(ckpt.model_checkpoint_path, tmp_checkpoint_path) 39 | 40 | global_step = int(checkpoint_pattern.match(ckpt.model_checkpoint_path).group(1)) 41 | model.restore(session, latest_checkpoint=True) 42 | 43 | eval_summary, f1 = model.evaluate(session) 44 | 45 | if f1 > max_f1: 46 | max_f1 = f1 47 | copy_checkpoint(tmp_checkpoint_path, os.path.join(log_dir, "model.max.ckpt")) 48 | 49 | print("Current max F1: {:.2f}".format(max_f1)) 50 | 51 | writer.add_summary(eval_summary, global_step) 52 | print("Evaluation written to {} at step {}".format(log_dir, global_step)) 53 | 54 | evaluated_checkpoints.add(ckpt.model_checkpoint_path) 55 | sleep_time = 60 56 | else: 57 | sleep_time = 10 58 | print("Waiting for {} seconds before looking for next checkpoint.".format(sleep_time)) 59 | time.sleep(sleep_time) 60 | -------------------------------------------------------------------------------- /cache_elmo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow_hub as hub 8 | import h5py 9 | import json 10 | import sys 11 | 12 | def build_elmo(): 13 | token_ph = tf.placeholder(tf.string, [None, None]) 14 | len_ph = tf.placeholder(tf.int32, [None]) 15 | elmo_module = hub.Module("https://tfhub.dev/google/elmo/2") 16 | lm_embeddings = elmo_module( 17 | inputs={"tokens": token_ph, "sequence_len": len_ph}, 18 | signature="tokens", as_dict=True) 19 | word_emb = lm_embeddings["word_emb"] 20 | lm_emb = tf.stack([tf.concat([word_emb, word_emb], -1), 21 | lm_embeddings["lstm_outputs1"], 22 | lm_embeddings["lstm_outputs2"]], -1) 23 | return token_ph, len_ph, lm_emb 24 | 25 | def cache_dataset(data_path, session, token_ph, len_ph, lm_emb, out_file): 26 | with open(data_path) as in_file: 27 | for doc_num, line in enumerate(in_file.readlines()): 28 | example = json.loads(line) 29 | sentences = example["sentences"] 30 | max_sentence_length = max(len(s) for s in sentences) 31 | tokens = [[""] * max_sentence_length for _ in sentences] 32 | text_len = np.array([len(s) for s in sentences]) 33 | for i, sentence in enumerate(sentences): 34 | for j, word in enumerate(sentence): 35 | tokens[i][j] = word 36 | tokens = np.array(tokens) 37 | tf_lm_emb = session.run(lm_emb, feed_dict={ 38 | token_ph: tokens, 39 | len_ph: text_len 40 | }) 41 | file_key = example["doc_key"].replace("/", ":") 42 | group = out_file.create_group(file_key) 43 | for i, (e, l) in enumerate(zip(tf_lm_emb, text_len)): 44 | e = e[:l, :, :] 45 | group[str(i)] = e 46 | if doc_num % 10 == 0: 47 | print("Cached {} documents in {}".format(doc_num + 1, data_path)) 48 | 49 | if __name__ == "__main__": 50 | token_ph, len_ph, lm_emb = build_elmo() 51 | with tf.Session() as session: 52 | session.run(tf.global_variables_initializer()) 53 | with h5py.File("elmo_cache.hdf5", "w") as out_file: 54 | for json_filename in sys.argv[1:]: 55 | cache_dataset(json_filename, session, token_ph, len_ph, lm_emb, out_file) 56 | -------------------------------------------------------------------------------- /experiments.conf: -------------------------------------------------------------------------------- 1 | # Word embeddings. 2 | glove_300d { 3 | path = glove.840B.300d.txt 4 | size = 300 5 | } 6 | glove_300d_filtered { 7 | path = glove.840B.300d.txt.filtered 8 | size = 300 9 | } 10 | glove_300d_2w { 11 | path = glove_50_300_2.txt 12 | size = 300 13 | } 14 | 15 | # Distributed training configurations. 16 | two_local_gpus { 17 | addresses { 18 | ps = [localhost:2222] 19 | worker = [localhost:2223, localhost:2224, localhost:2225, localhost:2226] 20 | } 21 | gpus = [0, 1, 2, 3] 22 | } 23 | 24 | # Main configuration. 25 | best { 26 | # Computation limits. 27 | max_top_antecedents = 50 28 | max_training_sentences = 50 29 | top_span_ratio = 0.4 30 | 31 | # Model hyperparameters. 32 | filter_widths = [3, 4, 5] 33 | filter_size = 50 34 | char_embedding_size = 8 35 | char_vocab_path = "char_vocab.english.txt" 36 | context_embeddings = ${glove_300d_filtered} 37 | head_embeddings = ${glove_300d_2w} 38 | contextualization_size = 200 39 | contextualization_layers = 3 40 | ffnn_size = 150 41 | ffnn_depth = 2 42 | feature_size = 20 43 | max_span_width = 30 44 | use_metadata = true 45 | use_features = true 46 | model_heads = true 47 | coref_depth = 2 48 | lm_layers = 4 49 | lm_size = 1024 50 | coarse_to_fine = true 51 | refinement_sharing = false 52 | 53 | # Learning hyperparameters. 54 | max_gradient_norm = 5.0 55 | lstm_dropout_rate = 0.4 56 | lexical_dropout_rate = 0.5 57 | dropout_rate = 0.2 58 | optimizer = adam 59 | learning_rate = 0.001 60 | decay_rate = 1.0 61 | decay_frequency = 100 62 | ema_decay = 0.9999 63 | 64 | # Other. 65 | train_path = train.english.jsonlines 66 | eval_path = dev.english.jsonlines 67 | conll_eval_path = dev.english.v4_gold_conll 68 | lm_path = bert_features.hdf5 69 | genres = ["bc", "bn", "mz", "nw", "pt", "tc", "wb"] 70 | eval_frequency = 5000 71 | # eval_frequency = 1 72 | report_frequency = 100 73 | log_root = logs 74 | cluster = ${two_local_gpus} 75 | multi_gpu = false 76 | gold_loss = false 77 | b3_loss = false 78 | mention_loss = false 79 | antecedent_loss = true 80 | 81 | # Entity Equalization 82 | entity_equalization = true 83 | antecedent_averaging = false 84 | use_cluster_size = true 85 | entity_average = false 86 | } 87 | 88 | entity_equalization = ${best} 89 | 90 | baseline = ${best} { 91 | decay_rate = 0.999 92 | entity_equalization = false 93 | antecedent_averaging = true 94 | ema_decay = 1.0 95 | refinement_sharing = true 96 | } 97 | 98 | antecedent_averaging = ${baseline} 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Coreference Resolution with Entity Equalization 2 | 3 | ## Introduction 4 | This repository contains the code for replicating results from 5 | 6 | * [Coreference Resolution with Entity Equalization](https://www.aclweb.org/anthology/P19-1066) 7 | * In ACL 2019 8 | * The baseline model is from the paper [Higher-order Coreference Resolution with Coarse-to-fine Inference](https://arxiv.org/abs/1804.05392) 9 | * Code for baseline model: [https://github.com/kentonl/e2e-coref](https://github.com/kentonl/e2e-coref) 10 | 11 | ## Getting Started 12 | 13 | * Install python (either 2 or 3) requirements: `pip install -r requirements.txt` 14 | * Download GloVe embeddings and build custom kernels by running `setup_all.sh`. 15 | * There are 3 platform-dependent ways to build custom TensorFlow kernels. Please comment/uncomment the appropriate lines in the script. 16 | * To train your own models, run `setup_training.sh`and `extract_bert_features.sh` 17 | * This assumes access to OntoNotes 5.0. Please edit the `ontonotes_path` variable. 18 | 19 | ## Training Instructions 20 | 21 | * Experiment configurations are found in `experiments.conf` 22 | * Choose an experiment that you would like to run, e.g. `best` 23 | * Training: `python train.py ` 24 | * Results are stored in the `logs` directory and can be viewed via TensorBoard. 25 | * Evaluation: `python evaluate.py ` 26 | 27 | ## Demo Instructions 28 | 29 | * Command-line demo: `python demo.py final` 30 | * To run the demo with other experiments, replace `final` with your configuration name. 31 | 32 | ## Batched Prediction Instructions 33 | 34 | * Create a file where each line is in the following json format (make sure to strip the newlines so each line is well-formed json): 35 | ``` 36 | { 37 | "clusters": [], 38 | "doc_key": "nw", 39 | "sentences": [["This", "is", "the", "first", "sentence", "."], ["This", "is", "the", "second", "."]], 40 | "speakers": [["spk1", "spk1", "spk1", "spk1", "spk1", "spk1"], ["spk2", "spk2", "spk2", "spk2", "spk2"]] 41 | } 42 | ``` 43 | * `clusters` should be left empty and is only used for evaluation purposes. 44 | * `doc_key` indicates the genre, which can be one of the following: `"bc", "bn", "mz", "nw", "pt", "tc", "wb"` 45 | * `speakers` indicates the speaker of each word. These can be all empty strings if there is only one known speaker. 46 | * Run `python predict.py `, which outputs the input jsonlines with predicted clusters. 47 | 48 | ## Other Quirks 49 | 50 | * It does not use GPUs by default. Instead, it looks for the `GPU` environment variable, which the code treats as shorthand for `CUDA_VISIBLE_DEVICES`. 51 | * The training runs indefinitely and needs to be terminated manually. The model generally converges at about 400k steps. 52 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import time 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import util 12 | import coref_model as cm 13 | 14 | if __name__ == "__main__": 15 | # tf.logging.set_verbosity(tf.logging.INFO) 16 | 17 | args = util.get_args() 18 | config = util.initialize_from_env(args.experiment, args.logdir) 19 | 20 | report_frequency = config["report_frequency"] 21 | eval_frequency = config["eval_frequency"] 22 | 23 | model = cm.CorefModel(config) 24 | 25 | print('# parameters:', np.sum([np.prod(v.get_shape().as_list()) for v in model.trainable_variables])) 26 | saver = tf.train.Saver() 27 | initial_step = 0 28 | 29 | log_dir = config["log_dir"] 30 | writer = tf.summary.FileWriter(log_dir, flush_secs=20) 31 | 32 | max_f1 = 0 33 | 34 | with tf.Session() as session: 35 | session.run(tf.global_variables_initializer()) 36 | model.start_enqueue_thread(session) 37 | accumulated_loss = 0.0 38 | 39 | ckpt = tf.train.get_checkpoint_state(log_dir) 40 | if ckpt and ckpt.model_checkpoint_path: 41 | print("Restoring from: {}".format(ckpt.model_checkpoint_path)) 42 | saver.restore(session, ckpt.model_checkpoint_path) 43 | initial_step = session.run(model.global_step) 44 | 45 | initial_time = time.time() 46 | while True: 47 | tf_loss, tf_global_step, _ = session.run([model.loss, model.global_step, model.train_op]) 48 | accumulated_loss += tf_loss 49 | 50 | if tf_global_step % report_frequency == 0: 51 | steps_per_second = (tf_global_step - initial_step) / (time.time() - initial_time) 52 | 53 | average_loss = accumulated_loss / report_frequency 54 | print("[{}] loss={:.2f}, steps/s={:.2f}".format(tf_global_step, 55 | average_loss, 56 | steps_per_second)) 57 | writer.add_summary(util.make_summary({"loss": average_loss, 58 | "learning_rate": session.run(model.learning_rate)}), tf_global_step) 59 | accumulated_loss = 0.0 60 | initial_time = time.time() 61 | initial_step = tf_global_step 62 | 63 | if tf_global_step % eval_frequency == 0: 64 | saver.save(session, os.path.join(log_dir, "model"), global_step=tf_global_step) 65 | eval_summary, eval_f1 = model.evaluate(session) 66 | 67 | if eval_f1 > max_f1: 68 | max_f1 = eval_f1 69 | util.copy_checkpoint(os.path.join(log_dir, "model-{}".format(tf_global_step)), os.path.join(log_dir, "model.max.ckpt")) 70 | 71 | writer.add_summary(eval_summary, tf_global_step) 72 | writer.add_summary(util.make_summary({"max_eval_f1": max_f1}), tf_global_step) 73 | 74 | print("[{}] evaL_f1={:.2f}, max_f1={:.2f}".format(tf_global_step, eval_f1, max_f1)) 75 | -------------------------------------------------------------------------------- /worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import time 9 | 10 | import tensorflow as tf 11 | import coref_model as cm 12 | import util 13 | 14 | if __name__ == "__main__": 15 | args = util.get_args() 16 | config = util.initialize_from_env(args.experiment, args.logdir) 17 | task_index = int(os.environ["TASK"]) 18 | 19 | report_frequency = config["report_frequency"] 20 | cluster_config = util.get_cluster_config() 21 | 22 | util.set_gpus(cluster_config["gpus"][task_index]) 23 | 24 | cluster = tf.train.ClusterSpec(cluster_config["addresses"]) 25 | server = tf.train.Server(cluster, 26 | job_name="worker", 27 | task_index=task_index) 28 | 29 | # Assigns ops to the local worker by default. 30 | with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % task_index, cluster=cluster)): 31 | model = cm.CorefModel(config) 32 | saver = tf.train.Saver() 33 | init_op = tf.global_variables_initializer() 34 | 35 | log_dir = config["log_dir"] 36 | writer = tf.summary.FileWriter(os.path.join(log_dir, "w{}".format(task_index)), flush_secs=20) 37 | 38 | is_chief = (task_index == 0) 39 | 40 | # Create a "supervisor", which oversees the training process. 41 | sv = tf.train.Supervisor(is_chief=is_chief, 42 | logdir=log_dir, 43 | init_op=init_op, 44 | saver=saver, 45 | global_step=model.global_step, 46 | save_model_secs=120) 47 | 48 | # The supervisor takes care of session initialization, restoring from 49 | # a checkpoint, and closing when done or an error occurs. 50 | with sv.managed_session(server.target) as session: 51 | model.start_enqueue_thread(session) 52 | accumulated_loss = 0.0 53 | local_steps = 0 54 | prev_report_global_steps = session.run(model.global_step) 55 | prev_report_time = time.time() 56 | while not sv.should_stop(): 57 | tf_loss, tf_global_step, _ = session.run([model.loss, model.global_step, model.train_op]) 58 | accumulated_loss += tf_loss 59 | local_steps += 1 60 | 61 | if local_steps == report_frequency: 62 | total_time = time.time() - prev_report_time 63 | steps_per_second = (tf_global_step - prev_report_global_steps) / total_time 64 | 65 | average_loss = accumulated_loss / report_frequency 66 | print("[{}] loss={:.2f}, steps/s={:.2f}".format(tf_global_step, average_loss, steps_per_second)) 67 | accumulated_loss = 0.0 68 | local_steps = 0 69 | prev_report_global_steps = tf_global_step 70 | prev_report_time = time.time() 71 | writer.add_summary(util.make_summary({ 72 | "loss": average_loss, 73 | "Steps per second": steps_per_second 74 | }), global_step=tf_global_step) 75 | 76 | # Ask for all the services to stop. 77 | sv.stop() 78 | -------------------------------------------------------------------------------- /conll.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import operator 6 | 7 | import collections 8 | import re 9 | import subprocess 10 | import tempfile 11 | 12 | BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)") 13 | COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL) 14 | 15 | def get_doc_key(doc_id, part): 16 | return "{}_{}".format(doc_id, int(part)) 17 | 18 | def output_conll(input_file, output_file, predictions): 19 | prediction_map = {} 20 | for doc_key, clusters in predictions.items(): 21 | start_map = collections.defaultdict(list) 22 | end_map = collections.defaultdict(list) 23 | word_map = collections.defaultdict(list) 24 | for cluster_id, mentions in enumerate(clusters): 25 | for start, end in mentions: 26 | if start == end: 27 | word_map[start].append(cluster_id) 28 | else: 29 | start_map[start].append((cluster_id, end)) 30 | end_map[end].append((cluster_id, start)) 31 | for k,v in start_map.items(): 32 | start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)] 33 | for k,v in end_map.items(): 34 | end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)] 35 | prediction_map[doc_key] = (start_map, end_map, word_map) 36 | 37 | word_index = 0 38 | for line in input_file.readlines(): 39 | row = line.split() 40 | if len(row) == 0: 41 | output_file.write("\n") 42 | elif row[0].startswith("#"): 43 | begin_match = re.match(BEGIN_DOCUMENT_REGEX, line) 44 | if begin_match: 45 | doc_key = get_doc_key(begin_match.group(1), begin_match.group(2)) 46 | start_map, end_map, word_map = prediction_map[doc_key] 47 | word_index = 0 48 | output_file.write(line) 49 | output_file.write("\n") 50 | else: 51 | assert get_doc_key(row[0], row[1]) == doc_key 52 | coref_list = [] 53 | if word_index in end_map: 54 | for cluster_id in end_map[word_index]: 55 | coref_list.append("{})".format(cluster_id)) 56 | if word_index in word_map: 57 | for cluster_id in word_map[word_index]: 58 | coref_list.append("({})".format(cluster_id)) 59 | if word_index in start_map: 60 | for cluster_id in start_map[word_index]: 61 | coref_list.append("({}".format(cluster_id)) 62 | 63 | if len(coref_list) == 0: 64 | row[-1] = "-" 65 | else: 66 | row[-1] = "|".join(coref_list) 67 | 68 | output_file.write(" ".join(row)) 69 | output_file.write("\n") 70 | word_index += 1 71 | 72 | def official_conll_eval(gold_path, predicted_path, metric, official_stdout=False): 73 | cmd = ["conll-2012/scorer/v8.01/scorer.pl", metric, gold_path, predicted_path, "none"] 74 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 75 | stdout, stderr = process.communicate() 76 | process.wait() 77 | 78 | stdout = stdout.decode("utf-8") 79 | if stderr is not None: 80 | print(stderr) 81 | 82 | if official_stdout: 83 | print("Official result for {}".format(metric)) 84 | print(stdout) 85 | 86 | coref_results_match = re.match(COREF_RESULTS_REGEX, stdout) 87 | recall = float(coref_results_match.group(1)) 88 | precision = float(coref_results_match.group(2)) 89 | f1 = float(coref_results_match.group(3)) 90 | return { "r": recall, "p": precision, "f": f1 } 91 | 92 | def evaluate_conll(gold_path, predictions, official_stdout=False): 93 | with tempfile.NamedTemporaryFile(delete=False, mode="w") as prediction_file: 94 | with open(gold_path, "r") as gold_file: 95 | output_conll(gold_file, prediction_file, predictions) 96 | print("Predicted conll file: {}".format(prediction_file.name)) 97 | return { m: official_conll_eval(gold_file.name, prediction_file.name, m, official_stdout) for m in ("muc", "bcub", "ceafe") } 98 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from collections import Counter 7 | from sklearn.utils.linear_assignment_ import linear_assignment 8 | 9 | 10 | def f1(p_num, p_den, r_num, r_den, beta=1): 11 | p = 0 if p_den == 0 else p_num / float(p_den) 12 | r = 0 if r_den == 0 else r_num / float(r_den) 13 | return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) 14 | 15 | class CorefEvaluator(object): 16 | def __init__(self): 17 | self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] 18 | 19 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 20 | for e in self.evaluators: 21 | e.update(predicted, gold, mention_to_predicted, mention_to_gold) 22 | 23 | def get_f1(self): 24 | return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) 25 | 26 | def get_recall(self): 27 | return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) 28 | 29 | def get_precision(self): 30 | return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) 31 | 32 | def get_prf(self): 33 | return self.get_precision(), self.get_recall(), self.get_f1() 34 | 35 | class Evaluator(object): 36 | def __init__(self, metric, beta=1): 37 | self.p_num = 0 38 | self.p_den = 0 39 | self.r_num = 0 40 | self.r_den = 0 41 | self.metric = metric 42 | self.beta = beta 43 | 44 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 45 | if self.metric == ceafe: 46 | pn, pd, rn, rd = self.metric(predicted, gold) 47 | else: 48 | pn, pd = self.metric(predicted, mention_to_gold) 49 | rn, rd = self.metric(gold, mention_to_predicted) 50 | self.p_num += pn 51 | self.p_den += pd 52 | self.r_num += rn 53 | self.r_den += rd 54 | 55 | def get_f1(self): 56 | return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) 57 | 58 | def get_recall(self): 59 | return 0 if self.r_num == 0 else self.r_num / float(self.r_den) 60 | 61 | def get_precision(self): 62 | return 0 if self.p_num == 0 else self.p_num / float(self.p_den) 63 | 64 | def get_prf(self): 65 | return self.get_precision(), self.get_recall(), self.get_f1() 66 | 67 | def get_counts(self): 68 | return self.p_num, self.p_den, self.r_num, self.r_den 69 | 70 | 71 | def evaluate_documents(documents, metric, beta=1): 72 | evaluator = Evaluator(metric, beta=beta) 73 | for document in documents: 74 | evaluator.update(document) 75 | return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1() 76 | 77 | 78 | def b_cubed(clusters, mention_to_gold): 79 | num, dem = 0, 0 80 | 81 | for c in clusters: 82 | if len(c) == 1: 83 | continue 84 | 85 | gold_counts = Counter() 86 | correct = 0 87 | for m in c: 88 | if m in mention_to_gold: 89 | gold_counts[tuple(mention_to_gold[m])] += 1 90 | for c2, count in gold_counts.items(): 91 | if len(c2) != 1: 92 | correct += count * count 93 | 94 | num += correct / float(len(c)) 95 | dem += len(c) 96 | 97 | return num, dem 98 | 99 | 100 | def muc(clusters, mention_to_gold): 101 | tp, p = 0, 0 102 | for c in clusters: 103 | p += len(c) - 1 104 | tp += len(c) 105 | linked = set() 106 | for m in c: 107 | if m in mention_to_gold: 108 | linked.add(mention_to_gold[m]) 109 | else: 110 | tp -= 1 111 | tp -= len(linked) 112 | return tp, p 113 | 114 | 115 | def phi4(c1, c2): 116 | return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) 117 | 118 | 119 | def ceafe_matching(clusters, gold_clusters): 120 | clusters = [c for c in clusters if len(c) != 1] 121 | scores = np.zeros((len(gold_clusters), len(clusters))) 122 | for i in range(len(gold_clusters)): 123 | for j in range(len(clusters)): 124 | scores[i, j] = phi4(gold_clusters[i], clusters[j]) 125 | return linear_assignment(-scores), scores 126 | 127 | 128 | def ceafe(clusters, gold_clusters): 129 | matching, scores = ceafe_matching(clusters, gold_clusters) 130 | similarity = sum(scores[matching[:, 0], matching[:, 1]]) 131 | return similarity, len(clusters), similarity, len(gold_clusters) 132 | 133 | 134 | def lea(clusters, mention_to_gold): 135 | num, dem = 0, 0 136 | 137 | for c in clusters: 138 | if len(c) == 1: 139 | continue 140 | 141 | common_links = 0 142 | all_links = len(c) * (len(c) - 1) / 2.0 143 | for i, m in enumerate(c): 144 | if m in mention_to_gold: 145 | for m2 in c[i + 1:]: 146 | if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: 147 | common_links += 1 148 | 149 | num += len(c) * common_links / float(all_links) 150 | dem += len(c) 151 | 152 | return num, dem 153 | -------------------------------------------------------------------------------- /coref_kernels.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "tensorflow/core/framework/op.h" 4 | #include "tensorflow/core/framework/shape_inference.h" 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | 7 | using namespace tensorflow; 8 | 9 | REGISTER_OP("ExtractSpans") 10 | .Input("span_scores: float32") 11 | .Input("candidate_starts: int32") 12 | .Input("candidate_ends: int32") 13 | .Input("num_output_spans: int32") 14 | .Input("max_sentence_length: int32") 15 | .Attr("sort_spans: bool") 16 | .Output("output_span_indices: int32"); 17 | 18 | class ExtractSpansOp : public OpKernel { 19 | public: 20 | explicit ExtractSpansOp(OpKernelConstruction* context) : OpKernel(context) { 21 | OP_REQUIRES_OK(context, context->GetAttr("sort_spans", &_sort_spans)); 22 | } 23 | 24 | void Compute(OpKernelContext* context) override { 25 | TTypes::ConstMatrix span_scores = context->input(0).matrix(); 26 | TTypes::ConstMatrix candidate_starts = context->input(1).matrix(); 27 | TTypes::ConstMatrix candidate_ends = context->input(2).matrix(); 28 | TTypes::ConstVec num_output_spans = context->input(3).vec(); 29 | int max_sentence_length = context->input(4).scalar()(); 30 | 31 | int num_sentences = span_scores.dimension(0); 32 | int num_input_spans = span_scores.dimension(1); 33 | int max_num_output_spans = 0; 34 | for (int i = 0; i < num_sentences; i++) { 35 | if (num_output_spans(i) > max_num_output_spans) { 36 | max_num_output_spans = num_output_spans(i); 37 | } 38 | } 39 | 40 | Tensor* output_span_indices_tensor = nullptr; 41 | TensorShape output_span_indices_shape({num_sentences, max_num_output_spans}); 42 | OP_REQUIRES_OK(context, context->allocate_output(0, output_span_indices_shape, 43 | &output_span_indices_tensor)); 44 | TTypes::Matrix output_span_indices = output_span_indices_tensor->matrix(); 45 | 46 | std::vector> sorted_input_span_indices(num_sentences, 47 | std::vector(num_input_spans)); 48 | for (int i = 0; i < num_sentences; i++) { 49 | std::iota(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(), 0); 50 | std::sort(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(), 51 | [&span_scores, &i](int j1, int j2) { 52 | return span_scores(i, j2) < span_scores(i, j1); 53 | }); 54 | } 55 | 56 | for (int l = 0; l < num_sentences; l++) { 57 | std::vector top_span_indices; 58 | std::unordered_map end_to_earliest_start; 59 | std::unordered_map start_to_latest_end; 60 | 61 | int current_span_index = 0, 62 | num_selected_spans = 0; 63 | while (num_selected_spans < num_output_spans(l) && current_span_index < num_input_spans) { 64 | int i = sorted_input_span_indices[l][current_span_index]; 65 | bool any_crossing = false; 66 | const int start = candidate_starts(l, i); 67 | const int end = candidate_ends(l, i); 68 | for (int j = start; j <= end; ++j) { 69 | auto latest_end_iter = start_to_latest_end.find(j); 70 | if (latest_end_iter != start_to_latest_end.end() && j > start && latest_end_iter->second > end) { 71 | // Given (), exists [], such that ( [ ) ] 72 | any_crossing = true; 73 | break; 74 | } 75 | auto earliest_start_iter = end_to_earliest_start.find(j); 76 | if (earliest_start_iter != end_to_earliest_start.end() && j < end && earliest_start_iter->second < start) { 77 | // Given (), exists [], such that [ ( ] ) 78 | any_crossing = true; 79 | break; 80 | } 81 | } 82 | if (!any_crossing) { 83 | if (_sort_spans) { 84 | top_span_indices.push_back(i); 85 | } else { 86 | output_span_indices(l, num_selected_spans) = i; 87 | } 88 | ++num_selected_spans; 89 | // Update data struct. 90 | auto latest_end_iter = start_to_latest_end.find(start); 91 | if (latest_end_iter == start_to_latest_end.end() || end > latest_end_iter->second) { 92 | start_to_latest_end[start] = end; 93 | } 94 | auto earliest_start_iter = end_to_earliest_start.find(end); 95 | if (earliest_start_iter == end_to_earliest_start.end() || start < earliest_start_iter->second) { 96 | end_to_earliest_start[end] = start; 97 | } 98 | } 99 | ++current_span_index; 100 | } 101 | // Sort and populate selected span indices. 102 | if (_sort_spans) { 103 | std::sort(top_span_indices.begin(), top_span_indices.end(), 104 | [&candidate_starts, &candidate_ends, &l] (int i1, int i2) { 105 | if (candidate_starts(l, i1) < candidate_starts(l, i2)) { 106 | return true; 107 | } else if (candidate_starts(l, i1) > candidate_starts(l, i2)) { 108 | return false; 109 | } else if (candidate_ends(l, i1) < candidate_ends(l, i2)) { 110 | return true; 111 | } else if (candidate_ends(l, i1) > candidate_ends(l, i2)) { 112 | return false; 113 | } else { 114 | return i1 < i2; 115 | } 116 | }); 117 | for (int i = 0; i < num_output_spans(l); ++i) { 118 | output_span_indices(l, i) = top_span_indices[i]; 119 | } 120 | } 121 | // Pad with the first span index. 122 | for (int i = num_selected_spans; i < max_num_output_spans; ++i) { 123 | output_span_indices(l, i) = output_span_indices(l, 0); 124 | } 125 | } 126 | } 127 | private: 128 | bool _sort_spans; 129 | }; 130 | 131 | REGISTER_KERNEL_BUILDER(Name("ExtractSpans").Device(DEVICE_CPU), ExtractSpansOp); 132 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import random 5 | 6 | import tokenization 7 | 8 | 9 | class Example(object): 10 | def __init__(self, doc_key, tokens, sentence_tokens, gold_starts, gold_ends, speaker_ids, cluster_ids, genre, document_index, 11 | offset=0, bert_to_orig_map=None): 12 | assert len(tokens) == len(speaker_ids) 13 | 14 | self.doc_key = doc_key 15 | self.tokens = tokens 16 | self.sentence_tokens = sentence_tokens 17 | self.gold_starts = gold_starts 18 | self.gold_ends = gold_ends 19 | self.speaker_ids = speaker_ids 20 | self.cluster_ids = cluster_ids 21 | self.genre = genre 22 | self.document_index = document_index 23 | self.offset = offset 24 | self.bert_to_orig_map = bert_to_orig_map 25 | 26 | def truncate(self, start, size): 27 | # don't truncate in the middle of a mention 28 | for mention in zip(self.gold_starts, self.gold_ends): 29 | if index_in_mention(start, mention): 30 | start = mention[0] 31 | 32 | if index_in_mention(start + size, mention): 33 | size -= start + size - mention[0] 34 | end = start + size 35 | 36 | tokens = self.tokens[start:end] 37 | sentence_tokens = None 38 | speaker_ids = self.speaker_ids[start:end] 39 | gold_spans = np.logical_and(self.gold_starts >= start, self.gold_ends < end) 40 | gold_starts = self.gold_starts[gold_spans] - start 41 | gold_ends = self.gold_ends[gold_spans] - start 42 | cluster_ids = self.cluster_ids[gold_spans] 43 | 44 | return Example(self.doc_key, tokens, sentence_tokens, gold_starts, gold_ends, speaker_ids, cluster_ids, 45 | self.genre, self.document_index, start) 46 | 47 | def bertify(self, tokenizer): 48 | assert self.offset == 0 49 | 50 | bert_tokens = [] 51 | orig_to_bert_map = [] 52 | orig_to_bert_end_map = [] 53 | bert_speaker_ids = [] 54 | for t, s in zip(self.tokens, self.speaker_ids): 55 | bert_t = tokenizer.tokenize(t) 56 | orig_to_bert_map.append(len(bert_tokens)) 57 | orig_to_bert_end_map.append(len(bert_tokens) + len(bert_t) - 1) 58 | bert_tokens.extend(bert_t) 59 | bert_speaker_ids.extend([s] * len(bert_t)) 60 | 61 | bert_sentence_tokens = [tokenizer.tokenize(' '.join(s)) for s in self.sentence_tokens] 62 | 63 | bert_to_orig_map = [-1] * len(bert_tokens) 64 | for i, bert_i in enumerate(orig_to_bert_map): 65 | bert_to_orig_map[bert_i] = i 66 | 67 | orig_to_bert_map = np.array(orig_to_bert_map) 68 | orig_to_bert_end_map = np.array(orig_to_bert_end_map) 69 | if len(self.gold_starts): 70 | gold_starts = orig_to_bert_map[self.gold_starts] 71 | gold_ends = orig_to_bert_end_map[self.gold_ends] 72 | else: 73 | gold_starts = self.gold_starts 74 | gold_ends = self.gold_ends 75 | 76 | return Example(self.doc_key, bert_tokens, bert_sentence_tokens, gold_starts, gold_ends, bert_speaker_ids, 77 | self.cluster_ids, self.genre, self.document_index, bert_to_orig_map=bert_to_orig_map) 78 | 79 | def unravel_token_index(self, token_index): 80 | prev_sentences_len = 0 81 | for i, s in enumerate(self.sentence_tokens): 82 | if token_index < prev_sentences_len + len(s): 83 | token_index_in_sentence = token_index - prev_sentences_len 84 | return i, token_index_in_sentence 85 | prev_sentences_len += len(s) 86 | 87 | raise ValueError('token_index is out of range ({} >= {})', token_index, len(self.tokens)) 88 | 89 | 90 | def index_in_mention(index, mention): 91 | return mention[0] <= index and mention[1] >= index 92 | 93 | 94 | def mention_contains(mention1, mention2): 95 | return mention1[0] <= mention2[0] and mention1[1] >= mention2[1] 96 | 97 | 98 | def filter_embedded_mentions(mentions): 99 | """ 100 | Filter out mentions embedded in other mentions 101 | """ 102 | filtered = [] 103 | for i, m in enumerate(mentions): 104 | other_mentions = mentions[:i] + mentions[i + 1:] 105 | if any(mention_contains(other_m, m) for other_m in other_mentions): 106 | continue 107 | filtered.append(m) 108 | return filtered 109 | 110 | 111 | def filter_overlapping_mentions(mentions): 112 | start_to_mentions = defaultdict(list) 113 | for m in mentions: 114 | start_to_mentions[m[0]].append(m) 115 | 116 | filtered_mentions = [] 117 | for ms in start_to_mentions.values(): 118 | if len(ms) > 1: 119 | pass 120 | max_mention = np.argmax([m[1] - m[0] for m in ms]) 121 | filtered_mentions.append(ms[max_mention]) 122 | 123 | return filtered_mentions 124 | 125 | 126 | def flatten(l): 127 | return [item for sublist in l for item in sublist] 128 | 129 | 130 | def tensorize_mentions(mentions): 131 | if len(mentions) > 0: 132 | starts, ends = zip(*mentions) 133 | else: 134 | starts, ends = [], [] 135 | return np.array(starts), np.array(ends) 136 | 137 | 138 | genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} 139 | 140 | def process_example(example, index, should_filter_embedded_mentions=False): 141 | clusters = example["clusters"] 142 | 143 | gold_mentions = sorted(tuple(m) for m in flatten(clusters)) 144 | if should_filter_embedded_mentions: 145 | gold_mentions = filter_overlapping_mentions(gold_mentions) 146 | # gold_mentions = filter_embedded_mentions(gold_mentions) 147 | gold_mention_map = {m: i for i, m in enumerate(gold_mentions)} 148 | cluster_ids = np.zeros(len(gold_mentions)) 149 | for cluster_id, cluster in enumerate(clusters): 150 | for mention in cluster: 151 | if tuple(mention) in gold_mention_map: 152 | cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + 1 153 | 154 | sentences = example["sentences"] 155 | num_words = sum(len(s) for s in sentences) 156 | speakers = flatten(example["speakers"]) 157 | 158 | assert num_words == len(speakers) 159 | 160 | sentence_tokens = [[tokenization.convert_to_unicode(w) for w in s] for s in sentences] 161 | 162 | tokens = sum(sentence_tokens, []) 163 | 164 | speaker_dict = {s: i for i, s in enumerate(set(speakers))} 165 | speaker_ids = np.array([speaker_dict[s] for s in speakers]) 166 | 167 | # TODO: genre 168 | doc_key = example["doc_key"] 169 | genre = genres[doc_key[:2]] 170 | 171 | gold_starts, gold_ends = tensorize_mentions(sorted(gold_mentions)) 172 | 173 | return Example(doc_key, tokens, sentence_tokens, gold_starts, gold_ends, speaker_ids, cluster_ids, genre, index) 174 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, tvars=None, global_step=None): 26 | """Creates an optimizer training op.""" 27 | if global_step is None: 28 | global_step = tf.train.get_or_create_global_step() 29 | 30 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 31 | 32 | # Implements linear decay of the learning rate. 33 | learning_rate = tf.train.polynomial_decay( 34 | learning_rate, 35 | global_step, 36 | num_train_steps, 37 | end_learning_rate=0.0, 38 | power=1.0, 39 | cycle=False) 40 | 41 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 42 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 43 | if num_warmup_steps: 44 | global_steps_int = tf.cast(global_step, tf.int32) 45 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 46 | 47 | global_steps_float = tf.cast(global_steps_int, tf.float32) 48 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 49 | 50 | warmup_percent_done = global_steps_float / warmup_steps_float 51 | warmup_learning_rate = init_lr * warmup_percent_done 52 | 53 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 54 | learning_rate = ( 55 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 56 | 57 | # It is recommended that you use this optimizer for fine tuning, since this 58 | # is how the model was trained (note that the Adam m/v variables are NOT 59 | # loaded from init_checkpoint.) 60 | optimizer = AdamWeightDecayOptimizer( 61 | learning_rate=learning_rate, 62 | weight_decay_rate=0.01, 63 | beta_1=0.9, 64 | beta_2=0.999, 65 | epsilon=1e-6, 66 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 67 | 68 | if tvars is None: 69 | tvars = tf.trainable_variables() 70 | grads = tf.gradients(loss, tvars) 71 | 72 | # This is how the model was pre-trained. 73 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 74 | 75 | train_op = optimizer.apply_gradients( 76 | zip(grads, tvars), global_step=global_step) 77 | 78 | # Normally the global step update is done inside of `apply_gradients`. 79 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 80 | # a different optimizer, you should probably take this line out. 81 | new_global_step = global_step + 1 82 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 83 | return train_op, learning_rate 84 | 85 | 86 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 87 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 88 | 89 | def __init__(self, 90 | learning_rate, 91 | weight_decay_rate=0.0, 92 | beta_1=0.9, 93 | beta_2=0.999, 94 | epsilon=1e-6, 95 | exclude_from_weight_decay=None, 96 | name="AdamWeightDecayOptimizer"): 97 | """Constructs a AdamWeightDecayOptimizer.""" 98 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 99 | 100 | self.learning_rate = learning_rate 101 | self.weight_decay_rate = weight_decay_rate 102 | self.beta_1 = beta_1 103 | self.beta_2 = beta_2 104 | self.epsilon = epsilon 105 | self.exclude_from_weight_decay = exclude_from_weight_decay 106 | 107 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 108 | """See base class.""" 109 | assignments = [] 110 | for (grad, param) in grads_and_vars: 111 | if grad is None or param is None: 112 | continue 113 | 114 | param_name = self._get_variable_name(param.name) 115 | 116 | m = tf.get_variable( 117 | name=param_name + "/adam_m", 118 | shape=param.shape.as_list(), 119 | dtype=tf.float32, 120 | trainable=False, 121 | initializer=tf.zeros_initializer()) 122 | v = tf.get_variable( 123 | name=param_name + "/adam_v", 124 | shape=param.shape.as_list(), 125 | dtype=tf.float32, 126 | trainable=False, 127 | initializer=tf.zeros_initializer()) 128 | 129 | # Standard Adam update. 130 | next_m = ( 131 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 132 | next_v = ( 133 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 134 | tf.square(grad))) 135 | 136 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 137 | 138 | # Just adding the square of the weights to the loss function is *not* 139 | # the correct way of using L2 regularization/weight decay with Adam, 140 | # since that will interact with the m and v parameters in strange ways. 141 | # 142 | # Instead we want ot decay the weights in a manner that doesn't interact 143 | # with the m/v parameters. This is equivalent to adding the square 144 | # of the weights to the loss with plain (non-momentum) SGD. 145 | if self._do_use_weight_decay(param_name): 146 | update += self.weight_decay_rate * param 147 | 148 | update_with_lr = self.learning_rate * update 149 | 150 | next_param = param - update_with_lr 151 | 152 | assignments.extend( 153 | [param.assign(next_param), 154 | m.assign(next_m), 155 | v.assign(next_v)]) 156 | return tf.group(*assignments, name=name) 157 | 158 | def _do_use_weight_decay(self, param_name): 159 | """Whether to use L2 weight decay for `param_name`.""" 160 | if not self.weight_decay_rate: 161 | return False 162 | if self.exclude_from_weight_decay: 163 | for r in self.exclude_from_weight_decay: 164 | if re.search(r, param_name) is not None: 165 | return False 166 | return True 167 | 168 | def _get_variable_name(self, param_name): 169 | """Get the variable name from the tensor name.""" 170 | m = re.match("^(.*):\\d+$", param_name) 171 | if m is not None: 172 | param_name = m.group(1) 173 | return param_name 174 | -------------------------------------------------------------------------------- /minimize.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import re 6 | import os 7 | import sys 8 | import json 9 | import tempfile 10 | import subprocess 11 | import collections 12 | 13 | import util 14 | import conll 15 | 16 | class DocumentState(object): 17 | def __init__(self): 18 | self.doc_key = None 19 | self.text = [] 20 | self.text_speakers = [] 21 | self.speakers = [] 22 | self.sentences = [] 23 | self.constituents = {} 24 | self.const_stack = [] 25 | self.ner = {} 26 | self.ner_stack = [] 27 | self.clusters = collections.defaultdict(list) 28 | self.coref_stacks = collections.defaultdict(list) 29 | 30 | def assert_empty(self): 31 | assert self.doc_key is None 32 | assert len(self.text) == 0 33 | assert len(self.text_speakers) == 0 34 | assert len(self.speakers) == 0 35 | assert len(self.sentences) == 0 36 | assert len(self.constituents) == 0 37 | assert len(self.const_stack) == 0 38 | assert len(self.ner) == 0 39 | assert len(self.ner_stack) == 0 40 | assert len(self.coref_stacks) == 0 41 | assert len(self.clusters) == 0 42 | 43 | def assert_finalizable(self): 44 | assert self.doc_key is not None 45 | assert len(self.text) == 0 46 | assert len(self.text_speakers) == 0 47 | assert len(self.speakers) > 0 48 | assert len(self.sentences) > 0 49 | assert len(self.constituents) > 0 50 | assert len(self.const_stack) == 0 51 | assert len(self.ner_stack) == 0 52 | assert all(len(s) == 0 for s in self.coref_stacks.values()) 53 | 54 | def span_dict_to_list(self, span_dict): 55 | return [(s,e,l) for (s,e),l in span_dict.items()] 56 | 57 | def finalize(self): 58 | merged_clusters = [] 59 | for c1 in self.clusters.values(): 60 | existing = None 61 | for m in c1: 62 | for c2 in merged_clusters: 63 | if m in c2: 64 | existing = c2 65 | break 66 | if existing is not None: 67 | break 68 | if existing is not None: 69 | print("Merging clusters (shouldn't happen very often.)") 70 | existing.update(c1) 71 | else: 72 | merged_clusters.append(set(c1)) 73 | merged_clusters = [list(c) for c in merged_clusters] 74 | all_mentions = util.flatten(merged_clusters) 75 | assert len(all_mentions) == len(set(all_mentions)) 76 | 77 | return { 78 | "doc_key": self.doc_key, 79 | "sentences": self.sentences, 80 | "speakers": self.speakers, 81 | "constituents": self.span_dict_to_list(self.constituents), 82 | "ner": self.span_dict_to_list(self.ner), 83 | "clusters": merged_clusters 84 | } 85 | 86 | def normalize_word(word, language): 87 | if language == "arabic": 88 | word = word[:word.find("#")] 89 | if word == "/." or word == "/?": 90 | return word[1:] 91 | else: 92 | return word 93 | 94 | def handle_bit(word_index, bit, stack, spans): 95 | asterisk_idx = bit.find("*") 96 | if asterisk_idx >= 0: 97 | open_parens = bit[:asterisk_idx] 98 | close_parens = bit[asterisk_idx + 1:] 99 | else: 100 | open_parens = bit[:-1] 101 | close_parens = bit[-1] 102 | 103 | current_idx = open_parens.find("(") 104 | while current_idx >= 0: 105 | next_idx = open_parens.find("(", current_idx + 1) 106 | if next_idx >= 0: 107 | label = open_parens[current_idx + 1:next_idx] 108 | else: 109 | label = open_parens[current_idx + 1:] 110 | stack.append((word_index, label)) 111 | current_idx = next_idx 112 | 113 | for c in close_parens: 114 | assert c == ")" 115 | open_index, label = stack.pop() 116 | current_span = (open_index, word_index) 117 | """ 118 | if current_span in spans: 119 | spans[current_span] += "_" + label 120 | else: 121 | spans[current_span] = label 122 | """ 123 | spans[current_span] = label 124 | 125 | def handle_line(line, document_state, language, labels, stats): 126 | begin_document_match = re.match(conll.BEGIN_DOCUMENT_REGEX, line) 127 | if len(document_state.text) == 0 and len(line.strip()) == 0: 128 | return None 129 | if begin_document_match: 130 | document_state.assert_empty() 131 | document_state.doc_key = conll.get_doc_key(begin_document_match.group(1), begin_document_match.group(2)) 132 | return None 133 | elif line.startswith("#end document"): 134 | document_state.assert_finalizable() 135 | finalized_state = document_state.finalize() 136 | stats["num_clusters"] += len(finalized_state["clusters"]) 137 | stats["num_mentions"] += sum(len(c) for c in finalized_state["clusters"]) 138 | labels["{}_const_labels".format(language)].update(l for _, _, l in finalized_state["constituents"]) 139 | labels["ner"].update(l for _, _, l in finalized_state["ner"]) 140 | return finalized_state 141 | else: 142 | row = line.split() 143 | if len(row) == 0: 144 | stats["max_sent_len_{}".format(language)] = max(len(document_state.text), stats["max_sent_len_{}".format(language)]) 145 | stats["num_sents_{}".format(language)] += 1 146 | document_state.sentences.append(tuple(document_state.text)) 147 | del document_state.text[:] 148 | document_state.speakers.append(tuple(document_state.text_speakers)) 149 | del document_state.text_speakers[:] 150 | return None 151 | assert len(row) >= 12 152 | 153 | doc_key = conll.get_doc_key(row[0], row[1]) 154 | word = normalize_word(row[3], language) 155 | parse = row[5] 156 | speaker = row[9] 157 | ner = row[10] 158 | coref = row[-1] 159 | 160 | word_index = len(document_state.text) + sum(len(s) for s in document_state.sentences) 161 | document_state.text.append(word) 162 | document_state.text_speakers.append(speaker) 163 | 164 | handle_bit(word_index, parse, document_state.const_stack, document_state.constituents) 165 | handle_bit(word_index, ner, document_state.ner_stack, document_state.ner) 166 | 167 | if coref != "-": 168 | for segment in coref.split("|"): 169 | if segment[0] == "(": 170 | if segment[-1] == ")": 171 | cluster_id = int(segment[1:-1]) 172 | document_state.clusters[cluster_id].append((word_index, word_index)) 173 | else: 174 | cluster_id = int(segment[1:]) 175 | document_state.coref_stacks[cluster_id].append(word_index) 176 | else: 177 | cluster_id = int(segment[:-1]) 178 | start = document_state.coref_stacks[cluster_id].pop() 179 | document_state.clusters[cluster_id].append((start, word_index)) 180 | return None 181 | 182 | def minimize_partition(name, language, extension, labels, stats): 183 | input_path = "{}.{}.{}".format(name, language, extension) 184 | output_path = "{}.{}.jsonlines".format(name, language) 185 | minimize_file(input_path, language, labels, stats, output_path) 186 | 187 | def minimize_file(input_path, language, labels, stats, output_path=None): 188 | if output_path is None: 189 | output_path = "{}.jsonlines".format(input_path) 190 | 191 | count = 0 192 | print("Minimizing {}".format(input_path)) 193 | with open(input_path, "r") as input_file: 194 | with open(output_path, "w") as output_file: 195 | document_state = DocumentState() 196 | for line in input_file.readlines(): 197 | document = handle_line(line, document_state, language, labels, stats) 198 | if document is not None: 199 | output_file.write(json.dumps(document)) 200 | output_file.write("\n") 201 | count += 1 202 | document_state = DocumentState() 203 | print("Wrote {} documents to {}".format(count, output_path)) 204 | 205 | def minimize_language(language, labels, stats): 206 | minimize_partition("dev", language, "v4_gold_conll", labels, stats) 207 | minimize_partition("train", language, "v4_gold_conll", labels, stats) 208 | minimize_partition("test", language, "v4_gold_conll", labels, stats) 209 | 210 | 211 | def main(): 212 | labels = collections.defaultdict(set) 213 | stats = collections.defaultdict(int) 214 | 215 | if args.input: 216 | minimize_file(args.input, "english", labels, stats) 217 | return 218 | 219 | minimize_language("english", labels, stats) 220 | minimize_language("chinese", labels, stats) 221 | minimize_language("arabic", labels, stats) 222 | for k, v in labels.items(): 223 | print("{} = [{}]".format(k, ", ".join("\"{}\"".format(label) for label in v))) 224 | for k, v in stats.items(): 225 | print("{} = {}".format(k, v)) 226 | 227 | if __name__ == "__main__": 228 | import argparse 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument('-i', '--input') 231 | args = parser.parse_args() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2017 Kenton Lee 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /prepare_bert_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import json 23 | import os 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | import h5py 28 | 29 | import modeling 30 | import tokenization 31 | from data import process_example 32 | from tqdm import tqdm 33 | 34 | flags = tf.flags 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | flags.DEFINE_string( 39 | "bert_config_file", None, 40 | "The config json file corresponding to the pre-trained BERT model. " 41 | "This specifies the model architecture.") 42 | 43 | flags.DEFINE_string("input_file", None, "") 44 | 45 | flags.DEFINE_string("output_file", None, "") 46 | 47 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 48 | 49 | flags.DEFINE_integer( 50 | "window_size", 511, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded.") 54 | 55 | flags.DEFINE_integer( 56 | "stride", 127, 57 | "The maximum total input sequence length after WordPiece tokenization. " 58 | "Sequences longer than this will be truncated, and sequences shorter " 59 | "than this will be padded.") 60 | 61 | flags.DEFINE_string("vocab_file", None, 62 | "The vocabulary file that the BERT model was trained on.") 63 | 64 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 65 | 66 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 67 | 68 | flags.DEFINE_string("master", None, 69 | "If using a TPU, the address of the master.") 70 | 71 | flags.DEFINE_integer( 72 | "num_tpu_cores", 8, 73 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 74 | 75 | flags.DEFINE_bool( 76 | "use_one_hot_embeddings", False, 77 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 78 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 79 | "since it is much faster.") 80 | 81 | 82 | def _convert_example_to_features(example, window_start, window_end, tokens_ids_to_extract, tokenizer, seq_length): 83 | window_tokens = example.tokens[window_start:window_end] 84 | 85 | tokens = [] 86 | segment_ids = [] 87 | for token in window_tokens: 88 | tokens.append(token) 89 | segment_ids.append(0) 90 | 91 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 92 | 93 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 94 | # tokens are attended to. 95 | input_mask = [1] * len(input_ids) 96 | 97 | # Zero-pad up to the sequence length. 98 | while len(input_ids) < seq_length: 99 | input_ids.append(0) 100 | input_mask.append(0) 101 | segment_ids.append(0) 102 | 103 | extract_mask = [0] * seq_length 104 | for i in tokens_ids_to_extract: 105 | extract_mask[i - window_start] = 1 106 | 107 | assert len(input_ids) == seq_length 108 | assert len(input_mask) == seq_length 109 | assert len(segment_ids) == seq_length 110 | 111 | return dict(doc_index=example.document_index, 112 | input_ids=input_ids, 113 | input_mask=input_mask, 114 | segment_ids=segment_ids, 115 | extract_mask=extract_mask, 116 | tokens=tokens) 117 | 118 | 119 | def convert_examples_to_features(bert_examples, orig_examples, window_size, stride, tokenizer): 120 | """Loads a data file into a list of `InputBatch`s.""" 121 | 122 | assert window_size % 2 == 1 123 | assert stride % 2 == 1 124 | 125 | for bert_example, orig_example in zip(bert_examples, orig_examples): 126 | current_example_features = [] 127 | for i in range(0, len(bert_example.tokens), stride): 128 | window_center = i + window_size // 2 129 | token_ids_to_extract = [] 130 | extract_start = int(np.clip(window_center - stride // 2, 0, len(bert_example.tokens))) 131 | extract_end = int(np.clip(window_center + stride // 2 + 1, extract_start, len(bert_example.tokens))) 132 | 133 | if i == 0: 134 | token_ids_to_extract.extend(range(extract_start)) 135 | 136 | token_ids_to_extract.extend(range(extract_start, extract_end)) 137 | 138 | if i + window_size >= len(bert_example.tokens): 139 | token_ids_to_extract.extend(range(extract_end, len(bert_example.tokens))) 140 | 141 | token_ids_to_extract = [t for t in token_ids_to_extract if bert_example.bert_to_orig_map[t] >= 0] 142 | 143 | features = _convert_example_to_features(bert_example, 144 | i, 145 | min(i + window_size, len(bert_example.tokens)), 146 | token_ids_to_extract, 147 | tokenizer, 148 | window_size) 149 | 150 | current_example_features.append(features) 151 | 152 | if i + window_size >= len(bert_example.tokens): 153 | break 154 | 155 | current_example_features = {k: np.array([c[k] for c in current_example_features]) for k in current_example_features[0]} 156 | 157 | max_sentence_len = max(len(s) for s in orig_example.sentence_tokens) 158 | extract_sentences = np.zeros((len(orig_example.sentence_tokens), max_sentence_len), np.int32) 159 | extract_mask = current_example_features['extract_mask'] 160 | extract_idxs = extract_mask.cumsum().reshape(extract_mask.shape) 161 | for c in range(extract_mask.shape[0]): 162 | for i in range(extract_mask.shape[1]): 163 | if extract_mask[c, i]: 164 | si, sj = orig_example.unravel_token_index(extract_idxs[c, i] - 1) 165 | extract_sentences[si, sj] = c * extract_mask.shape[1] + i + 1 166 | current_example_features['extract_sentences'] = extract_sentences 167 | 168 | yield current_example_features 169 | 170 | 171 | def main(_): 172 | tf.logging.set_verbosity(tf.logging.INFO) 173 | 174 | tokenizer = tokenization.FullTokenizer( 175 | vocab_file=FLAGS.vocab_file, do_lower_case=False) 176 | 177 | json_examples = [] 178 | for x in ['test', 'train', 'dev']: 179 | # for x in ['test']: 180 | with open(os.path.join(FLAGS.input_file, x + '.english.jsonlines')) as f: 181 | json_examples.extend((json.loads(jsonline) for jsonline in f.readlines())) 182 | 183 | orig_examples = [] 184 | bert_examples = [] 185 | for i, json_e in enumerate(json_examples): 186 | e = process_example(json_e, i, should_filter_embedded_mentions=True) 187 | orig_examples.append(e) 188 | bert_examples.append(e.bertify(tokenizer)) 189 | 190 | writer = h5py.File(FLAGS.output_file, 'w') 191 | for data in tqdm(convert_examples_to_features(bert_examples, 192 | orig_examples, 193 | FLAGS.window_size, 194 | FLAGS.stride, 195 | tokenizer), total=len(json_examples)): 196 | document_index = int(data["doc_index"][0]) 197 | bert_example = bert_examples[document_index] 198 | dataset_key = bert_example.doc_key.replace('/', ':') 199 | 200 | sentences = [] 201 | for sentence_indices in data['extract_sentences']: 202 | cur_sentence = [] 203 | for i in sentence_indices: 204 | tokens_flattened = sum([list(ts) for ts in data['tokens']], []) 205 | if i > 0: 206 | cur_sentence.append(tokens_flattened[i - 1]) 207 | sentences.append(cur_sentence) 208 | assert [len(s) for s in sentences] == [len(s) for s in orig_examples[document_index].sentence_tokens] 209 | sentences_flattened = sum(sentences, []) 210 | expected = [t for i, t in enumerate(bert_example.tokens) if bert_example.bert_to_orig_map[i] >= 0] 211 | assert sentences_flattened == expected 212 | 213 | writer.create_dataset('{}/input_ids'.format(dataset_key), data=data['input_ids']) 214 | writer.create_dataset('{}/input_mask'.format(dataset_key), data=data['input_mask']) 215 | writer.create_dataset('{}/segment_ids'.format(dataset_key), data=data['segment_ids']) 216 | writer.create_dataset('{}/extract_mask'.format(dataset_key), data=data['extract_mask']) 217 | writer.create_dataset('{}/extract_sentences'.format(dataset_key), data=data['extract_sentences']) 218 | # for i, s in enumerate(data['tokens']): 219 | # tokens_dset = writer.create_dataset('{}/tokens/{}'.format(dataset_key, i), 220 | # (len(s),), 221 | # dtype=h5py.special_dtype(vlen=unicode)) 222 | # for j, w in enumerate(s): 223 | # tokens_dset[j] = w 224 | writer.close() 225 | 226 | # dataset = tf.data.Dataset.from_generator(functools.partial(convert_examples_to_features, 227 | # bert_examples, 228 | # FLAGS.window_size, 229 | # FLAGS.stride, 230 | # tokenizer), 231 | # dict(doc_index=tf.int32, 232 | # input_ids=tf.int32, 233 | # input_mask=tf.int32, 234 | # segment_ids=tf.int32, 235 | # extract_mask=tf.int32), 236 | # dict(doc_index=tf.TensorShape([None]), 237 | # input_ids=tf.TensorShape([None, FLAGS.window_size]), 238 | # input_mask=tf.TensorShape([None, FLAGS.window_size]), 239 | # segment_ids=tf.TensorShape([None, FLAGS.window_size]), 240 | # extract_mask=tf.TensorShape([None, FLAGS.window_size])) 241 | # ) 242 | # inputs = dataset.make_one_shot_iterator().get_next() 243 | # 244 | # bert_inputs = dict( 245 | # input_ids=tf.expand_dims(inputs['input_ids'], 1), 246 | # input_mask=tf.expand_dims(inputs['input_mask'], 1), 247 | # segment_ids=tf.expand_dims(inputs['segment_ids'], 1)) 248 | # 249 | # bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 250 | # bert_outputs = tf.map_fn( 251 | # lambda x: modeling.BertModel(config=bert_config, 252 | # is_training=True, 253 | # input_ids=x['input_ids'], 254 | # input_mask=x['input_mask'], 255 | # token_type_ids=x['segment_ids'], 256 | # use_one_hot_embeddings=FLAGS.use_one_hot_embeddings).sequence_output, 257 | # bert_inputs, dtype=tf.float32, parallel_iterations=1, swap_memory=True) 258 | # # repeat = 5 259 | # # bert_outputs = modeling.BertModel(config=bert_config, 260 | # # is_training=True, 261 | # # input_ids=tf.tile(inputs['input_ids'], [repeat, 1]), 262 | # # input_mask=tf.tile(inputs['input_mask'], [repeat, 1]), 263 | # # token_type_ids=tf.tile(inputs['segment_ids'], [repeat, 1]), 264 | # # use_one_hot_embeddings=FLAGS.use_one_hot_embeddings).sequence_output 265 | # loss = tf.nn.sigmoid(tf.reduce_mean(bert_outputs)) 266 | # train_op = tf.train.AdamOptimizer().minimize(loss) 267 | # # bert_outputs_masked = tf.squeeze(bert_outputs, 1) 268 | # # bert_outputs_masked = tf.boolean_mask(bert_outputs_masked, inputs['extract_mask']) 269 | # 270 | # with tf.Session() as sess: 271 | # tf.global_variables_initializer().run() 272 | # while True: 273 | # # features_raw, features, doc_index = sess.run([bert_outputs, bert_outputs_masked, inputs['doc_index']]) 274 | # # print(len(features_raw), features.shape, len(bert_examples[doc_index[0]].tokens)) 275 | # l, _ = sess.run([loss, train_op]) 276 | # print(l) 277 | 278 | 279 | if __name__ == "__main__": 280 | flags.mark_flag_as_required("input_file") 281 | flags.mark_flag_as_required("vocab_file") 282 | flags.mark_flag_as_required("output_file") 283 | tf.app.run() 284 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | 23 | import os 24 | import codecs 25 | import collections 26 | 27 | import h5py 28 | import json 29 | import re 30 | from tqdm import tqdm 31 | 32 | import modeling 33 | import tokenization 34 | import tensorflow as tf 35 | import numpy as np 36 | 37 | from data import process_example 38 | 39 | flags = tf.flags 40 | 41 | FLAGS = flags.FLAGS 42 | 43 | flags.DEFINE_string("input_file", None, "") 44 | 45 | flags.DEFINE_string("output_file", None, "") 46 | 47 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 48 | 49 | flags.DEFINE_string( 50 | "bert_config_file", None, 51 | "The config json file corresponding to the pre-trained BERT model. " 52 | "This specifies the model architecture.") 53 | 54 | flags.DEFINE_integer( 55 | "window_size", 511, 56 | "The maximum total input sequence length after WordPiece tokenization. " 57 | "Sequences longer than this will be truncated, and sequences shorter " 58 | "than this will be padded.") 59 | 60 | flags.DEFINE_integer( 61 | "stride", 127, 62 | "The maximum total input sequence length after WordPiece tokenization. " 63 | "Sequences longer than this will be truncated, and sequences shorter " 64 | "than this will be padded.") 65 | 66 | flags.DEFINE_string( 67 | "init_checkpoint", None, 68 | "Initial checkpoint (usually from a pre-trained BERT model).") 69 | 70 | flags.DEFINE_string("vocab_file", None, 71 | "The vocabulary file that the BERT model was trained on.") 72 | 73 | flags.DEFINE_bool( 74 | "do_lower_case", True, 75 | "Whether to lower case the input text. Should be True for uncased " 76 | "models and False for cased models.") 77 | 78 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 79 | 80 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 81 | 82 | flags.DEFINE_string("master", None, 83 | "If using a TPU, the address of the master.") 84 | 85 | flags.DEFINE_integer( 86 | "num_tpu_cores", 8, 87 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 88 | 89 | flags.DEFINE_bool( 90 | "use_one_hot_embeddings", False, 91 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 92 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 93 | "since it is much faster.") 94 | 95 | 96 | def input_fn_builder(examples, window_size, stride, tokenizer): 97 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 98 | 99 | def input_fn(params): 100 | """The actual input function.""" 101 | batch_size = params["batch_size"] 102 | 103 | d = tf.data.Dataset.from_generator( 104 | functools.partial(convert_examples_to_features, 105 | examples=examples, 106 | window_size=window_size, 107 | stride=stride, 108 | tokenizer=tokenizer), 109 | dict(unique_ids=tf.int32, 110 | input_ids=tf.int32, 111 | input_mask=tf.int32, 112 | input_type_ids=tf.int32, 113 | extract_indices=tf.int32), 114 | dict(unique_ids=tf.TensorShape([]), 115 | input_ids=tf.TensorShape([window_size]), 116 | input_mask=tf.TensorShape([window_size]), 117 | input_type_ids=tf.TensorShape([window_size]), 118 | extract_indices=tf.TensorShape([window_size]))) 119 | 120 | d = d.batch(batch_size=batch_size, drop_remainder=False) 121 | return d 122 | 123 | return input_fn 124 | 125 | 126 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 127 | use_one_hot_embeddings): 128 | """Returns `model_fn` closure for TPUEstimator.""" 129 | 130 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 131 | """The `model_fn` for TPUEstimator.""" 132 | 133 | unique_ids = features["unique_ids"] 134 | input_ids = features["input_ids"] 135 | input_mask = features["input_mask"] 136 | input_type_ids = features["input_type_ids"] 137 | extract_indices = features["extract_indices"] 138 | 139 | model = modeling.BertModel( 140 | config=bert_config, 141 | is_training=False, 142 | input_ids=input_ids, 143 | input_mask=input_mask, 144 | token_type_ids=input_type_ids, 145 | use_one_hot_embeddings=use_one_hot_embeddings) 146 | 147 | if mode != tf.estimator.ModeKeys.PREDICT: 148 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 149 | 150 | tvars = tf.trainable_variables() 151 | scaffold_fn = None 152 | (assignment_map, 153 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( 154 | tvars, init_checkpoint) 155 | if use_tpu: 156 | 157 | def tpu_scaffold(): 158 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 159 | return tf.train.Scaffold() 160 | 161 | scaffold_fn = tpu_scaffold 162 | else: 163 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 164 | 165 | tf.logging.info("**** Trainable Variables ****") 166 | for var in tvars: 167 | init_string = "" 168 | if var.name in initialized_variable_names: 169 | init_string = ", *INIT_FROM_CKPT*" 170 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 171 | init_string) 172 | 173 | all_layers = model.get_all_encoder_layers() 174 | 175 | predictions = { 176 | "unique_ids": unique_ids, 177 | "extract_indices": extract_indices 178 | } 179 | 180 | for (i, layer_index) in enumerate(layer_indexes): 181 | predictions["layer_output_%d" % i] = all_layers[layer_index] 182 | 183 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 184 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 185 | return output_spec 186 | 187 | return model_fn 188 | 189 | 190 | def _convert_example_to_features(example, window_start, window_end, tokens_ids_to_extract, tokenizer, seq_length): 191 | window_tokens = example.tokens[window_start:window_end] 192 | 193 | tokens = [] 194 | input_type_ids = [] 195 | for token in window_tokens: 196 | tokens.append(token) 197 | input_type_ids.append(0) 198 | 199 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 200 | 201 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 202 | # tokens are attended to. 203 | input_mask = [1] * len(input_ids) 204 | 205 | # Zero-pad up to the sequence length. 206 | while len(input_ids) < seq_length: 207 | input_ids.append(0) 208 | input_mask.append(0) 209 | input_type_ids.append(0) 210 | 211 | extract_indices = [-1] * seq_length 212 | for i in tokens_ids_to_extract: 213 | assert i - window_start >= 0 214 | extract_indices[i - window_start] = i 215 | 216 | assert len(input_ids) == seq_length 217 | assert len(input_mask) == seq_length 218 | assert len(input_type_ids) == seq_length 219 | 220 | return dict(unique_ids=example.document_index, 221 | input_ids=input_ids, 222 | input_mask=input_mask, 223 | input_type_ids=input_type_ids, 224 | extract_indices=extract_indices) 225 | 226 | 227 | def convert_examples_to_features(examples, window_size, stride, tokenizer): 228 | """Loads a data file into a list of `InputBatch`s.""" 229 | 230 | assert window_size % 2 == 1 231 | assert stride % 2 == 1 232 | 233 | for example in examples: 234 | for i in range(0, len(example.tokens), stride): 235 | window_center = i + window_size // 2 236 | token_ids_to_extract = [] 237 | extract_start = int(np.clip(window_center - stride // 2, 0, len(example.tokens))) 238 | extract_end = int(np.clip(window_center + stride // 2 + 1, extract_start, len(example.tokens))) 239 | 240 | if i == 0: 241 | token_ids_to_extract.extend(range(extract_start)) 242 | 243 | token_ids_to_extract.extend(range(extract_start, extract_end)) 244 | 245 | if i + stride >= len(example.tokens): 246 | token_ids_to_extract.extend(range(extract_end, len(example.tokens))) 247 | 248 | token_ids_to_extract = [t for t in token_ids_to_extract if example.bert_to_orig_map[t] >= 0] 249 | 250 | yield _convert_example_to_features(example, 251 | i, 252 | min(i + window_size, len(example.tokens)), 253 | token_ids_to_extract, 254 | tokenizer, 255 | window_size) 256 | 257 | 258 | def main(_): 259 | tf.logging.set_verbosity(tf.logging.INFO) 260 | 261 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 262 | 263 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 264 | 265 | tokenizer = tokenization.FullTokenizer( 266 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 267 | 268 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 269 | run_config = tf.contrib.tpu.RunConfig( 270 | master=FLAGS.master, 271 | tpu_config=tf.contrib.tpu.TPUConfig( 272 | num_shards=FLAGS.num_tpu_cores, 273 | per_host_input_for_training=is_per_host)) 274 | 275 | # examples = read_examples(FLAGS.input_file) 276 | json_examples = [] 277 | for x in ['test', 'train', 'dev']: 278 | with open(os.path.join(FLAGS.input_file, x + '.english.jsonlines')) as f: 279 | json_examples.extend((json.loads(jsonline) for jsonline in f.readlines())) 280 | 281 | orig_examples = [] 282 | bert_examples = [] 283 | for i, json_e in enumerate(json_examples): 284 | e = process_example(json_e, i, should_filter_embedded_mentions=True) 285 | orig_examples.append(e) 286 | bert_examples.append(e.bertify(tokenizer)) 287 | 288 | model_fn = model_fn_builder( 289 | bert_config=bert_config, 290 | init_checkpoint=FLAGS.init_checkpoint, 291 | layer_indexes=layer_indexes, 292 | use_tpu=FLAGS.use_tpu, 293 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 294 | 295 | # If TPU is not available, this will fall back to normal Estimator on CPU 296 | # or GPU. 297 | estimator = tf.contrib.tpu.TPUEstimator( 298 | use_tpu=FLAGS.use_tpu, 299 | model_fn=model_fn, 300 | config=run_config, 301 | predict_batch_size=FLAGS.batch_size) 302 | 303 | input_fn = input_fn_builder( 304 | examples=bert_examples, window_size=FLAGS.window_size, stride=FLAGS.stride, tokenizer=tokenizer) 305 | 306 | writer = h5py.File(FLAGS.output_file, 'w') 307 | with tqdm(total=sum(len(e.tokens) for e in orig_examples)) as t: 308 | for result in estimator.predict(input_fn, yield_single_examples=True): 309 | document_index = int(result["unique_ids"]) 310 | bert_example = bert_examples[document_index] 311 | orig_example = orig_examples[document_index] 312 | file_key = bert_example.doc_key.replace('/', ':') 313 | 314 | t.update(n=(result['extract_indices'] >= 0).sum()) 315 | 316 | for output_index, bert_token_index in enumerate(result['extract_indices']): 317 | if bert_token_index < 0: 318 | continue 319 | 320 | token_index = bert_example.bert_to_orig_map[bert_token_index] 321 | sentence_index, token_index = orig_example.unravel_token_index(token_index) 322 | 323 | dataset_key ="{}/{}".format(file_key, sentence_index) 324 | if dataset_key not in writer: 325 | writer.create_dataset(dataset_key, 326 | (len(orig_example.sentence_tokens[sentence_index]), bert_config.hidden_size, len(layer_indexes)), 327 | dtype=np.float32) 328 | 329 | dset = writer[dataset_key] 330 | for j, layer_index in enumerate(layer_indexes): 331 | layer_output = result["layer_output_%d" % j] 332 | dset[token_index, :, j] = layer_output[output_index] 333 | writer.close() 334 | 335 | 336 | if __name__ == "__main__": 337 | flags.mark_flag_as_required("input_file") 338 | flags.mark_flag_as_required("vocab_file") 339 | flags.mark_flag_as_required("bert_config_file") 340 | flags.mark_flag_as_required("init_checkpoint") 341 | flags.mark_flag_as_required("output_file") 342 | tf.app.run() 343 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import errno 7 | import codecs 8 | import collections 9 | import math 10 | import shutil 11 | 12 | import numpy as np 13 | import six 14 | import tensorflow as tf 15 | import pyhocon 16 | from argparse import ArgumentParser 17 | from colorama import Back, Style 18 | from pip._vendor.colorama import Fore 19 | 20 | 21 | def get_cluster_config(): 22 | # # Distributed training configurations. 23 | # two_local_gpus { 24 | # addresses { 25 | # ps = [localhost:2222] 26 | # worker = [localhost:2223, localhost:2224, localhost:2225, localhost:2226] 27 | # } 28 | # gpus = [0, 1, 2, 3] 29 | # } 30 | if "GPUS" not in os.environ: 31 | raise ValueError("Need to set GPU environment variable") 32 | gpus = list(map(int, os.environ["GPUS"].split(','))) 33 | 34 | workers = ['localhost:{}'.format(port) for port in range(2223, 2223 + len(gpus))] 35 | cluster_config = {'addresses': {'ps': ['localhost:2222'], 36 | 'worker': workers}, 37 | 'gpus': gpus} 38 | 39 | return cluster_config 40 | 41 | 42 | def initialize_from_env(experiment, logdir=None): 43 | if "GPU" in os.environ: 44 | set_gpus(int(os.environ["GPU"])) 45 | else: 46 | set_gpus() 47 | 48 | print("Running experiment: {}".format(experiment)) 49 | 50 | config = pyhocon.ConfigFactory.parse_file("experiments.conf")[experiment] 51 | 52 | if logdir is None: 53 | logdir = experiment 54 | 55 | config["log_dir"] = mkdirs(os.path.join(config["log_root"], logdir)) 56 | 57 | print(pyhocon.HOCONConverter.convert(config, "hocon")) 58 | return config 59 | 60 | 61 | def get_args(): 62 | parser = ArgumentParser() 63 | parser.add_argument('experiment') 64 | parser.add_argument('-l', '--logdir') 65 | parser.add_argument('--latest-checkpoint', action='store_true') 66 | return parser.parse_args() 67 | 68 | 69 | def copy_checkpoint(source, target): 70 | for ext in (".index", ".data-00000-of-00001"): 71 | shutil.copyfile(source + ext, target + ext) 72 | 73 | 74 | def make_summary(value_dict): 75 | return tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=v) for k, v in value_dict.items()]) 76 | 77 | 78 | def flatten(l): 79 | return [item for sublist in l for item in sublist] 80 | 81 | 82 | def set_gpus(*gpus): 83 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpus) 84 | print("Setting CUDA_VISIBLE_DEVICES to: {}".format(os.environ["CUDA_VISIBLE_DEVICES"])) 85 | 86 | 87 | def mkdirs(path): 88 | try: 89 | os.makedirs(path) 90 | except OSError as exception: 91 | if exception.errno != errno.EEXIST: 92 | raise 93 | return path 94 | 95 | 96 | def load_char_dict(char_vocab_path): 97 | vocab = [u""] 98 | with codecs.open(char_vocab_path, encoding="utf-8") as f: 99 | vocab.extend(l.strip() for l in f.readlines()) 100 | char_dict = collections.defaultdict(int) 101 | char_dict.update({c: i for i, c in enumerate(vocab)}) 102 | return char_dict 103 | 104 | 105 | def maybe_divide(x, y): 106 | return 0 if y == 0 else x / float(y) 107 | 108 | 109 | def projection(inputs, output_size, initializer=None): 110 | return ffnn(inputs, 0, -1, output_size, dropout=None, output_weights_initializer=initializer) 111 | 112 | 113 | def highway(inputs, num_layers, dropout): 114 | for i in range(num_layers): 115 | with tf.variable_scope("highway_{}".format(i)): 116 | j, f = tf.split(projection(inputs, 2 * shape(inputs, -1)), 2, -1) 117 | f = tf.sigmoid(f) 118 | j = tf.nn.relu(j) 119 | if dropout is not None: 120 | j = tf.nn.dropout(j, dropout) 121 | inputs = f * j + (1 - f) * inputs 122 | return inputs 123 | 124 | 125 | def shape(x, dim): 126 | return x.get_shape()[dim].value or tf.shape(x)[dim] 127 | 128 | 129 | def ffnn(inputs, num_hidden_layers, hidden_size, output_size, dropout, output_weights_initializer=None): 130 | if len(inputs.get_shape()) > 3: 131 | raise ValueError("FFNN with rank {} not supported".format(len(inputs.get_shape()))) 132 | 133 | if len(inputs.get_shape()) == 3: 134 | batch_size = shape(inputs, 0) 135 | seqlen = shape(inputs, 1) 136 | emb_size = shape(inputs, 2) 137 | current_inputs = tf.reshape(inputs, [batch_size * seqlen, emb_size]) 138 | else: 139 | current_inputs = inputs 140 | 141 | for i in range(num_hidden_layers): 142 | hidden_weights = tf.get_variable("hidden_weights_{}".format(i), [shape(current_inputs, 1), hidden_size], 143 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 144 | hidden_bias = tf.get_variable("hidden_bias_{}".format(i), [hidden_size], 145 | initializer=tf.zeros_initializer()) 146 | current_outputs = tf.nn.relu(tf.nn.xw_plus_b(current_inputs, hidden_weights, hidden_bias)) 147 | 148 | if dropout is not None: 149 | current_outputs = tf.nn.dropout(current_outputs, dropout) 150 | current_inputs = current_outputs 151 | 152 | output_weights = tf.get_variable("output_weights", [shape(current_inputs, 1), output_size], 153 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 154 | output_bias = tf.get_variable("output_bias", [output_size], 155 | initializer=tf.zeros_initializer()) 156 | outputs = tf.nn.xw_plus_b(current_inputs, output_weights, output_bias) 157 | 158 | if len(inputs.get_shape()) == 3: 159 | outputs = tf.reshape(outputs, [batch_size, seqlen, output_size]) 160 | return outputs 161 | 162 | 163 | def cnn(inputs, filter_sizes, num_filters): 164 | num_words = shape(inputs, 0) 165 | num_chars = shape(inputs, 1) 166 | input_size = shape(inputs, 2) 167 | outputs = [] 168 | for i, filter_size in enumerate(filter_sizes): 169 | with tf.variable_scope("conv_{}".format(i)): 170 | w = tf.get_variable("w", [filter_size, input_size, num_filters]) 171 | b = tf.get_variable("b", [num_filters]) 172 | conv = tf.nn.conv1d(inputs, w, stride=1, padding="VALID") # [num_words, num_chars - filter_size, num_filters] 173 | h = tf.nn.relu(tf.nn.bias_add(conv, b)) # [num_words, num_chars - filter_size, num_filters] 174 | pooled = tf.reduce_max(h, 1) # [num_words, num_filters] 175 | outputs.append(pooled) 176 | return tf.concat(outputs, 1) # [num_words, num_filters * len(filter_sizes)] 177 | 178 | 179 | def batch_gather(emb, indices): 180 | batch_size = shape(emb, 0) 181 | seqlen = shape(emb, 1) 182 | if len(emb.get_shape()) > 2: 183 | emb_size = shape(emb, 2) 184 | else: 185 | emb_size = 1 186 | flattened_emb = tf.reshape(emb, [batch_size * seqlen, emb_size]) # [batch_size * seqlen, emb] 187 | offset = tf.expand_dims(tf.range(batch_size) * seqlen, 1) # [batch_size, 1] 188 | gathered = tf.gather(flattened_emb, indices + offset) # [batch_size, num_indices, emb] 189 | if len(emb.get_shape()) == 2: 190 | gathered = tf.squeeze(gathered, 2) # [batch_size, num_indices] 191 | return gathered 192 | 193 | 194 | def assert_rank(tensor, expected_rank, name=None): 195 | """Raises an exception if the tensor rank is not of the expected rank. 196 | 197 | Args: 198 | tensor: A tf.Tensor to check the rank of. 199 | expected_rank: Python integer or list of integers, expected rank. 200 | name: Optional name of the tensor for the error message. 201 | 202 | Raises: 203 | ValueError: If the expected shape doesn't match the actual shape. 204 | """ 205 | if name is None: 206 | name = tensor.name 207 | 208 | expected_rank_dict = {} 209 | if isinstance(expected_rank, six.integer_types): 210 | expected_rank_dict[expected_rank] = True 211 | else: 212 | for x in expected_rank: 213 | expected_rank_dict[x] = True 214 | 215 | actual_rank = tensor.shape.ndims 216 | if actual_rank not in expected_rank_dict: 217 | scope_name = tf.get_variable_scope().name 218 | raise ValueError( 219 | "For the tensor `%s` in scope `%s`, the actual rank " 220 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 221 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 222 | 223 | 224 | def get_shape_list(tensor, expected_rank=None, name=None): 225 | """Returns a list of the shape of tensor, preferring static dimensions. 226 | 227 | Args: 228 | tensor: A tf.Tensor object to find the shape of. 229 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 230 | specified and the `tensor` has a different rank, and exception will be 231 | thrown. 232 | name: Optional name of the tensor for the error message. 233 | 234 | Returns: 235 | A list of dimensions of the shape of tensor. All static dimensions will 236 | be returned as python integers, and dynamic dimensions will be returned 237 | as tf.Tensor scalars. 238 | """ 239 | if name is None: 240 | name = tensor.name 241 | 242 | if expected_rank is not None: 243 | assert_rank(tensor, expected_rank, name) 244 | 245 | shape = tensor.shape.as_list() 246 | 247 | non_static_indexes = [] 248 | for (index, dim) in enumerate(shape): 249 | if dim is None: 250 | non_static_indexes.append(index) 251 | 252 | if not non_static_indexes: 253 | return shape 254 | 255 | dyn_shape = tf.shape(tensor) 256 | for index in non_static_indexes: 257 | shape[index] = dyn_shape[index] 258 | return shape 259 | 260 | 261 | def reshape_to_matrix(input_tensor): 262 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 263 | ndims = input_tensor.shape.ndims 264 | if ndims < 2: 265 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 266 | (input_tensor.shape)) 267 | if ndims == 2: 268 | return input_tensor 269 | 270 | width = input_tensor.shape[-1] 271 | output_tensor = tf.reshape(input_tensor, [-1, width]) 272 | return output_tensor 273 | 274 | 275 | def create_initializer(initializer_range=0.02): 276 | """Creates a `truncated_normal_initializer` with the given range.""" 277 | return tf.truncated_normal_initializer(stddev=initializer_range) 278 | 279 | 280 | def dropout(input_tensor, dropout_prob): 281 | """Perform dropout. 282 | 283 | Args: 284 | input_tensor: float Tensor. 285 | dropout_prob: Python float. The probability of dropping out a value (NOT of 286 | *keeping* a dimension as in `tf.nn.dropout`). 287 | 288 | Returns: 289 | A version of `input_tensor` with dropout applied. 290 | """ 291 | if dropout_prob is None or dropout_prob == 0.0: 292 | return input_tensor 293 | 294 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 295 | return output 296 | 297 | 298 | def attention_scores_layer(from_tensor, 299 | to_tensor, 300 | attention_mask=None, 301 | num_attention_heads=1, 302 | size_per_head=512, 303 | query_act=None, 304 | key_act=None, 305 | initializer_range=0.02, 306 | batch_size=None, 307 | from_seq_length=None, 308 | to_seq_length=None, 309 | query_equals_key=False, 310 | return_features=False): 311 | """Calculate multi-headed attention probabilities from `from_tensor` to `to_tensor`. 312 | 313 | This is an implementation of multi-headed attention based on "Attention 314 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 315 | this is self-attention. Each timestep in `from_tensor` attends to the 316 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 317 | 318 | This function first projects `from_tensor` into a "query" tensor and 319 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 320 | of tensors of length `num_attention_heads`, where each tensor is of shape 321 | [batch_size, seq_length, size_per_head]. 322 | 323 | Then, the query and key tensors are dot-producted and scaled. These are 324 | softmaxed to obtain attention probabilities. The value tensors are then 325 | interpolated by these probabilities, then concatenated back to a single 326 | tensor and returned. 327 | 328 | In practice, the multi-headed attention are done with transposes and 329 | reshapes rather than actual separate tensors. 330 | 331 | Args: 332 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 333 | from_width]. 334 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 335 | attention_mask: (optional) int32 Tensor of shape [batch_size, 336 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 337 | attention scores will effectively be set to -infinity for any positions in 338 | the mask that are 0, and will be unchanged for positions that are 1. 339 | num_attention_heads: int. Number of attention heads. 340 | size_per_head: int. Size of each attention head. 341 | query_act: (optional) Activation function for the query transform. 342 | key_act: (optional) Activation function for the key transform. 343 | initializer_range: float. Range of the weight initializer. 344 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 345 | of the 3D version of the `from_tensor` and `to_tensor`. 346 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 347 | of the 3D version of the `from_tensor`. 348 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 349 | of the 3D version of the `to_tensor`. 350 | 351 | Returns: 352 | float Tensor of shape [batch_size, num_attention_heads, from_seq_length, to_seq_length]. 353 | 354 | Raises: 355 | ValueError: Any of the arguments or tensor shapes are invalid. 356 | """ 357 | 358 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 359 | seq_length, width): 360 | output_tensor = tf.reshape( 361 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 362 | 363 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 364 | return output_tensor 365 | 366 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 367 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 368 | 369 | if len(from_shape) != len(to_shape): 370 | raise ValueError( 371 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 372 | 373 | if len(from_shape) == 3: 374 | batch_size = from_shape[0] 375 | from_seq_length = from_shape[1] 376 | to_seq_length = to_shape[1] 377 | elif len(from_shape) == 2: 378 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 379 | raise ValueError( 380 | "When passing in rank 2 tensors to attention_layer, the values " 381 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 382 | "must all be specified.") 383 | 384 | # Scalar dimensions referenced here: 385 | # B = batch size (number of sequences) 386 | # F = `from_tensor` sequence length 387 | # T = `to_tensor` sequence length 388 | # N = `num_attention_heads` 389 | # H = `size_per_head` 390 | 391 | from_tensor_2d = reshape_to_matrix(from_tensor) 392 | to_tensor_2d = reshape_to_matrix(to_tensor) 393 | 394 | # `query_layer` = [B*F, N*H] 395 | query_layer = tf.layers.dense( 396 | from_tensor_2d, 397 | num_attention_heads * size_per_head, 398 | activation=query_act, 399 | name="query", 400 | kernel_initializer=create_initializer(initializer_range)) 401 | 402 | # `key_layer` = [B*T, N*H] 403 | if query_equals_key: 404 | key_layer = query_layer 405 | else: 406 | key_layer = tf.layers.dense( 407 | to_tensor_2d, 408 | num_attention_heads * size_per_head, 409 | activation=key_act, 410 | name="key", 411 | kernel_initializer=create_initializer(initializer_range)) 412 | 413 | # `query_layer` = [B, N, F, H] 414 | query_layer = transpose_for_scores(query_layer, batch_size, 415 | num_attention_heads, from_seq_length, 416 | size_per_head) 417 | 418 | # `key_layer` = [B, N, T, H] 419 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 420 | to_seq_length, size_per_head) 421 | 422 | # Take the dot product between "query" and "key" to get the raw 423 | # attention scores. 424 | # `attention_scores` = [B, N, F, T] 425 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 426 | attention_scores = tf.multiply(attention_scores, 427 | 1.0 / math.sqrt(float(size_per_head))) 428 | 429 | if attention_mask is not None: 430 | # `attention_mask` = [B, 1, F, T] 431 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 432 | 433 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 434 | # masked positions, this operation will create a tensor which is 0.0 for 435 | # positions we want to attend and -10000.0 for masked positions. 436 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 437 | 438 | # Since we are adding it to the raw scores before the softmax, this is 439 | # effectively the same as removing these entirely. 440 | attention_scores = attention_scores * tf.cast(attention_mask, tf.float32) + adder 441 | 442 | if return_features: 443 | return attention_scores, query_layer, key_layer 444 | else: 445 | return attention_scores 446 | 447 | 448 | def attention_layer(from_tensor, 449 | to_tensor, 450 | attention_mask=None, 451 | num_attention_heads=1, 452 | size_per_head=512, 453 | query_act=None, 454 | key_act=None, 455 | value_act=None, 456 | attention_probs_dropout_prob=0.0, 457 | initializer_range=0.02, 458 | do_return_2d_tensor=False, 459 | batch_size=None, 460 | from_seq_length=None, 461 | to_seq_length=None): 462 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 463 | 464 | This is an implementation of multi-headed attention based on "Attention 465 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 466 | this is self-attention. Each timestep in `from_tensor` attends to the 467 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 468 | 469 | This function first projects `from_tensor` into a "query" tensor and 470 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 471 | of tensors of length `num_attention_heads`, where each tensor is of shape 472 | [batch_size, seq_length, size_per_head]. 473 | 474 | Then, the query and key tensors are dot-producted and scaled. These are 475 | softmaxed to obtain attention probabilities. The value tensors are then 476 | interpolated by these probabilities, then concatenated back to a single 477 | tensor and returned. 478 | 479 | In practice, the multi-headed attention are done with transposes and 480 | reshapes rather than actual separate tensors. 481 | 482 | Args: 483 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 484 | from_width]. 485 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 486 | attention_mask: (optional) int32 Tensor of shape [batch_size, 487 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 488 | attention scores will effectively be set to -infinity for any positions in 489 | the mask that are 0, and will be unchanged for positions that are 1. 490 | num_attention_heads: int. Number of attention heads. 491 | size_per_head: int. Size of each attention head. 492 | query_act: (optional) Activation function for the query transform. 493 | key_act: (optional) Activation function for the key transform. 494 | value_act: (optional) Activation function for the value transform. 495 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 496 | attention probabilities. 497 | initializer_range: float. Range of the weight initializer. 498 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 499 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 500 | output will be of shape [batch_size, from_seq_length, num_attention_heads 501 | * size_per_head]. 502 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 503 | of the 3D version of the `from_tensor` and `to_tensor`. 504 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 505 | of the 3D version of the `from_tensor`. 506 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 507 | of the 3D version of the `to_tensor`. 508 | 509 | Returns: 510 | float Tensor of shape [batch_size, from_seq_length, 511 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 512 | true, this will be of shape [batch_size * from_seq_length, 513 | num_attention_heads * size_per_head]). 514 | 515 | Raises: 516 | ValueError: Any of the arguments or tensor shapes are invalid. 517 | """ 518 | 519 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 520 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 521 | 522 | if len(from_shape) != len(to_shape): 523 | raise ValueError( 524 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 525 | 526 | if len(from_shape) == 3: 527 | batch_size = from_shape[0] 528 | from_seq_length = from_shape[1] 529 | to_seq_length = to_shape[1] 530 | elif len(from_shape) == 2: 531 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 532 | raise ValueError( 533 | "When passing in rank 2 tensors to attention_layer, the values " 534 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 535 | "must all be specified.") 536 | 537 | to_tensor_2d = reshape_to_matrix(to_tensor) 538 | 539 | # `value_layer` = [B*T, N*H] 540 | value_layer = tf.layers.dense( 541 | to_tensor_2d, 542 | num_attention_heads * size_per_head, 543 | activation=value_act, 544 | name="value", 545 | kernel_initializer=create_initializer(initializer_range)) 546 | 547 | # Normalize the attention scores to probabilities. 548 | # `attention_probs` = [B, N, F, T] 549 | attention_scores = attention_scores_layer(from_tensor, 550 | to_tensor, 551 | attention_mask, 552 | num_attention_heads, 553 | size_per_head, 554 | query_act, 555 | key_act, 556 | initializer_range, 557 | batch_size, 558 | from_seq_length, 559 | to_seq_length) 560 | 561 | # Normalize the attention scores to probabilities. 562 | # `attention_probs` = [B, N, F, T] 563 | attention_probs = tf.nn.softmax(attention_scores) 564 | 565 | # This is actually dropping out entire tokens to attend to, which might 566 | # seem a bit unusual, but is taken from the original Transformer paper. 567 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 568 | 569 | # `value_layer` = [B, T, N, H] 570 | value_layer = tf.reshape( 571 | value_layer, 572 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 573 | 574 | # `value_layer` = [B, N, T, H] 575 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 576 | 577 | # `context_layer` = [B, N, F, H] 578 | context_layer = tf.matmul(attention_probs, value_layer) 579 | 580 | # `context_layer` = [B, F, N, H] 581 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 582 | 583 | if do_return_2d_tensor: 584 | # `context_layer` = [B*F, N*H] 585 | context_layer = tf.reshape( 586 | context_layer, 587 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 588 | else: 589 | # `context_layer` = [B, F, N*H] 590 | context_layer = tf.reshape( 591 | context_layer, 592 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 593 | 594 | return context_layer 595 | 596 | 597 | FORES = [Fore.BLUE, 598 | Fore.CYAN, 599 | Fore.GREEN, 600 | Fore.MAGENTA, 601 | Fore.RED, 602 | Fore.YELLOW] 603 | BACKS = [Back.BLUE, 604 | Back.CYAN, 605 | Back.GREEN, 606 | Back.MAGENTA, 607 | Back.RED, 608 | Back.YELLOW] 609 | COLOR_WHEEL = FORES + [f + b for f in FORES for b in BACKS] 610 | 611 | 612 | def coref_pprint(tokens, clusters): 613 | clusters = [tuple(tuple(m) for m in c) for c in clusters] 614 | cluster_to_color = {c: i % len(COLOR_WHEEL) for i, c in enumerate(clusters)} 615 | pretty_str = '' 616 | color_stack = [] 617 | for i, t in enumerate(tokens): 618 | for c in clusters: 619 | for start, end in sorted(c, key=lambda m: m[1]): 620 | if i == start: 621 | cluster_color = cluster_to_color[c] 622 | pretty_str += Style.BRIGHT + COLOR_WHEEL[cluster_color] 623 | color_stack.append(cluster_color) 624 | 625 | pretty_str += t + u' ' 626 | 627 | for c in clusters: 628 | for start, end in c: 629 | if i == end: 630 | pretty_str += Style.RESET_ALL 631 | color_stack.pop(-1) 632 | if color_stack: 633 | pretty_str += Style.BRIGHT + COLOR_WHEEL[color_stack[-1]] 634 | 635 | print(pretty_str) 636 | 637 | class RetrievalEvaluator(object): 638 | def __init__(self): 639 | self._num_correct = 0 640 | self._num_gold = 0 641 | self._num_predicted = 0 642 | 643 | def update(self, gold_set, predicted_set): 644 | self._num_correct += len(gold_set & predicted_set) 645 | self._num_gold += len(gold_set) 646 | self._num_predicted += len(predicted_set) 647 | 648 | def recall(self): 649 | return maybe_divide(self._num_correct, self._num_gold) 650 | 651 | def precision(self): 652 | return maybe_divide(self._num_correct, self._num_predicted) 653 | 654 | def metrics(self): 655 | recall = self.recall() 656 | precision = self.precision() 657 | f1 = maybe_divide(2 * recall * precision, precision + recall) 658 | return recall, precision, f1 659 | 660 | 661 | class EmbeddingDictionary(object): 662 | def __init__(self, info, normalize=True, maybe_cache=None): 663 | self._size = info["size"] 664 | self._normalize = normalize 665 | self._path = info["path"] 666 | if maybe_cache is not None and maybe_cache._path == self._path: 667 | assert self._size == maybe_cache._size 668 | self._embeddings = maybe_cache._embeddings 669 | else: 670 | self._embeddings = self.load_embedding_dict(self._path) 671 | 672 | @property 673 | def size(self): 674 | return self._size 675 | 676 | def load_embedding_dict(self, path): 677 | print("Loading word embeddings from {}...".format(path)) 678 | default_embedding = np.zeros(self.size) 679 | embedding_dict = collections.defaultdict(lambda: default_embedding) 680 | if len(path) > 0: 681 | vocab_size = None 682 | with open(path) as f: 683 | for i, line in enumerate(f.readlines()): 684 | word_end = line.find(" ") 685 | word = line[:word_end] 686 | embedding = np.fromstring(line[word_end + 1:], np.float32, sep=" ") 687 | assert len(embedding) == self.size 688 | embedding_dict[word] = embedding 689 | if vocab_size is not None: 690 | assert vocab_size == len(embedding_dict) 691 | print("Done loading word embeddings.") 692 | return embedding_dict 693 | 694 | def __getitem__(self, key): 695 | embedding = self._embeddings[key] 696 | if self._normalize: 697 | embedding = self.normalize(embedding) 698 | return embedding 699 | 700 | def normalize(self, v): 701 | norm = np.linalg.norm(v) 702 | if norm > 0: 703 | return v / norm 704 | else: 705 | return v 706 | 707 | 708 | class CustomLSTMCell(tf.contrib.rnn.RNNCell): 709 | def __init__(self, num_units, batch_size, dropout): 710 | self._num_units = num_units 711 | self._dropout = dropout 712 | self._dropout_mask = tf.nn.dropout(tf.ones([batch_size, self.output_size]), dropout) 713 | self._initializer = self._block_orthonormal_initializer([self.output_size] * 3) 714 | initial_cell_state = tf.get_variable("lstm_initial_cell_state", [1, self.output_size]) 715 | initial_hidden_state = tf.get_variable("lstm_initial_hidden_state", [1, self.output_size]) 716 | self._initial_state = tf.contrib.rnn.LSTMStateTuple(initial_cell_state, initial_hidden_state) 717 | 718 | @property 719 | def state_size(self): 720 | return tf.contrib.rnn.LSTMStateTuple(self.output_size, self.output_size) 721 | 722 | @property 723 | def output_size(self): 724 | return self._num_units 725 | 726 | @property 727 | def initial_state(self): 728 | return self._initial_state 729 | 730 | def __call__(self, inputs, state, scope=None): 731 | """Long short-term memory cell (LSTM).""" 732 | with tf.variable_scope(scope or type(self).__name__): # "CustomLSTMCell" 733 | c, h = state 734 | h *= self._dropout_mask 735 | concat = projection(tf.concat([inputs, h], 1), 3 * self.output_size, initializer=self._initializer) 736 | i, j, o = tf.split(concat, num_or_size_splits=3, axis=1) 737 | i = tf.sigmoid(i) 738 | new_c = (1 - i) * c + i * tf.tanh(j) 739 | new_h = tf.tanh(new_c) * tf.sigmoid(o) 740 | new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 741 | return new_h, new_state 742 | 743 | def _orthonormal_initializer(self, scale=1.0): 744 | def _initializer(shape, dtype=tf.float32, partition_info=None): 745 | M1 = np.random.randn(shape[0], shape[0]).astype(np.float32) 746 | M2 = np.random.randn(shape[1], shape[1]).astype(np.float32) 747 | Q1, R1 = np.linalg.qr(M1) 748 | Q2, R2 = np.linalg.qr(M2) 749 | Q1 = Q1 * np.sign(np.diag(R1)) 750 | Q2 = Q2 * np.sign(np.diag(R2)) 751 | n_min = min(shape[0], shape[1]) 752 | params = np.dot(Q1[:, :n_min], Q2[:n_min, :]) * scale 753 | return params 754 | 755 | return _initializer 756 | 757 | def _block_orthonormal_initializer(self, output_sizes): 758 | def _initializer(shape, dtype=np.float32, partition_info=None): 759 | assert len(shape) == 2 760 | assert sum(output_sizes) == shape[1] 761 | initializer = self._orthonormal_initializer() 762 | params = np.concatenate([initializer([shape[0], o], dtype, partition_info) for o in output_sizes], 1) 763 | return params 764 | 765 | return _initializer 766 | 767 | 768 | def softmax(X, theta = 1.0, axis = None): 769 | """ 770 | Compute the softmax of each element along an axis of X. 771 | 772 | Parameters 773 | ---------- 774 | X: ND-Array. Probably should be floats. 775 | theta (optional): float parameter, used as a multiplier 776 | prior to exponentiation. Default = 1.0 777 | axis (optional): axis to compute values along. Default is the 778 | first non-singleton axis. 779 | 780 | Returns an array the same size as X. The result will sum to 1 781 | along the specified axis. 782 | """ 783 | 784 | # make X at least 2d 785 | y = np.atleast_2d(X) 786 | 787 | # find axis 788 | if axis is None: 789 | axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1) 790 | 791 | # multiply y against the theta parameter, 792 | y = y * float(theta) 793 | 794 | # subtract the max for numerical stability 795 | y = y - np.expand_dims(np.max(y, axis = axis), axis) 796 | 797 | # exponentiate y 798 | y = np.exp(y) 799 | 800 | # take the sum along the specified axis 801 | ax_sum = np.expand_dims(np.sum(y, axis = axis), axis) 802 | 803 | # finally: divide elementwise 804 | p = y / ax_sum 805 | 806 | # flatten if X was 1D 807 | if len(X.shape) == 1: p = p.flatten() 808 | 809 | return p 810 | 811 | 812 | def compute_p_m_entity(p_m_link, k): 813 | p_m_entity = tf.concat([[[1.]], tf.zeros([1, k - 1])], 1) 814 | 815 | def _time_step(i, p_m_entity): 816 | p_m_e = p_m_entity[:, :i] # [i, i] x[i, j] = p(m_i \in E_j) 817 | p_m_link_i = p_m_link[i:i + 1, :i] # [1, i] x[0, j] = p(a_i = j) 818 | p_m_e_i = tf.matmul(p_m_link_i, p_m_e) # [1, i] x[0, j] = \sum_k (p(a_i = k) * p(m_k \in E_j)) 819 | p_m_e_i = tf.concat([p_m_e_i, p_m_link[i:i + 1, i:i + 1]], 1) 820 | p_m_e_i = tf.pad(p_m_e_i, [[0, 0], [0, k - i - 1]], mode='CONSTANT') 821 | p_m_entity = tf.concat([p_m_entity, p_m_e_i], 0) 822 | return i + 1, p_m_entity 823 | 824 | _, p_m_entity = tf.while_loop(cond=lambda i, *_: tf.less(i, k), 825 | body=_time_step, 826 | loop_vars=(tf.constant(1), p_m_entity), 827 | shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, None]))) 828 | 829 | return p_m_entity 830 | 831 | def compute_b3_lost(p_m_entity, x_gold_class_cluster_ids_supgen, k, beta=2.0): 832 | # remove singleton entities 833 | gold_entities = tf.reduce_sum(x_gold_class_cluster_ids_supgen, 0) > 1.2 834 | 835 | sys_m_e = tf.one_hot(tf.argmax(p_m_entity, 1), k) 836 | sys_entities = tf.reduce_sum(sys_m_e, 0) > 1.2 837 | 838 | gold_entity_filter = tf.reshape(tf.where(gold_entities), [-1]) 839 | gold_cluster = tf.gather(tf.transpose(x_gold_class_cluster_ids_supgen), gold_entity_filter) 840 | 841 | sys_entity_filter, merge = tf.cond(pred=tf.reduce_any(sys_entities & gold_entities), 842 | fn1=lambda: (tf.reshape(tf.where(sys_entities), [-1]), tf.constant(0)), 843 | fn2=lambda: ( 844 | tf.reshape(tf.where(sys_entities | gold_entities), [-1]), tf.constant(1))) 845 | system_cluster = tf.gather(tf.transpose(p_m_entity), sys_entity_filter) 846 | 847 | # compute intersections 848 | gold_sys_intersect = tf.pow(tf.matmul(gold_cluster, system_cluster, transpose_b=True), 2) 849 | r_num = tf.reduce_sum(tf.reduce_sum(gold_sys_intersect, 1) / tf.reduce_sum(gold_cluster, 1)) 850 | r_den = tf.reduce_sum(gold_cluster) 851 | recall = tf.reshape(r_num / r_den, []) 852 | 853 | sys_gold_intersection = tf.transpose(gold_sys_intersect) 854 | p_num = tf.reduce_sum(tf.reduce_sum(sys_gold_intersection, 1) / tf.reduce_sum(system_cluster, 1)) 855 | p_den = tf.reduce_sum(system_cluster) 856 | prec = tf.reshape(p_num / p_den, []) 857 | 858 | beta_2 = beta ** 2 859 | f_beta = (1 + beta_2) * prec * recall / (beta_2 * prec + recall) 860 | 861 | lost = -f_beta 862 | # lost = tf.Print(lost, [merge, 863 | # r_num, r_den, p_num, p_den, 864 | # gold_entity_filter, sys_entity_filter, # tf.reduce_sum(p_m_entity, 0), 865 | # beta, recall, prec, f_beta], summarize=1000) 866 | 867 | return tf.cond(pred=tf.reduce_all([r_num > .1, p_num > .1, r_den > .1, p_den > .1]), 868 | fn1=lambda: lost, 869 | fn2=lambda: tf.stop_gradient(tf.constant(0.))) 870 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import six 27 | import tensorflow as tf 28 | 29 | 30 | class BertConfig(object): 31 | """Configuration for `BertModel`.""" 32 | 33 | def __init__(self, 34 | vocab_size, 35 | hidden_size=768, 36 | num_hidden_layers=12, 37 | num_attention_heads=12, 38 | intermediate_size=3072, 39 | hidden_act="gelu", 40 | hidden_dropout_prob=0.1, 41 | attention_probs_dropout_prob=0.1, 42 | max_position_embeddings=512, 43 | type_vocab_size=16, 44 | initializer_range=0.02): 45 | """Constructs BertConfig. 46 | 47 | Args: 48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 49 | hidden_size: Size of the encoder layers and the pooler layer. 50 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 51 | num_attention_heads: Number of attention heads for each attention layer in 52 | the Transformer encoder. 53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 54 | layer in the Transformer encoder. 55 | hidden_act: The non-linear activation function (function or string) in the 56 | encoder and pooler. 57 | hidden_dropout_prob: The dropout probability for all fully connected 58 | layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this model might 62 | ever be used with. Typically set this to something large just in case 63 | (e.g., 512 or 1024 or 2048). 64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 65 | `BertModel`. 66 | initializer_range: The stdev of the truncated_normal_initializer for 67 | initializing all weight matrices. 68 | """ 69 | self.vocab_size = vocab_size 70 | self.hidden_size = hidden_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.hidden_act = hidden_act 74 | self.intermediate_size = intermediate_size 75 | self.hidden_dropout_prob = hidden_dropout_prob 76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 77 | self.max_position_embeddings = max_position_embeddings 78 | self.type_vocab_size = type_vocab_size 79 | self.initializer_range = initializer_range 80 | 81 | @classmethod 82 | def from_dict(cls, json_object): 83 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 84 | config = BertConfig(vocab_size=None) 85 | for (key, value) in six.iteritems(json_object): 86 | config.__dict__[key] = value 87 | return config 88 | 89 | @classmethod 90 | def from_json_file(cls, json_file): 91 | """Constructs a `BertConfig` from a json file of parameters.""" 92 | with tf.gfile.GFile(json_file, "r") as reader: 93 | text = reader.read() 94 | return cls.from_dict(json.loads(text)) 95 | 96 | def to_dict(self): 97 | """Serializes this instance to a Python dictionary.""" 98 | output = copy.deepcopy(self.__dict__) 99 | return output 100 | 101 | def to_json_string(self): 102 | """Serializes this instance to a JSON string.""" 103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 104 | 105 | 106 | class BertModel(object): 107 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 108 | 109 | Example usage: 110 | 111 | ```python 112 | # Already been converted into WordPiece token ids 113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 116 | 117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 119 | 120 | model = modeling.BertModel(config=config, is_training=True, 121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 122 | 123 | label_embeddings = tf.get_variable(...) 124 | pooled_output = model.get_pooled_output() 125 | logits = tf.matmul(pooled_output, label_embeddings) 126 | ... 127 | ``` 128 | """ 129 | 130 | def __init__(self, 131 | config, 132 | is_training, 133 | input_ids, 134 | input_mask=None, 135 | token_type_ids=None, 136 | use_one_hot_embeddings=True, 137 | scope=None): 138 | """Constructor for BertModel. 139 | 140 | Args: 141 | config: `BertConfig` instance. 142 | is_training: bool. true for training model, false for eval model. Controls 143 | whether dropout will be applied. 144 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 149 | it is much faster if this is True, on the CPU or GPU, it is faster if 150 | this is False. 151 | scope: (optional) variable scope. Defaults to "bert". 152 | 153 | Raises: 154 | ValueError: The config is invalid or one of the input tensor shapes 155 | is invalid. 156 | """ 157 | config = copy.deepcopy(config) 158 | # if not is_training: 159 | # config.hidden_dropout_prob = 0.0 160 | # config.attention_probs_dropout_prob = 0.0 161 | 162 | input_shape = get_shape_list(input_ids, expected_rank=2) 163 | batch_size = input_shape[0] 164 | seq_length = input_shape[1] 165 | 166 | if input_mask is None: 167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 168 | 169 | if token_type_ids is None: 170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 171 | 172 | with tf.variable_scope(scope, default_name="bert"): 173 | with tf.variable_scope("embeddings"): 174 | # Perform embedding lookup on the word ids. 175 | (self.embedding_output, self.embedding_table) = embedding_lookup( 176 | input_ids=input_ids, 177 | vocab_size=config.vocab_size, 178 | embedding_size=config.hidden_size, 179 | initializer_range=config.initializer_range, 180 | word_embedding_name="word_embeddings", 181 | use_one_hot_embeddings=use_one_hot_embeddings) 182 | 183 | # Add positional embeddings and token type embeddings, then layer 184 | # normalize and perform dropout. 185 | self.embedding_output = embedding_postprocessor( 186 | input_tensor=self.embedding_output, 187 | use_token_type=True, 188 | token_type_ids=token_type_ids, 189 | token_type_vocab_size=config.type_vocab_size, 190 | token_type_embedding_name="token_type_embeddings", 191 | use_position_embeddings=True, 192 | position_embedding_name="position_embeddings", 193 | initializer_range=config.initializer_range, 194 | max_position_embeddings=config.max_position_embeddings, 195 | dropout_prob=config.hidden_dropout_prob, 196 | is_training=is_training) 197 | 198 | with tf.variable_scope("encoder"): 199 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 200 | # mask of shape [batch_size, seq_length, seq_length] which is used 201 | # for the attention scores. 202 | attention_mask = create_attention_mask_from_input_mask( 203 | input_ids, input_mask) 204 | 205 | # Run the stacked transformer. 206 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 207 | self.all_encoder_layers = transformer_model( 208 | input_tensor=self.embedding_output, 209 | attention_mask=attention_mask, 210 | hidden_size=config.hidden_size, 211 | num_hidden_layers=config.num_hidden_layers, 212 | num_attention_heads=config.num_attention_heads, 213 | intermediate_size=config.intermediate_size, 214 | intermediate_act_fn=get_activation(config.hidden_act), 215 | hidden_dropout_prob=config.hidden_dropout_prob, 216 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 217 | initializer_range=config.initializer_range, 218 | do_return_all_layers=True, 219 | is_training=is_training) 220 | 221 | self.sequence_output = self.all_encoder_layers[-1] 222 | # The "pooler" converts the encoded sequence tensor of shape 223 | # [batch_size, seq_length, hidden_size] to a tensor of shape 224 | # [batch_size, hidden_size]. This is necessary for segment-level 225 | # (or segment-pair-level) classification tasks where we need a fixed 226 | # dimensional representation of the segment. 227 | with tf.variable_scope("pooler"): 228 | # We "pool" the model by simply taking the hidden state corresponding 229 | # to the first token. We assume that this has been pre-trained 230 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 231 | self.pooled_output = tf.layers.dense( 232 | first_token_tensor, 233 | config.hidden_size, 234 | activation=tf.tanh, 235 | kernel_initializer=create_initializer(config.initializer_range)) 236 | 237 | def get_pooled_output(self): 238 | return self.pooled_output 239 | 240 | def get_sequence_output(self): 241 | """Gets final hidden layer of encoder. 242 | 243 | Returns: 244 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 245 | to the final hidden of the transformer encoder. 246 | """ 247 | return self.sequence_output 248 | 249 | def get_all_encoder_layers(self): 250 | return self.all_encoder_layers 251 | 252 | def get_embedding_output(self): 253 | """Gets output of the embedding lookup (i.e., input to the transformer). 254 | 255 | Returns: 256 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 257 | to the output of the embedding layer, after summing the word 258 | embeddings with the positional embeddings and the token type embeddings, 259 | then performing layer normalization. This is the input to the transformer. 260 | """ 261 | return self.embedding_output 262 | 263 | def get_embedding_table(self): 264 | return self.embedding_table 265 | 266 | 267 | def gelu(input_tensor): 268 | """Gaussian Error Linear Unit. 269 | 270 | This is a smoother version of the RELU. 271 | Original paper: https://arxiv.org/abs/1606.08415 272 | 273 | Args: 274 | input_tensor: float Tensor to perform activation. 275 | 276 | Returns: 277 | `input_tensor` with the GELU activation applied. 278 | """ 279 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 280 | return input_tensor * cdf 281 | 282 | 283 | def get_activation(activation_string): 284 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 285 | 286 | Args: 287 | activation_string: String name of the activation function. 288 | 289 | Returns: 290 | A Python function corresponding to the activation function. If 291 | `activation_string` is None, empty, or "linear", this will return None. 292 | If `activation_string` is not a string, it will return `activation_string`. 293 | 294 | Raises: 295 | ValueError: The `activation_string` does not correspond to a known 296 | activation. 297 | """ 298 | 299 | # We assume that anything that"s not a string is already an activation 300 | # function, so we just return it. 301 | if not isinstance(activation_string, six.string_types): 302 | return activation_string 303 | 304 | if not activation_string: 305 | return None 306 | 307 | act = activation_string.lower() 308 | if act == "linear": 309 | return None 310 | elif act == "relu": 311 | return tf.nn.relu 312 | elif act == "gelu": 313 | return gelu 314 | elif act == "tanh": 315 | return tf.tanh 316 | else: 317 | raise ValueError("Unsupported activation: %s" % act) 318 | 319 | 320 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 321 | """Compute the union of the current variables and checkpoint variables.""" 322 | assignment_map = {} 323 | initialized_variable_names = {} 324 | 325 | name_to_variable = collections.OrderedDict() 326 | for var in tvars: 327 | name = var.name 328 | m = re.match("^(.*):\\d+$", name) 329 | if m is not None: 330 | name = m.group(1) 331 | name_to_variable[name] = var 332 | 333 | init_vars = tf.train.list_variables(init_checkpoint) 334 | 335 | assignment_map = collections.OrderedDict() 336 | for x in init_vars: 337 | (name, var) = (x[0], x[1]) 338 | if name not in name_to_variable: 339 | continue 340 | assignment_map[name] = name 341 | initialized_variable_names[name] = 1 342 | initialized_variable_names[name + ":0"] = 1 343 | 344 | return (assignment_map, initialized_variable_names) 345 | 346 | 347 | def dropout(input_tensor, dropout_prob): 348 | """Perform dropout. 349 | 350 | Args: 351 | input_tensor: float Tensor. 352 | dropout_prob: Python float. The probability of dropping out a value (NOT of 353 | *keeping* a dimension as in `tf.nn.dropout`). 354 | 355 | Returns: 356 | A version of `input_tensor` with dropout applied. 357 | """ 358 | if dropout_prob is None or dropout_prob == 0.0: 359 | return input_tensor 360 | 361 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 362 | return output 363 | 364 | 365 | def layer_norm(input_tensor, name=None): 366 | """Run layer normalization on the last dimension of the tensor.""" 367 | return tf.contrib.layers.layer_norm( 368 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 369 | 370 | 371 | def layer_norm_and_dropout(input_tensor, dropout_prob, is_training, name=None): 372 | """Runs layer normalization followed by dropout.""" 373 | output_tensor = layer_norm(input_tensor, name) 374 | output_tensor = tf.layers.dropout(output_tensor, dropout_prob, training=is_training) 375 | return output_tensor 376 | 377 | 378 | def create_initializer(initializer_range=0.02): 379 | """Creates a `truncated_normal_initializer` with the given range.""" 380 | return tf.truncated_normal_initializer(stddev=initializer_range) 381 | 382 | 383 | def embedding_lookup(input_ids, 384 | vocab_size, 385 | embedding_size=128, 386 | initializer_range=0.02, 387 | word_embedding_name="word_embeddings", 388 | use_one_hot_embeddings=False): 389 | """Looks up words embeddings for id tensor. 390 | 391 | Args: 392 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 393 | ids. 394 | vocab_size: int. Size of the embedding vocabulary. 395 | embedding_size: int. Width of the word embeddings. 396 | initializer_range: float. Embedding initialization range. 397 | word_embedding_name: string. Name of the embedding table. 398 | use_one_hot_embeddings: bool. If True, use one-hot method for word 399 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 400 | for TPUs. 401 | 402 | Returns: 403 | float Tensor of shape [batch_size, seq_length, embedding_size]. 404 | """ 405 | # This function assumes that the input is of shape [batch_size, seq_length, 406 | # num_inputs]. 407 | # 408 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 409 | # reshape to [batch_size, seq_length, 1]. 410 | if input_ids.shape.ndims == 2: 411 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 412 | 413 | embedding_table = tf.get_variable( 414 | name=word_embedding_name, 415 | shape=[vocab_size, embedding_size], 416 | initializer=create_initializer(initializer_range)) 417 | 418 | if use_one_hot_embeddings: 419 | flat_input_ids = tf.reshape(input_ids, [-1]) 420 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 421 | output = tf.matmul(one_hot_input_ids, embedding_table) 422 | else: 423 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 424 | 425 | input_shape = get_shape_list(input_ids) 426 | 427 | output = tf.reshape(output, 428 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 429 | return (output, embedding_table) 430 | 431 | 432 | def embedding_postprocessor(input_tensor, 433 | use_token_type=False, 434 | token_type_ids=None, 435 | token_type_vocab_size=16, 436 | token_type_embedding_name="token_type_embeddings", 437 | use_position_embeddings=True, 438 | position_embedding_name="position_embeddings", 439 | initializer_range=0.02, 440 | max_position_embeddings=512, 441 | dropout_prob=0.1, 442 | is_training=True): 443 | """Performs various post-processing on a word embedding tensor. 444 | 445 | Args: 446 | input_tensor: float Tensor of shape [batch_size, seq_length, 447 | embedding_size]. 448 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 449 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 450 | Must be specified if `use_token_type` is True. 451 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 452 | token_type_embedding_name: string. The name of the embedding table variable 453 | for token type ids. 454 | use_position_embeddings: bool. Whether to add position embeddings for the 455 | position of each token in the sequence. 456 | position_embedding_name: string. The name of the embedding table variable 457 | for positional embeddings. 458 | initializer_range: float. Range of the weight initialization. 459 | max_position_embeddings: int. Maximum sequence length that might ever be 460 | used with this model. This can be longer than the sequence length of 461 | input_tensor, but cannot be shorter. 462 | dropout_prob: float. Dropout probability applied to the final output tensor. 463 | 464 | Returns: 465 | float tensor with same shape as `input_tensor`. 466 | 467 | Raises: 468 | ValueError: One of the tensor shapes or input values is invalid. 469 | """ 470 | input_shape = get_shape_list(input_tensor, expected_rank=3) 471 | batch_size = input_shape[0] 472 | seq_length = input_shape[1] 473 | width = input_shape[2] 474 | 475 | output = input_tensor 476 | 477 | if use_token_type: 478 | if token_type_ids is None: 479 | raise ValueError("`token_type_ids` must be specified if" 480 | "`use_token_type` is True.") 481 | token_type_table = tf.get_variable( 482 | name=token_type_embedding_name, 483 | shape=[token_type_vocab_size, width], 484 | initializer=create_initializer(initializer_range)) 485 | # This vocab will be small so we always do one-hot here, since it is always 486 | # faster for a small vocabulary. 487 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 488 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 489 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 490 | token_type_embeddings = tf.reshape(token_type_embeddings, 491 | [batch_size, seq_length, width]) 492 | output += token_type_embeddings 493 | 494 | if use_position_embeddings: 495 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 496 | with tf.control_dependencies([assert_op]): 497 | full_position_embeddings = tf.get_variable( 498 | name=position_embedding_name, 499 | shape=[max_position_embeddings, width], 500 | initializer=create_initializer(initializer_range)) 501 | # Since the position embedding table is a learned variable, we create it 502 | # using a (long) sequence length `max_position_embeddings`. The actual 503 | # sequence length might be shorter than this, for faster training of 504 | # tasks that do not have long sequences. 505 | # 506 | # So `full_position_embeddings` is effectively an embedding table 507 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 508 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 509 | # perform a slice. 510 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 511 | [seq_length, -1]) 512 | num_dims = len(output.shape.as_list()) 513 | 514 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 515 | # we broadcast among the first dimensions, which is typically just 516 | # the batch size. 517 | position_broadcast_shape = [] 518 | for _ in range(num_dims - 2): 519 | position_broadcast_shape.append(1) 520 | position_broadcast_shape.extend([seq_length, width]) 521 | position_embeddings = tf.reshape(position_embeddings, 522 | position_broadcast_shape) 523 | output += position_embeddings 524 | 525 | output = layer_norm_and_dropout(output, dropout_prob, is_training=is_training) 526 | return output 527 | 528 | 529 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 530 | """Create 3D attention mask from a 2D tensor mask. 531 | 532 | Args: 533 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 534 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 535 | 536 | Returns: 537 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 538 | """ 539 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 540 | batch_size = from_shape[0] 541 | from_seq_length = from_shape[1] 542 | 543 | to_shape = get_shape_list(to_mask, expected_rank=2) 544 | to_seq_length = to_shape[1] 545 | 546 | to_mask = tf.cast( 547 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 548 | 549 | # We don't assume that `from_tensor` is a mask (although it could be). We 550 | # don't actually care if we attend *from* padding tokens (only *to* padding) 551 | # tokens so we create a tensor of all ones. 552 | # 553 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 554 | broadcast_ones = tf.ones( 555 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 556 | 557 | # Here we broadcast along two dimensions to create the mask. 558 | mask = broadcast_ones * to_mask 559 | 560 | return mask 561 | 562 | 563 | def attention_scores_layer(from_tensor, 564 | to_tensor, 565 | attention_mask=None, 566 | num_attention_heads=1, 567 | size_per_head=512, 568 | query_act=None, 569 | key_act=None, 570 | initializer_range=0.02, 571 | batch_size=None, 572 | from_seq_length=None, 573 | to_seq_length=None, 574 | query_equals_key=False, 575 | return_features=False): 576 | """Calculate multi-headed attention probabilities from `from_tensor` to `to_tensor`. 577 | 578 | This is an implementation of multi-headed attention based on "Attention 579 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 580 | this is self-attention. Each timestep in `from_tensor` attends to the 581 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 582 | 583 | This function first projects `from_tensor` into a "query" tensor and 584 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 585 | of tensors of length `num_attention_heads`, where each tensor is of shape 586 | [batch_size, seq_length, size_per_head]. 587 | 588 | Then, the query and key tensors are dot-producted and scaled. These are 589 | softmaxed to obtain attention probabilities. The value tensors are then 590 | interpolated by these probabilities, then concatenated back to a single 591 | tensor and returned. 592 | 593 | In practice, the multi-headed attention are done with transposes and 594 | reshapes rather than actual separate tensors. 595 | 596 | Args: 597 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 598 | from_width]. 599 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 600 | attention_mask: (optional) int32 Tensor of shape [batch_size, 601 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 602 | attention scores will effectively be set to -infinity for any positions in 603 | the mask that are 0, and will be unchanged for positions that are 1. 604 | num_attention_heads: int. Number of attention heads. 605 | size_per_head: int. Size of each attention head. 606 | query_act: (optional) Activation function for the query transform. 607 | key_act: (optional) Activation function for the key transform. 608 | initializer_range: float. Range of the weight initializer. 609 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 610 | of the 3D version of the `from_tensor` and `to_tensor`. 611 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 612 | of the 3D version of the `from_tensor`. 613 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 614 | of the 3D version of the `to_tensor`. 615 | 616 | Returns: 617 | float Tensor of shape [batch_size, num_attention_heads, from_seq_length, to_seq_length]. 618 | 619 | Raises: 620 | ValueError: Any of the arguments or tensor shapes are invalid. 621 | """ 622 | 623 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 624 | seq_length, width): 625 | output_tensor = tf.reshape( 626 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 627 | 628 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 629 | return output_tensor 630 | 631 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 632 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 633 | 634 | if len(from_shape) != len(to_shape): 635 | raise ValueError( 636 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 637 | 638 | if len(from_shape) == 3: 639 | batch_size = from_shape[0] 640 | from_seq_length = from_shape[1] 641 | to_seq_length = to_shape[1] 642 | elif len(from_shape) == 2: 643 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 644 | raise ValueError( 645 | "When passing in rank 2 tensors to attention_layer, the values " 646 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 647 | "must all be specified.") 648 | 649 | # Scalar dimensions referenced here: 650 | # B = batch size (number of sequences) 651 | # F = `from_tensor` sequence length 652 | # T = `to_tensor` sequence length 653 | # N = `num_attention_heads` 654 | # H = `size_per_head` 655 | 656 | from_tensor_2d = reshape_to_matrix(from_tensor) 657 | to_tensor_2d = reshape_to_matrix(to_tensor) 658 | 659 | # `query_layer` = [B*F, N*H] 660 | query_layer = tf.layers.dense( 661 | from_tensor_2d, 662 | num_attention_heads * size_per_head, 663 | activation=query_act, 664 | name="query", 665 | kernel_initializer=create_initializer(initializer_range)) 666 | 667 | # `key_layer` = [B*T, N*H] 668 | if query_equals_key: 669 | key_layer = query_layer 670 | else: 671 | key_layer = tf.layers.dense( 672 | to_tensor_2d, 673 | num_attention_heads * size_per_head, 674 | activation=key_act, 675 | name="key", 676 | kernel_initializer=create_initializer(initializer_range)) 677 | 678 | # `query_layer` = [B, N, F, H] 679 | query_layer = transpose_for_scores(query_layer, batch_size, 680 | num_attention_heads, from_seq_length, 681 | size_per_head) 682 | 683 | # `key_layer` = [B, N, T, H] 684 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 685 | to_seq_length, size_per_head) 686 | 687 | # Take the dot product between "query" and "key" to get the raw 688 | # attention scores. 689 | # `attention_scores` = [B, N, F, T] 690 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 691 | attention_scores = tf.multiply(attention_scores, 692 | 1.0 / math.sqrt(float(size_per_head))) 693 | 694 | if attention_mask is not None: 695 | # `attention_mask` = [B, 1, F, T] 696 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 697 | 698 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 699 | # masked positions, this operation will create a tensor which is 0.0 for 700 | # positions we want to attend and -10000.0 for masked positions. 701 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 702 | 703 | # Since we are adding it to the raw scores before the softmax, this is 704 | # effectively the same as removing these entirely. 705 | attention_scores = attention_scores * tf.cast(attention_mask, tf.float32) + adder 706 | 707 | if return_features: 708 | return attention_scores, query_layer, key_layer 709 | else: 710 | return attention_scores 711 | 712 | 713 | def attention_layer(from_tensor, 714 | to_tensor, 715 | attention_mask=None, 716 | num_attention_heads=1, 717 | size_per_head=512, 718 | query_act=None, 719 | key_act=None, 720 | value_act=None, 721 | attention_probs_dropout_prob=0.0, 722 | initializer_range=0.02, 723 | do_return_2d_tensor=False, 724 | batch_size=None, 725 | from_seq_length=None, 726 | to_seq_length=None, 727 | is_training=True): 728 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 729 | 730 | This is an implementation of multi-headed attention based on "Attention 731 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 732 | this is self-attention. Each timestep in `from_tensor` attends to the 733 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 734 | 735 | This function first projects `from_tensor` into a "query" tensor and 736 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 737 | of tensors of length `num_attention_heads`, where each tensor is of shape 738 | [batch_size, seq_length, size_per_head]. 739 | 740 | Then, the query and key tensors are dot-producted and scaled. These are 741 | softmaxed to obtain attention probabilities. The value tensors are then 742 | interpolated by these probabilities, then concatenated back to a single 743 | tensor and returned. 744 | 745 | In practice, the multi-headed attention are done with transposes and 746 | reshapes rather than actual separate tensors. 747 | 748 | Args: 749 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 750 | from_width]. 751 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 752 | attention_mask: (optional) int32 Tensor of shape [batch_size, 753 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 754 | attention scores will effectively be set to -infinity for any positions in 755 | the mask that are 0, and will be unchanged for positions that are 1. 756 | num_attention_heads: int. Number of attention heads. 757 | size_per_head: int. Size of each attention head. 758 | query_act: (optional) Activation function for the query transform. 759 | key_act: (optional) Activation function for the key transform. 760 | value_act: (optional) Activation function for the value transform. 761 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 762 | attention probabilities. 763 | initializer_range: float. Range of the weight initializer. 764 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 765 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 766 | output will be of shape [batch_size, from_seq_length, num_attention_heads 767 | * size_per_head]. 768 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 769 | of the 3D version of the `from_tensor` and `to_tensor`. 770 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 771 | of the 3D version of the `from_tensor`. 772 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 773 | of the 3D version of the `to_tensor`. 774 | 775 | Returns: 776 | float Tensor of shape [batch_size, from_seq_length, 777 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 778 | true, this will be of shape [batch_size * from_seq_length, 779 | num_attention_heads * size_per_head]). 780 | 781 | Raises: 782 | ValueError: Any of the arguments or tensor shapes are invalid. 783 | """ 784 | 785 | to_tensor_2d = reshape_to_matrix(to_tensor) 786 | 787 | # `value_layer` = [B*T, N*H] 788 | value_layer = tf.layers.dense( 789 | to_tensor_2d, 790 | num_attention_heads * size_per_head, 791 | activation=value_act, 792 | name="value", 793 | kernel_initializer=create_initializer(initializer_range)) 794 | 795 | # Normalize the attention scores to probabilities. 796 | # `attention_probs` = [B, N, F, T] 797 | attention_scores = attention_scores_layer(from_tensor, 798 | to_tensor, 799 | attention_mask, 800 | num_attention_heads, 801 | size_per_head, 802 | query_act, 803 | key_act, 804 | initializer_range, 805 | batch_size, 806 | from_seq_length, 807 | to_seq_length) 808 | 809 | # Normalize the attention scores to probabilities. 810 | # `attention_probs` = [B, N, F, T] 811 | attention_probs = tf.nn.softmax(attention_scores) 812 | 813 | # This is actually dropping out entire tokens to attend to, which might 814 | # seem a bit unusual, but is taken from the original Transformer paper. 815 | attention_probs = tf.layers.dropout(attention_probs, attention_probs_dropout_prob, training=is_training) 816 | 817 | # `value_layer` = [B, T, N, H] 818 | value_layer = tf.reshape( 819 | value_layer, 820 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 821 | 822 | # `value_layer` = [B, N, T, H] 823 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 824 | 825 | # `context_layer` = [B, N, F, H] 826 | context_layer = tf.matmul(attention_probs, value_layer) 827 | 828 | # `context_layer` = [B, F, N, H] 829 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 830 | 831 | if do_return_2d_tensor: 832 | # `context_layer` = [B*F, N*H] 833 | context_layer = tf.reshape( 834 | context_layer, 835 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 836 | else: 837 | # `context_layer` = [B, F, N*H] 838 | context_layer = tf.reshape( 839 | context_layer, 840 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 841 | 842 | return context_layer 843 | 844 | 845 | def transformer_model(input_tensor, 846 | attention_mask=None, 847 | hidden_size=768, 848 | num_hidden_layers=12, 849 | num_attention_heads=12, 850 | intermediate_size=3072, 851 | intermediate_act_fn=gelu, 852 | hidden_dropout_prob=0.1, 853 | attention_probs_dropout_prob=0.1, 854 | initializer_range=0.02, 855 | do_return_all_layers=False, 856 | is_training=True): 857 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 858 | 859 | This is almost an exact implementation of the original Transformer encoder. 860 | 861 | See the original paper: 862 | https://arxiv.org/abs/1706.03762 863 | 864 | Also see: 865 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 866 | 867 | Args: 868 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 869 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 870 | seq_length], with 1 for positions that can be attended to and 0 in 871 | positions that should not be. 872 | hidden_size: int. Hidden size of the Transformer. 873 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 874 | num_attention_heads: int. Number of attention heads in the Transformer. 875 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 876 | forward) layer. 877 | intermediate_act_fn: function. The non-linear activation function to apply 878 | to the output of the intermediate/feed-forward layer. 879 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 880 | attention_probs_dropout_prob: float. Dropout probability of the attention 881 | probabilities. 882 | initializer_range: float. Range of the initializer (stddev of truncated 883 | normal). 884 | do_return_all_layers: Whether to also return all layers or just the final 885 | layer. 886 | 887 | Returns: 888 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 889 | hidden layer of the Transformer. 890 | 891 | Raises: 892 | ValueError: A Tensor shape or parameter is invalid. 893 | """ 894 | if hidden_size % num_attention_heads != 0: 895 | raise ValueError( 896 | "The hidden size (%d) is not a multiple of the number of attention " 897 | "heads (%d)" % (hidden_size, num_attention_heads)) 898 | 899 | attention_head_size = int(hidden_size / num_attention_heads) 900 | input_shape = get_shape_list(input_tensor, expected_rank=3) 901 | batch_size = input_shape[0] 902 | seq_length = input_shape[1] 903 | input_width = input_shape[2] 904 | 905 | # The Transformer performs sum residuals on all layers so the input needs 906 | # to be the same as the hidden size. 907 | if input_width != hidden_size: 908 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 909 | (input_width, hidden_size)) 910 | 911 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 912 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 913 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 914 | # help the optimizer. 915 | prev_output = reshape_to_matrix(input_tensor) 916 | 917 | all_layer_outputs = [] 918 | for layer_idx in range(num_hidden_layers): 919 | with tf.variable_scope("layer_%d" % layer_idx): 920 | layer_input = prev_output 921 | 922 | with tf.variable_scope("attention"): 923 | attention_heads = [] 924 | with tf.variable_scope("self"): 925 | attention_head = attention_layer( 926 | from_tensor=layer_input, 927 | to_tensor=layer_input, 928 | attention_mask=attention_mask, 929 | num_attention_heads=num_attention_heads, 930 | size_per_head=attention_head_size, 931 | attention_probs_dropout_prob=attention_probs_dropout_prob, 932 | initializer_range=initializer_range, 933 | do_return_2d_tensor=True, 934 | batch_size=batch_size, 935 | from_seq_length=seq_length, 936 | to_seq_length=seq_length, 937 | is_training=is_training) 938 | attention_heads.append(attention_head) 939 | 940 | attention_output = None 941 | if len(attention_heads) == 1: 942 | attention_output = attention_heads[0] 943 | else: 944 | # In the case where we have other sequences, we just concatenate 945 | # them to the self-attention head before the projection. 946 | attention_output = tf.concat(attention_heads, axis=-1) 947 | 948 | # Run a linear projection of `hidden_size` then add a residual 949 | # with `layer_input`. 950 | with tf.variable_scope("output"): 951 | attention_output = tf.layers.dense( 952 | attention_output, 953 | hidden_size, 954 | kernel_initializer=create_initializer(initializer_range)) 955 | attention_output = tf.layers.dropout(attention_output, hidden_dropout_prob, training=is_training) 956 | attention_output = layer_norm(attention_output + layer_input) 957 | 958 | # The activation is only applied to the "intermediate" hidden layer. 959 | with tf.variable_scope("intermediate"): 960 | intermediate_output = tf.layers.dense( 961 | attention_output, 962 | intermediate_size, 963 | activation=intermediate_act_fn, 964 | kernel_initializer=create_initializer(initializer_range)) 965 | 966 | # Down-project back to `hidden_size` then add the residual. 967 | with tf.variable_scope("output"): 968 | layer_output = tf.layers.dense( 969 | intermediate_output, 970 | hidden_size, 971 | kernel_initializer=create_initializer(initializer_range)) 972 | layer_output = tf.layers.dropout(layer_output, hidden_dropout_prob, training=is_training) 973 | layer_output = layer_norm(layer_output + attention_output) 974 | prev_output = layer_output 975 | all_layer_outputs.append(layer_output) 976 | 977 | if do_return_all_layers: 978 | final_outputs = [] 979 | for layer_output in all_layer_outputs: 980 | final_output = reshape_from_matrix(layer_output, input_shape) 981 | final_outputs.append(final_output) 982 | return final_outputs 983 | else: 984 | final_output = reshape_from_matrix(prev_output, input_shape) 985 | return final_output 986 | 987 | 988 | def get_shape_list(tensor, expected_rank=None, name=None): 989 | """Returns a list of the shape of tensor, preferring static dimensions. 990 | 991 | Args: 992 | tensor: A tf.Tensor object to find the shape of. 993 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 994 | specified and the `tensor` has a different rank, and exception will be 995 | thrown. 996 | name: Optional name of the tensor for the error message. 997 | 998 | Returns: 999 | A list of dimensions of the shape of tensor. All static dimensions will 1000 | be returned as python integers, and dynamic dimensions will be returned 1001 | as tf.Tensor scalars. 1002 | """ 1003 | if name is None: 1004 | name = tensor.name 1005 | 1006 | if expected_rank is not None: 1007 | assert_rank(tensor, expected_rank, name) 1008 | 1009 | shape = tensor.shape.as_list() 1010 | 1011 | non_static_indexes = [] 1012 | for (index, dim) in enumerate(shape): 1013 | if dim is None: 1014 | non_static_indexes.append(index) 1015 | 1016 | if not non_static_indexes: 1017 | return shape 1018 | 1019 | dyn_shape = tf.shape(tensor) 1020 | for index in non_static_indexes: 1021 | shape[index] = dyn_shape[index] 1022 | return shape 1023 | 1024 | 1025 | def reshape_to_matrix(input_tensor): 1026 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 1027 | ndims = input_tensor.shape.ndims 1028 | if ndims < 2: 1029 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 1030 | (input_tensor.shape)) 1031 | if ndims == 2: 1032 | return input_tensor 1033 | 1034 | width = input_tensor.shape[-1] 1035 | output_tensor = tf.reshape(input_tensor, [-1, width]) 1036 | return output_tensor 1037 | 1038 | 1039 | def reshape_from_matrix(output_tensor, orig_shape_list): 1040 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 1041 | if len(orig_shape_list) == 2: 1042 | return output_tensor 1043 | 1044 | output_shape = get_shape_list(output_tensor) 1045 | 1046 | orig_dims = orig_shape_list[0:-1] 1047 | width = output_shape[-1] 1048 | 1049 | return tf.reshape(output_tensor, orig_dims + [width]) 1050 | 1051 | 1052 | def assert_rank(tensor, expected_rank, name=None): 1053 | """Raises an exception if the tensor rank is not of the expected rank. 1054 | 1055 | Args: 1056 | tensor: A tf.Tensor to check the rank of. 1057 | expected_rank: Python integer or list of integers, expected rank. 1058 | name: Optional name of the tensor for the error message. 1059 | 1060 | Raises: 1061 | ValueError: If the expected shape doesn't match the actual shape. 1062 | """ 1063 | if name is None: 1064 | name = tensor.name 1065 | 1066 | expected_rank_dict = {} 1067 | if isinstance(expected_rank, six.integer_types): 1068 | expected_rank_dict[expected_rank] = True 1069 | else: 1070 | for x in expected_rank: 1071 | expected_rank_dict[x] = True 1072 | 1073 | actual_rank = tensor.shape.ndims 1074 | if actual_rank not in expected_rank_dict: 1075 | scope_name = tf.get_variable_scope().name 1076 | raise ValueError( 1077 | "For the tensor `%s` in scope `%s`, the actual rank " 1078 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 1079 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 1080 | --------------------------------------------------------------------------------