├── README.md ├── create_datasets.sh ├── data └── config │ ├── config.knn.conll2003.json │ ├── config.knn.genia.json │ ├── config.span.conll2003.json │ └── config.span.genia.json ├── models ├── __init__.py ├── base_model.py ├── decoders.py ├── knn_models.py ├── network_components.py └── span_models.py ├── requirements.txt ├── run_knn_models.py ├── run_span_models.py ├── scripts ├── convert_conll03_to_json.py ├── convert_genia_to_json.py └── retrieve_knn_sents_with_glove.py ├── train_knn_models.py ├── train_span_models.py └── utils ├── __init__.py ├── batchers ├── __init__.py ├── base_batchers.py ├── knn_batchers.py └── span_batchers.py ├── common.py ├── data_utils.py ├── logger.py └── preprocessors ├── __init__.py ├── base_preprocessors.py ├── knn_preprocessors.py └── span_preprocessors.py /README.md: -------------------------------------------------------------------------------- 1 | # Instance-Based Named Entity Recognizer 2 | 3 | This codebase is partially based on [neural_sequence_labeling](https://github.com/IsaacChanghau/neural_sequence_labeling) 4 | 5 | ## Citation 6 | * Instance-Based Learning of Span Representations: A Case Study through Named Entity Recognition 7 | * Hiroki Ouchi, Jun Suzuki, Sosuke Kobayashi, Sho Yokoi, Tatsuki Kuribayashi, Ryuto Konno, Kentaro Inui 8 | * In ACL 2020 9 | * Conference paper: https://www.aclweb.org/anthology/2020.acl-main.575/ 10 | * arXiv version: https://arxiv.org/abs/2004.14514 11 | 12 | ``` 13 | @inproceedings{ouchi-etal-2020-instance, 14 | title = "Instance-Based Learning of Span Representations: A Case Study through Named Entity Recognition", 15 | author = "Ouchi, Hiroki and 16 | Suzuki, Jun and 17 | Kobayashi, Sosuke and 18 | Yokoi, Sho and 19 | Kuribayashi, Tatsuki and 20 | Konno, Ryuto and 21 | Inui, Kentaro", 22 | booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", 23 | month = jul, 24 | year = "2020", 25 | address = "Online", 26 | publisher = "Association for Computational Linguistics", 27 | url = "https://www.aclweb.org/anthology/2020.acl-main.575", 28 | pages = "6452--6459", 29 | abstract = "Interpretable rationales for model predictions play a critical role in practical applications. In this study, we develop models possessing interpretable inference process for structured prediction. Specifically, we present a method of instance-based learning that learns similarities between spans. At inference time, each span is assigned a class label based on its similar spans in the training set, where it is easy to understand how much each training instance contributes to the predictions. Through empirical analysis on named entity recognition, we demonstrate that our method enables to build models that have high interpretability without sacrificing performance.", 30 | } 31 | ``` 32 | 33 | ## Prerequisites 34 | * [python3](https://www.python.org/downloads/) 35 | * [TensorFlow](https://www.tensorflow.org/) 36 | * [h5py](https://www.h5py.org/) 37 | 38 | ## Installation 39 | - CPU 40 | ``` 41 | conda create -n instance-based-ner python=3.6 42 | source activate instance-based-ner 43 | conda install -c conda-forge tensorflow 44 | pip install ujson tqdm 45 | git clone https://github.com/cl-tohoku/instance-based-ner_dev.git 46 | ``` 47 | - GPU 48 | ``` 49 | conda create -n instance-based-ner python=3.6 50 | source activate instance-based-ner 51 | pip install tensorflow-gpu==1.10 ujson tqdm 52 | git clone https://github.com/cl-tohoku/instance-based-ner_dev.git 53 | ``` 54 | 55 | ## Data Preparation 56 | `./create_datasets.sh` 57 | 58 | ## Pretrained Models 59 | * [Instance-based span model](https://drive.google.com/open?id=1d_KzED0UKEVnorymxiylzEOHpXoFF8TN) 60 | * [Classifier-based span model](https://drive.google.com/open?id=16MFR1IQ5mPx0bFAXMxdEbnTEn8zO2RNX) 61 | 62 | ## Get Started 63 | `python run_knn_models.py --mode cmd --config_file checkpoint_knn_conll2003_lstm-minus_batch8_keep07_0/config.json` 64 | 65 | ## Usage 66 | ### Instance-based span model 67 | * Training: `python train_knn_models.py --config_file data/config/config.knn.conll2003.json` 68 | * Predicting with random training sentences: `python run_knn_models.py --config_file checkpoint_knn/conll2003/config.json --knn_sampling random --data_path data/conll2003/valid.json` 69 | * Predicting with nearest training sentences: `python run_knn_models.py --config_file checkpoint_knn/conll2003/config.json --knn_sampling random --data_path data/conll2003/valid.glove.50-nn.json` 70 | ### Classifier-based span model 71 | * Training: `python train_span_models.py --config_file data/config/config.span.conll2003.json` 72 | * Predicting: `python run_span_models.py --config_file checkpoint_span/conll2003/config.json --data_path data/conll2003/valid.json` 73 | 74 | ## LICENSE 75 | MIT License 76 | -------------------------------------------------------------------------------- /create_datasets.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Download the GloVe embeddings 4 | wget http://nlp.stanford.edu/data/glove.6B.zip 5 | mkdir data/emb 6 | mv glove.6B/* data/emb/ 7 | 8 | # Download the CoNLL-2003 dataset 9 | git clone https://github.com/IsaacChanghau/neural_sequence_labeling.git 10 | mkdir data/conll2003 11 | mv neural_sequence_labeling/data/raw/conll2003/raw/* data/conll2003/ 12 | python scripts/convert_conll03_to_json.py --input_file data/conll2003/train.txt --output_file data/conll2003/train.json --remove_duplicates 13 | python scripts/convert_conll03_to_json.py --input_file data/conll2003/valid.txt --output_file data/conll2003/valid.json 14 | python scripts/convert_conll03_to_json.py --input_file data/conll2003/test.txt --output_file data/conll2003/test.json 15 | # Valid/Test sets with nearest training sentences 16 | python scripts/retrieve_knn_sents_with_glove.py --train_json data/conll2003/train.json --test_json data/conll2003/valid.json --glove data/emb/glove.6B.100d.txt --k 50 --output_file data/conll2003/valid.glove.50-nn.json 17 | python scripts/retrieve_knn_sents_with_glove.py --train_json data/conll2003/train.json --test_json data/conll2003/test.json --glove data/emb/glove.6B.100d.txt --k 50 --output_file data/conll2003/test.glove.50-nn.json 18 | 19 | # Download the GENIA dataset 20 | git clone https://github.com/thecharm/boundary-aware-nested-ner.git 21 | mkdir data/genia 22 | mv boundary-aware-nested-ner/Our_boundary-aware_model/data/genia/* data/genia/ 23 | python scripts/convert_genia_to_json.py --input_file data/genia/genia.train.iob2 --output_file data/genia/train.json --remove_duplicates 24 | python scripts/convert_genia_to_json.py --input_file data/genia/genia.dev.iob2 --output_file data/genia/valid.json 25 | python scripts/convert_genia_to_json.py --input_file data/genia/genia.test.iob2 --output_file data/genia/test.json 26 | # Valid/Test sets with nearest training sentences 27 | python scripts/retrieve_knn_sents_with_glove.py --train_json data/genia/train.json --test_json data/genia/valid.json --glove data/emb/glove.6B.100d.txt --k 50 --output_file data/genia/valid.glove.50-nn.json 28 | python scripts/retrieve_knn_sents_with_glove.py --train_json data/genia/train.json --test_json data/genia/test.json --glove data/emb/glove.6B.100d.txt --k 50 --output_file data/genia/test.glove.50-nn.json 29 | -------------------------------------------------------------------------------- /data/config/config.knn.conll2003.json: -------------------------------------------------------------------------------- 1 | { "task_name": "ner", 2 | "model_name": "ner_knn_model", 3 | "model": "knn", 4 | "raw_path": "data/conll2003", 5 | "save_path": "data/dataset/conll2003_knn", 6 | "checkpoint_path": "checkpoint_knn/conll2003", 7 | "summary_path": "checkpoint_knn/conll2003/summary", 8 | "glove_path": "data/emb/glove.6B.100d.txt", 9 | "glove_name": "6B", 10 | "vocab": "data/dataset/conll2003_knn/vocab.json", 11 | "train_set": "data/dataset/conll2003_knn/train.json", 12 | "valid_set": "data/dataset/conll2003_knn/valid.json", 13 | "pretrained_emb": "data/dataset/conll2003_knn/glove_emb.npz", 14 | "data_size": 1000000000, 15 | "max_sent_len": 1000000, 16 | "k": 50, 17 | "knn_sampling": "random", 18 | "predict": "max_margin", 19 | "cell_type": "lstm", 20 | "char_lowercase": false, 21 | "emb_dim": 100, 22 | "char_emb_dim": 30, 23 | "char_proj_dim": 100, 24 | "filter_sizes": [30, 30], 25 | "channel_sizes": [3, 3], 26 | "highway_layers": 2, 27 | "num_units": 100, 28 | "num_layers": 2, 29 | "bilstm_type": "minus", 30 | "tuning_emb": false, 31 | "use_stack_rnn": true, 32 | "use_pretrained": true, 33 | "use_chars": true, 34 | "use_highway": true, 35 | "max_span_len": 6, 36 | "max_n_spans": 0, 37 | "optimizer": "adam", 38 | "lr": 0.001, 39 | "use_lr_decay": true, 40 | "lr_decay": 0.05, 41 | "minimal_lr": 1e-5, 42 | "grad_clip": 5.0, 43 | "keep_prob": 0.7, 44 | "batch_size": 8, 45 | "epochs": 100, 46 | "max_to_keep": 2 47 | } 48 | 49 | -------------------------------------------------------------------------------- /data/config/config.knn.genia.json: -------------------------------------------------------------------------------- 1 | { "task_name": "ner", 2 | "model_name": "ner_knn_model", 3 | "model": "knn", 4 | "raw_path": "data/genia", 5 | "save_path": "data/dataset/genia_knn", 6 | "checkpoint_path": "checkpoint_knn/genia", 7 | "summary_path": "checkpoint_knn/genia/summary", 8 | "glove_path": "data/emb/glove.6B.100d.txt", 9 | "glove_name": "6B", 10 | "vocab": "data/dataset/genia_knn/vocab.json", 11 | "train_set": "data/dataset/genia_knn/train.json", 12 | "valid_set": "data/dataset/genia_knn/valid.json", 13 | "pretrained_emb": "data/dataset/genia_knn/glove_emb.npz", 14 | "data_size": 1000000000, 15 | "max_sent_len": 1000000, 16 | "k": 10, 17 | "knn_sampling": "random", 18 | "predict": "max_margin", 19 | "cell_type": "lstm", 20 | "char_lowercase": false, 21 | "emb_dim": 100, 22 | "char_emb_dim": 30, 23 | "char_proj_dim": 100, 24 | "filter_sizes": [30, 30], 25 | "channel_sizes": [3, 3], 26 | "highway_layers": 2, 27 | "num_units": 100, 28 | "num_layers": 2, 29 | "bilstm_type": "minus", 30 | "tuning_emb": false, 31 | "use_stack_rnn": true, 32 | "use_pretrained": true, 33 | "use_chars": true, 34 | "use_highway": true, 35 | "max_span_len": 6, 36 | "max_n_spans": 0, 37 | "optimizer": "adam", 38 | "lr": 0.001, 39 | "use_lr_decay": true, 40 | "lr_decay": 0.05, 41 | "minimal_lr": 1e-5, 42 | "grad_clip": 5.0, 43 | "keep_prob": 0.9, 44 | "batch_size": 8, 45 | "epochs": 100, 46 | "max_to_keep": 2 47 | } 48 | 49 | -------------------------------------------------------------------------------- /data/config/config.span.conll2003.json: -------------------------------------------------------------------------------- 1 | { "task_name": "ner", 2 | "model_name": "ner_span_model", 3 | "model": "span", 4 | "raw_path": "data/conll2003", 5 | "save_path": "data/dataset/conll2003_span", 6 | "checkpoint_path": "checkpoint_span/conll2003", 7 | "summary_path": "checkpoint_span/conll2003/summary", 8 | "glove_path": "data/emb/glove.6B.100d.txt", 9 | "glove_name": "6B", 10 | "vocab": "data/dataset/conll2003_span/vocab.json", 11 | "train_set": "data/dataset/conll2003_span/train.json", 12 | "valid_set": "data/dataset/conll2003_span/valid.json", 13 | "pretrained_emb": "data/dataset/conll2003_span/glove_emb.npz", 14 | "data_size": 1000000000, 15 | "max_sent_len": 1000000, 16 | "cell_type": "lstm", 17 | "char_lowercase": false, 18 | "emb_dim": 100, 19 | "char_emb_dim": 30, 20 | "char_proj_dim": 100, 21 | "filter_sizes": [30, 30], 22 | "channel_sizes": [3, 3], 23 | "highway_layers": 2, 24 | "num_units": 100, 25 | "num_layers": 2, 26 | "bilstm_type": "minus", 27 | "tuning_emb": false, 28 | "use_stack_rnn": true, 29 | "use_pretrained": true, 30 | "use_chars": true, 31 | "use_highway": true, 32 | "max_span_len": 6, 33 | "max_n_spans": 0, 34 | "optimizer": "adam", 35 | "lr": 0.001, 36 | "use_lr_decay": true, 37 | "lr_decay": 0.05, 38 | "minimal_lr": 1e-5, 39 | "grad_clip": 5.0, 40 | "keep_prob": 0.7, 41 | "batch_size": 8, 42 | "epochs": 100, 43 | "max_to_keep": 2 44 | } 45 | 46 | -------------------------------------------------------------------------------- /data/config/config.span.genia.json: -------------------------------------------------------------------------------- 1 | { "task_name": "ner", 2 | "model_name": "ner_span_model", 3 | "model": "span", 4 | "raw_path": "data/genia", 5 | "save_path": "data/dataset/genia_span", 6 | "checkpoint_path": "checkpoint/genia_span", 7 | "summary_path": "checkpoint/genia_span/summary", 8 | "glove_path": "data/emb/glove.6B.100d.txt", 9 | "glove_name": "6B", 10 | "vocab": "data/dataset/genia_span/vocab.json", 11 | "train_set": "data/dataset/genia_span/train.json", 12 | "valid_set": "data/dataset/genia_span/valid.json", 13 | "pretrained_emb": "data/dataset/genia_span/glove_emb.npz", 14 | "data_size": 1000000000, 15 | "max_sent_len": 1000000, 16 | "cell_type": "lstm", 17 | "char_lowercase": false, 18 | "emb_dim": 100, 19 | "char_emb_dim": 30, 20 | "char_proj_dim": 100, 21 | "filter_sizes": [30, 30], 22 | "channel_sizes": [3, 3], 23 | "highway_layers": 2, 24 | "num_units": 100, 25 | "num_layers": 2, 26 | "bilstm_type": "minus", 27 | "tuning_emb": false, 28 | "use_stack_rnn": true, 29 | "use_pretrained": true, 30 | "use_chars": true, 31 | "use_highway": true, 32 | "max_span_len": 6, 33 | "max_n_spans": 0, 34 | "optimizer": "adam", 35 | "lr": 0.001, 36 | "use_lr_decay": true, 37 | "lr_decay": 0.05, 38 | "minimal_lr": 1e-5, 39 | "grad_clip": 5.0, 40 | "keep_prob": 0.7, 41 | "batch_size": 8, 42 | "epochs": 100, 43 | "max_to_keep": 2 44 | } 45 | 46 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.base_model import * 2 | from models.network_components import * 3 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import os 8 | 9 | import numpy as np 10 | 11 | import tensorflow as tf 12 | from tensorflow.python.ops.rnn_cell import LSTMCell, GRUCell, MultiRNNCell 13 | 14 | from utils import get_logger 15 | from utils.common import load_json, word_convert, UNK 16 | 17 | 18 | class BaseModel(object): 19 | 20 | def __init__(self, config, batcher, is_train=True): 21 | self.cfg = config 22 | self.batcher = batcher 23 | self.sess = None 24 | self.saver = None 25 | 26 | self._initialize_config() 27 | self._add_placeholders() 28 | self._build_embedding_op() 29 | self._build_model_op() 30 | if is_train: 31 | self._build_loss_op() 32 | self._build_train_op() 33 | self._build_predict_op() 34 | print('Num. params: {}'.format( 35 | np.sum([np.prod(v.get_shape().as_list()) 36 | for v in tf.trainable_variables()]))) 37 | self.initialize_session() 38 | 39 | def _initialize_config(self): 40 | # create folders and logger 41 | os.makedirs(self.cfg["checkpoint_path"], exist_ok=True) 42 | os.makedirs(os.path.join(self.cfg["summary_path"]), exist_ok=True) 43 | self.logger = get_logger( 44 | os.path.join(self.cfg["checkpoint_path"], "log.txt")) 45 | 46 | # load dictionary 47 | dict_data = load_json(self.cfg["vocab"]) 48 | self.word_dict = dict_data["word_dict"] 49 | self.char_dict = dict_data["char_dict"] 50 | self.tag_dict = dict_data["tag_dict"] 51 | del dict_data 52 | self.word_vocab_size = len(self.word_dict) 53 | self.char_vocab_size = len(self.char_dict) 54 | self.tag_vocab_size = len(self.tag_dict) 55 | self.rev_word_dict = dict([(idx, word) 56 | for word, idx in self.word_dict.items()]) 57 | self.rev_char_dict = dict([(idx, char) 58 | for char, idx in self.char_dict.items()]) 59 | self.rev_tag_dict = dict([(idx, tag) 60 | for tag, idx in self.tag_dict.items()]) 61 | 62 | def initialize_session(self): 63 | sess_config = tf.ConfigProto() 64 | sess_config.gpu_options.allow_growth = True 65 | self.sess = tf.Session(config=sess_config) 66 | self.saver = tf.train.Saver(max_to_keep=self.cfg["max_to_keep"]) 67 | self.sess.run(tf.global_variables_initializer()) 68 | 69 | def restore_last_session(self, ckpt_path=None): 70 | if ckpt_path is not None: 71 | ckpt = tf.train.get_checkpoint_state(ckpt_path) 72 | else: 73 | ckpt = tf.train.get_checkpoint_state(self.cfg["checkpoint_path"]) 74 | if ckpt and ckpt.model_checkpoint_path: # restore session 75 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 76 | 77 | def log_trainable_variables(self): 78 | self.logger.info("\nTrainable variable") 79 | for v in tf.trainable_variables(): 80 | self.logger.info("-- {}: shape:{}".format( 81 | v.name, v.get_shape().as_list())) 82 | 83 | def save_session(self, epoch): 84 | self.saver.save(self.sess, 85 | os.path.join(self.cfg["checkpoint_path"], 86 | self.cfg["model_name"]), 87 | global_step=epoch) 88 | 89 | def close_session(self): 90 | self.sess.close() 91 | 92 | def _add_summary(self): 93 | self.summary = tf.summary.merge_all() 94 | self.train_writer = tf.summary.FileWriter( 95 | os.path.join(self.cfg["summary_path"], "train"), 96 | self.sess.graph) 97 | self.test_writer = tf.summary.FileWriter( 98 | os.path.join(self.cfg["summary_path"], "test")) 99 | 100 | def reinitialize_weights(self, scope_name=None): 101 | """Reinitialize parameters in a scope""" 102 | if scope_name is None: 103 | self.sess.run(tf.global_variables_initializer()) 104 | else: 105 | variables = tf.contrib.framework.get_variables(scope_name) 106 | self.sess.run(tf.variables_initializer(variables)) 107 | 108 | @staticmethod 109 | def variable_summaries(variable, name=None): 110 | with tf.name_scope(name or "summary"): 111 | mean = tf.reduce_mean(variable) 112 | tf.summary.scalar("mean", mean) # add mean value 113 | stddev = tf.sqrt(tf.reduce_mean(tf.square(variable - mean))) 114 | tf.summary.scalar("stddev", stddev) # add standard deviation value 115 | tf.summary.scalar("max", tf.reduce_max(variable)) # add maximal value 116 | tf.summary.scalar("min", tf.reduce_min(variable)) # add minimal value 117 | tf.summary.histogram("histogram", variable) # add histogram 118 | 119 | def _create_single_rnn_cell(self, num_units): 120 | return GRUCell(num_units) \ 121 | if self.cfg["cell_type"] == "gru" else LSTMCell(num_units) 122 | 123 | def _create_rnn_cell(self): 124 | if self.cfg["num_layers"] is None or self.cfg["num_layers"] <= 1: 125 | return self._create_single_rnn_cell(self.cfg["num_units"]) 126 | else: 127 | MultiRNNCell([self._create_single_rnn_cell(self.cfg["num_units"]) 128 | for _ in range(self.cfg["num_layers"])]) 129 | 130 | def _add_placeholders(self): 131 | raise NotImplementedError("To be implemented...") 132 | 133 | def _get_feed_dict(self, data): 134 | raise NotImplementedError("To be implemented...") 135 | 136 | def _build_embedding_op(self): 137 | raise NotImplementedError("To be implemented...") 138 | 139 | def _build_model_op(self): 140 | raise NotImplementedError("To be implemented...") 141 | 142 | def _build_loss_op(self): 143 | raise NotImplementedError("To be implemented...") 144 | 145 | def _build_train_op(self): 146 | raise NotImplementedError("To be implemented...") 147 | 148 | def _build_predict_op(self): 149 | raise NotImplementedError("To be implemented...") 150 | 151 | def train_epoch(self, **kwargs): 152 | raise NotImplementedError("To be implemented...") 153 | 154 | def train(self, **kwargs): 155 | raise NotImplementedError("To be implemented...") 156 | 157 | def words_to_indices(self, words): 158 | chars_idx = [] 159 | for word in words: 160 | chars = [self.char_dict[char] 161 | if char in self.char_dict else self.char_dict[UNK] 162 | for char in word] 163 | chars_idx.append(chars) 164 | words = [word_convert(word) for word in words] 165 | words_idx = [self.word_dict[word] 166 | if word in self.word_dict else self.word_dict[UNK] 167 | for word in words] 168 | return self.batcher.make_each_batch([words_idx], [chars_idx]) 169 | -------------------------------------------------------------------------------- /models/decoders.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | 4 | 5 | def get_span_indices(n_words, max_span_len): 6 | ones = np.ones(shape=(n_words, n_words), dtype='int32') 7 | mask = np.triu(ones, k=np.minimum(n_words, max_span_len)) 8 | indices = np.triu(ones) - mask 9 | return np.nonzero(indices) 10 | 11 | 12 | def get_scored_spans(scores, n_words, max_span_len) -> List[List[tuple]]: 13 | indx_i, indx_j = get_span_indices(n_words, max_span_len) 14 | assert len(scores) == len(indx_i) == len(indx_j), "%d %d %d" % (len(scores), 15 | len(indx_i), 16 | len(indx_j)) 17 | spans = [] 18 | for scores_each_span, i, j in zip(scores, indx_i, indx_j): 19 | spans.append( 20 | [(r, i, j, score) for r, score in enumerate(scores_each_span)]) 21 | return spans 22 | 23 | 24 | def get_labeled_spans(labels, n_words, max_span_len, 25 | null_label_id) -> List[tuple]: 26 | indx_i, indx_j = get_span_indices(n_words, max_span_len) 27 | assert len(labels) == len(indx_i) == len(indx_j), "%d %d %d" % (len(labels), 28 | len(indx_i), 29 | len(indx_j)) 30 | return [(r, i, j) for r, i, j in zip(labels, indx_i, indx_j) 31 | if r != null_label_id] 32 | 33 | 34 | def get_scores_and_spans(spans, scores, sent_id, indx_i, indx_j) -> List[List]: 35 | scored_spans = [] 36 | span_boundaries = [(i, j) for (_, i, j) in spans] 37 | for i, j, score in zip(indx_i, indx_j, scores): 38 | if (i, j) in span_boundaries: 39 | index = span_boundaries.index((i, j)) 40 | r = spans[index][0] 41 | else: 42 | r = "O" 43 | scored_spans.append([r, sent_id, i, j, score]) 44 | return scored_spans 45 | 46 | 47 | def get_batch_labeled_spans(batch_labels, n_words, max_span_len, 48 | null_label_id) -> List[List[List[int]]]: 49 | indx_i, indx_j = get_span_indices(n_words, max_span_len) 50 | return [[[r, i, j] for r, i, j in zip(labels, indx_i, indx_j) if 51 | r != null_label_id] 52 | for labels in batch_labels] 53 | 54 | 55 | def get_pred_spans_with_proba(scores, n_words, max_span_len) -> dict: 56 | indx_i, indx_j = get_span_indices(n_words, max_span_len) 57 | assert len(scores) == len(indx_i) == len(indx_j), "%d %d %d" % (len(scores), 58 | len(indx_i), 59 | len(indx_j)) 60 | spans = {} 61 | for label_scores, i, j in zip(scores, indx_i, indx_j): 62 | spans['%d,%d' % (i, j)] = [float(score) for score in label_scores] 63 | return spans 64 | 65 | 66 | def sort_scored_spans(scored_spans, null_label_id) -> List[tuple]: 67 | sorted_spans = [] 68 | for spans in scored_spans: 69 | r, i, j, null_score = spans[null_label_id] 70 | for (r, i, j, score) in spans: 71 | if null_score < score: 72 | sorted_spans.append((r, i, j, score)) 73 | sorted_spans.sort(key=lambda span: span[-1], reverse=True) 74 | return sorted_spans 75 | 76 | 77 | def greedy_search(scores, n_words, max_span_len, 78 | null_label_id) -> List[List[int]]: 79 | triples = [] 80 | used_words = np.zeros(n_words, 'int32') 81 | scored_spans = get_scored_spans(scores, n_words, max_span_len) 82 | sorted_spans = sort_scored_spans(scored_spans, null_label_id) 83 | 84 | for (r, i, j, _) in sorted_spans: 85 | if sum(used_words[i: j + 1]) > 0: 86 | continue 87 | triples.append([r, i, j]) 88 | used_words[i: j + 1] = 1 89 | return triples 90 | -------------------------------------------------------------------------------- /models/knn_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import codecs 8 | import os 9 | import math 10 | import random 11 | import time 12 | 13 | import h5py 14 | import numpy as np 15 | from numpy.linalg import norm 16 | import tensorflow as tf 17 | from tqdm import tqdm 18 | 19 | from models.decoders import get_scores_and_spans, get_span_indices, \ 20 | greedy_search 21 | from models.span_models import MaskSpanModel 22 | from utils.common import load_json, write_json, word_convert, UNK 23 | from utils.data_utils import f_score, count_gold_spans, \ 24 | count_gold_and_system_outputs, span2bio 25 | 26 | NULL_LABEL_ID = 0 27 | 28 | 29 | class KnnModel(MaskSpanModel): 30 | 31 | def __init__(self, config, batcher, is_train=True): 32 | self.knn_ids = None 33 | self.gold_label_proba = None 34 | self.max_n_spans = config["max_n_spans"] 35 | super(KnnModel, self).__init__(config, batcher, is_train) 36 | 37 | def _add_placeholders(self): 38 | self.words = tf.placeholder(tf.int32, shape=[None, None], name="words") 39 | self.tags = tf.placeholder(tf.int32, shape=[None, None], name="tags") 40 | self.seq_len = tf.placeholder(tf.int32, shape=[None], name="seq_len") 41 | self.neighbor_reps = tf.placeholder(tf.float32, shape=[None, None], 42 | name="neighbor_reps") 43 | self.neighbor_tags = tf.placeholder(tf.float32, shape=[None], 44 | name="neighbor_tags") 45 | self.neighbor_tag_one_hots = tf.placeholder(tf.float32, shape=[None, None], 46 | name="neighbor_tag_one_hots") 47 | if self.cfg["use_chars"]: 48 | self.chars = tf.placeholder(tf.int32, shape=[None, None, None], 49 | name="chars") 50 | # hyperparameters 51 | self.is_train = tf.placeholder(tf.bool, name="is_train") 52 | self.keep_prob = tf.placeholder(tf.float32, name="rnn_keep_probability") 53 | self.drop_rate = tf.placeholder(tf.float32, name="dropout_rate") 54 | self.lr = tf.placeholder(tf.float32, name="learning_rate") 55 | 56 | def _get_feed_dict(self, batch, keep_prob=1.0, is_train=False, lr=None): 57 | feed_dict = {self.words: batch["words"], self.seq_len: batch["seq_len"]} 58 | if "tags" in batch: 59 | feed_dict[self.tags] = batch["tags"] 60 | if self.cfg["use_chars"]: 61 | feed_dict[self.chars] = batch["chars"] 62 | if "neighbor_reps" in batch: 63 | feed_dict[self.neighbor_reps] = batch["neighbor_reps"] 64 | if "neighbor_tags" in batch: 65 | feed_dict[self.neighbor_tags] = batch["neighbor_tags"] 66 | if "neighbor_tag_one_hots" in batch: 67 | feed_dict[self.neighbor_tag_one_hots] = batch["neighbor_tag_one_hots"] 68 | feed_dict[self.keep_prob] = keep_prob 69 | feed_dict[self.drop_rate] = 1.0 - keep_prob 70 | feed_dict[self.is_train] = is_train 71 | if lr is not None: 72 | feed_dict[self.lr] = lr 73 | return feed_dict 74 | 75 | def _build_neighbor_similarity_op(self): 76 | with tf.name_scope("similarity"): 77 | # 1D: batch_size, 2D: max_num_spans, 3D: num_instances 78 | self.similarity = tf.tensordot(self.span_rep, self.neighbor_reps, 79 | axes=[-1, -1]) 80 | 81 | def _build_neighbor_proba_op(self): 82 | with tf.name_scope("neighbor_prob"): 83 | # 1D: batch_size, 2D: max_num_spans, 3D: num_instances 84 | self.neighbor_proba = tf.nn.softmax(self.similarity, axis=-1) 85 | 86 | def _build_marginal_proba_op(self): 87 | with tf.name_scope("gold_label_prob"): 88 | # 1D: batch_size, 2D: max_num_spans, 3D: 1 89 | tags = tf.expand_dims(tf.cast(self.tags, dtype=tf.float32), axis=2) 90 | # 1D: batch_size, 2D: max_num_spans, 3D: num_instances 91 | gold_label_mask = tf.cast( 92 | tf.equal(self.neighbor_tags, tags), dtype=tf.float32) 93 | # 1D: batch_size, 2D: max_num_spans, 3D: num_instances 94 | proba = self.neighbor_proba * gold_label_mask 95 | # 1D: batch_size, 2D: max_num_spans 96 | self.gold_label_proba = tf.reduce_sum( 97 | tf.clip_by_value(proba, 1e-10, 1.0), axis=2) 98 | 99 | def _build_knn_loss_op(self): 100 | with tf.name_scope("loss"): 101 | # 1D: batch_size, 2D: max_num_spans 102 | self.losses = tf.math.log(self.gold_label_proba) 103 | self.loss = - tf.reduce_mean(tf.reduce_sum(self.losses, axis=-1)) 104 | tf.summary.scalar("loss", self.loss) 105 | 106 | def _build_one_nn_predict_op(self): 107 | with tf.name_scope("prediction"): 108 | neighbor_indices = tf.argmax(self.similarity, axis=2) 109 | knn_predicts = tf.gather(self.neighbor_tags, neighbor_indices) 110 | self.predicts = tf.reshape(knn_predicts, 111 | shape=(tf.shape(self.words)[0], -1)) 112 | 113 | def _build_max_marginal_predict_op(self): 114 | with tf.name_scope("prediction"): 115 | # 1D: 1, 2D: 1, 3D: num_instances, 4D: num_tags 116 | one_hot_tags = tf.reshape(self.neighbor_tag_one_hots, 117 | shape=[1, 1, -1, self.tag_vocab_size]) 118 | # 1D: batch_size, 2D: max_num_spans, 3D: num_instances, 4D: 1 119 | proba = tf.expand_dims(self.neighbor_proba, axis=3) 120 | # 1D: batch_size, 2D: max_num_spans, 3D: num_instances, 4D: num_tags 121 | proba = proba * one_hot_tags 122 | # 1D: batch_size, 2D: max_num_spans, 3D: num_tags 123 | self.marginal_proba = tf.reduce_sum(proba, axis=2) 124 | self.predicts = tf.argmax(self.marginal_proba, axis=2) 125 | 126 | def _build_model_op(self): 127 | self._build_rnn_op() 128 | self._make_span_indices() 129 | if self.cfg["bilstm_type"] == "minus": 130 | self._build_span_minus_op() 131 | else: 132 | self._build_span_add_and_minus_op() 133 | self._build_span_projection_op() 134 | self._build_neighbor_similarity_op() 135 | self._build_neighbor_proba_op() 136 | 137 | def _build_loss_op(self): 138 | self._build_marginal_proba_op() 139 | self._build_knn_loss_op() 140 | 141 | def _build_predict_op(self): 142 | if self.cfg["predict"] == "one_nn": 143 | self._build_one_nn_predict_op() 144 | else: 145 | self._build_max_marginal_predict_op() 146 | 147 | def get_neighbor_batch(self, train_sents, train_sent_ids): 148 | return self.batcher.batchnize_neighbor_train_sents( 149 | train_sents, train_sent_ids, self.max_span_len, self.max_n_spans) 150 | 151 | def get_neighbor_reps_and_tags(self, span_reps, batch): 152 | return self.batcher.batchnize_span_reps_and_tags( 153 | span_reps, batch["tags"], batch["masks"]) 154 | 155 | def get_neighbor_reps_and_tag_one_hots(self, span_reps, batch): 156 | return self.batcher.batchnize_span_reps_and_tag_one_hots( 157 | span_reps, batch["tags"], batch["masks"], self.tag_vocab_size) 158 | 159 | def make_one_batch_for_target(self, data, sent_id, add_tags=True): 160 | return self.batcher.make_each_batch_for_targets( 161 | batch_words=[data["words"]], 162 | batch_chars=[data["chars"]], 163 | batch_ids=[sent_id], 164 | max_span_len=self.max_span_len, 165 | max_n_spans=0, 166 | batch_tags=[data["tags"]] if add_tags else None) 167 | 168 | def _add_neighbor_instances_to_batch(self, batch, train_sents, 169 | train_sent_ids, is_train): 170 | if train_sent_ids: 171 | if is_train: 172 | train_sent_ids = list(set(train_sent_ids) - set(batch["instance_ids"])) 173 | random.shuffle(train_sent_ids) 174 | else: 175 | train_sent_ids = batch["train_sent_ids"] 176 | 177 | neighbor_batch = self.get_neighbor_batch(train_sents, 178 | train_sent_ids[:self.cfg["k"]]) 179 | feed_dict = self._get_feed_dict(neighbor_batch) 180 | span_reps = self.sess.run([self.span_rep], feed_dict)[0] 181 | 182 | if is_train or self.cfg["predict"] == "one_nn": 183 | rep_list, tag_list = self.get_neighbor_reps_and_tags( 184 | span_reps, neighbor_batch) 185 | batch["neighbor_reps"] = rep_list 186 | batch["neighbor_tags"] = tag_list 187 | else: 188 | rep_list, tag_list = self.get_neighbor_reps_and_tag_one_hots( 189 | span_reps, neighbor_batch) 190 | batch["neighbor_reps"] = rep_list 191 | batch["neighbor_tag_one_hots"] = tag_list 192 | 193 | return batch 194 | 195 | def _make_batch_and_sample_sent_ids(self, batch, valid_record, train_sents, 196 | train_sent_ids): 197 | if train_sent_ids: 198 | random.shuffle(train_sent_ids) 199 | sampled_train_sent_ids = train_sent_ids[:self.cfg["k"]] 200 | else: 201 | sampled_train_sent_ids = valid_record["train_sent_ids"][:self.cfg["k"]] 202 | 203 | train_batch = self.batcher.make_batch_from_sent_ids(train_sents, 204 | sampled_train_sent_ids) 205 | feed_dict = self._get_feed_dict(train_batch) 206 | span_reps = self.sess.run([self.span_rep], feed_dict)[0] 207 | rep_list, tag_list = self.get_neighbor_reps_and_tag_one_hots(span_reps, 208 | train_batch) 209 | batch["neighbor_reps"] = rep_list 210 | batch["neighbor_tag_one_hots"] = tag_list 211 | 212 | return batch, sampled_train_sent_ids 213 | 214 | def train_knn_epoch(self, batches, name): 215 | loss_total = 0. 216 | num_batches = 0 217 | start_time = time.time() 218 | train_sents = load_json(self.cfg["train_set"]) 219 | if self.cfg["knn_sampling"] == "random": 220 | train_sent_ids = [sent_id for sent_id in range(len(train_sents))] 221 | else: 222 | train_sent_ids = None 223 | 224 | for batch in batches: 225 | num_batches += 1 226 | if num_batches % 100 == 0: 227 | print("%d" % num_batches, flush=True, end=" ") 228 | 229 | # Setup a batch 230 | batch = self._add_neighbor_instances_to_batch(batch, 231 | train_sents, 232 | train_sent_ids, 233 | is_train=True) 234 | # Convert a batch to the input format 235 | feed_dict = self._get_feed_dict(batch, 236 | is_train=True, 237 | keep_prob=self.cfg["keep_prob"], 238 | lr=self.cfg["lr"]) 239 | # Train a model 240 | _, train_loss = self.sess.run([self.train_op, self.loss], 241 | feed_dict) 242 | 243 | if math.isnan(train_loss): 244 | self.logger.info("\n\n\nNAN: Index: %d\n" % num_batches) 245 | exit() 246 | 247 | loss_total += train_loss 248 | 249 | avg_loss = loss_total / num_batches 250 | self.logger.info("-- Time: %f seconds" % (time.time() - start_time)) 251 | self.logger.info( 252 | "-- Averaged loss: %f(%f/%d)" % (avg_loss, loss_total, num_batches)) 253 | return avg_loss, loss_total 254 | 255 | def evaluate_knn_epoch(self, batches, name): 256 | correct = 0 257 | p_total = 0 258 | num_batches = 0 259 | start_time = time.time() 260 | train_sents = load_json(self.cfg["train_set"]) 261 | if self.cfg["knn_sampling"] == "random": 262 | train_sent_ids = [sent_id for sent_id in range(len(train_sents))] 263 | else: 264 | train_sent_ids = None 265 | 266 | for batch in batches: 267 | num_batches += 1 268 | if num_batches % 100 == 0: 269 | print("%d" % num_batches, flush=True, end=" ") 270 | 271 | # Setup a batch 272 | batch = self._add_neighbor_instances_to_batch(batch, 273 | train_sents, 274 | train_sent_ids, 275 | is_train=False) 276 | # Convert a batch to the input format 277 | feed_dict = self._get_feed_dict(batch) 278 | # Classify spans 279 | predicted_tags = self.sess.run([self.predicts], feed_dict)[0] 280 | 281 | crr_i, p_total_i = count_gold_and_system_outputs(batch["tags"], 282 | predicted_tags, 283 | NULL_LABEL_ID) 284 | correct += crr_i 285 | p_total += p_total_i 286 | 287 | p, r, f = f_score(correct, p_total, self.n_gold_spans) 288 | self.logger.info("-- Time: %f seconds" % (time.time() - start_time)) 289 | self.logger.info( 290 | "-- {} set\tF:{:>7.2%} P:{:>7.2%} ({:>5}/{:>5}) R:{:>7.2%} ({:>5}/{:>5})" 291 | .format(name, f, p, correct, p_total, r, correct, self.n_gold_spans)) 292 | return f, p, r, correct, p_total, self.n_gold_spans 293 | 294 | def train(self): 295 | self.logger.info(str(self.cfg)) 296 | 297 | config_path = os.path.join(self.cfg["checkpoint_path"], "config.json") 298 | write_json(config_path, self.cfg) 299 | 300 | batch_size = self.cfg["batch_size"] 301 | epochs = self.cfg["epochs"] 302 | train_path = self.cfg["train_set"] 303 | valid_path = self.cfg["valid_set"] 304 | self.n_gold_spans = count_gold_spans(valid_path) 305 | 306 | if self.cfg["knn_sampling"] == "knn": 307 | self.knn_ids = h5py.File( 308 | os.path.join(self.cfg["raw_path"], "knn_ids.hdf5"), "r") 309 | valid_batch_size = 1 310 | shuffle = False 311 | else: 312 | valid_batch_size = batch_size 313 | shuffle = True 314 | 315 | valid_set = list( 316 | self.batcher.batchnize_dataset(data=valid_path, 317 | data_name="valid", 318 | batch_size=valid_batch_size, 319 | shuffle=shuffle)) 320 | best_f1 = -np.inf 321 | init_lr = self.cfg["lr"] 322 | 323 | self.log_trainable_variables() 324 | self.logger.info("Start training...") 325 | self._add_summary() 326 | 327 | for epoch in range(1, epochs + 1): 328 | self.logger.info('Epoch {}/{}:'.format(epoch, epochs)) 329 | 330 | train_set = self.batcher.batchnize_dataset(data=train_path, 331 | data_name="train", 332 | batch_size=batch_size, 333 | shuffle=True) 334 | _ = self.train_knn_epoch(train_set, "train") 335 | 336 | if self.cfg["use_lr_decay"]: # learning rate decay 337 | self.cfg["lr"] = max(init_lr / (1.0 + self.cfg["lr_decay"] * epoch), 338 | self.cfg["minimal_lr"]) 339 | 340 | eval_metrics = self.evaluate_knn_epoch(valid_set, "valid") 341 | cur_valid_f1 = eval_metrics[0] 342 | 343 | if cur_valid_f1 > best_f1: 344 | best_f1 = cur_valid_f1 345 | self.save_session(epoch) 346 | self.logger.info( 347 | '-- new BEST F1 on valid set: {:>7.2%}'.format(best_f1)) 348 | 349 | self.train_writer.close() 350 | self.test_writer.close() 351 | 352 | def eval(self, preprocessor): 353 | self.logger.info(str(self.cfg)) 354 | 355 | ######################## 356 | # Load validation data # 357 | ######################## 358 | valid_data = preprocessor.load_dataset( 359 | self.cfg["data_path"], keep_number=True, 360 | lowercase=self.cfg["char_lowercase"]) 361 | valid_data = valid_data[:self.cfg["data_size"]] 362 | dataset = preprocessor.build_dataset(valid_data, 363 | self.word_dict, 364 | self.char_dict, 365 | self.tag_dict) 366 | dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") 367 | write_json(dataset_path, dataset) 368 | self.logger.info("Valid sentences: {:>7}".format(len(dataset))) 369 | self.n_gold_spans = count_gold_spans(dataset_path) 370 | 371 | ###################### 372 | # Load training data # 373 | ###################### 374 | train_sents = load_json(self.cfg["train_set"]) 375 | if self.cfg["knn_sampling"] == "random": 376 | train_sent_ids = [sent_id for sent_id in range(len(train_sents))] 377 | else: 378 | train_sent_ids = None 379 | self.logger.info("Train sentences: {:>7}".format(len(train_sents))) 380 | 381 | ############# 382 | # Main loop # 383 | ############# 384 | correct = 0 385 | p_total = 0 386 | start_time = time.time() 387 | 388 | print("PREDICTION START") 389 | for record, data in zip(valid_data, dataset): 390 | valid_sent_id = record["sent_id"] 391 | 392 | if (valid_sent_id + 1) % 100 == 0: 393 | print("%d" % (valid_sent_id + 1), flush=True, end=" ") 394 | 395 | batch = self.make_one_batch_for_target(data, valid_sent_id) 396 | 397 | ##################### 398 | # Sentence sampling # 399 | ##################### 400 | batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids( 401 | batch, record, train_sents, train_sent_ids) 402 | 403 | ############## 404 | # Prediction # 405 | ############## 406 | feed_dict = self._get_feed_dict(batch) 407 | batch_sims, batch_preds = self.sess.run( 408 | [self.similarity, self.predicts], feed_dict) 409 | 410 | crr_i, p_total_i = count_gold_and_system_outputs( 411 | batch["tags"], batch_preds, NULL_LABEL_ID) 412 | correct += crr_i 413 | p_total += p_total_i 414 | 415 | ############## 416 | # Evaluation # 417 | ############## 418 | p, r, f = f_score(correct, p_total, self.n_gold_spans) 419 | self.logger.info("-- Time: %f seconds" % (time.time() - start_time)) 420 | self.logger.info( 421 | "-- F:{:>7.2%} P:{:>7.2%} ({:>5}/{:>5}) R:{:>7.2%} ({:>5}/{:>5})" 422 | .format(f, p, correct, p_total, r, correct, self.n_gold_spans)) 423 | 424 | def save_predicted_spans(self, data_name, preprocessor): 425 | self.logger.info(str(self.cfg)) 426 | 427 | ######################## 428 | # Load validation data # 429 | ######################## 430 | valid_data = preprocessor.load_dataset( 431 | self.cfg["data_path"], keep_number=True, 432 | lowercase=self.cfg["char_lowercase"]) 433 | valid_data = valid_data[:self.cfg["data_size"]] 434 | dataset = preprocessor.build_dataset(valid_data, 435 | self.word_dict, 436 | self.char_dict, 437 | self.tag_dict) 438 | dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") 439 | write_json(dataset_path, dataset) 440 | self.logger.info("Valid sentences: {:>7}".format(len(dataset))) 441 | 442 | ###################### 443 | # Load training data # 444 | ###################### 445 | train_sents = load_json(self.cfg["train_set"]) 446 | if self.cfg["knn_sampling"] == "random": 447 | train_sent_ids = [sent_id for sent_id in range(len(train_sents))] 448 | else: 449 | train_sent_ids = None 450 | self.logger.info("Train sentences: {:>7}".format(len(train_sents))) 451 | 452 | ############# 453 | # Main loop # 454 | ############# 455 | start_time = time.time() 456 | results = [] 457 | print("PREDICTION START") 458 | for record, data in zip(valid_data, dataset): 459 | valid_sent_id = record["sent_id"] 460 | batch = self.make_one_batch_for_target(data, valid_sent_id, 461 | add_tags=False) 462 | if (valid_sent_id + 1) % 100 == 0: 463 | print("%d" % (valid_sent_id + 1), flush=True, end=" ") 464 | 465 | ##################### 466 | # Sentence sampling # 467 | ##################### 468 | batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids( 469 | batch, record, train_sents, train_sent_ids) 470 | 471 | ############### 472 | # KNN predict # 473 | ############### 474 | feed_dict = self._get_feed_dict(batch) 475 | batch_preds = self.sess.run([self.predicts], feed_dict)[0] 476 | preds = batch_preds[0] 477 | 478 | ######################## 479 | # Make predicted spans # 480 | ######################## 481 | indx_i, indx_j = get_span_indices(n_words=len(record["words"]), 482 | max_span_len=self.max_span_len) 483 | assert len(preds) == len(indx_i) == len(indx_j) 484 | pred_spans = [[self.rev_tag_dict[pred_label_id], int(i), int(j)] 485 | for pred_label_id, i, j in zip(preds, indx_i, indx_j) 486 | if pred_label_id != NULL_LABEL_ID] 487 | 488 | ################## 489 | # Add the result # 490 | ################## 491 | results.append({"sent_id": valid_sent_id, 492 | "words": record["words"], 493 | "spans": pred_spans, 494 | "train_sent_ids": sampled_sent_ids}) 495 | 496 | path = os.path.join(self.cfg["checkpoint_path"], 497 | "%s.predicted_spans.json" % data_name) 498 | write_json(path, results) 499 | self.logger.info( 500 | "-- Time: %f seconds\nFINISHED." % (time.time() - start_time)) 501 | 502 | def save_predicted_bio_tags(self, data_name, preprocessor): 503 | self.logger.info(str(self.cfg)) 504 | 505 | ######################## 506 | # Load validation data # 507 | ######################## 508 | valid_data = preprocessor.load_dataset( 509 | self.cfg["data_path"], keep_number=True, 510 | lowercase=self.cfg["char_lowercase"]) 511 | valid_data = valid_data[:self.cfg["data_size"]] 512 | dataset = preprocessor.build_dataset(valid_data, 513 | self.word_dict, 514 | self.char_dict, 515 | self.tag_dict) 516 | dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") 517 | write_json(dataset_path, dataset) 518 | self.logger.info("Valid sentences: {:>7}".format(len(dataset))) 519 | 520 | ###################### 521 | # Load training data # 522 | ###################### 523 | train_sents = load_json(self.cfg["train_set"]) 524 | if self.cfg["knn_sampling"] == "random": 525 | train_sent_ids = [sent_id for sent_id in range(len(train_sents))] 526 | else: 527 | train_sent_ids = None 528 | self.logger.info("Train sentences: {:>7}".format(len(train_sents))) 529 | 530 | ############# 531 | # Main loop # 532 | ############# 533 | start_time = time.time() 534 | path = os.path.join(self.cfg["checkpoint_path"], "%s.bio.txt" % data_name) 535 | fout_txt = open(path, "w") 536 | print("PREDICTION START") 537 | for record, data in zip(valid_data, dataset): 538 | valid_sent_id = record["sent_id"] 539 | batch = self.make_one_batch_for_target(data, valid_sent_id, 540 | add_tags=False) 541 | if (valid_sent_id + 1) % 100 == 0: 542 | print("%d" % (valid_sent_id + 1), flush=True, end=" ") 543 | 544 | ##################### 545 | # Sentence sampling # 546 | ##################### 547 | batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids( 548 | batch, record, train_sents, train_sent_ids) 549 | 550 | ############### 551 | # KNN predict # 552 | ############### 553 | feed_dict = self._get_feed_dict(batch) 554 | proba = self.sess.run([self.marginal_proba], feed_dict)[0][0] 555 | 556 | ######################## 557 | # Make predicted spans # 558 | ######################## 559 | words = record["words"] 560 | triples = greedy_search(proba, 561 | n_words=len(words), 562 | max_span_len=self.max_span_len, 563 | null_label_id=NULL_LABEL_ID) 564 | pred_bio_tags = span2bio(spans=triples, 565 | n_words=len(words), 566 | tag_dict=self.rev_tag_dict) 567 | gold_bio_tags = span2bio(spans=record["tags"], 568 | n_words=len(words)) 569 | assert len(words) == len(pred_bio_tags) == len(gold_bio_tags) 570 | 571 | #################### 572 | # Write the result # 573 | #################### 574 | for word, gold_tag, pred_tag in zip(words, gold_bio_tags, pred_bio_tags): 575 | fout_txt.write("%s _ %s %s\n" % (word, gold_tag, pred_tag)) 576 | fout_txt.write("\n") 577 | 578 | self.logger.info( 579 | "-- Time: %f seconds\nFINISHED." % (time.time() - start_time)) 580 | 581 | def save_nearest_spans(self, data_name, preprocessor, print_knn): 582 | self.logger.info(str(self.cfg)) 583 | 584 | ######################## 585 | # Load validation data # 586 | ######################## 587 | valid_data = preprocessor.load_dataset( 588 | self.cfg["data_path"], keep_number=True, 589 | lowercase=self.cfg["char_lowercase"]) 590 | valid_data = valid_data[:self.cfg["data_size"]] 591 | dataset = preprocessor.build_dataset(valid_data, 592 | self.word_dict, 593 | self.char_dict, 594 | self.tag_dict) 595 | dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") 596 | write_json(dataset_path, dataset) 597 | self.logger.info("Valid sentences: {:>7}".format(len(dataset))) 598 | self.n_gold_spans = count_gold_spans(dataset_path) 599 | 600 | ###################### 601 | # Load training data # 602 | ###################### 603 | train_sents = load_json(self.cfg["train_set"]) 604 | if self.cfg["knn_sampling"] == "random": 605 | train_sent_ids = [sent_id for sent_id in range(len(train_sents))] 606 | else: 607 | train_sent_ids = None 608 | train_data = preprocessor.load_dataset( 609 | os.path.join(self.cfg["raw_path"], "train.json"), 610 | keep_number=True, lowercase=False) 611 | self.logger.info("Train sentences: {:>7}".format(len(train_sents))) 612 | 613 | ############# 614 | # Main loop # 615 | ############# 616 | correct = 0 617 | p_total = 0 618 | start_time = time.time() 619 | file_path = os.path.join(self.cfg["checkpoint_path"], 620 | "%s.nearest_spans.txt" % data_name) 621 | fout_txt = open(file_path, "w") 622 | print("PREDICTION START") 623 | for record, data in zip(valid_data, dataset): 624 | valid_sent_id = record["sent_id"] 625 | batch = self.make_one_batch_for_target(data, valid_sent_id) 626 | 627 | if (valid_sent_id + 1) % 100 == 0: 628 | print("%d" % (valid_sent_id + 1), flush=True, end=" ") 629 | 630 | ##################### 631 | # Sentence sampling # 632 | ##################### 633 | batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids( 634 | batch, record, train_sents, train_sent_ids) 635 | 636 | ############## 637 | # Prediction # 638 | ############## 639 | feed_dict = self._get_feed_dict(batch) 640 | batch_sims, batch_preds = self.sess.run( 641 | [self.similarity, self.predicts], feed_dict) 642 | 643 | crr_i, p_total_i = count_gold_and_system_outputs( 644 | batch["tags"], batch_preds, NULL_LABEL_ID) 645 | correct += crr_i 646 | p_total += p_total_i 647 | 648 | #################### 649 | # Write the result # 650 | #################### 651 | self._write_predictions(fout_txt, record) 652 | self._write_nearest_spans( 653 | fout_txt, record, train_data, sampled_sent_ids, batch_sims, 654 | batch_preds, print_knn) 655 | 656 | fout_txt.close() 657 | 658 | p, r, f = f_score(correct, p_total, self.n_gold_spans) 659 | self.logger.info("-- Time: %f seconds" % (time.time() - start_time)) 660 | self.logger.info( 661 | "-- {} set\tF:{:>7.2%} P:{:>7.2%} ({:>5}/{:>5}) R:{:>7.2%} ({:>5}/{:>5})" 662 | .format(data_name, f, p, correct, p_total, r, correct, 663 | self.n_gold_spans)) 664 | 665 | @staticmethod 666 | def _write_predictions(fout_txt, record): 667 | fout_txt.write("-SENT:%d || %s || %s\n" % ( 668 | record["sent_id"], 669 | " ".join(record["words"]), 670 | " ".join(["(%s,%d,%d)" % (r, i, j) for (r, i, j) in record["tags"]]))) 671 | 672 | def _write_nearest_spans(self, fout_txt, record, train_data, 673 | sampled_sent_ids, batch_sims, batch_preds, 674 | print_knn): 675 | 676 | def _write_train_sents(_sampled_train_sents): 677 | for _train_record in _sampled_train_sents: 678 | fout_txt.write("--kNN:%d || %s || %s\n" % ( 679 | _train_record["sent_id"], 680 | " ".join(_train_record["words"]), 681 | " ".join(["(%s,%d,%d)" % (r, i, j) 682 | for (r, i, j) in _train_record["tags"]]))) 683 | 684 | def _write_gold_and_pred_spans(_record, _pred_label_id, _span_boundaries): 685 | if (i, j) in _span_boundaries: 686 | _index = _span_boundaries.index((i, j)) 687 | gold_label = _record["tags"][_index][0] 688 | else: 689 | gold_label = "O" 690 | 691 | pred_label = self.rev_tag_dict[_pred_label_id] 692 | fout_txt.write("##(%d,%d) || %s || %s || %s\n" % ( 693 | i, j, " ".join(record["words"][i: j + 1]), pred_label, gold_label)) 694 | 695 | def _get_nearest_spans(_sampled_train_sents): 696 | _nearest_spans = [] 697 | _prev_indx = 0 698 | _temp_indx = 0 699 | for _record in _sampled_train_sents: 700 | _indx_i, _indx_j = get_span_indices(n_words=len(_record["words"]), 701 | max_span_len=self.max_span_len) 702 | _temp_indx += len(_indx_i) 703 | _temp_scores = scores[_prev_indx: _temp_indx] 704 | assert len(_temp_scores) == len(_indx_i) == len(_indx_j) 705 | _nearest_spans.extend( 706 | get_scores_and_spans(spans=_record["tags"], 707 | scores=_temp_scores, 708 | sent_id=_record["sent_id"], 709 | indx_i=_indx_i, 710 | indx_j=_indx_j)) 711 | _prev_indx = _temp_indx 712 | return _nearest_spans 713 | 714 | def _write_nearest_spans_for_each_span(_sampled_train_sents): 715 | nearest_spans = _get_nearest_spans(_sampled_train_sents) 716 | nearest_spans.sort(key=lambda span: span[-1], reverse=True) 717 | for rank, (r, sent_id, i, j, score) in enumerate(nearest_spans[:10]): 718 | mention = " ".join(train_data[sent_id]["words"][i: j + 1]) 719 | text = "{} || {} || sent:{} || ({},{}) || {:.3g}".format( 720 | r, mention, sent_id, i, j, score) 721 | fout_txt.write("####RANK:%d %s\n" % (rank, text)) 722 | 723 | sampled_train_sents = [train_data[sent_id] 724 | for sent_id in sampled_sent_ids] 725 | if print_knn: 726 | _write_train_sents(sampled_train_sents) 727 | 728 | sims = batch_sims[0] # 1D: n_spans, 2D: n_instances 729 | preds = batch_preds[0] # 1D: n_spans 730 | indx_i, indx_j = get_span_indices(n_words=len(record["words"]), 731 | max_span_len=self.max_span_len) 732 | span_boundaries = [(i, j) for _, i, j in record["tags"]] 733 | 734 | assert len(sims) == len(preds) == len(indx_i) == len(indx_j) 735 | for scores, pred_label_id, i, j in zip(sims, preds, indx_i, indx_j): 736 | if pred_label_id == NULL_LABEL_ID and (i, j) not in span_boundaries: 737 | continue 738 | _write_gold_and_pred_spans(record, pred_label_id, span_boundaries) 739 | _write_nearest_spans_for_each_span(sampled_train_sents) 740 | 741 | fout_txt.write("\n") 742 | 743 | def save_span_representation(self, data_name, preprocessor): 744 | self.logger.info(str(self.cfg)) 745 | 746 | ######################## 747 | # Load validation data # 748 | ######################## 749 | valid_data = preprocessor.load_dataset( 750 | self.cfg["data_path"], keep_number=True, 751 | lowercase=self.cfg["char_lowercase"]) 752 | valid_data = valid_data[:self.cfg["data_size"]] 753 | dataset = preprocessor.build_dataset(valid_data, 754 | self.word_dict, 755 | self.char_dict, 756 | self.tag_dict) 757 | self.logger.info("Valid sentences: {:>7}".format(len(dataset))) 758 | 759 | ############# 760 | # Main loop # 761 | ############# 762 | start_time = time.time() 763 | gold_labels = {} 764 | fout_path = os.path.join(self.cfg["checkpoint_path"], 765 | "%s.span_reps.hdf5" % data_name) 766 | fout = h5py.File(fout_path, 'w') 767 | 768 | print("PREDICTION START") 769 | for record, data in zip(valid_data, dataset): 770 | valid_sent_id = record["sent_id"] 771 | batch = self.make_one_batch_for_target(data, valid_sent_id) 772 | 773 | if (valid_sent_id + 1) % 100 == 0: 774 | print("%d" % (valid_sent_id + 1), flush=True, end=" ") 775 | 776 | ############## 777 | # Prediction # 778 | ############## 779 | feed_dict = self._get_feed_dict(batch) 780 | span_reps = self.sess.run([self.span_rep], feed_dict)[0][0] 781 | span_tags = batch["tags"][0] 782 | assert len(span_reps) == len(span_tags) 783 | 784 | ################## 785 | # Add the result # 786 | ################## 787 | fout.create_dataset( 788 | name='{}'.format(valid_sent_id), 789 | dtype='float32', 790 | data=span_reps) 791 | gold_labels[valid_sent_id] = [self.rev_tag_dict[int(tag)] 792 | for tag in span_tags] 793 | fout.close() 794 | path = os.path.join(self.cfg["checkpoint_path"], 795 | "%s.gold_labels.json" % data_name) 796 | write_json(path, gold_labels) 797 | self.logger.info( 798 | "-- Time: %f seconds\nFINISHED." % (time.time() - start_time)) 799 | 800 | def predict_on_command_line(self, preprocessor): 801 | 802 | def _load_glove(glove_path): 803 | vocab = {} 804 | vectors = [] 805 | total = int(4e5) 806 | with codecs.open(glove_path, mode='r', encoding='utf-8') as f: 807 | for line in tqdm(f, total=total, desc="Load glove"): 808 | line = line.lstrip().rstrip().split(" ") 809 | vocab[line[0]] = len(vocab) 810 | vectors.append([float(x) for x in line[1:]]) 811 | assert len(vocab) == len(vectors) 812 | return vocab, np.asarray(vectors) 813 | 814 | def _mean_vectors(sents, emb, vocab): 815 | unk_vec = np.zeros(emb.shape[1]) 816 | mean_vecs = [] 817 | for words in sents: 818 | vecs = [] 819 | for word in words: 820 | word = word.lower() 821 | if word in vocab: 822 | vec = emb[vocab[word]] 823 | else: 824 | vec = unk_vec 825 | vecs.append(vec) 826 | mean_vecs.append(np.mean(vecs, axis=0)) 827 | return mean_vecs 828 | 829 | def _cosine_sim(p0, p1): 830 | d = (norm(p0) * norm(p1)) 831 | if d > 0: 832 | return np.dot(p0, p1) / d 833 | return 0.0 834 | 835 | def _setup_repository(_train_sents, _train_data=None): 836 | if self.cfg["knn_sampling"] == "random": 837 | _train_sent_ids = [_sent_id for _sent_id in range(len(_train_sents))] 838 | _vocab = _glove = _train_embs = None 839 | else: 840 | _train_sent_ids = None 841 | _vocab, _glove = _load_glove("data/emb/glove.6B.100d.txt") 842 | _train_words = [[w.lower() for w in _train_record["words"]] 843 | for _train_record in _train_data] 844 | _train_embs = _mean_vectors(_train_words, _glove, _vocab) 845 | return _train_sent_ids, _train_embs, _vocab, _glove 846 | 847 | def _make_ids(_words): 848 | _char_ids = [] 849 | _word_ids = [] 850 | for word in _words: 851 | _char_ids.append([self.char_dict[char] 852 | if char in self.char_dict else self.char_dict[UNK] 853 | for char in word]) 854 | word = word_convert(word, keep_number=False, lowercase=True) 855 | _word_ids.append(self.word_dict[word] 856 | if word in self.word_dict else self.word_dict[UNK]) 857 | return _char_ids, _word_ids 858 | 859 | def _retrieve_knn_train_sents(_record, _train_embs, _vocab, _glove): 860 | test_words = [w.lower() for w in _record["words"]] 861 | test_emb = _mean_vectors([test_words], _glove, _vocab)[0] 862 | sim = [_cosine_sim(train_emb, test_emb) for train_emb in _train_embs] 863 | arg_sort = np.argsort(sim)[::-1][:self.cfg["k"]] 864 | _record["train_sent_ids"] = [int(arg) for arg in arg_sort] 865 | return _record 866 | 867 | def _get_nearest_spans(_sampled_train_sents, _scores): 868 | _nearest_spans = [] 869 | _prev_indx = 0 870 | _temp_indx = 0 871 | for _record in _sampled_train_sents: 872 | _indx_i, _indx_j = get_span_indices(n_words=len(_record["words"]), 873 | max_span_len=self.max_span_len) 874 | _temp_indx += len(_indx_i) 875 | _temp_scores = _scores[_prev_indx: _temp_indx] 876 | assert len(_temp_scores) == len(_indx_i) == len(_indx_j) 877 | _nearest_spans.extend( 878 | get_scores_and_spans(spans=_record["tags"], 879 | scores=_temp_scores, 880 | sent_id=_record["sent_id"], 881 | indx_i=_indx_i, 882 | indx_j=_indx_j)) 883 | _prev_indx = _temp_indx 884 | _nearest_spans.sort(key=lambda span: span[-1], reverse=True) 885 | return _nearest_spans 886 | 887 | ###################### 888 | # Load training data # 889 | ###################### 890 | train_sents = load_json(self.cfg["train_set"]) 891 | train_data = preprocessor.load_dataset( 892 | os.path.join(self.cfg["raw_path"], "train.json"), 893 | keep_number=True, lowercase=False) 894 | train_sent_ids, train_embs, vocab, glove = _setup_repository( 895 | train_sents, train_data) 896 | 897 | ######################################## 898 | # Load each sentence from command line # 899 | ######################################## 900 | print("\nPREDICTION START\n") 901 | while True: 902 | sentence = input('\nEnter a tokenized sentence: ') 903 | words = sentence.split() 904 | char_ids, word_ids = _make_ids(words) 905 | data = {"words": word_ids, "chars": char_ids} 906 | record = {"sent_id": 0, "words": words, "train_sent_ids": None} 907 | batch = self.make_one_batch_for_target(data, sent_id=0, add_tags=False) 908 | 909 | ##################### 910 | # Sentence sampling # 911 | ##################### 912 | if self.cfg["knn_sampling"] == "knn": 913 | record = _retrieve_knn_train_sents(record, train_embs, vocab, glove) 914 | batch, sampled_sent_ids = self._make_batch_and_sample_sent_ids( 915 | batch, record, train_sents, train_sent_ids) 916 | 917 | ############## 918 | # Prediction # 919 | ############## 920 | feed_dict = self._get_feed_dict(batch) 921 | batch_sims, batch_preds = self.sess.run( 922 | [self.similarity, self.predicts], feed_dict) 923 | 924 | #################### 925 | # Write the result # 926 | #################### 927 | sims = batch_sims[0] # 1D: n_spans, 2D: n_instances 928 | preds = batch_preds[0] # 1D: n_spans 929 | indx_i, indx_j = get_span_indices(n_words=len(record["words"]), 930 | max_span_len=self.max_span_len) 931 | 932 | assert len(sims) == len(preds) == len(indx_i) == len(indx_j) 933 | sampled_train_sents = [train_data[sent_id] 934 | for sent_id in sampled_sent_ids] 935 | 936 | for scores, pred_label_id, i, j in zip(sims, preds, indx_i, indx_j): 937 | if pred_label_id == NULL_LABEL_ID: 938 | continue 939 | pred_label = self.rev_tag_dict[pred_label_id] 940 | print("#(%d,%d) || %s || %s" % ( 941 | i, j, " ".join(record["words"][i: j + 1]), pred_label)) 942 | 943 | nearest_spans = _get_nearest_spans(sampled_train_sents, scores) 944 | for k, (r, _sent_id, a, b, _score) in enumerate(nearest_spans[:5]): 945 | train_words = train_data[_sent_id]["words"] 946 | if a - 5 < 0: 947 | left_context = "" 948 | else: 949 | left_context = " ".join(train_words[a - 5: a]) 950 | left_context = "... " + left_context 951 | right_context = " ".join(train_words[b + 1: b + 6]) 952 | if b + 6 < len(train_words): 953 | right_context = right_context + " ..." 954 | mention = " ".join(train_words[a: b + 1]) 955 | text = "{}: {} [{}] {}".format( 956 | r, left_context, mention, right_context) 957 | print("## %d %s" % (k, text)) 958 | -------------------------------------------------------------------------------- /models/network_components.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def highway_layer(inputs, use_bias=True, bias_init=0.0, keep_prob=1.0, 11 | is_train=False, scope=None): 12 | with tf.variable_scope(scope or "highway_layer"): 13 | hidden = inputs.get_shape().as_list()[-1] 14 | with tf.variable_scope("trans"): 15 | trans = tf.layers.dropout(inputs, rate=1.0 - keep_prob, 16 | training=is_train) 17 | trans = tf.layers.dense(trans, units=hidden, use_bias=use_bias, 18 | bias_initializer=tf.constant_initializer( 19 | bias_init), activation=None) 20 | trans = tf.nn.relu(trans) 21 | with tf.variable_scope("gate"): 22 | gate = tf.layers.dropout(inputs, rate=1.0 - keep_prob, training=is_train) 23 | gate = tf.layers.dense(gate, units=hidden, use_bias=use_bias, 24 | bias_initializer=tf.constant_initializer( 25 | bias_init), activation=None) 26 | gate = tf.nn.sigmoid(gate) 27 | outputs = gate * trans + (1 - gate) * inputs 28 | return outputs 29 | 30 | 31 | def highway_network(inputs, highway_layers=2, use_bias=True, bias_init=0.0, 32 | keep_prob=1.0, is_train=False, scope=None): 33 | with tf.variable_scope(scope or "highway_network"): 34 | prev = inputs 35 | cur = None 36 | for idx in range(highway_layers): 37 | cur = highway_layer(prev, use_bias, bias_init, keep_prob, is_train, 38 | scope="highway_layer_{}".format(idx)) 39 | prev = cur 40 | return cur 41 | 42 | 43 | def conv1d(in_, filter_size, height, padding, is_train=True, drop_rate=0.0, 44 | scope=None): 45 | with tf.variable_scope(scope or "conv1d"): 46 | num_channels = in_.get_shape()[-1] 47 | filter_ = tf.get_variable("filter", 48 | shape=[1, height, num_channels, filter_size], 49 | dtype=tf.float32) 50 | bias = tf.get_variable("bias", shape=[filter_size], dtype=tf.float32) 51 | strides = [1, 1, 1, 1] 52 | in_ = tf.layers.dropout(in_, rate=drop_rate, training=is_train) 53 | # [batch, max_len_sent, max_len_word / filter_stride, char output size] 54 | xxc = tf.nn.conv2d(in_, filter_, strides, padding) + bias 55 | out = tf.reduce_max(tf.nn.relu(xxc), axis=2) 56 | return out 57 | 58 | 59 | def multi_conv1d(in_, filter_sizes, heights, padding="VALID", is_train=True, 60 | drop_rate=0.0, scope=None): 61 | with tf.variable_scope(scope or "multi_conv1d"): 62 | assert len(filter_sizes) == len(heights) 63 | outs = [] 64 | for i, (filter_size, height) in enumerate(zip(filter_sizes, heights)): 65 | if filter_size == 0: 66 | continue 67 | out = conv1d(in_, 68 | filter_size, 69 | height, 70 | padding, 71 | is_train=is_train, 72 | drop_rate=drop_rate, 73 | scope="conv1d_{}".format(i)) 74 | outs.append(out) 75 | concat_out = tf.concat(axis=2, values=outs) 76 | return concat_out 77 | -------------------------------------------------------------------------------- /models/span_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import os 8 | import time 9 | 10 | import h5py 11 | import numpy as np 12 | 13 | import tensorflow as tf 14 | from tensorflow.python.ops.rnn_cell import MultiRNNCell 15 | from tensorflow.python.ops.rnn import bidirectional_dynamic_rnn 16 | from tensorflow.contrib.rnn.python.ops.rnn import \ 17 | stack_bidirectional_dynamic_rnn 18 | 19 | from models import BaseModel 20 | from models.decoders import get_span_indices, greedy_search 21 | from models.network_components import multi_conv1d, highway_network 22 | from utils.common import write_json 23 | from utils.data_utils import metrics_for_multi_class_spans, f_score, \ 24 | count_gold_spans, count_gold_and_system_outputs, span2bio 25 | 26 | NULL_LABEL_ID = 0 27 | 28 | 29 | class SpanModel(BaseModel): 30 | 31 | def __init__(self, config, batcher, is_train=True): 32 | self.max_span_len = config["max_span_len"] 33 | self.n_gold_spans = None 34 | self.proba = None 35 | super(SpanModel, self).__init__(config, batcher, is_train) 36 | 37 | def _add_placeholders(self): 38 | self.words = tf.placeholder(tf.int32, shape=[None, None], name="words") 39 | self.tags = tf.placeholder(tf.int32, shape=[None, None], name="tags") 40 | self.seq_len = tf.placeholder(tf.int32, shape=[None], name="seq_len") 41 | if self.cfg["use_chars"]: 42 | self.chars = tf.placeholder(tf.int32, shape=[None, None, None], 43 | name="chars") 44 | # hyperparameters 45 | self.is_train = tf.placeholder(tf.bool, name="is_train") 46 | self.keep_prob = tf.placeholder(tf.float32, name="rnn_keep_probability") 47 | self.drop_rate = tf.placeholder(tf.float32, name="dropout_rate") 48 | self.lr = tf.placeholder(tf.float32, name="learning_rate") 49 | 50 | def _get_feed_dict(self, batch, keep_prob=1.0, is_train=False, lr=None): 51 | feed_dict = {self.words: batch["words"], self.seq_len: batch["seq_len"]} 52 | if "tags" in batch: 53 | feed_dict[self.tags] = batch["tags"] 54 | if self.cfg["use_chars"]: 55 | feed_dict[self.chars] = batch["chars"] 56 | feed_dict[self.keep_prob] = keep_prob 57 | feed_dict[self.drop_rate] = 1.0 - keep_prob 58 | feed_dict[self.is_train] = is_train 59 | if lr is not None: 60 | feed_dict[self.lr] = lr 61 | return feed_dict 62 | 63 | def _create_rnn_cell(self): 64 | if self.cfg["num_layers"] is None or self.cfg["num_layers"] <= 1: 65 | return self._create_single_rnn_cell(self.cfg["num_units"]) 66 | else: 67 | if self.cfg["use_stack_rnn"]: 68 | lstm_cells = [] 69 | for i in range(self.cfg["num_layers"]): 70 | cell = tf.nn.rnn_cell.LSTMCell(self.cfg["num_units"], 71 | initializer=tf.initializers.orthogonal 72 | ) 73 | cell = tf.contrib.rnn.DropoutWrapper(cell, 74 | state_keep_prob=self.keep_prob, 75 | input_keep_prob=self.keep_prob, 76 | dtype=tf.float32) 77 | lstm_cells.append(cell) 78 | return lstm_cells 79 | else: 80 | return MultiRNNCell( 81 | [self._create_single_rnn_cell(self.cfg["num_units"]) 82 | for _ in range(self.cfg["num_layers"])]) 83 | 84 | def _build_embedding_op(self): 85 | with tf.variable_scope("embeddings"): 86 | if not self.cfg["use_pretrained"]: 87 | self.word_embeddings = tf.get_variable(name="emb", 88 | dtype=tf.float32, 89 | trainable=True, 90 | shape=[self.word_vocab_size, 91 | self.cfg["emb_dim"]]) 92 | else: 93 | padding_token_emb = tf.get_variable(name="padding_emb", 94 | dtype=tf.float32, 95 | trainable=False, 96 | shape=[1, self.cfg["emb_dim"]]) 97 | special_token_emb = tf.get_variable(name="spacial_emb", 98 | dtype=tf.float32, 99 | trainable=True, 100 | shape=[2, self.cfg["emb_dim"]]) 101 | token_emb = tf.Variable( 102 | np.load(self.cfg["pretrained_emb"])["embeddings"], 103 | name="emb", dtype=tf.float32, trainable=self.cfg["tuning_emb"]) 104 | self.word_embeddings = tf.concat( 105 | [padding_token_emb, special_token_emb, token_emb], axis=0) 106 | 107 | word_emb = tf.nn.embedding_lookup(self.word_embeddings, self.words, 108 | name="words_emb") 109 | print("word embedding shape: {}".format(word_emb.get_shape().as_list())) 110 | 111 | if self.cfg["use_chars"]: 112 | self.char_embeddings = tf.get_variable(name="char_emb", 113 | dtype=tf.float32, 114 | trainable=True, 115 | shape=[self.char_vocab_size, 116 | self.cfg["char_emb_dim"]] 117 | ) 118 | char_emb = tf.nn.embedding_lookup(self.char_embeddings, self.chars, 119 | name="chars_emb") 120 | char_represent = multi_conv1d(char_emb, self.cfg["filter_sizes"], 121 | self.cfg["channel_sizes"], 122 | drop_rate=self.drop_rate, 123 | is_train=self.is_train) 124 | print("chars representation shape: {}".format( 125 | char_represent.get_shape().as_list())) 126 | word_emb = tf.concat([word_emb, char_represent], axis=-1) 127 | 128 | if self.cfg["use_highway"]: 129 | self.word_emb = highway_network(word_emb, self.cfg["highway_layers"], 130 | use_bias=True, bias_init=0.0, 131 | keep_prob=self.keep_prob, 132 | is_train=self.is_train) 133 | else: 134 | self.word_emb = tf.layers.dropout(word_emb, rate=self.drop_rate, 135 | training=self.is_train) 136 | print("word and chars concatenation shape: {}".format( 137 | self.word_emb.get_shape().as_list())) 138 | 139 | def _build_rnn_op(self): 140 | with tf.variable_scope("bi_directional_rnn"): 141 | cell_fw = self._create_rnn_cell() 142 | cell_bw = self._create_rnn_cell() 143 | 144 | if self.cfg["use_stack_rnn"]: 145 | rnn_outs, *_ = stack_bidirectional_dynamic_rnn( 146 | cell_fw, cell_bw, self.word_emb, dtype=tf.float32) 147 | else: 148 | rnn_outs, *_ = bidirectional_dynamic_rnn( 149 | cell_fw, cell_bw, self.word_emb, dtype=tf.float32) 150 | rnn_outs = tf.concat(rnn_outs, axis=-1) 151 | rnn_outs = tf.layers.dropout(rnn_outs, rate=self.drop_rate, 152 | training=self.is_train) 153 | self.rnn_outs = rnn_outs 154 | print("rnn output shape: {}".format(rnn_outs.get_shape().as_list())) 155 | 156 | def _make_span_indices(self): 157 | with tf.name_scope("span_indices"): 158 | n_words = tf.shape(self.rnn_outs)[1] 159 | n_spans = tf.cast(n_words * (n_words + 1) / 2, dtype=tf.int32) 160 | ones = tf.contrib.distributions.fill_triangular(tf.ones(shape=[n_spans]), 161 | upper=True) 162 | num_upper = tf.minimum(n_words, self.cfg["max_span_len"] - 1) 163 | ones = tf.linalg.band_part(ones, num_lower=tf.cast(0, dtype=tf.int32), 164 | num_upper=num_upper) 165 | self.span_indices = tf.transpose( 166 | tf.where(tf.not_equal(ones, tf.constant(0, dtype=tf.float32)))) 167 | 168 | def _build_span_minus_op(self): 169 | with tf.variable_scope("rnn_span_rep"): 170 | i = self.span_indices[0] 171 | j = self.span_indices[1] 172 | batch_size = tf.shape(self.rnn_outs)[0] 173 | dim = self.cfg["num_units"] 174 | x_fw = self.rnn_outs[:, :, :dim] 175 | x_bw = self.rnn_outs[:, :, dim:] 176 | 177 | pad = tf.zeros(shape=(batch_size, 1, dim), dtype=tf.float32) 178 | x_fw_pad = tf.concat([pad, x_fw], axis=1) 179 | x_bw_pad = tf.concat([x_bw, pad], axis=1) 180 | 181 | h_fw_i = tf.gather(x_fw_pad, i, axis=1) 182 | h_fw_j = tf.gather(x_fw, j, axis=1) 183 | h_bw_i = tf.gather(x_bw, i, axis=1) 184 | h_bw_j = tf.gather(x_bw_pad, j + 1, axis=1) 185 | 186 | span_fw = h_fw_j - h_fw_i 187 | span_bw = h_bw_i - h_bw_j 188 | self.rnn_span_rep = tf.concat([span_fw, span_bw], axis=-1) 189 | 190 | print("rnn span rep shape: {}".format( 191 | self.rnn_span_rep.get_shape().as_list())) 192 | 193 | def _build_span_add_and_minus_op(self): 194 | with tf.variable_scope("rnn_span_rep"): 195 | i = self.span_indices[0] 196 | j = self.span_indices[1] 197 | batch_size = tf.shape(self.rnn_outs)[0] 198 | dim = self.cfg["num_units"] 199 | x_fw = self.rnn_outs[:, :, :dim] 200 | x_bw = self.rnn_outs[:, :, dim:] 201 | 202 | pad = tf.zeros(shape=(batch_size, 1, dim), dtype=tf.float32) 203 | x_fw_pad = tf.concat([pad, x_fw], axis=1) 204 | x_bw_pad = tf.concat([x_bw, pad], axis=1) 205 | 206 | h_fw_i = tf.gather(x_fw, i, axis=1) 207 | h_fw_i_pad = tf.gather(x_fw_pad, i, axis=1) 208 | h_fw_j = tf.gather(x_fw, j, axis=1) 209 | h_bw_i = tf.gather(x_bw, i, axis=1) 210 | h_bw_j = tf.gather(x_bw, j, axis=1) 211 | h_bw_j_pad = tf.gather(x_bw_pad, j + 1, axis=1) 212 | 213 | span_add_fw = h_fw_i + h_fw_j 214 | span_add_bw = h_bw_i + h_bw_j 215 | span_minus_fw = h_fw_j - h_fw_i_pad 216 | span_minus_bw = h_bw_i - h_bw_j_pad 217 | self.rnn_span_rep = tf.concat( 218 | [span_add_fw, span_add_bw, span_minus_fw, span_minus_bw], axis=-1) 219 | 220 | print("rnn span rep shape: {}".format( 221 | self.rnn_span_rep.get_shape().as_list())) 222 | 223 | def _build_span_projection_op(self): 224 | with tf.variable_scope("span_projection"): 225 | span_rep = tf.layers.dense(self.rnn_span_rep, 226 | units=self.cfg["num_units"], 227 | use_bias=True) 228 | self.span_rep = tf.layers.dropout(span_rep, 229 | rate=self.drop_rate, 230 | training=self.is_train) 231 | print("span rep shape: {}".format(self.span_rep.get_shape().as_list())) 232 | 233 | def _build_label_projection_with_null_zero_op(self): 234 | with tf.variable_scope("label_projection"): 235 | null_label_emb = tf.get_variable(name="null_label_emb", 236 | trainable=False, 237 | shape=[1, self.cfg["num_units"]]) 238 | label_emb = tf.get_variable(name="label_emb", 239 | dtype=tf.float32, 240 | trainable=True, 241 | shape=[self.tag_vocab_size - 1, 242 | self.cfg["num_units"]]) 243 | self.label_embeddings = tf.concat([null_label_emb, label_emb], axis=0) 244 | self.logits = tf.tensordot(self.span_rep, self.label_embeddings, 245 | axes=[-1, -1]) 246 | print("logits shape: {}".format(self.logits.get_shape().as_list())) 247 | 248 | def _build_model_op(self): 249 | self._build_rnn_op() 250 | self._make_span_indices() 251 | 252 | if self.cfg["bilstm_type"] == "minus": 253 | self._build_span_minus_op() 254 | else: 255 | self._build_span_add_and_minus_op() 256 | 257 | self._build_span_projection_op() 258 | self._build_label_projection_with_null_zero_op() 259 | 260 | def _build_loss_op(self): 261 | self.losses = tf.nn.sparse_softmax_cross_entropy_with_logits( 262 | logits=self.logits, labels=self.tags) 263 | self.loss = tf.reduce_mean(tf.reduce_sum(self.losses, axis=-1)) 264 | tf.summary.scalar("loss", self.loss) 265 | 266 | def _build_train_op(self): 267 | with tf.variable_scope("train_step"): 268 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 269 | if self.cfg["grad_clip"] is not None and self.cfg["grad_clip"] > 0: 270 | grads, vs = zip(*optimizer.compute_gradients(self.loss)) 271 | grads, _ = tf.clip_by_global_norm(grads, self.cfg["grad_clip"]) 272 | self.train_op = optimizer.apply_gradients(zip(grads, vs)) 273 | else: 274 | self.train_op = optimizer.minimize(self.loss) 275 | 276 | def _build_predict_op(self): 277 | self.predicts = tf.cast(tf.argmax(self.logits, axis=-1), tf.int32) 278 | 279 | def build_proba_op(self): 280 | self.proba = tf.nn.softmax(self.logits) 281 | 282 | def train_epoch(self, batches): 283 | loss_total = 0. 284 | correct = 0 285 | p_total = 0 286 | r_total = 0 287 | num_batches = 0 288 | start_time = time.time() 289 | 290 | for batch in batches: 291 | num_batches += 1 292 | if num_batches % 100 == 0: 293 | print("%d" % num_batches, flush=True, end=" ") 294 | 295 | feed_dict = self._get_feed_dict(batch, is_train=True, 296 | keep_prob=self.cfg["keep_prob"], 297 | lr=self.cfg["lr"]) 298 | outputs = self.sess.run([self.train_op, self.loss, self.predicts], 299 | feed_dict) 300 | _, train_loss, predicts = outputs 301 | 302 | loss_total += train_loss 303 | crr_i, p_total_i, r_total_i = metrics_for_multi_class_spans( 304 | batch["tags"], predicts, NULL_LABEL_ID) 305 | correct += crr_i 306 | p_total += p_total_i 307 | r_total += r_total_i 308 | 309 | avg_loss = loss_total / num_batches 310 | p, r, f = f_score(correct, p_total, r_total) 311 | 312 | self.logger.info("-- Time: %f seconds" % (time.time() - start_time)) 313 | self.logger.info( 314 | "-- Averaged loss: %f(%f/%d)" % (avg_loss, loss_total, num_batches)) 315 | self.logger.info( 316 | "-- {} set\tF:{:>7.2%} P:{:>7.2%} ({:>5}/{:>5}) R:{:>7.2%} ({:>5}/{:>5})" 317 | .format("train", f, p, correct, p_total, r, correct, r_total)) 318 | return avg_loss, loss_total 319 | 320 | def evaluate_epoch(self, batches, name): 321 | correct = 0 322 | p_total = 0 323 | num_batches = 0 324 | start_time = time.time() 325 | 326 | for batch in batches: 327 | num_batches += 1 328 | if num_batches % 100 == 0: 329 | print("%d" % num_batches, flush=True, end=" ") 330 | 331 | feed_dict = self._get_feed_dict(batch) 332 | predicts = self.sess.run(self.predicts, feed_dict) 333 | crr_i, p_total_i = count_gold_and_system_outputs( 334 | batch["tags"], predicts, NULL_LABEL_ID) 335 | correct += crr_i 336 | p_total += p_total_i 337 | 338 | p, r, f = f_score(correct, p_total, self.n_gold_spans) 339 | self.logger.info('-- Time: %f seconds' % (time.time() - start_time)) 340 | self.logger.info( 341 | "-- {} set\tF:{:>7.2%} P:{:>7.2%} ({:>5}/{:>5}) R:{:>7.2%} ({:>5}/{:>5})" 342 | .format(name, f, p, correct, p_total, r, correct, self.n_gold_spans)) 343 | return f, p, r, correct, p_total, self.n_gold_spans 344 | 345 | def train(self): 346 | self.logger.info(str(self.cfg)) 347 | write_json(os.path.join(self.cfg["checkpoint_path"], "config.json"), 348 | self.cfg) 349 | 350 | batch_size = self.cfg["batch_size"] 351 | epochs = self.cfg["epochs"] 352 | train_path = self.cfg["train_set"] 353 | valid_path = self.cfg["valid_set"] 354 | self.n_gold_spans = count_gold_spans(valid_path) 355 | valid_set = list( 356 | self.batcher.batchnize_dataset(valid_path, batch_size, shuffle=True)) 357 | 358 | best_f1 = -np.inf 359 | init_lr = self.cfg["lr"] 360 | 361 | self.log_trainable_variables() 362 | self.logger.info("Start training...") 363 | self._add_summary() 364 | for epoch in range(1, epochs + 1): 365 | self.logger.info('Epoch {}/{}:'.format(epoch, epochs)) 366 | 367 | train_set = self.batcher.batchnize_dataset(train_path, batch_size, 368 | shuffle=True) 369 | _ = self.train_epoch(train_set) 370 | 371 | if self.cfg["use_lr_decay"]: # learning rate decay 372 | self.cfg["lr"] = max(init_lr / (1.0 + self.cfg["lr_decay"] * epoch), 373 | self.cfg["minimal_lr"]) 374 | 375 | eval_metrics = self.evaluate_epoch(valid_set, "valid") 376 | cur_valid_f1 = eval_metrics[0] 377 | 378 | if cur_valid_f1 > best_f1: 379 | best_f1 = cur_valid_f1 380 | self.save_session(epoch) 381 | self.logger.info( 382 | "-- new BEST F1 on valid set: {:>7.2%}".format(best_f1)) 383 | 384 | self.train_writer.close() 385 | self.test_writer.close() 386 | 387 | def eval(self, preprocessor): 388 | self.logger.info(str(self.cfg)) 389 | data = preprocessor.load_dataset(self.cfg["data_path"], 390 | keep_number=True, 391 | lowercase=self.cfg["char_lowercase"]) 392 | data = data[:self.cfg["data_size"]] 393 | dataset = preprocessor.build_dataset(data, self.word_dict, 394 | self.char_dict, self.tag_dict) 395 | write_json(os.path.join(self.cfg["save_path"], "tmp.json"), dataset) 396 | self.n_gold_spans = count_gold_spans( 397 | os.path.join(self.cfg["save_path"], "tmp.json")) 398 | self.logger.info("Target data: %s sentences" % len(dataset)) 399 | del dataset 400 | 401 | batches = list(self.batcher.batchnize_dataset( 402 | os.path.join(self.cfg["save_path"], "tmp.json"), 403 | batch_size=self.cfg["batch_size"], shuffle=True)) 404 | self.logger.info("Target data: %s batches" % len(batches)) 405 | _ = self.evaluate_epoch(batches, "valid") 406 | 407 | def make_one_batch(self, data, add_tags=True): 408 | return self.batcher.make_each_batch( 409 | batch_words=[data["words"]], 410 | batch_chars=[data["chars"]], 411 | max_span_len=self.max_span_len, 412 | batch_tags=[data["tags"]] if add_tags else None) 413 | 414 | def save_predicted_spans(self, data_name, preprocessor): 415 | self.logger.info(str(self.cfg)) 416 | 417 | ######################## 418 | # Load validation data # 419 | ######################## 420 | valid_data = preprocessor.load_dataset( 421 | self.cfg["data_path"], keep_number=True, 422 | lowercase=self.cfg["char_lowercase"]) 423 | valid_data = valid_data[:self.cfg["data_size"]] 424 | dataset = preprocessor.build_dataset(valid_data, 425 | self.word_dict, 426 | self.char_dict, 427 | self.tag_dict) 428 | dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") 429 | write_json(dataset_path, dataset) 430 | self.logger.info("Valid sentences: {:>7}".format(len(dataset))) 431 | 432 | ############# 433 | # Main loop # 434 | ############# 435 | start_time = time.time() 436 | results = [] 437 | print("PREDICTION START") 438 | for record, data in zip(valid_data, dataset): 439 | valid_sent_id = record["sent_id"] 440 | batch = self.batcher.make_each_batch( 441 | batch_words=[data["words"]], batch_chars=[data["chars"]], 442 | max_span_len=self.max_span_len) 443 | 444 | if (valid_sent_id + 1) % 100 == 0: 445 | print("%d" % (valid_sent_id + 1), flush=True, end=" ") 446 | 447 | ################# 448 | # Predict spans # 449 | ################# 450 | feed_dict = self._get_feed_dict(batch) 451 | batch_preds = self.sess.run([self.predicts], feed_dict)[0] 452 | preds = batch_preds[0] 453 | 454 | ######################## 455 | # Make predicted spans # 456 | ######################## 457 | indx_i, indx_j = get_span_indices(n_words=len(record["words"]), 458 | max_span_len=self.max_span_len) 459 | assert len(preds) == len(indx_i) == len(indx_j) 460 | pred_spans = [[self.rev_tag_dict[pred_label_id], int(i), int(j)] 461 | for pred_label_id, i, j in zip(preds, indx_i, indx_j) 462 | if pred_label_id != NULL_LABEL_ID] 463 | 464 | ################## 465 | # Add the result # 466 | ################## 467 | results.append({"sent_id": valid_sent_id, 468 | "words": record["words"], 469 | "spans": pred_spans}) 470 | 471 | path = os.path.join(self.cfg["checkpoint_path"], 472 | "%s.predicted_spans.json" % data_name) 473 | write_json(path, results) 474 | self.logger.info( 475 | "-- Time: %f seconds\nFINISHED." % (time.time() - start_time)) 476 | 477 | def save_predicted_bio_tags(self, data_name, preprocessor): 478 | self.logger.info(str(self.cfg)) 479 | 480 | ######################## 481 | # Load validation data # 482 | ######################## 483 | valid_data = preprocessor.load_dataset( 484 | self.cfg["data_path"], keep_number=True, 485 | lowercase=self.cfg["char_lowercase"]) 486 | valid_data = valid_data[:self.cfg["data_size"]] 487 | dataset = preprocessor.build_dataset(valid_data, 488 | self.word_dict, 489 | self.char_dict, 490 | self.tag_dict) 491 | dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") 492 | write_json(dataset_path, dataset) 493 | self.logger.info("Valid sentences: {:>7}".format(len(dataset))) 494 | 495 | ############# 496 | # Main loop # 497 | ############# 498 | start_time = time.time() 499 | path = os.path.join(self.cfg["checkpoint_path"], "%s.bio.txt" % data_name) 500 | fout_txt = open(path, "w") 501 | print("PREDICTION START") 502 | for record, data in zip(valid_data, dataset): 503 | valid_sent_id = record["sent_id"] 504 | batch = self.make_one_batch(data, add_tags=False) 505 | 506 | if (valid_sent_id + 1) % 100 == 0: 507 | print("%d" % (valid_sent_id + 1), flush=True, end=" ") 508 | 509 | ################# 510 | # Predict spans # 511 | ################# 512 | feed_dict = self._get_feed_dict(batch) 513 | proba = self.sess.run([self.proba], feed_dict)[0][0] 514 | 515 | ######################## 516 | # Make predicted spans # 517 | ######################## 518 | words = record["words"] 519 | triples = greedy_search(proba, 520 | n_words=len(words), 521 | max_span_len=self.max_span_len, 522 | null_label_id=NULL_LABEL_ID) 523 | pred_bio_tags = span2bio(spans=triples, 524 | n_words=len(words), 525 | tag_dict=self.rev_tag_dict) 526 | gold_bio_tags = span2bio(spans=record["tags"], 527 | n_words=len(words)) 528 | assert len(words) == len(pred_bio_tags) == len(gold_bio_tags) 529 | 530 | #################### 531 | # Write the result # 532 | #################### 533 | for word, gold_tag, pred_tag in zip(words, gold_bio_tags, pred_bio_tags): 534 | fout_txt.write("%s _ %s %s\n" % (word, gold_tag, pred_tag)) 535 | fout_txt.write("\n") 536 | 537 | self.logger.info( 538 | "-- Time: %f seconds\nFINISHED." % (time.time() - start_time)) 539 | 540 | def save_span_representation(self, data_name, preprocessor): 541 | self.logger.info(str(self.cfg)) 542 | 543 | ######################## 544 | # Load validation data # 545 | ######################## 546 | valid_data = preprocessor.load_dataset( 547 | self.cfg["data_path"], keep_number=True, 548 | lowercase=self.cfg["char_lowercase"]) 549 | valid_data = valid_data[:self.cfg["data_size"]] 550 | dataset = preprocessor.build_dataset(valid_data, self.word_dict, 551 | self.char_dict, self.tag_dict) 552 | dataset_path = os.path.join(self.cfg["save_path"], "tmp.json") 553 | write_json(dataset_path, dataset) 554 | self.logger.info("Valid sentences: {:>7}".format(len(dataset))) 555 | 556 | ############# 557 | # Main loop # 558 | ############# 559 | start_time = time.time() 560 | results = [] 561 | fout_hdf5 = h5py.File(os.path.join(self.cfg["checkpoint_path"], 562 | "%s.span_reps.hdf5" % data_name), 'w') 563 | print("PREDICTION START") 564 | for record, data in zip(valid_data, dataset): 565 | valid_sent_id = record["sent_id"] 566 | batch = self.batcher.make_each_batch( 567 | batch_words=[data["words"]], batch_chars=[data["chars"]], 568 | max_span_len=self.max_span_len, batch_tags=[data["tags"]]) 569 | 570 | if (valid_sent_id + 1) % 100 == 0: 571 | print("%d" % (valid_sent_id + 1), flush=True, end=" ") 572 | 573 | ################# 574 | # Predict spans # 575 | ################# 576 | feed_dict = self._get_feed_dict(batch) 577 | preds, span_reps = self.sess.run([self.predicts, self.span_rep], 578 | feed_dict=feed_dict) 579 | golds = batch["tags"][0] 580 | preds = preds[0] 581 | span_reps = span_reps[0] 582 | assert len(span_reps) == len(golds) == len(preds) 583 | 584 | ######################## 585 | # Make predicted spans # 586 | ######################## 587 | indx_i, indx_j = get_span_indices(n_words=len(record["words"]), 588 | max_span_len=self.max_span_len) 589 | assert len(preds) == len(indx_i) == len(indx_j) 590 | pred_spans = [[self.rev_tag_dict[label_id], int(i), int(j)] 591 | for label_id, i, j in zip(preds, indx_i, indx_j)] 592 | gold_spans = [[self.rev_tag_dict[label_id], int(i), int(j)] 593 | for label_id, i, j in zip(golds, indx_i, indx_j)] 594 | 595 | #################### 596 | # Write the result # 597 | #################### 598 | fout_hdf5.create_dataset( 599 | name='{}'.format(valid_sent_id), 600 | dtype='float32', 601 | data=span_reps) 602 | results.append({"sent_id": valid_sent_id, 603 | "words": record["words"], 604 | "gold_spans": gold_spans, 605 | "pred_spans": pred_spans}) 606 | fout_hdf5.close() 607 | write_json(os.path.join(self.cfg["checkpoint_path"], 608 | "%s.spans.json" % data_name), results) 609 | self.logger.info( 610 | "-- Time: %f seconds\nFINISHED." % (time.time() - start_time)) 611 | 612 | 613 | class MaskSpanModel(SpanModel): 614 | 615 | def _add_placeholders(self): 616 | self.words = tf.placeholder(tf.int32, shape=[None, None], name="words") 617 | self.tags = tf.placeholder(tf.int32, shape=[None, None], name="tags") 618 | self.seq_len = tf.placeholder(tf.int32, shape=[None], name="seq_len") 619 | self.masks = tf.placeholder(tf.float32, shape=[None, None], name="mask") 620 | if self.cfg["use_chars"]: 621 | self.chars = tf.placeholder(tf.int32, shape=[None, None, None], 622 | name="chars") 623 | # hyperparameters 624 | self.is_train = tf.placeholder(tf.bool, name="is_train") 625 | self.keep_prob = tf.placeholder(tf.float32, name="rnn_keep_probability") 626 | self.drop_rate = tf.placeholder(tf.float32, name="dropout_rate") 627 | self.lr = tf.placeholder(tf.float32, name="learning_rate") 628 | 629 | def _get_feed_dict(self, batch, keep_prob=1.0, is_train=False, lr=None): 630 | feed_dict = {self.words: batch["words"], 631 | self.seq_len: batch["seq_len"], 632 | self.masks: batch["masks"]} 633 | if "tags" in batch: 634 | feed_dict[self.tags] = batch["tags"] 635 | if self.cfg["use_chars"]: 636 | feed_dict[self.chars] = batch["chars"] 637 | feed_dict[self.keep_prob] = keep_prob 638 | feed_dict[self.drop_rate] = 1.0 - keep_prob 639 | feed_dict[self.is_train] = is_train 640 | if lr is not None: 641 | feed_dict[self.lr] = lr 642 | return feed_dict 643 | 644 | def _build_rnn_op(self): 645 | with tf.variable_scope("bi_directional_rnn"): 646 | cell_fw = self._create_rnn_cell() 647 | cell_bw = self._create_rnn_cell() 648 | 649 | if self.cfg["use_stack_rnn"]: 650 | rnn_outs, *_ = stack_bidirectional_dynamic_rnn( 651 | cell_fw, cell_bw, self.word_emb, 652 | dtype=tf.float32, sequence_length=self.seq_len) 653 | else: 654 | rnn_outs, *_ = bidirectional_dynamic_rnn( 655 | cell_fw, cell_bw, self.word_emb, 656 | dtype=tf.float32, sequence_length=self.seq_len) 657 | rnn_outs = tf.concat(rnn_outs, axis=-1) 658 | rnn_outs = tf.layers.dropout(rnn_outs, 659 | rate=self.drop_rate, 660 | training=self.is_train) 661 | self.rnn_outs = rnn_outs 662 | print("rnn output shape: {}".format(rnn_outs.get_shape().as_list())) 663 | 664 | def _build_loss_op(self): 665 | self.losses = tf.nn.sparse_softmax_cross_entropy_with_logits( 666 | logits=self.logits, labels=self.tags) 667 | self.losses = self.losses * self.masks 668 | self.loss = tf.reduce_mean(tf.reduce_sum(self.losses, axis=-1)) 669 | tf.summary.scalar("loss", self.loss) 670 | 671 | def _build_predict_op(self): 672 | self.predicts = tf.cast(tf.argmax(self.logits, axis=-1), 673 | tf.int32) * tf.cast(self.masks, tf.int32) 674 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.8.0 3 | certifi==2019.3.9 4 | et-xmlfile==1.0.1 5 | gast==0.2.2 6 | grpcio==1.21.1 7 | h5py==2.10.0 8 | jdcal==1.4.1 9 | Markdown==3.1.1 10 | numpy==1.14.5 11 | openpyxl==3.0.2 12 | protobuf==3.8.0 13 | six==1.12.0 14 | tensorboard==1.10.0 15 | tensorflow-gpu==1.10.0 16 | termcolor==1.1.0 17 | tqdm==4.32.1 18 | ujson==1.35 19 | Werkzeug==0.15.4 20 | -------------------------------------------------------------------------------- /run_knn_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import argparse 8 | import os 9 | import json 10 | 11 | from models.knn_models import KnnModel 12 | from utils.batchers.knn_batchers import BaseKnnBatcher 13 | from utils.preprocessors.knn_preprocessors import KnnPreprocessor 14 | 15 | 16 | def set_config(args, config): 17 | if args.raw_path: 18 | config["raw_path"] = args.raw_path 19 | if args.save_path: 20 | config["save_path"] = args.save_path 21 | if args.data_path: 22 | config["data_path"] = args.data_path 23 | if args.checkpoint_path: 24 | config["checkpoint_path"] = args.checkpoint_path 25 | config["summary_path"] = os.path.join(args.checkpoint_path, "summary") 26 | if args.summary_path: 27 | config["summary_path"] = args.summary_path 28 | if args.model_name: 29 | config["model_name"] = args.model_name 30 | if args.batch_size: 31 | config["batch_size"] = args.batch_size 32 | if args.data_size: 33 | config["data_size"] = args.data_size 34 | if args.bilstm_type: 35 | config["bilstm_type"] = args.bilstm_type 36 | if args.k: 37 | config["k"] = args.k 38 | if args.predict: 39 | config["predict"] = args.predict 40 | if args.max_span_len: 41 | config["max_span_len"] = args.max_span_len 42 | if args.knn_sampling: 43 | config["knn_sampling"] = args.knn_sampling 44 | return config 45 | 46 | 47 | def main(args): 48 | config = json.load(open(args.config_file)) 49 | config = set_config(args, config) 50 | os.makedirs(config["save_path"], exist_ok=True) 51 | 52 | print("Build a knn span model...") 53 | model = KnnModel(config, BaseKnnBatcher(config), is_train=False) 54 | preprocessor = KnnPreprocessor(config) 55 | model.restore_last_session(config["checkpoint_path"]) 56 | 57 | if args.mode == "eval": 58 | model.eval(preprocessor) 59 | elif args.mode == "span": 60 | model.save_predicted_spans(args.data_name, preprocessor) 61 | elif args.mode == "bio": 62 | model.save_predicted_bio_tags(args.data_name, preprocessor) 63 | elif args.mode == "nearest_span": 64 | model.save_nearest_spans(args.data_name, preprocessor, args.print_knn) 65 | elif args.mode == "cmd": 66 | model.predict_on_command_line(preprocessor) 67 | else: 68 | model.save_span_representation(args.data_name, preprocessor) 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--mode', 74 | default='eval', 75 | help='eval/span/bio/nearest_span/span_rep') 76 | parser.add_argument('--config_file', 77 | default='checkpoint/config.json', 78 | help='Configuration file') 79 | parser.add_argument('--data_name', 80 | default='valid', 81 | help='Data to be processed') 82 | parser.add_argument('--data_path', 83 | default=None, 84 | help='Path to data') 85 | parser.add_argument('--bilstm_type', 86 | default=None, 87 | help='bilstm type') 88 | parser.add_argument('--raw_path', 89 | default=None, 90 | help='Raw data directory') 91 | parser.add_argument('--save_path', 92 | default=None, 93 | help='Save directory') 94 | parser.add_argument('--checkpoint_path', 95 | default=None, 96 | help='Checkpoint directory') 97 | parser.add_argument('--summary_path', 98 | default=None, 99 | help='Summary directory') 100 | parser.add_argument('--model_name', 101 | default=None, 102 | help='Model name') 103 | parser.add_argument('--batch_size', 104 | default=None, 105 | type=int, 106 | help='Batch size') 107 | parser.add_argument('--data_size', 108 | default=None, 109 | type=int, 110 | help='Data size') 111 | parser.add_argument('--k', 112 | default=None, 113 | type=int, 114 | help='k-NN sentences') 115 | parser.add_argument('--predict', 116 | default='max_margin', 117 | help='prediction methods') 118 | parser.add_argument('--max_span_len', 119 | default=None, 120 | type=int, 121 | help='max span length') 122 | parser.add_argument('--knn_sampling', 123 | default=None, 124 | help='k-NN sentence sampling') 125 | parser.add_argument('--print_knn', 126 | action='store_true', 127 | default=False, 128 | help='print knn sentences') 129 | main(parser.parse_args()) 130 | -------------------------------------------------------------------------------- /run_span_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | from models.span_models import SpanModel 6 | from utils.batchers.span_batchers import BaseSpanBatcher 7 | from utils.preprocessors.span_preprocessors import SpanPreprocessor 8 | 9 | 10 | def set_config(args, config): 11 | if args.raw_path: 12 | config["raw_path"] = args.raw_path 13 | if args.save_path: 14 | config["save_path"] = args.save_path 15 | if args.data_path: 16 | config["data_path"] = args.data_path 17 | if args.checkpoint_path: 18 | config["checkpoint_path"] = args.checkpoint_path 19 | config["summary_path"] = os.path.join(args.checkpoint_path, "summary") 20 | if args.summary_path: 21 | config["summary_path"] = args.summary_path 22 | if args.model_name: 23 | config["model_name"] = args.model_name 24 | if args.batch_size: 25 | config["batch_size"] = args.batch_size 26 | if args.data_size: 27 | config["data_size"] = args.data_size 28 | if args.bilstm_type: 29 | config["bilstm_type"] = args.bilstm_type 30 | if args.max_span_len: 31 | config["max_span_len"] = args.max_span_len 32 | return config 33 | 34 | 35 | def main(args): 36 | config = json.load(open(args.config_file)) 37 | config = set_config(args, config) 38 | os.makedirs(config["save_path"], exist_ok=True) 39 | 40 | print("Build models...") 41 | model = SpanModel(config, BaseSpanBatcher(config), is_train=False) 42 | preprocessor = SpanPreprocessor(config) 43 | model.restore_last_session(config["checkpoint_path"]) 44 | 45 | if args.mode == "eval": 46 | model.eval(preprocessor) 47 | elif args.mode == "span": 48 | model.save_predicted_spans(args.data_name, preprocessor) 49 | elif args.mode == "bio": 50 | model.build_proba_op() 51 | model.save_predicted_bio_tags(args.data_name, preprocessor) 52 | else: 53 | model.save_span_representation(args.data_name, preprocessor) 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--mode', 59 | default="eval", 60 | help='eval/span/bio/proba/span_rep') 61 | parser.add_argument('--config_file', 62 | default='checkpoint/config.json', 63 | help='Configuration file') 64 | parser.add_argument('--data_name', 65 | default="valid", 66 | help='Data to be processed') 67 | parser.add_argument('--data_path', 68 | required=True, 69 | default=None, 70 | help='Path to data') 71 | parser.add_argument('--raw_path', 72 | default=None, 73 | help='Raw data directory') 74 | parser.add_argument('--save_path', 75 | default=None, 76 | help='Save directory') 77 | parser.add_argument('--checkpoint_path', 78 | default=None, 79 | help='Checkpoint directory') 80 | parser.add_argument('--summary_path', 81 | default=None, 82 | help='Summary directory') 83 | parser.add_argument('--model_name', 84 | default=None, 85 | help='Model name') 86 | parser.add_argument('--batch_size', 87 | default=None, 88 | type=int, 89 | help='Batch size') 90 | parser.add_argument('--data_size', 91 | default=None, 92 | type=int, 93 | help='Data size') 94 | parser.add_argument('--bilstm_type', 95 | default=None, 96 | help='standard/interleave') 97 | parser.add_argument('--max_span_len', 98 | default=None, 99 | type=int, 100 | help='max span length') 101 | main(parser.parse_args()) 102 | -------------------------------------------------------------------------------- /scripts/convert_conll03_to_json.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import argparse 8 | import codecs 9 | import random 10 | import ujson 11 | 12 | 13 | def load(filename): 14 | with codecs.open(filename, mode="r", encoding="utf-8") as f: 15 | words, tags = [], [] 16 | for line in f: 17 | line = line.lstrip().rstrip() 18 | if line.startswith("-DOCSTART-"): 19 | continue 20 | if len(line) == 0: 21 | if len(words) != 0: 22 | yield words, tags 23 | words, tags = [], [] 24 | else: 25 | line = line.split() 26 | words.append(line[0]) 27 | tags.append(line[-1]) 28 | 29 | 30 | def write_json(filename, data): 31 | with codecs.open(filename, mode="w", encoding="utf-8") as f: 32 | ujson.dump(data, f, ensure_ascii=False) 33 | 34 | 35 | def remove_duplicate_sents(sents): 36 | new_sents = [] 37 | for i, (words1, tags1) in enumerate(sents): 38 | for (words2, _) in sents[i + 1:]: 39 | if words1 == words2: 40 | break 41 | else: 42 | new_sents.append((words1, tags1)) 43 | return new_sents 44 | 45 | 46 | def bio2span(labels): 47 | spans = [] 48 | span = [] 49 | for w_i, label in enumerate(labels): 50 | if label.startswith('B-'): 51 | if span: 52 | spans.append(span) 53 | span = [label[2:], w_i, w_i] 54 | elif label.startswith('I-'): 55 | if span: 56 | if label[2:] == span[0]: 57 | span[2] = w_i 58 | else: 59 | spans.append(span) 60 | span = [label[2:], w_i, w_i] 61 | else: 62 | span = [label[2:], w_i, w_i] 63 | else: 64 | if span: 65 | spans.append(span) 66 | span = [] 67 | if span: 68 | spans.append(span) 69 | return spans 70 | 71 | 72 | def main(argv): 73 | sents = list(load(argv.input_file)) 74 | print("Sents:%d" % len(sents)) 75 | if argv.remove_duplicates: 76 | sents = remove_duplicate_sents(sents) 77 | print("Sents (removed duplicates): %d" % len(sents)) 78 | 79 | data = [] 80 | n_sents = 0 81 | n_words = 0 82 | n_spans = 0 83 | for words, bio_labels in sents: 84 | spans = bio2span(bio_labels) 85 | data.append({"sent_id": n_sents, 86 | "words": words, 87 | "bio_labels": bio_labels, 88 | "spans": spans}) 89 | n_sents += 1 90 | n_words += len(words) 91 | n_spans += len(spans) 92 | 93 | if argv.split > 1: 94 | split_size = int(len(data) / argv.split) 95 | random.shuffle(data) 96 | data = data[:split_size] 97 | n_sents = len(data) 98 | n_words = 0 99 | n_spans = 0 100 | for record in data: 101 | n_words += len(record["words"]) 102 | n_spans += len(record["spans"]) 103 | 104 | if argv.output_file.endswith(".json"): 105 | path = argv.output_file 106 | else: 107 | path = argv.output_file + ".json" 108 | write_json(path, data) 109 | print("Sents:%d\tWords:%d\tEntities:%d" % (n_sents, n_words, n_spans)) 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = argparse.ArgumentParser(description='SCRIPT') 114 | parser.add_argument('--input_file', 115 | help='path to conll2003') 116 | parser.add_argument('--output_file', 117 | default="output", 118 | help='output file name') 119 | parser.add_argument('--remove_duplicates', 120 | action='store_true', 121 | default=False, 122 | help='remove duplicates') 123 | parser.add_argument('--split', 124 | default=1, 125 | type=int, 126 | help='split size of the data') 127 | main(parser.parse_args()) 128 | -------------------------------------------------------------------------------- /scripts/convert_genia_to_json.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import argparse 3 | import codecs 4 | import ujson 5 | 6 | 7 | def load(filename): 8 | with codecs.open(filename, mode="r", encoding="utf-8") as f: 9 | words, tags = [], [] 10 | for line in f: 11 | line = line.lstrip().rstrip() 12 | if len(line) == 0: 13 | if len(words) != 0: 14 | yield words, tags 15 | words, tags = [], [] 16 | else: 17 | line = line.split("\t") 18 | words.append(line[0]) 19 | tags.append(line[1:]) 20 | 21 | 22 | def write_json(filename, data): 23 | with codecs.open(filename, mode="w", encoding="utf-8") as f: 24 | ujson.dump(data, f, ensure_ascii=False) 25 | 26 | 27 | def remove_duplicate_sents(sents): 28 | new_sents = [] 29 | for i, (words1, tags1) in enumerate(sents): 30 | for (words2, _) in sents[i + 1:]: 31 | if words1 == words2: 32 | break 33 | else: 34 | new_sents.append((words1, tags1)) 35 | return new_sents 36 | 37 | 38 | def bio2span(labels): 39 | spans = [] 40 | span = [] 41 | for w_i, label in enumerate(labels): 42 | if label.startswith('B-'): 43 | if span: 44 | spans.append(span) 45 | span = [label[2:], w_i, w_i] 46 | elif label.startswith('I-'): 47 | if span: 48 | if label[2:] == span[0]: 49 | span[2] = w_i 50 | else: 51 | spans.append(span) 52 | span = [label[2:], w_i, w_i] 53 | else: 54 | span = [label[2:], w_i, w_i] 55 | else: 56 | if span: 57 | spans.append(span) 58 | span = [] 59 | if span: 60 | spans.append(span) 61 | return spans 62 | 63 | 64 | def main(argv): 65 | sents = list(load(argv.input_file)) 66 | print("Sents: %d" % len(sents)) 67 | if argv.remove_duplicates: 68 | sents = remove_duplicate_sents(sents) 69 | print("Sents (removed duplicates): %d" % len(sents)) 70 | 71 | data = [] 72 | n_sents = 0 73 | n_words = 0 74 | n_spans = 0 75 | for words, multi_layered_labels in sents: 76 | spans = [] 77 | for bio_labels in zip(*multi_layered_labels): 78 | spans += bio2span(bio_labels) 79 | data.append({"sent_id": n_sents, 80 | "words": words, 81 | "spans": spans}) 82 | n_sents += 1 83 | n_words += len(words) 84 | n_spans += len(spans) 85 | 86 | if argv.output_file.endswith(".json"): 87 | path = argv.output_file 88 | else: 89 | path = argv.output_file + ".json" 90 | write_json(path, data) 91 | print("Sents:%d\tWords:%d\tEntities:%d" % (n_sents, n_words, n_spans)) 92 | 93 | 94 | if __name__ == '__main__': 95 | parser = argparse.ArgumentParser(description='SCRIPT') 96 | parser.add_argument('--input_file', 97 | help='path to genia corpus') 98 | parser.add_argument('--output_file', 99 | default="output", 100 | help='output file name') 101 | parser.add_argument('--remove_duplicates', 102 | action='store_true', 103 | default=False, 104 | help='remove duplicates') 105 | main(parser.parse_args()) 106 | -------------------------------------------------------------------------------- /scripts/retrieve_knn_sents_with_glove.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import argparse 8 | import codecs 9 | import ujson 10 | import numpy as np 11 | from numpy.linalg import norm 12 | from tqdm import tqdm 13 | 14 | glove_sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6)} 15 | 16 | 17 | def load_json(filename): 18 | with codecs.open(filename, mode='r', encoding='utf-8') as f: 19 | dataset = ujson.load(f) 20 | return dataset 21 | 22 | 23 | def write_json(filename, data): 24 | with codecs.open(filename, mode="w", encoding="utf-8") as f: 25 | ujson.dump(data, f, ensure_ascii=False) 26 | 27 | 28 | def load_glove(glove_path, glove_name="6B"): 29 | vocab = {} 30 | vectors = [] 31 | total = glove_sizes[glove_name] 32 | with codecs.open(glove_path, mode='r', encoding='utf-8') as f: 33 | for line in tqdm(f, total=total, desc="Load glove"): 34 | line = line.lstrip().rstrip().split(" ") 35 | vocab[line[0]] = len(vocab) 36 | vectors.append([float(x) for x in line[1:]]) 37 | assert len(vocab) == len(vectors) 38 | return vocab, np.asarray(vectors) 39 | 40 | 41 | def mean_vectors(data, emb, vocab): 42 | unk_vec = np.zeros(emb.shape[1]) 43 | mean_vecs = [] 44 | for record in data: 45 | vecs = [] 46 | for word in record["words"]: 47 | word = word.lower() 48 | if word in vocab: 49 | vec = emb[vocab[word]] 50 | else: 51 | vec = unk_vec 52 | vecs.append(vec) 53 | mean_vecs.append(np.mean(vecs, axis=0)) 54 | return mean_vecs 55 | 56 | 57 | def cosine_similarity(p0, p1): 58 | d = (norm(p0) * norm(p1)) 59 | if d > 0: 60 | return np.dot(p0, p1) / d 61 | return 0.0 62 | 63 | 64 | def knn(test_sents, train_embs, test_embs, k, path): 65 | for index, (sent, vec) in enumerate(zip(test_sents, test_embs)): 66 | assert index == sent["sent_id"] 67 | if (index + 1) % 100 == 0: 68 | print("%d" % (index + 1), flush=True, end=" ") 69 | sim = [cosine_similarity(train_vec, vec) for train_vec in train_embs] 70 | arg_sort = np.argsort(sim)[::-1][:k] 71 | sent["train_sent_ids"] = [int(arg) for arg in arg_sort] 72 | write_json(path, test_sents) 73 | 74 | 75 | def main(args): 76 | train_sents = load_json(args.train_json)[:args.data_size] 77 | test_sents = load_json(args.test_json)[:args.data_size] 78 | vocab, glove = load_glove(args.glove) 79 | print("Train sents: {:>7}".format(len(train_sents))) 80 | print("Test sents: {:>7}".format(len(test_sents))) 81 | train_embs = mean_vectors(train_sents, glove, vocab) 82 | test_embs = mean_vectors(test_sents, glove, vocab) 83 | if args.output_file.endswith(".json"): 84 | path = args.output_file 85 | else: 86 | path = args.output_file + ".json" 87 | knn(test_sents, train_embs, test_embs, args.k, path) 88 | 89 | 90 | if __name__ == '__main__': 91 | parser = argparse.ArgumentParser(description='SCRIPT') 92 | parser.add_argument('--train_json', 93 | type=str, 94 | default='data/conll2003/train.json', 95 | help='path to json-format data') 96 | parser.add_argument('--test_json', 97 | type=str, 98 | default='data/conll2003/test.json', 99 | help='path to json-format data') 100 | parser.add_argument('--output_file', 101 | default="output", 102 | help='output file name') 103 | parser.add_argument('--glove', 104 | type=str, 105 | default='data/emb/glove.6B.100d.txt', 106 | help='path to glove embeddings') 107 | parser.add_argument('--k', 108 | type=int, 109 | default=50, 110 | help='k') 111 | parser.add_argument('--data_size', 112 | type=int, 113 | default=100000000, 114 | help='number of sentences to be used') 115 | main(parser.parse_args()) 116 | -------------------------------------------------------------------------------- /train_knn_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import argparse 8 | import os 9 | import json 10 | 11 | from models.knn_models import KnnModel 12 | from utils.batchers.knn_batchers import BaseKnnBatcher 13 | from utils.preprocessors.knn_preprocessors import KnnPreprocessor 14 | 15 | 16 | def set_config(args, config): 17 | if args.raw_path: 18 | config["raw_path"] = args.raw_path 19 | if args.save_path: 20 | config["save_path"] = args.save_path 21 | config["train_set"] = os.path.join(args.save_path, "train.json") 22 | config["valid_set"] = os.path.join(args.save_path, "valid.json") 23 | config["vocab"] = os.path.join(args.save_path, "vocab.json") 24 | config["pretrained_emb"] = os.path.join(args.save_path, "glove_emb.npz") 25 | if args.train_set: 26 | config["train_set"] = args.train_set 27 | if args.valid_set: 28 | config["valid_set"] = args.valid_set 29 | if args.pretrained_emb: 30 | config["pretrained_emb"] = args.pretrained_emb 31 | if args.vocab: 32 | config["vocab"] = args.vocab 33 | if args.checkpoint_path: 34 | config["checkpoint_path"] = args.checkpoint_path 35 | config["summary_path"] = os.path.join(args.checkpoint_path, "summary") 36 | if args.summary_path: 37 | config["summary_path"] = args.summary_path 38 | if args.model_name: 39 | config["model_name"] = args.model_name 40 | if args.batch_size: 41 | config["batch_size"] = args.batch_size 42 | if args.data_size: 43 | config["data_size"] = args.data_size 44 | if args.bilstm_type: 45 | config["bilstm_type"] = args.bilstm_type 46 | if args.keep_prob: 47 | config["keep_prob"] = args.keep_prob 48 | if args.k: 49 | config["k"] = args.k 50 | if args.predict: 51 | config["predict"] = args.predict 52 | if args.max_span_len: 53 | config["max_span_len"] = args.max_span_len 54 | if args.max_n_spans: 55 | config["max_n_spans"] = args.max_n_spans 56 | if args.knn_sampling: 57 | config["knn_sampling"] = args.knn_sampling 58 | return config 59 | 60 | 61 | def main(args): 62 | config = json.load(open(args.config_file)) 63 | config = set_config(args, config) 64 | 65 | preprocessor = KnnPreprocessor(config) 66 | 67 | # create dataset from raw data files 68 | if not os.path.exists(config["save_path"]): 69 | preprocessor.preprocess() 70 | 71 | model = KnnModel(config, BaseKnnBatcher(config)) 72 | model.train() 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('--config_file', 78 | required=True, 79 | default='data/config/config.json', 80 | help='Configuration file') 81 | parser.add_argument('--raw_path', 82 | default=None, 83 | help='Raw data directory') 84 | parser.add_argument('--save_path', 85 | default=None, 86 | help='Save directory') 87 | parser.add_argument('--checkpoint_path', 88 | default=None, 89 | help='Checkpoint directory') 90 | parser.add_argument('--summary_path', 91 | default=None, 92 | help='Summary directory') 93 | parser.add_argument('--model_name', 94 | default=None, 95 | help='Model name') 96 | parser.add_argument('--batch_size', 97 | default=None, 98 | type=int, 99 | help='Batch size') 100 | parser.add_argument('--train_set', 101 | default=None, 102 | help='path to training set') 103 | parser.add_argument('--valid_set', 104 | default=None, 105 | help='path to training set') 106 | parser.add_argument('--pretrained_emb', 107 | default=None, 108 | help='path to pretrained embeddings') 109 | parser.add_argument('--vocab', 110 | default=None, 111 | help='path to vocabulary') 112 | parser.add_argument('--data_size', 113 | default=None, 114 | type=int, 115 | help='Data size') 116 | parser.add_argument('--bilstm_type', 117 | default=None, 118 | help='standard/interleave') 119 | parser.add_argument('--keep_prob', 120 | default=None, 121 | type=float, 122 | help='Keep (dropout) probability') 123 | parser.add_argument('--k', 124 | default=None, 125 | type=int, 126 | help='k-NN sentences') 127 | parser.add_argument('--predict', 128 | default='max_margin', 129 | help='prediction methods') 130 | parser.add_argument('--max_span_len', 131 | default=None, 132 | type=int, 133 | help='max span length') 134 | parser.add_argument('--max_n_spans', 135 | default=None, 136 | type=int, 137 | help='max num of spans') 138 | parser.add_argument('--knn_sampling', 139 | default=None, 140 | help='k-NN sentence sampling') 141 | main(parser.parse_args()) 142 | -------------------------------------------------------------------------------- /train_span_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import argparse 8 | import os 9 | import json 10 | 11 | from models.span_models import SpanModel 12 | from utils.batchers.span_batchers import BaseSpanBatcher 13 | from utils.preprocessors.span_preprocessors import SpanPreprocessor 14 | 15 | 16 | def set_config(args, config): 17 | if args.raw_path: 18 | config["raw_path"] = args.raw_path 19 | if args.save_path: 20 | config["save_path"] = args.save_path 21 | config["train_set"] = os.path.join(args.save_path, "train.json") 22 | config["valid_set"] = os.path.join(args.save_path, "valid.json") 23 | config["vocab"] = os.path.join(args.save_path, "vocab.json") 24 | config["pretrained_emb"] = os.path.join(args.save_path, "glove_emb.npz") 25 | if args.train_set: 26 | config["train_set"] = args.train_set 27 | if args.valid_set: 28 | config["valid_set"] = args.valid_set 29 | if args.pretrained_emb: 30 | config["pretrained_emb"] = args.pretrained_emb 31 | if args.vocab: 32 | config["vocab"] = args.vocab 33 | if args.checkpoint_path: 34 | config["checkpoint_path"] = args.checkpoint_path 35 | config["summary_path"] = os.path.join(args.checkpoint_path, "summary") 36 | if args.summary_path: 37 | config["summary_path"] = args.summary_path 38 | if args.model_name: 39 | config["model_name"] = args.model_name 40 | if args.batch_size: 41 | config["batch_size"] = args.batch_size 42 | if args.data_size: 43 | config["data_size"] = args.data_size 44 | if args.bilstm_type: 45 | config["bilstm_type"] = args.bilstm_type 46 | if args.keep_prob: 47 | config["keep_prob"] = args.keep_prob 48 | if args.max_span_len: 49 | config["max_span_len"] = args.max_span_len 50 | if args.max_n_spans: 51 | config["max_n_spans"] = args.max_n_spans 52 | return config 53 | 54 | 55 | def main(args): 56 | config = json.load(open(args.config_file)) 57 | config = set_config(args, config) 58 | 59 | preprocessor = SpanPreprocessor(config) 60 | 61 | # create dataset from raw data files 62 | if not os.path.exists(config["save_path"]): 63 | preprocessor.preprocess() 64 | 65 | print("Build a span model...") 66 | model = SpanModel(config, BaseSpanBatcher(config)) 67 | model.train() 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--config_file', 73 | default='data/config/config.json', 74 | help='Configuration file') 75 | parser.add_argument('--raw_path', 76 | default=None, 77 | help='Raw data directory') 78 | parser.add_argument('--save_path', 79 | default=None, 80 | help='Save directory') 81 | parser.add_argument('--checkpoint_path', 82 | default=None, 83 | help='Checkpoint directory') 84 | parser.add_argument('--summary_path', 85 | default=None, 86 | help='Summary directory') 87 | parser.add_argument('--model_name', 88 | default=None, 89 | help='Model name') 90 | parser.add_argument('--batch_size', 91 | default=None, 92 | type=int, 93 | help='Batch size') 94 | parser.add_argument('--train_set', 95 | default=None, 96 | help='path to training set') 97 | parser.add_argument('--valid_set', 98 | default=None, 99 | help='path to training set') 100 | parser.add_argument('--pretrained_emb', 101 | default=None, 102 | help='path to pretrained embeddings') 103 | parser.add_argument('--vocab', 104 | default=None, 105 | help='path to vocabulary') 106 | parser.add_argument('--data_size', 107 | default=None, 108 | type=int, 109 | help='Data size') 110 | parser.add_argument('--bilstm_type', 111 | default=None, 112 | help='standard/interleave') 113 | parser.add_argument('--keep_prob', 114 | default=None, 115 | type=float, 116 | help='Keep (dropout) probability') 117 | parser.add_argument('--max_span_len', 118 | default=None, 119 | type=int, 120 | help='max span length') 121 | parser.add_argument('--max_n_spans', 122 | default=None, 123 | type=int, 124 | help='max num of spans') 125 | main(parser.parse_args()) 126 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.logger import * 2 | from utils.data_utils import * 3 | -------------------------------------------------------------------------------- /utils/batchers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiroki13/instance-based-ner/7bd8a29dfb1e13de0775b5814e8f9b27ec490008/utils/batchers/__init__.py -------------------------------------------------------------------------------- /utils/batchers/base_batchers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | 8 | class Batcher(object): 9 | 10 | def __init__(self, config): 11 | self.config = config 12 | 13 | @staticmethod 14 | def pad_sequences(sequences, pad_tok=None, max_length=None): 15 | if pad_tok is None: 16 | # 0: "PAD" for words and chars, "O" for tags 17 | pad_tok = 0 18 | if max_length is None: 19 | max_length = max([len(seq) for seq in sequences]) 20 | sequence_padded, sequence_length = [], [] 21 | for seq in sequences: 22 | seq_ = seq[:max_length] + [pad_tok] * max(max_length - len(seq), 0) 23 | sequence_padded.append(seq_) 24 | sequence_length.append(min(len(seq), max_length)) 25 | return sequence_padded, sequence_length 26 | 27 | def pad_char_sequences(self, sequences, max_length=None, 28 | max_token_length=None): 29 | sequence_padded, sequence_length = [], [] 30 | if max_length is None: 31 | max_length = max(map(lambda x: len(x), sequences)) 32 | if max_token_length is None: 33 | max_token_length = max( 34 | [max(map(lambda x: len(x), seq)) for seq in sequences]) 35 | for seq in sequences: 36 | sp, sl = self.pad_sequences(seq, max_length=max_token_length) 37 | sequence_padded.append(sp) 38 | sequence_length.append(sl) 39 | sequence_padded, _ = self.pad_sequences(sequence_padded, 40 | pad_tok=[0] * max_token_length, 41 | max_length=max_length) 42 | sequence_length, _ = self.pad_sequences(sequence_length, 43 | max_length=max_length) 44 | return sequence_padded, sequence_length 45 | 46 | def make_each_batch(self, **kwargs): 47 | raise NotImplementedError 48 | 49 | def batchnize_dataset(self, **kwargs): 50 | raise NotImplementedError 51 | -------------------------------------------------------------------------------- /utils/batchers/knn_batchers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import random 8 | 9 | import numpy as np 10 | 11 | from utils.batchers.span_batchers import MaskSpanBatcher 12 | from utils.common import load_json 13 | 14 | 15 | class KnnBatcher(MaskSpanBatcher): 16 | 17 | def make_each_batch_for_targets(self, **kwargs): 18 | raise NotImplementedError 19 | 20 | def make_each_batch_for_neighbors(self, **kwargs): 21 | raise NotImplementedError 22 | 23 | def batchnize_neighbor_train_sents(self, **kwargs): 24 | raise NotImplementedError 25 | 26 | def make_batch_from_sent_ids(self, **kwargs): 27 | raise NotImplementedError 28 | 29 | def batchnize_span_reps_and_tags(self, train_batch_reps, train_batch_tags, 30 | train_batch_masks): 31 | """ 32 | :param train_batch_reps: 1D: num_sents, 2D: max_num_spans, 3D: dim 33 | :param train_batch_tags: 1D: num_sents, 2D: max_num_spans 34 | :param train_batch_masks: 1D: num_sents, 2D: max_num_spans 35 | :return: batch_span_rep_list: 1D: num_instances, 2D: dim 36 | :return: batch_tag_one_hot_list: 1D: num_instances, 2D: num_tags 37 | """ 38 | rep_list = [] 39 | tag_list = [] 40 | for reps, tags, masks in zip(train_batch_reps, 41 | train_batch_tags, 42 | train_batch_masks): 43 | for rep, tag, mask in zip(reps, tags, masks): 44 | if mask: 45 | rep_list.append(rep) 46 | tag_list.append(tag) 47 | return rep_list, tag_list 48 | 49 | @staticmethod 50 | def batchnize_span_reps_and_tag_one_hots(train_batch_reps, 51 | train_batch_tags, 52 | train_batch_masks, 53 | num_tags): 54 | """ 55 | :param train_batch_reps: 1D: num_sents, 2D: max_num_spans, 3D: dim 56 | :param train_batch_tags: 1D: num_sents, 2D: max_num_spans 57 | :param train_batch_masks: 1D: num_sents, 2D: max_num_spans 58 | :param num_tags; the number of possible tags 59 | :return: batch_span_rep_list: 1D: num_instances, 2D: dim 60 | :return: batch_tag_one_hot_list: 1D: num_instances, 2D: num_tags 61 | """ 62 | rep_list = [] 63 | tag_list = [] 64 | for reps, tags, masks in zip(train_batch_reps, 65 | train_batch_tags, 66 | train_batch_masks): 67 | for rep, tag, mask in zip(reps, tags, masks): 68 | if mask: 69 | rep_list.append(rep) 70 | tag_list.append(tag) 71 | 72 | tag_one_hot_list = np.zeros(shape=(len(tag_list), num_tags)) 73 | for i, tag_id in enumerate(tag_list): 74 | tag_one_hot_list[i][tag_id] = 1 75 | 76 | return rep_list, tag_one_hot_list 77 | 78 | 79 | class BaseKnnBatcher(KnnBatcher): 80 | 81 | def make_each_batch_for_targets(self, batch_words, batch_chars, batch_ids, 82 | max_span_len, max_n_spans, batch_tags=None): 83 | b_words, b_words_len = self.pad_sequences(batch_words) 84 | b_chars, _ = self.pad_char_sequences(batch_chars, max_token_length=20) 85 | batch = {"words": b_words, 86 | "chars": b_chars, 87 | "seq_len": b_words_len, 88 | "instance_ids": batch_ids} 89 | n_words = b_words_len[0] 90 | span_indices = self._make_span_indices(n_words=n_words, 91 | max_span_len=max_span_len, 92 | max_n_spans=max_n_spans) 93 | if max_n_spans: 94 | batch["span_indices"] = span_indices 95 | if batch_tags is not None: 96 | batch["tags"] = self._make_tag_sequences(batch_triples=batch_tags, 97 | indices=span_indices, 98 | n_words=n_words) 99 | return batch 100 | 101 | def make_each_batch_for_neighbors(self, batch_words, batch_chars, 102 | max_span_len, max_n_spans, 103 | batch_tags=None): 104 | b_words, b_words_len = self.pad_sequences(batch_words) 105 | b_chars, _ = self.pad_char_sequences(batch_chars, max_token_length=20) 106 | max_n_words = max(b_words_len) 107 | span_indices = self._make_span_indices(n_words=max_n_words, 108 | max_span_len=max_span_len, 109 | max_n_spans=max_n_spans) 110 | b_masks = self._make_masks(lengths=b_words_len, 111 | indices=span_indices, 112 | max_n_words=max_n_words) 113 | batch = {"words": b_words, 114 | "chars": b_chars, 115 | "seq_len": b_words_len, 116 | "masks": b_masks} 117 | if max_n_spans: 118 | batch["span_indices"] = span_indices 119 | if batch_tags is not None: 120 | batch["tags"] = self._make_tag_sequences(batch_triples=batch_tags, 121 | indices=span_indices, 122 | n_words=max_n_words) 123 | return batch 124 | 125 | def batchnize_dataset(self, data, data_name=None, batch_size=None, 126 | shuffle=True): 127 | max_span_len = self.config["max_span_len"] 128 | if data_name == "train": 129 | max_n_spans = self.config["max_n_spans"] 130 | else: 131 | if self.config["max_n_spans"] > 0: 132 | max_n_spans = 1000000 133 | else: 134 | max_n_spans = 0 135 | 136 | dataset = load_json(data) 137 | for instance_id, record in enumerate(dataset): 138 | record["instance_id"] = instance_id 139 | 140 | if shuffle: 141 | random.shuffle(dataset) 142 | dataset.sort(key=lambda record: len(record["words"])) 143 | 144 | batches = [] 145 | batch_words, batch_chars, batch_tags, batch_ids = [], [], [], [] 146 | prev_seq_len = len(dataset[0]["words"]) 147 | 148 | for record in dataset: 149 | seq_len = len(record["words"]) 150 | 151 | if len(batch_words) == batch_size or prev_seq_len != seq_len: 152 | batches.append(self.make_each_batch_for_targets(batch_words, 153 | batch_chars, 154 | batch_ids, 155 | max_span_len, 156 | max_n_spans, 157 | batch_tags)) 158 | batch_words, batch_chars, batch_tags, batch_ids = [], [], [], [] 159 | prev_seq_len = seq_len 160 | 161 | batch_words.append(record["words"]) 162 | batch_chars.append(record["chars"]) 163 | batch_tags.append(record["tags"]) 164 | batch_ids.append(record["instance_id"]) 165 | 166 | if len(batch_words) > 0: 167 | batches.append(self.make_each_batch_for_targets(batch_words, 168 | batch_chars, 169 | batch_ids, 170 | max_span_len, 171 | max_n_spans, 172 | batch_tags)) 173 | if shuffle: 174 | random.shuffle(batches) 175 | for batch in batches: 176 | yield batch 177 | 178 | def batchnize_neighbor_train_sents(self, train_sents, train_sent_ids, 179 | max_span_len, max_n_spans): 180 | batch_words, batch_chars, batch_tags = [], [], [] 181 | for sent_id in train_sent_ids: 182 | batch_words.append(train_sents[sent_id]["words"]) 183 | batch_chars.append(train_sents[sent_id]["chars"]) 184 | batch_tags.append(train_sents[sent_id]["tags"]) 185 | return self.make_each_batch_for_neighbors(batch_words, 186 | batch_chars, 187 | max_span_len, 188 | max_n_spans, 189 | batch_tags) 190 | 191 | def make_batch_from_sent_ids(self, train_sents, sent_ids): 192 | batch_words, batch_chars, batch_tags = [], [], [] 193 | max_n_spans = 0 194 | for sent_id in sent_ids: 195 | batch_words.append(train_sents[sent_id]["words"]) 196 | batch_chars.append(train_sents[sent_id]["chars"]) 197 | batch_tags.append(train_sents[sent_id]["tags"]) 198 | return self.make_each_batch_for_neighbors(batch_words, 199 | batch_chars, 200 | self.config["max_span_len"], 201 | max_n_spans, 202 | batch_tags) 203 | -------------------------------------------------------------------------------- /utils/batchers/span_batchers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import random 8 | 9 | import numpy as np 10 | 11 | from utils.common import load_json 12 | from utils.batchers.base_batchers import Batcher 13 | 14 | 15 | class SpanBatcher(Batcher): 16 | 17 | @staticmethod 18 | def _make_span_indices(n_words, max_span_len, max_n_spans=0): 19 | indices = np.triu(np.ones(shape=(n_words, n_words), dtype='int32')) 20 | mask = np.triu(np.ones(shape=(n_words, n_words), dtype='int32'), 21 | k=max_span_len) 22 | indices = np.nonzero(indices - mask) 23 | 24 | if max_n_spans > 0: 25 | indices = [(i, j) for i, j in zip(*indices)] 26 | random.shuffle(indices) 27 | return np.asarray(list(zip(*indices[:max_n_spans])), dtype='int32') 28 | return indices 29 | 30 | @staticmethod 31 | def _make_tag_sequences(batch_triples, indices, n_words): 32 | def _gen_tag_sequence(triples): 33 | matrix = np.zeros(shape=(n_words, n_words), dtype='int32') 34 | for (r, i, j) in triples: 35 | matrix[i][j] = r 36 | return [matrix[i][j] for (i, j) in zip(*indices)] 37 | 38 | return np.asarray([_gen_tag_sequence(triples) 39 | for triples in batch_triples], dtype='int32') 40 | 41 | def make_each_batch(self, **kwargs): 42 | raise NotImplementedError 43 | 44 | def batchnize_dataset(self, **kwargs): 45 | raise NotImplementedError 46 | 47 | 48 | class BaseSpanBatcher(SpanBatcher): 49 | 50 | def make_each_batch(self, batch_words, batch_chars, max_span_len=None, 51 | batch_tags=None): 52 | b_words, b_words_len = self.pad_sequences(batch_words) 53 | b_chars, _ = self.pad_char_sequences(batch_chars, max_token_length=20) 54 | n_words = b_words_len[0] 55 | span_indices = self._make_span_indices(n_words, max_span_len) 56 | if batch_tags is None: 57 | return {"words": b_words, 58 | "chars": b_chars, 59 | "seq_len": b_words_len} 60 | else: 61 | b_tags = self._make_tag_sequences(batch_tags, span_indices, n_words) 62 | return {"words": b_words, 63 | "chars": b_chars, 64 | "tags": b_tags, 65 | "seq_len": b_words_len} 66 | 67 | def batchnize_dataset(self, data, batch_size=None, shuffle=True): 68 | batches = [] 69 | max_span_len = self.config["max_span_len"] 70 | dataset = load_json(data) 71 | 72 | if shuffle: 73 | random.shuffle(dataset) 74 | dataset.sort(key=lambda record: len(record["words"])) 75 | 76 | prev_seq_len = len(dataset[0]["words"]) 77 | batch_words, batch_chars, batch_tags = [], [], [] 78 | 79 | for record in dataset: 80 | seq_len = len(record["words"]) 81 | 82 | if len(batch_words) == batch_size or prev_seq_len != seq_len: 83 | batches.append(self.make_each_batch(batch_words, 84 | batch_chars, 85 | max_span_len, 86 | batch_tags)) 87 | batch_words, batch_chars, batch_tags = [], [], [] 88 | prev_seq_len = seq_len 89 | 90 | batch_words.append(record["words"]) 91 | batch_chars.append(record["chars"]) 92 | batch_tags.append(record["tags"]) 93 | 94 | if len(batch_words) > 0: 95 | batches.append(self.make_each_batch(batch_words, 96 | batch_chars, 97 | max_span_len, 98 | batch_tags)) 99 | if shuffle: 100 | random.shuffle(batches) 101 | for batch in batches: 102 | yield batch 103 | 104 | 105 | class MaskSpanBatcher(SpanBatcher): 106 | 107 | @staticmethod 108 | def _make_masks(lengths, indices, max_n_words): 109 | def _gen_mask(n_words): 110 | matrix = np.zeros(shape=(max_n_words, max_n_words), dtype='float32') 111 | matrix[:n_words, :n_words] = 1.0 112 | return [matrix[i][j] for (i, j) in zip(*indices)] 113 | 114 | return np.asarray([_gen_mask(n_words) for n_words in lengths], 115 | dtype='float32') 116 | 117 | def make_each_batch(self, batch_words, batch_chars, max_span_len=None, 118 | max_n_spans=None, batch_tags=None): 119 | b_words, b_words_len = self.pad_sequences(batch_words) 120 | b_chars, _ = self.pad_char_sequences(batch_chars, max_token_length=20) 121 | max_n_words = max(b_words_len) 122 | span_indices = self._make_span_indices(max_n_words, max_span_len) 123 | b_masks = self._make_masks(b_words_len, span_indices, max_n_words) 124 | if batch_tags is None: 125 | return {"words": b_words, 126 | "chars": b_chars, 127 | "masks": b_masks, 128 | "seq_len": b_words_len} 129 | else: 130 | b_tags = self._make_tag_sequences(batch_tags, span_indices, max_n_words) 131 | return {"words": b_words, 132 | "chars": b_chars, 133 | "tags": b_tags, 134 | "masks": b_masks, 135 | "seq_len": b_words_len} 136 | 137 | def batchnize_dataset(self, data, batch_size=None, shuffle=True): 138 | max_span_len = self.config["max_span_len"] 139 | max_n_spans = None 140 | dataset = load_json(data) 141 | 142 | if shuffle: 143 | random.shuffle(dataset) 144 | 145 | batch_words, batch_chars, batch_tags = [], [], [] 146 | 147 | for record in dataset: 148 | if len(batch_words) == batch_size: 149 | yield self.make_each_batch(batch_words, batch_chars, max_span_len, 150 | max_n_spans, batch_tags) 151 | batch_words, batch_chars, batch_tags = [], [], [] 152 | 153 | batch_words.append(record["words"]) 154 | batch_chars.append(record["chars"]) 155 | batch_tags.append(record["tags"]) 156 | 157 | if len(batch_words) > 0: 158 | yield self.make_each_batch(batch_words, batch_chars, max_span_len, 159 | max_n_spans, batch_tags) 160 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import codecs 8 | import gzip 9 | import re 10 | import pickle 11 | import ujson 12 | import unicodedata 13 | 14 | PAD = "" 15 | UNK = "" 16 | NUM = "" 17 | SPACE = "_SPACE" 18 | BOS = "" 19 | EOS = "" 20 | 21 | 22 | def load_json(filename): 23 | with codecs.open(filename, mode='r', encoding='utf-8') as f: 24 | dataset = ujson.load(f) 25 | return dataset 26 | 27 | 28 | def write_json(filename, data): 29 | with codecs.open(filename, mode="w", encoding="utf-8") as f: 30 | ujson.dump(data, f, ensure_ascii=False) 31 | 32 | 33 | def load_pickle(filename): 34 | with gzip.open(filename, 'rb') as gf: 35 | return pickle.load(gf) 36 | 37 | 38 | def write_pickle(filename, data): 39 | with gzip.open(filename + '.pkl.gz', 'wb') as gf: 40 | pickle.dump(data, gf, pickle.HIGHEST_PROTOCOL) 41 | 42 | 43 | def word_convert(word, keep_number=True, lowercase=True): 44 | if not keep_number: 45 | if is_digit(word): 46 | return NUM 47 | if lowercase: 48 | word = word.lower() 49 | return word 50 | 51 | 52 | def is_digit(word): 53 | try: 54 | float(word) 55 | return True 56 | except ValueError: 57 | pass 58 | try: 59 | unicodedata.numeric(word) 60 | return True 61 | except (TypeError, ValueError): 62 | pass 63 | result = re.compile(r'^[-+]?[0-9]+,[0-9]+$').match(word) 64 | if result: 65 | return True 66 | return False 67 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | from utils.common import load_json 8 | 9 | 10 | def metrics_for_multi_class_spans(batch_gold_spans, batch_pred_spans, 11 | null_label_id=0): 12 | correct = 0 13 | p_total = 0 14 | r_total = 0 15 | for gold_spans, pred_spans in zip(batch_gold_spans, batch_pred_spans): 16 | assert len(gold_spans) == len(pred_spans) 17 | r_total += sum([1 for e in gold_spans if e > null_label_id]) 18 | p_total += sum([1 for e in pred_spans if e > null_label_id]) 19 | correct += sum( 20 | [1 for t, p in zip(gold_spans, pred_spans) if t == p > null_label_id]) 21 | return correct, p_total, r_total 22 | 23 | 24 | def count_gold_and_system_outputs(batch_gold_spans, batch_pred_spans, 25 | null_label_id=0): 26 | correct = 0 27 | p_total = 0 28 | for gold_spans, pred_spans in zip(batch_gold_spans, batch_pred_spans): 29 | assert len(gold_spans) == len(pred_spans) 30 | p_total += sum([1 for e in pred_spans if e > null_label_id]) 31 | correct += sum( 32 | [1 for t, p in zip(gold_spans, pred_spans) if t == p > null_label_id]) 33 | return correct, p_total 34 | 35 | 36 | def count_gold_spans(path): 37 | n_gold_spans = 0 38 | for record in load_json(path): 39 | n_gold_spans += len(record["tags"]) 40 | return n_gold_spans 41 | 42 | 43 | def f_score(correct, p_total, r_total): 44 | precision = correct / p_total if p_total > 0 else 0. 45 | recall = correct / r_total if r_total > 0 else 0. 46 | f1 = (2 * precision * recall) / ( 47 | precision + recall) if precision + recall > 0 else 0. 48 | return precision, recall, f1 49 | 50 | 51 | def align_data(data): 52 | """Given dict with lists, creates aligned strings 53 | Args: 54 | data: (dict) data["x"] = ["I", "love", "you"] 55 | (dict) data["y"] = ["O", "O", "O"] 56 | Returns: 57 | data_aligned: (dict) data_align["x"] = "I love you" 58 | data_align["y"] = "O O O " 59 | """ 60 | spacings = [max([len(seq[i]) for seq in data.values()]) 61 | for i in range(len(data[list(data.keys())[0]]))] 62 | data_aligned = dict() 63 | # for each entry, create aligned string 64 | for key, seq in data.items(): 65 | str_aligned = '' 66 | for token, spacing in zip(seq, spacings): 67 | str_aligned += token + ' ' * (spacing - len(token) + 1) 68 | data_aligned[key] = str_aligned 69 | return data_aligned 70 | 71 | 72 | def span2bio(spans, n_words, tag_dict=None): 73 | bio_tags = ['O' for _ in range(n_words)] 74 | for (label_id, pre_index, post_index) in spans: 75 | if tag_dict: 76 | label = tag_dict[label_id] 77 | else: 78 | label = str(label_id) 79 | bio_tags[pre_index] = 'B-%s' % label 80 | for index in range(pre_index + 1, post_index + 1): 81 | bio_tags[index] = 'I-%s' % label 82 | return bio_tags 83 | 84 | 85 | def bio2triple(tags): 86 | triples = [] 87 | for i, tag in enumerate(tags): 88 | if tag.startswith('B-'): 89 | label = tag[2:] 90 | triples.append([label, i, i]) 91 | elif tag.startswith('I-'): 92 | triples[-1] = triples[-1][:-1] + [i] 93 | return triples 94 | 95 | 96 | def bio_to_span(datasets): 97 | for dataset in datasets: 98 | for record in dataset: 99 | bio_tags = record["tags"] 100 | triples = bio2triple(bio_tags) 101 | record["tags"] = triples 102 | return datasets 103 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import logging 4 | import numpy as np 5 | 6 | 7 | def get_logger(filename): 8 | """Return a logger instance that writes in filename 9 | Args: 10 | filename: (string) path to log.txt 11 | Returns: 12 | logger: (instance of logger) 13 | """ 14 | logger = logging.getLogger('logger') 15 | logger.setLevel(logging.DEBUG) 16 | logging.basicConfig(format='%(message)s', level=logging.DEBUG) 17 | handler = logging.FileHandler(filename) 18 | handler.setLevel(logging.DEBUG) 19 | handler.setFormatter( 20 | logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 21 | logging.getLogger().addHandler(handler) 22 | return logger 23 | 24 | 25 | class Progbar(object): 26 | """Progbar class copied from keras (https://github.com/fchollet/keras/) 27 | Displays a progress bar. 28 | Small edit : added strict arg to update 29 | Arguments 30 | target: Total number of steps expected. 31 | interval: Minimum visual progress update interval (in seconds). 32 | """ 33 | 34 | def __init__(self, target, width=30, verbose=1): 35 | self.width = width 36 | self.target = target 37 | self.sum_values = {} 38 | self.unique_values = [] 39 | self.start = time.time() 40 | self.total_width = 0 41 | self.seen_so_far = 0 42 | self.verbose = verbose 43 | 44 | def update(self, current, values=None, exact=None, strict=None): 45 | """Updates the progress bar. 46 | Arguments 47 | current: Index of current step. 48 | values: List of tuples (name, value_for_last_step). 49 | The progress bar will display averages for these values. 50 | exact: List of tuples (name, value_for_last_step). 51 | The progress bar will display these values directly. 52 | """ 53 | if strict is None: 54 | strict = [] 55 | if exact is None: 56 | exact = [] 57 | if values is None: 58 | values = [] 59 | for k, v in values: 60 | if type(v) == int: # for global steps 61 | if k not in self.sum_values: 62 | self.unique_values.append(k) 63 | self.sum_values[k] = v 64 | else: 65 | self.sum_values[k] = v 66 | else: 67 | if k not in self.sum_values: 68 | self.sum_values[k] = [v * (current - self.seen_so_far), 69 | current - self.seen_so_far] 70 | self.unique_values.append(k) 71 | else: 72 | self.sum_values[k][0] += v * (current - self.seen_so_far) 73 | self.sum_values[k][1] += (current - self.seen_so_far) 74 | for k, v in exact: 75 | if k not in self.sum_values: 76 | self.unique_values.append(k) 77 | self.sum_values[k] = [v, 1] 78 | 79 | for k, v in strict: 80 | if k not in self.sum_values: 81 | self.unique_values.append(k) 82 | self.sum_values[k] = v 83 | 84 | self.seen_so_far = current 85 | 86 | now = time.time() 87 | if self.verbose == 1: 88 | prev_total_width = self.total_width 89 | sys.stdout.write("\b" * prev_total_width) 90 | sys.stdout.write("\r") 91 | numdigits = int(np.floor(np.log10(self.target))) + 1 92 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 93 | bar = barstr % (current, self.target) 94 | prog = float(current) / self.target 95 | prog_width = int(self.width * prog) 96 | if prog_width > 0: 97 | bar += ('=' * (prog_width - 1)) 98 | if current < self.target: 99 | bar += '>' 100 | else: 101 | bar += '=' 102 | bar += ('.' * (self.width - prog_width)) 103 | bar += ']' 104 | sys.stdout.write(bar) 105 | self.total_width = len(bar) 106 | if current: 107 | time_per_unit = (now - self.start) / current 108 | else: 109 | time_per_unit = 0 110 | eta = time_per_unit * (self.target - current) 111 | info = '' 112 | if current < self.target: 113 | info += ' - ETA: %ds' % eta 114 | else: 115 | info += ' - %ds' % (now - self.start) 116 | for k in self.unique_values: 117 | if type(self.sum_values[k]) is list: 118 | info += ' - %s: %.4f' % ( 119 | k, self.sum_values[k][0] / max(1, self.sum_values[k][1])) 120 | else: 121 | info += ' - %s: %s' % (k, self.sum_values[k]) 122 | self.total_width += len(info) 123 | if prev_total_width > self.total_width: 124 | info += ((prev_total_width - self.total_width) * ' ') 125 | sys.stdout.write(info) 126 | sys.stdout.flush() 127 | if current >= self.target: 128 | sys.stdout.write("\n") 129 | if self.verbose == 2: 130 | if current >= self.target: 131 | info = '%ds' % (now - self.start) 132 | for k in self.unique_values: 133 | info += ' - %s: %.4f' % ( 134 | k, self.sum_values[k][0] / max(1, self.sum_values[k][1])) 135 | sys.stdout.write(info + "\n") 136 | 137 | def add(self, n, values=None): 138 | if values is None: 139 | values = [] 140 | self.update(self.seen_so_far + n, values) 141 | -------------------------------------------------------------------------------- /utils/preprocessors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiroki13/instance-based-ner/7bd8a29dfb1e13de0775b5814e8f9b27ec490008/utils/preprocessors/__init__.py -------------------------------------------------------------------------------- /utils/preprocessors/base_preprocessors.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import codecs 8 | from collections import Counter 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from utils.common import PAD, UNK, NUM, word_convert 13 | 14 | glove_sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), 15 | '2B': int(1.2e6)} 16 | 17 | 18 | class Preprocessor(object): 19 | 20 | def __init__(self, config): 21 | self.config = config 22 | 23 | @staticmethod 24 | def raw_dataset_iter(filename, keep_number, lowercase): 25 | with codecs.open(filename, mode="r", encoding="utf-8") as f: 26 | words, tags = [], [] 27 | for line in f: 28 | line = line.lstrip().rstrip() 29 | if line.startswith("-DOCSTART-"): 30 | continue 31 | if len(line) == 0: 32 | if len(words) != 0: 33 | yield words, tags 34 | words, tags = [], [] 35 | else: 36 | line = line.split() 37 | word = line[0] 38 | tag = line[-1] 39 | word = word_convert(word, keep_number=keep_number, 40 | lowercase=lowercase) 41 | words.append(word) 42 | tags.append(tag) 43 | 44 | def load_dataset(self, filename, keep_number=False, lowercase=True): 45 | dataset = [] 46 | for words, tags in self.raw_dataset_iter(filename, keep_number, lowercase): 47 | dataset.append({"words": words, "tags": tags}) 48 | return dataset 49 | 50 | @staticmethod 51 | def load_glove_vocab(glove_path, glove_name): 52 | vocab = set() 53 | total = glove_sizes[glove_name] 54 | with codecs.open(glove_path, mode='r', encoding='utf-8') as f: 55 | for line in tqdm(f, total=total, desc="Load glove vocabulary"): 56 | line = line.lstrip().rstrip().split(" ") 57 | vocab.add(line[0]) 58 | return vocab 59 | 60 | @staticmethod 61 | def build_word_vocab(datasets): 62 | word_counter = Counter() 63 | for dataset in datasets: 64 | for record in dataset: 65 | words = record["words"] 66 | for word in words: 67 | word_counter[word] += 1 68 | word_vocab = [PAD, UNK, NUM] + [word for word, _ in 69 | word_counter.most_common(10000) if 70 | word != NUM] 71 | word_dict = dict([(word, idx) for idx, word in enumerate(word_vocab)]) 72 | return word_dict 73 | 74 | @staticmethod 75 | def build_char_vocab(datasets): 76 | char_counter = Counter() 77 | for dataset in datasets: 78 | for record in dataset: 79 | for word in record["words"]: 80 | for char in word: 81 | char_counter[char] += 1 82 | word_vocab = [PAD, UNK] + sorted( 83 | [char for char, _ in char_counter.most_common()]) 84 | word_dict = dict([(word, idx) for idx, word in enumerate(word_vocab)]) 85 | return word_dict 86 | 87 | @staticmethod 88 | def build_word_vocab_pretrained(datasets, glove_vocab): 89 | word_counter = Counter() 90 | for dataset in datasets: 91 | for record in dataset: 92 | words = record["words"] 93 | for word in words: 94 | word_counter[word] += 1 95 | # build word dict 96 | word_vocab = [PAD, UNK, NUM] + sorted(list(glove_vocab)) 97 | word_dict = dict([(word, idx) for idx, word in enumerate(word_vocab)]) 98 | return word_dict 99 | 100 | @staticmethod 101 | def filter_glove_emb(word_dict, glove_path, glove_name, dim): 102 | vectors = np.zeros([len(word_dict) - 3, dim]) 103 | with codecs.open(glove_path, mode='r', encoding='utf-8') as f: 104 | for line in tqdm(f, total=glove_sizes[glove_name], 105 | desc="Filter glove embeddings"): 106 | line = line.lstrip().rstrip().split(" ") 107 | word = line[0] 108 | vector = [float(x) for x in line[1:]] 109 | if word in word_dict: 110 | word_idx = word_dict[word] - 3 111 | vectors[word_idx] = np.asarray(vector) 112 | return vectors 113 | 114 | @staticmethod 115 | def build_tag_vocab(datasets): 116 | raise NotImplementedError 117 | 118 | @staticmethod 119 | def build_dataset(data, word_dict, char_dict, tag_dict): 120 | raise NotImplementedError 121 | 122 | def preprocess(self): 123 | raise NotImplementedError 124 | -------------------------------------------------------------------------------- /utils/preprocessors/knn_preprocessors.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import os 8 | import numpy as np 9 | 10 | from utils.common import write_json, load_json, word_convert 11 | from utils.preprocessors.span_preprocessors import SpanPreprocessor 12 | 13 | 14 | class KnnPreprocessor(SpanPreprocessor): 15 | 16 | def load_dataset(self, filename, keep_number=False, lowercase=True): 17 | dataset = [] 18 | for record in load_json(filename): 19 | words = [word_convert(word, keep_number=keep_number, lowercase=lowercase) 20 | for word in record["words"]] 21 | if "train_sent_ids" in record: 22 | dataset.append({"sent_id": record["sent_id"], 23 | "words": words, 24 | "tags": record["spans"], 25 | "train_sent_ids": record["train_sent_ids"]}) 26 | else: 27 | dataset.append({"sent_id": record["sent_id"], 28 | "words": words, 29 | "tags": record["spans"]}) 30 | return dataset 31 | 32 | def preprocess(self): 33 | config = self.config 34 | os.makedirs(config["save_path"], exist_ok=True) 35 | 36 | # List[{'words': List[str], 'tags': List[str]}] 37 | train_data = self.load_dataset( 38 | os.path.join(config["raw_path"], "train.json"), 39 | keep_number=False, 40 | lowercase=True) 41 | valid_data = self.load_dataset( 42 | os.path.join(config["raw_path"], "valid.json"), 43 | keep_number=False, 44 | lowercase=True) 45 | train_data = train_data[:config["data_size"]] 46 | valid_data = valid_data[:config["data_size"]] 47 | 48 | # build vocabulary 49 | if config["use_pretrained"]: 50 | glove_path = self.config["glove_path"].format(config["glove_name"], 51 | config["emb_dim"]) 52 | glove_vocab = self.load_glove_vocab(glove_path, config["glove_name"]) 53 | word_dict = self.build_word_vocab_pretrained([train_data, valid_data], 54 | glove_vocab) 55 | vectors = self.filter_glove_emb(word_dict, 56 | glove_path, 57 | config["glove_name"], 58 | config["emb_dim"]) 59 | np.savez_compressed(config["pretrained_emb"], embeddings=vectors) 60 | else: 61 | word_dict = self.build_word_vocab([train_data, valid_data]) 62 | 63 | # build tag dict 64 | tag_dict = self.build_tag_vocab([train_data, valid_data]) 65 | 66 | # build char dict 67 | train_data = self.load_dataset( 68 | os.path.join(config["raw_path"], "train.json"), 69 | keep_number=True, 70 | lowercase=config["char_lowercase"]) 71 | valid_data = self.load_dataset( 72 | os.path.join(config["raw_path"], "valid.json"), 73 | keep_number=True, 74 | lowercase=config["char_lowercase"]) 75 | 76 | train_data = train_data[:config["data_size"]] 77 | valid_data = valid_data[:config["data_size"]] 78 | 79 | char_dict = self.build_char_vocab([train_data]) 80 | 81 | # create indices dataset 82 | # List[{'words': List[str], 'chars': List[List[str]], 'tags': List[str]}] 83 | train_set = self.build_dataset(train_data, word_dict, char_dict, tag_dict) 84 | valid_set = self.build_dataset(valid_data, word_dict, char_dict, tag_dict) 85 | vocab = {"word_dict": word_dict, 86 | "char_dict": char_dict, 87 | "tag_dict": tag_dict} 88 | 89 | print("Train Sents: %d" % len(train_set)) 90 | print("Valid Sents: %d" % len(valid_set)) 91 | 92 | # write to file 93 | write_json(os.path.join(config["save_path"], "vocab.json"), vocab) 94 | write_json(os.path.join(config["save_path"], "train.json"), train_set) 95 | write_json(os.path.join(config["save_path"], "valid.json"), valid_set) 96 | -------------------------------------------------------------------------------- /utils/preprocessors/span_preprocessors.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import os 8 | from collections import Counter 9 | import numpy as np 10 | 11 | from utils.common import write_json, load_json, UNK, word_convert 12 | from utils.preprocessors.base_preprocessors import Preprocessor 13 | 14 | 15 | class SpanPreprocessor(Preprocessor): 16 | 17 | def load_dataset(self, filename, keep_number=False, lowercase=True): 18 | dataset = [] 19 | for record in load_json(filename): 20 | words = [word_convert(word, keep_number=keep_number, lowercase=lowercase) 21 | for word in record["words"]] 22 | dataset.append({"sent_id": record["sent_id"], 23 | "words": words, 24 | "tags": record["spans"]}) 25 | return dataset 26 | 27 | @staticmethod 28 | def build_tag_vocab(datasets): 29 | tag_counter = Counter() 30 | for dataset in datasets: 31 | for record in dataset: 32 | for (tag, _, _) in record["tags"]: 33 | tag_counter[tag] += 1 34 | tag_vocab = ["O"] + [tag for tag, _ in tag_counter.most_common()] 35 | tag_dict = dict([(ner, idx) for idx, ner in enumerate(tag_vocab)]) 36 | return tag_dict 37 | 38 | @staticmethod 39 | def build_dataset(data, word_dict, char_dict, tag_dict): 40 | dataset = [] 41 | for record in data: 42 | chars_list = [] 43 | words = [] 44 | for word in record["words"]: 45 | chars = [char_dict[char] if char in char_dict else char_dict[UNK] for 46 | char in word] 47 | chars_list.append(chars) 48 | word = word_convert(word, keep_number=False, lowercase=True) 49 | words.append(word_dict[word] if word in word_dict else word_dict[UNK]) 50 | tags = [(tag_dict[tag], i, j) for (tag, i, j) in record["tags"]] 51 | dataset.append({"words": words, "chars": chars_list, "tags": tags}) 52 | return dataset 53 | 54 | def preprocess(self): 55 | config = self.config 56 | os.makedirs(config["save_path"], exist_ok=True) 57 | 58 | # List[{'words': List[str], 'tags': List[str]}] 59 | train_data = self.load_dataset( 60 | os.path.join(config["raw_path"], "train.json"), 61 | keep_number=False, 62 | lowercase=True) 63 | valid_data = self.load_dataset( 64 | os.path.join(config["raw_path"], "valid.json"), 65 | keep_number=False, 66 | lowercase=True) 67 | train_data = train_data[:config["data_size"]] 68 | valid_data = valid_data[:config["data_size"]] 69 | 70 | # build vocabulary 71 | if config["use_pretrained"]: 72 | glove_path = self.config["glove_path"].format(config["glove_name"], 73 | config["emb_dim"]) 74 | glove_vocab = self.load_glove_vocab(glove_path, config["glove_name"]) 75 | word_dict = self.build_word_vocab_pretrained([train_data, valid_data], 76 | glove_vocab) 77 | vectors = self.filter_glove_emb(word_dict, 78 | glove_path, 79 | config["glove_name"], 80 | config["emb_dim"]) 81 | np.savez_compressed(config["pretrained_emb"], embeddings=vectors) 82 | else: 83 | word_dict = self.build_word_vocab([train_data, valid_data]) 84 | 85 | # build tag dict 86 | tag_dict = self.build_tag_vocab([train_data, valid_data]) 87 | 88 | # build char dict 89 | train_data = self.load_dataset( 90 | os.path.join(config["raw_path"], "train.json"), 91 | keep_number=True, 92 | lowercase=config["char_lowercase"]) 93 | valid_data = self.load_dataset( 94 | os.path.join(config["raw_path"], "valid.json"), 95 | keep_number=True, 96 | lowercase=config["char_lowercase"]) 97 | 98 | train_data = train_data[:config["data_size"]] 99 | valid_data = valid_data[:config["data_size"]] 100 | 101 | if config["max_sent_len"] > 0: 102 | train_data = [record for record in train_data 103 | if len(record["words"]) <= config["max_sent_len"]] 104 | print( 105 | "Train Sents (remove max_sent_len: %d): %d" % (config["max_sent_len"], 106 | len(train_data))) 107 | 108 | char_dict = self.build_char_vocab([train_data]) 109 | 110 | # create indices dataset 111 | # List[{'words': List[str], 'chars': List[List[str]], 'tags': List[str]}] 112 | train_set = self.build_dataset(train_data, word_dict, char_dict, tag_dict) 113 | valid_set = self.build_dataset(valid_data, word_dict, char_dict, tag_dict) 114 | vocab = {"word_dict": word_dict, 115 | "char_dict": char_dict, 116 | "tag_dict": tag_dict} 117 | 118 | print("Train Sents: %d" % len(train_set)) 119 | print("Valid Sents: %d" % len(valid_set)) 120 | 121 | # write to file 122 | write_json(os.path.join(config["save_path"], "vocab.json"), vocab) 123 | write_json(os.path.join(config["save_path"], "train.json"), train_set) 124 | write_json(os.path.join(config["save_path"], "valid.json"), valid_set) 125 | --------------------------------------------------------------------------------