├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md ├── embedding │ └── .gitkeep ├── genia.dev.iob2 ├── genia.test.iob2 └── genia.train.iob2 ├── dataset.py ├── eval.py ├── model.py ├── train.py └── utils ├── json_util.py ├── path_util.py └── torch_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | 3 | # pycharm IDE files 4 | .idea/ 5 | 6 | !.gitkeep 7 | 8 | 9 | # tensorflow runs dir 10 | runs/ 11 | 12 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 13 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 14 | 15 | # User-specific stuff 16 | .idea/**/workspace.xml 17 | .idea/**/tasks.xml 18 | .idea/**/usage.statistics.xml 19 | .idea/**/dictionaries 20 | .idea/**/shelf 21 | 22 | # Sensitive or high-churn files 23 | .idea/**/dataSources/ 24 | .idea/**/dataSources.ids 25 | .idea/**/dataSources.local.xml 26 | .idea/**/sqlDataSources.xml 27 | .idea/**/dynamic.xml 28 | .idea/**/uiDesigner.xml 29 | .idea/**/dbnavigator.xml 30 | 31 | # Gradle 32 | .idea/**/gradle.xml 33 | .idea/**/libraries 34 | 35 | # Gradle and Maven with auto-import 36 | # When using Gradle or Maven with auto-import, you should exclude module files, 37 | # since they will be recreated, and may cause churn. Uncomment if using 38 | # auto-import. 39 | # .idea/modules.xml 40 | # .idea/*.iml 41 | # .idea/modules 42 | 43 | # CMake 44 | cmake-build-*/ 45 | 46 | # Mongo Explorer plugin 47 | .idea/**/mongoSettings.xml 48 | 49 | # File-based project format 50 | *.iws 51 | 52 | # IntelliJ 53 | out/ 54 | 55 | # mpeltonen/sbt-idea plugin 56 | .idea_modules/ 57 | 58 | # JIRA plugin 59 | atlassian-ide-plugin.xml 60 | 61 | # Cursive Clojure plugin 62 | .idea/replstate.xml 63 | 64 | # Crashlytics plugin (for Android Studio and IntelliJ) 65 | com_crashlytics_export_strings.xml 66 | crashlytics.properties 67 | crashlytics-build.properties 68 | fabric.properties 69 | 70 | # Editor-based Rest Client 71 | .idea/httpRequests 72 | 73 | # Byte-compiled / optimized / DLL files 74 | __pycache__/ 75 | *.py[cod] 76 | *$py.class 77 | 78 | # C extensions 79 | *.so 80 | 81 | # Distribution / packaging 82 | .Python 83 | build/ 84 | develop-eggs/ 85 | dist/ 86 | downloads/ 87 | eggs/ 88 | .eggs/ 89 | lib/ 90 | lib64/ 91 | parts/ 92 | sdist/ 93 | var/ 94 | wheels/ 95 | *.egg-info/ 96 | .installed.cfg 97 | *.egg 98 | MANIFEST 99 | 100 | # PyInstaller 101 | # Usually these files are written by a python script from a template 102 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 103 | *.manifest 104 | *.spec 105 | 106 | # Installer logs 107 | pip-log.txt 108 | pip-delete-this-directory.txt 109 | 110 | # Unit test / coverage reports 111 | htmlcov/ 112 | .tox/ 113 | .coverage 114 | .coverage.* 115 | .cache 116 | nosetests.xml 117 | coverage.xml 118 | *.cover 119 | .hypothesis/ 120 | .pytest_cache/ 121 | 122 | # Translations 123 | *.mo 124 | *.pot 125 | 126 | # Django stuff: 127 | *.log 128 | local_settings.py 129 | db.sqlite3 130 | 131 | # Flask stuff: 132 | instance/ 133 | .webassets-cache 134 | 135 | # Scrapy stuff: 136 | .scrapy 137 | 138 | # Sphinx documentation 139 | docs/_build/ 140 | 141 | # PyBuilder 142 | target/ 143 | 144 | # Jupyter Notebook 145 | .ipynb_checkpoints 146 | 147 | # pyenv 148 | .python-version 149 | 150 | # celery beat schedule file 151 | celerybeat-schedule 152 | 153 | # SageMath parsed files 154 | *.sage.py 155 | 156 | # Environments 157 | .env 158 | .venv 159 | env/ 160 | venv/ 161 | ENV/ 162 | env.bak/ 163 | venv.bak/ 164 | 165 | # Spyder project settings 166 | .spyderproject 167 | .spyproject 168 | 169 | # Rope project settings 170 | .ropeproject 171 | 172 | # mkdocs documentation 173 | /site 174 | 175 | # mypy 176 | .mypy_cache/ 177 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Deng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Implementation of Deep Exhaustive Model for Nested NER 2 | 3 | Original paper: [Sohrab, M. G., & Miwa, M. (2018). Deep Exhaustive Model for Nested Named Entity Recognition. In 2018 EMNLP](http://aclweb.org/anthology/D18-1309) 4 | 5 | # Requirements 6 | * `python 3.6.7` 7 | * `pytorch 1.0.0` 8 | * `numpy 1.15.3` 9 | * `gensim 3.6.0` 10 | * `scikit-learn 0.20.0` 11 | * `joblib 0.12.5` 12 | 13 | # Data Format 14 | Our processed `GENIA` dataset is in `./data/`. 15 | 16 | The data format is the same as in [Neural Layered Model, Ju et al. 2018 NAACL](https://github.com/meizhiju/layered-bilstm-crf) 17 | >Each line has multiple columns separated by a tab key. 18 | >Each line contains 19 | >``` 20 | >word label1 label2 label3 ... labelN 21 | >``` 22 | >The number of labels (`N`) for each word is determined by the maximum nested level in the data set. `N=maximum nested level + 1` 23 | >Each sentence is separated by an empty line. 24 | >For example, for these two sentences, `John killed Mary's husband. He was arrested last night` , they contain four entities: John (`PER`), Mary(`PER`), Mary's husband(`PER`),He (`PER`). 25 | >The format for these two sentences is listed as following: 26 | >``` 27 | >John B-PER O O 28 | >killed O O O 29 | >Mary B-PER B-PER O 30 | >'s O I-PER O 31 | >husband O I-PER O 32 | >. O O O 33 | > 34 | >He B-PER O O 35 | >was O O O 36 | >arrested O O O 37 | >last O O O 38 | >night O O O 39 | >. O O O 40 | >``` 41 | 42 | # Pre-trained word embeddings 43 | * [Pre-trained word embeddings](https://drive.google.com/open?id=0BzMCqpcgEJgiUWs0ZnU0NlFTam8) used here is the same as in [Neural Layered Model](https://github.com/meizhiju/layered-bilstm-crf) 44 | 45 | # Setup 46 | Download pre-trained embedding above, unzip it, and place `PubMed-shuffle-win-30.bin` into `./data/embedding/` 47 | 48 | # Usage 49 | ## Training 50 | 51 | ```sh 52 | python3 train.py 53 | ``` 54 | trained best model will be saved at `./data/model/` 55 | ## Testing 56 | set `model_url` to the url of saved model in training in `main()` of `eval.py` 57 | ```sh 58 | python3 eval.py 59 | ``` 60 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # pretrained word embeddings 2 | 3 | * PubMed 4 | * Same as [layered lstm model](https://github.com/meizhiju/layered-bilstm-crf#pretrained-word-embeddings) 5 | * Download [here](https://drive.google.com/open?id=0BzMCqpcgEJgiUWs0ZnU0NlFTam8) 6 | * [Related paper](http://www.aclweb.org/anthology/W16-2922) 7 | 8 | -------------------------------------------------------------------------------- /data/embedding/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csJd/deep_exhaustive_model/59afe18864d86cdc72314e10fd8c9c72ae1ffa09/data/embedding/.gitkeep -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # created by deng on 2019-03-13 3 | 4 | import os 5 | import numpy as np 6 | import joblib 7 | import torch 8 | from collections import defaultdict 9 | from gensim.models import KeyedVectors 10 | from torch.utils.data import Dataset 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | from utils.path_util import from_project_root, dirname 14 | import utils.json_util as ju 15 | 16 | 17 | def gen_sentence_tensors(sentence_list, device, data_url): 18 | """ generate input tensors from sentence list 19 | 20 | Args: 21 | sentence_list: list of raw sentence 22 | device: torch device 23 | data_url: raw data url to locate the vocab url 24 | 25 | Returns: 26 | sentences, tensor 27 | sentence_lengths, tensor 28 | sentence_words, list of tensor 29 | sentence_word_lengths, list of tensor 30 | sentence_word_indices, list of tensor 31 | 32 | """ 33 | vocab = ju.load(dirname(data_url) + '/vocab.json') 34 | char_vocab = ju.load(dirname(data_url) + '/char_vocab.json') 35 | 36 | sentences = list() 37 | sentence_words = list() 38 | sentence_word_lengths = list() 39 | sentence_word_indices = list() 40 | 41 | unk_idx = 1 42 | for sent in sentence_list: 43 | # word to word id 44 | sentence = torch.LongTensor([vocab[word] if word in vocab else unk_idx 45 | for word in sent]).to(device) 46 | 47 | # char of word to char id 48 | words = list() 49 | for word in sent: 50 | words.append([char_vocab[ch] if ch in char_vocab else unk_idx 51 | for ch in word]) 52 | 53 | # save word lengths 54 | word_lengths = torch.LongTensor([len(word) for word in words]).to(device) 55 | 56 | # sorting lengths according to length 57 | word_lengths, word_indices = torch.sort(word_lengths, descending=True) 58 | 59 | # sorting word according word length 60 | words = np.array(words)[word_indices.cpu().numpy()] 61 | word_indices = word_indices.to(device) 62 | words = [torch.LongTensor(word).to(device) for word in words] 63 | 64 | # padding char tensor of words 65 | words = pad_sequence(words, batch_first=True).to(device) 66 | # (max_word_len, sent_len) 67 | 68 | sentences.append(sentence) 69 | sentence_words.append(words) 70 | sentence_word_lengths.append(word_lengths) 71 | sentence_word_indices.append(word_indices) 72 | 73 | # record sentence length and padding sentences 74 | sentence_lengths = [len(sentence) for sentence in sentences] 75 | # (batch_size) 76 | sentences = pad_sequence(sentences, batch_first=True).to(device) 77 | # (batch_size, max_sent_len) 78 | 79 | return sentences, sentence_lengths, sentence_words, sentence_word_lengths, sentence_word_indices 80 | 81 | 82 | class ExhaustiveDataset(Dataset): 83 | 84 | def __init__(self, data_url, device, max_region=10): 85 | super().__init__() 86 | self.x, self.y = load_raw_data(data_url) 87 | 88 | categories = set() 89 | for dic in self.y: 90 | categories = categories.union(dic.values()) 91 | self.categories = ['NA'] + sorted(categories) 92 | self.n_tags = len(self.categories) 93 | self.data_url = data_url 94 | self.max_region = max_region 95 | self.device = device 96 | 97 | def __getitem__(self, index): 98 | return self.x[index], self.y[index] 99 | 100 | def __len__(self): 101 | return len(self.x) 102 | 103 | def collate_func(self, data_list): 104 | data_list = sorted(data_list, key=lambda tup: len(tup[0]), reverse=True) 105 | sentence_list, records_list = zip(*data_list) # un zip 106 | max_sent_len = len(sentence_list[0]) 107 | sentence_tensors = gen_sentence_tensors(sentence_list, self.device, self.data_url) 108 | # (sentences, sentence_lengths, sentence_words, sentence_word_lengths, sentence_word_indices) 109 | 110 | region_labels = list() 111 | for records, length in zip(records_list, sentence_tensors[1]): 112 | labels = list() 113 | for region_size in range(1, self.max_region + 1): 114 | for start in range(0, max_sent_len - region_size + 1): 115 | if start + region_size > length: 116 | labels.append(self.n_tags) # for padding 117 | elif (start, start + region_size) in records: 118 | labels.append(self.categories.index(records[start, start + region_size])) 119 | else: 120 | labels.append(0) 121 | region_labels.append(labels) 122 | region_labels = torch.LongTensor(region_labels).to(self.device) 123 | # (batch_size, n_regions) 124 | 125 | return sentence_tensors, region_labels, records_list 126 | 127 | 128 | def gen_vocab_from_data(data_urls, pretrained_url, binary=True, update=False, min_count=1): 129 | """ generate vocabulary and embeddings from data file, generated vocab files will be saved in 130 | data dir 131 | 132 | Args: 133 | data_urls: url to data file(s), list or string 134 | pretrained_url: url to pretrained embedding file 135 | binary: binary for load word2vec 136 | update: force to update even vocab file exists 137 | min_count: minimum count of a word 138 | 139 | Returns: 140 | generated word embedding url 141 | """ 142 | 143 | if isinstance(data_urls, str): 144 | data_urls = [data_urls] 145 | data_dir = os.path.dirname(data_urls[0]) 146 | vocab_url = os.path.join(data_dir, "vocab.json") 147 | char_vocab_url = os.path.join(data_dir, "char_vocab.json") 148 | embedding_url = os.path.join(data_dir, "embeddings.npy") if pretrained_url else None 149 | 150 | if (not update) and os.path.exists(vocab_url): 151 | print("vocab file already exists") 152 | return embedding_url 153 | 154 | vocab = set() 155 | char_vocab = set() 156 | word_counts = defaultdict(int) 157 | print("generating vocab from", data_urls) 158 | for data_url in data_urls: 159 | with open(data_url, 'r', encoding='utf-8') as data_file: 160 | for row in data_file: 161 | if row == '\n': 162 | continue 163 | token = row.split()[0] 164 | word_counts[token] += 1 165 | if word_counts[token] > min_count: 166 | vocab.add(row.split()[0]) 167 | char_vocab = char_vocab.union(row.split()[0]) 168 | 169 | # sorting vocab according alphabet order 170 | vocab = sorted(vocab) 171 | char_vocab = sorted(char_vocab) 172 | 173 | # generate word embeddings for vocab 174 | if pretrained_url is not None: 175 | print("generating pre-trained embedding from", pretrained_url) 176 | kvs = KeyedVectors.load_word2vec_format(pretrained_url, binary=binary) 177 | embeddings = list() 178 | for word in vocab: 179 | if word in kvs: 180 | embeddings.append(kvs[word]) 181 | else: 182 | embeddings.append(np.random.uniform(-0.25, 0.25, kvs.vector_size)), 183 | 184 | char_vocab = [''] + char_vocab 185 | vocab = ['', ''] + vocab 186 | ju.dump(ju.list_to_dict(vocab), vocab_url) 187 | ju.dump(ju.list_to_dict(char_vocab), char_vocab_url) 188 | 189 | if pretrained_url is None: 190 | return 191 | embeddings = np.vstack([np.zeros(kvs.vector_size), # for 192 | np.random.uniform(-0.25, 0.25, kvs.vector_size), # for 193 | embeddings]) 194 | np.save(embedding_url, embeddings) 195 | return embedding_url 196 | 197 | 198 | def infer_records(columns): 199 | """ inferring all entity records of a sentence 200 | Args: 201 | columns: columns of a sentence in iob2 format 202 | Returns: 203 | entity record in gave sentence 204 | """ 205 | records = dict() 206 | for col in columns: 207 | start = 0 208 | while start < len(col): 209 | end = start + 1 210 | if col[start][0] == 'B': 211 | while end < len(col) and col[end][0] == 'I': 212 | end += 1 213 | records[(start, end)] = col[start][2:] 214 | start = end 215 | return records 216 | 217 | 218 | def load_raw_data(data_url, update=False): 219 | """ load data into sentences and records 220 | 221 | Args: 222 | data_url: url to data file 223 | update: whether force to update 224 | Returns: 225 | sentences(raw), records 226 | """ 227 | 228 | # load from pickle 229 | save_url = data_url.replace('.bio', '.raw.pkl').replace('.iob2', '.raw.pkl') 230 | if not update and os.path.exists(save_url): 231 | return joblib.load(save_url) 232 | 233 | sentences = list() 234 | records = list() 235 | with open(data_url, 'r', encoding='utf-8') as iob_file: 236 | first_line = iob_file.readline() 237 | n_columns = first_line.count('\t') 238 | # JNLPBA dataset don't contains the extra 'O' column 239 | if 'jnlpba' in data_url: 240 | n_columns += 1 241 | columns = [[x] for x in first_line.split()] 242 | for line in iob_file: 243 | if line != '\n': 244 | line_values = line.split() 245 | for i in range(n_columns): 246 | columns[i].append(line_values[i]) 247 | 248 | else: # end of a sentence 249 | sentence = columns[0] 250 | sentences.append(sentence) 251 | records.append(infer_records(columns[1:])) 252 | columns = [list() for i in range(n_columns)] 253 | joblib.dump((sentences, records), save_url) 254 | return sentences, records 255 | 256 | 257 | def prepare_vocab(data_urls, pretrained_url, update=True, min_count=1): 258 | """ prepare vocab and embedding 259 | 260 | Args: 261 | data_urls: urls to data file for preparing vocab 262 | pretrained_url: url to pretrained embedding file 263 | min_count: minimum count of word 264 | update: force to update 265 | 266 | """ 267 | binary = pretrained_url and pretrained_url.endswith('.bin') 268 | return gen_vocab_from_data(data_urls, pretrained_url, binary=binary, update=update, min_count=min_count) 269 | 270 | 271 | def main(): 272 | data_urls = [from_project_root("data/genia.train.iob2"), 273 | from_project_root("data/genia.dev.iob2"), 274 | from_project_root("data/genia.test.iob2")] 275 | prepare_vocab(data_urls, update=True, min_count=1) 276 | pass 277 | 278 | 279 | if __name__ == '__main__': 280 | main() 281 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # created by deng on 2019-02-13 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from sklearn.metrics import classification_report 7 | 8 | from dataset import ExhaustiveDataset, gen_sentence_tensors 9 | from utils.torch_util import calc_f1 10 | from utils.path_util import from_project_root 11 | 12 | 13 | def evaluate(model, data_url): 14 | """ eval model on specific dataset 15 | 16 | Args: 17 | model: model to evaluate 18 | data_url: url to data for evaluating 19 | 20 | Returns: 21 | metrics on dataset 22 | 23 | """ 24 | print("\nEvaluating model use data from ", data_url, "\n") 25 | max_region = model.max_region 26 | dataset = ExhaustiveDataset(data_url, next(model.parameters()).device, max_region=max_region) 27 | data_loader = DataLoader(dataset, batch_size=100, collate_fn=dataset.collate_func) 28 | # switch to eval mode 29 | model.eval() 30 | 31 | region_true_list = list() 32 | region_pred_list = list() 33 | region_true_count = 0 34 | region_pred_count = 0 35 | 36 | with torch.no_grad(): 37 | for data, labels, records_list in data_loader: 38 | batch_region_labels = torch.argmax(model.forward(*data), dim=1).cpu() 39 | lengths = data[1] 40 | batch_maxlen = lengths[0] 41 | for region_labels, length, true_records in zip(batch_region_labels, lengths, records_list): 42 | pred_records = {} 43 | ind = 0 44 | for region_size in range(1, max_region + 1): 45 | for start in range(0, batch_maxlen - region_size + 1): 46 | end = start + region_size 47 | if 0 < region_labels[ind] < dataset.n_tags and end <= length: 48 | pred_records[(start, start + region_size)] = region_labels[ind] 49 | ind += 1 50 | 51 | region_true_count += len(true_records) 52 | region_pred_count += len(pred_records) 53 | 54 | for region in true_records: 55 | true_label = dataset.categories.index(true_records[region]) 56 | pred_label = pred_records[region] if region in pred_records else 0 57 | region_true_list.append(true_label) 58 | region_pred_list.append(pred_label) 59 | for region in pred_records: 60 | if region not in true_records: 61 | region_pred_list.append(pred_records[region]) 62 | region_true_list.append(0) 63 | 64 | print(classification_report(region_true_list, region_pred_list, 65 | target_names=dataset.categories, digits=6)) 66 | 67 | ret = dict() 68 | tp = 0 69 | for pv, tv in zip(region_pred_list, region_true_list): 70 | if pv == tv: 71 | tp += 1 72 | fp = region_pred_count - tp 73 | fn = region_true_count - tp 74 | 75 | ret['precision'], ret['recall'], ret['f1'] = calc_f1(tp, fp, fn) 76 | return ret 77 | 78 | 79 | def predict(model, sentences, categories, data_url): 80 | """ predict NER result for sentence list 81 | 82 | Args: 83 | model: trained model 84 | sentences: sentences to be predicted 85 | categories: category list to transform id into category 86 | data_url: data_url to locate vocab files, `vocab.json` and `char_vocab.json` should be in the folder of data_url 87 | 88 | Returns: 89 | predicted results [ {(start, end): type, }, ] 90 | 91 | """ 92 | max_region = model.max_region 93 | device = next(model.parameters()).device 94 | tensors = gen_sentence_tensors( 95 | sentences, device, data_url) 96 | pred_regions_list = torch.argmax(model.forward(*tensors), dim=1).cpu() 97 | 98 | lengths = tensors[1] 99 | pred_sentence_records = [] 100 | for pred_regions, length in zip(pred_regions_list, lengths): 101 | pred_records = {} 102 | ind = 0 103 | for region_size in range(1, max_region + 1): 104 | for start in range(0, lengths[0] - region_size + 1): 105 | if 0 < pred_regions[ind] < len(categories): 106 | pred_records[(start, start + region_size)] = \ 107 | categories[pred_regions[ind]] 108 | ind += 1 109 | pred_sentence_records.append(pred_records) 110 | return pred_sentence_records 111 | 112 | 113 | def predict_on_iob2(model, iob_url): 114 | """ predict on iob2 file and save the results 115 | 116 | Args: 117 | model: trained model 118 | iob_url: url to iob file 119 | 120 | """ 121 | 122 | save_url = iob_url.replace('.iob2', '.pred.txt') 123 | print("predicting on {} \n the result will be saved in {}".format( 124 | iob_url, save_url)) 125 | test_set = ExhaustiveDataset(iob_url, device=next( 126 | model.parameters()).device) 127 | 128 | model.eval() 129 | with open(save_url, 'w', encoding='utf-8', newline='\n') as save_file: 130 | for sentence, records in test_set: 131 | save_file.write(' '.join(sentence) + '\n') 132 | save_file.write("length = {} \n".format(len(sentence))) 133 | save_file.write("Gold: {}\n".format(str(records))) 134 | pred_result = str(predict(model, [sentence], test_set.categories, iob_url)[0]) 135 | save_file.write("Pred: {}\n\n".format(pred_result)) 136 | 137 | 138 | def main(): 139 | model_url = from_project_root("data/model/model.pt") 140 | print("loading model from", model_url) 141 | # model = torch.load(model_url, map_location='cpu') 142 | model = torch.load(model_url) 143 | test_url = from_project_root("data/genia.test.iob2") 144 | evaluate(model, test_url) 145 | # predict_on_iob2(model, test_url) 146 | pass 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # created by deng on 2018-12-31 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class ExhaustiveModel(nn.Module): 10 | 11 | def __init__(self, hidden_size, n_tags, max_region, embedding_url=None, bidirectional=True, lstm_layers=1, 12 | n_embeddings=None, embedding_dim=None, freeze=False, char_feat_dim=100, n_chars = 100): 13 | super().__init__() 14 | 15 | if embedding_url: 16 | self.embedding = nn.Embedding.from_pretrained( 17 | embeddings=torch.Tensor(np.load(embedding_url)), 18 | freeze=freeze 19 | ) 20 | else: 21 | self.embedding = nn.Embedding(n_embeddings, embedding_dim, padding_idx=0) 22 | 23 | self.embedding_dim = self.embedding.embedding_dim 24 | self.char_feat_dim = char_feat_dim 25 | self.word_repr_dim = self.embedding_dim + self.char_feat_dim 26 | 27 | self.char_repr = CharLSTM( 28 | n_chars=n_chars, 29 | embedding_size=char_feat_dim // 2, 30 | hidden_size=char_feat_dim // 2, 31 | ) if char_feat_dim > 0 else None 32 | 33 | self.dropout = nn.Dropout(p=0.5) 34 | 35 | self.lstm = nn.LSTM( 36 | input_size=self.word_repr_dim, 37 | hidden_size=hidden_size, 38 | bidirectional=bidirectional, 39 | num_layers=lstm_layers, 40 | batch_first=True 41 | ) 42 | 43 | self.lstm_layers = lstm_layers 44 | self.n_tags = n_tags 45 | self.max_region = max_region 46 | self.n_hidden = (1 + bidirectional) * hidden_size 47 | 48 | self.region_clf = nn.Sequential( 49 | nn.ReLU(), 50 | nn.Linear(3 * self.n_hidden, n_tags), 51 | # nn.Softmax(), 52 | ) 53 | 54 | def forward(self, sentences, sentence_lengths, sentence_words, sentence_word_lengths, 55 | sentence_word_indices): 56 | 57 | # sentences (batch_size, max_sent_len) 58 | # sentence_length (batch_size) 59 | word_repr = self.embedding(sentences) 60 | # word_feat shape: (batch_size, max_sent_len, embedding_dim) 61 | 62 | # add character level feature 63 | if self.char_feat_dim > 0: 64 | # sentence_words (batch_size, *sent_len, max_word_len) 65 | # sentence_word_lengths (batch_size, *sent_len) 66 | # sentence_word_indices (batch_size, *sent_len, max_word_len) 67 | # char level feature 68 | char_feat = self.char_repr(sentence_words, sentence_word_lengths, sentence_word_indices) 69 | # char_feat shape: (batch_size, max_sent_len, char_feat_dim) 70 | 71 | # concatenate char level representation and word level one 72 | word_repr = torch.cat([word_repr, char_feat], dim=-1) 73 | # word_repr shape: (batch_size, max_sent_len, word_repr_dim) 74 | 75 | # word_repr = self.dropout(word_repr) 76 | 77 | packed = nn.utils.rnn.pack_padded_sequence(word_repr, sentence_lengths, batch_first=True) 78 | out, (hn, _) = self.lstm(packed) 79 | 80 | max_sent_len = sentences.shape[1] 81 | unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, total_length=max_sent_len, batch_first=True) 82 | # unpacked (batch_size, max_sent_len, n_hidden) 83 | unpacked = unpacked.transpose(0, 1) 84 | # unpacked (max_sent_len, batch_size, n_hidden) 85 | # shape of hn: (n_layers * n_directions, batch_size, hidden_size) 86 | 87 | max_len = sentence_lengths[0] 88 | regions = list() 89 | for region_size in range(1, self.max_region + 1): 90 | for start in range(0, max_len - region_size + 1): 91 | end = start + region_size 92 | regions.append(torch.cat([unpacked[start], torch.mean(unpacked[start:end], dim=0), 93 | unpacked[end - 1]], dim=-1)) 94 | # shape of each region: (batch_size, 3 * n_hidden) 95 | output = torch.stack([self.region_clf(region) for region in regions], dim=-1) 96 | # shape of each region_clf output: (batch_size, n_classes) 97 | # shape of output: (batch_size, n_classes, n_regions) 98 | return output 99 | 100 | 101 | class CharLSTM(nn.Module): 102 | 103 | def __init__(self, n_chars, embedding_size, hidden_size, lstm_layers=1, bidirectional=True): 104 | super().__init__() 105 | self.n_chars = n_chars 106 | self.embedding_size = embedding_size 107 | self.n_hidden = hidden_size * (1 + bidirectional) 108 | 109 | self.embedding = nn.Embedding(n_chars, embedding_size, padding_idx=0) 110 | 111 | self.lstm = nn.LSTM( 112 | input_size=embedding_size, 113 | hidden_size=hidden_size, 114 | bidirectional=bidirectional, 115 | num_layers=lstm_layers, 116 | batch_first=True, 117 | ) 118 | 119 | def sent_forward(self, words, lengths, indices): 120 | sent_len = words.shape[0] 121 | # words shape: (sent_len, max_word_len) 122 | 123 | embedded = self.embedding(words) 124 | # in_data shape: (sent_len, max_word_len, embedding_dim) 125 | 126 | packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths, batch_first=True) 127 | _, (hn, _) = self.lstm(packed) 128 | # shape of hn: (n_layers * n_directions, sent_len, hidden_size) 129 | 130 | hn = hn.permute(1, 0, 2).contiguous().view(sent_len, -1) 131 | # shape of hn: (sent_len, n_layers * n_directions * hidden_size) = (sent_len, 2*hidden_size) 132 | 133 | # shape of indices: (sent_len, max_word_len) 134 | hn[indices] = hn # unsort hn 135 | # unsorted = hn.new_empty(hn.size()) 136 | # unsorted.scatter_(dim=0, index=indices.unsqueeze(-1).expand_as(hn), src=hn) 137 | return hn 138 | 139 | def forward(self, sentence_words, sentence_word_lengths, sentence_word_indices): 140 | # sentence_words [batch_size, *sent_len, max_word_len] 141 | # sentence_word_lengths [batch_size, *sent_len] 142 | # sentence_word_indices [batch_size, *sent_len, max_word_len] 143 | 144 | batch_size = len(sentence_words) 145 | batch_char_feat = torch.nn.utils.rnn.pad_sequence( 146 | [self.sent_forward(sentence_words[i], sentence_word_lengths[i], sentence_word_indices[i]) 147 | for i in range(batch_size)], batch_first=True) 148 | 149 | return batch_char_feat 150 | # (batch_size, sent_len, 2 * hidden_size) 151 | 152 | 153 | def main(): 154 | pass 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # created by deng on 2019-02-13 3 | 4 | import sys 5 | import os 6 | import torch 7 | import json 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from datetime import datetime 11 | 12 | import utils.json_util as ju 13 | from utils.path_util import from_project_root, exists 14 | from utils.torch_util import set_random_seed, get_device 15 | from dataset import prepare_vocab 16 | from dataset import ExhaustiveDataset 17 | from model import ExhaustiveModel 18 | from eval import evaluate 19 | 20 | RANDOM_SEED = 233 21 | set_random_seed(RANDOM_SEED) 22 | 23 | # EMBD_URL = None # fot not use pretrained embedding 24 | EMBD_URL = from_project_root("data/embedding/PubMed-shuffle-win-30.bin") 25 | VOCAB_URL = from_project_root("data/vocab.json") 26 | TRAIN_URL = from_project_root("data/genia.train.iob2") 27 | DEV_URL = from_project_root("data/genia.dev.iob2") 28 | TEST_URL = from_project_root("data/genia.test.iob2") 29 | 30 | LOG_PER_BATCH = 20 31 | 32 | 33 | def train(n_epochs=30, 34 | embedding_url=None, 35 | char_feat_dim=50, 36 | freeze=False, 37 | train_url=TRAIN_URL, 38 | dev_url=DEV_URL, 39 | test_url=None, 40 | max_region=10, 41 | learning_rate=0.001, 42 | batch_size=100, 43 | early_stop=5, 44 | clip_norm=5, 45 | device='auto', 46 | save_only_best = True 47 | ): 48 | """ Train deep exhaustive model, Sohrab et al. 2018 EMNLP 49 | 50 | Args: 51 | n_epochs: number of epochs 52 | embedding_url: url to pretrained embedding file, set as None to use random embedding 53 | char_feat_dim: size of character level feature 54 | freeze: whether to freeze embedding 55 | train_url: url to train data 56 | dev_url: url to dev data 57 | test_url: url to test data for evaluating, set to None for not evaluating 58 | max_region: max entity region size 59 | learning_rate: learning rate 60 | batch_size: batch_size 61 | early_stop: early stop for training 62 | clip_norm: whether to perform norm clipping, set to 0 if not need 63 | device: device for torch 64 | save_only_best: only save model of best performance 65 | """ 66 | 67 | # print arguments 68 | arguments = json.dumps(vars(), indent=2) 69 | print("exhaustive model is training with arguments", arguments) 70 | device = get_device(device) 71 | 72 | train_set = ExhaustiveDataset(train_url, device=device, max_region=max_region) 73 | train_loader = DataLoader(train_set, batch_size=batch_size, drop_last=False, 74 | collate_fn=train_set.collate_func) 75 | 76 | vocab = ju.load(VOCAB_URL) 77 | n_words = len(vocab) 78 | char_vocab = ju.load(VOCAB_URL.replace('vocab', 'char_vocab')) 79 | n_chars = len(char_vocab) 80 | 81 | model = ExhaustiveModel( 82 | hidden_size=200, 83 | n_tags=train_set.n_tags + 1, 84 | char_feat_dim=char_feat_dim, 85 | embedding_url=embedding_url, 86 | bidirectional=True, 87 | max_region=max_region, 88 | n_embeddings=n_words, 89 | n_chars = n_chars, 90 | embedding_dim=200, 91 | freeze=freeze 92 | ) 93 | 94 | if device.type == 'cuda': 95 | print("using gpu,", torch.cuda.device_count(), "gpu(s) available!\n") 96 | # model = nn.DataParallel(model) 97 | else: 98 | print("using cpu\n") 99 | model = model.to(device) 100 | 101 | criterion = F.cross_entropy 102 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 103 | 104 | max_f1, max_f1_epoch, cnt = 0, 0, 0 105 | # ignore the padding part when calcuting loss 106 | tag_weights = torch.Tensor([1] * train_set.n_tags + [0]).to(device) 107 | best_model_url = None 108 | 109 | # train and evaluate model 110 | for epoch in range(n_epochs): 111 | # switch to train mode 112 | model.train() 113 | batch_id = 0 114 | for data, labels, _ in train_loader: 115 | optimizer.zero_grad() 116 | outputs = model.forward(*data) 117 | # use weight parameter to skip padding part 118 | loss = criterion(outputs, labels, weight=tag_weights) 119 | loss.backward() 120 | # gradient clipping 121 | if clip_norm > 0: 122 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm) 123 | optimizer.step() 124 | 125 | endl = '\n' if batch_id % LOG_PER_BATCH == 0 else '\r' 126 | sys.stdout.write("epoch #%d, batch #%d, loss: %.6f, %s%s" % 127 | (epoch, batch_id, loss.item(), datetime.now().strftime("%X"), endl)) 128 | sys.stdout.flush() 129 | batch_id += 1 130 | 131 | cnt += 1 132 | # metrics on development set 133 | dev_metrics = evaluate(model, dev_url) 134 | if dev_metrics['f1'] > max_f1: 135 | max_f1 = dev_metrics['f1'] 136 | max_f1_epoch = epoch 137 | if save_only_best and best_model_url: 138 | os.remove(best_model_url) 139 | best_model_url = from_project_root( 140 | "data/model/exhaustive_model_epoch%d_%f.pt" % (epoch, max_f1)) 141 | torch.save(model, best_model_url) 142 | cnt = 0 143 | 144 | print("maximum of f1 value: %.6f, in epoch #%d\n" % (max_f1, max_f1_epoch)) 145 | if cnt >= early_stop > 0: 146 | break 147 | print('\n') 148 | 149 | if test_url and best_model_url: 150 | model = torch.load(best_model_url) 151 | print("best model url:", best_model_url) 152 | print("evaluating on test dataset:", test_url) 153 | evaluate(model, test_url) 154 | 155 | print(arguments) 156 | 157 | 158 | def main(): 159 | start_time = datetime.now() 160 | embedding_url = prepare_vocab([TRAIN_URL, DEV_URL, TEST_URL], 161 | EMBD_URL, update=False, min_count=0) 162 | train(test_url=TEST_URL, embedding_url=embedding_url) 163 | print("finished in:", datetime.now() - start_time) 164 | pass 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /utils/json_util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # created by deng on 7/24/2018 3 | 4 | import json 5 | 6 | 7 | def load(json_url): 8 | """ load python object form json file 9 | 10 | Args: 11 | json_url: url to json file 12 | 13 | Returns: 14 | python object 15 | 16 | """ 17 | with open(json_url, "r", encoding="utf-8") as json_file: 18 | obj = json.load(json_file) 19 | return obj 20 | 21 | 22 | def dump(obj, json_url): 23 | """ dump python object into json file 24 | 25 | Args: 26 | obj: python object, more information here 27 | https://docs.python.org/2/library/json.html#encoders-and-decoders 28 | json_url: url to save json file 29 | 30 | """ 31 | with open(json_url, "w", encoding="utf-8", newline='\n') as json_file: 32 | json.dump(obj, json_file, separators=[',', ': '], indent=4, ensure_ascii=False) 33 | 34 | 35 | def sort_dict_by_value(dic, reverse=False): 36 | """ sort a dict by value 37 | 38 | Args: 39 | dic: the dict to be sorted 40 | reverse: reverse order or not 41 | 42 | Returns: 43 | sorted dict 44 | 45 | """ 46 | return dict(sorted(dic.items(), key=lambda x: x[1], reverse=reverse)) 47 | 48 | 49 | def list_to_dict(lis): 50 | """ transform list into value-index dict 51 | 52 | Args: 53 | lis: list 54 | 55 | Returns: 56 | value-index dict for lis 57 | 58 | """ 59 | dic = dict() 60 | for ind, value in enumerate(lis): 61 | dic[value] = ind 62 | return dic 63 | 64 | 65 | def main(): 66 | pass 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /utils/path_util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # created by deng on 7/23/2018 3 | 4 | from os.path import dirname, join, normpath, exists 5 | from os import makedirs 6 | import time 7 | 8 | # to get the absolute path of current project 9 | project_root_url = normpath(join(dirname(__file__), '..')) 10 | 11 | 12 | def from_project_root(rel_path, create=True): 13 | """ return system absolute path according to relative path, if path dirs not exists and create is True, 14 | required folders will be created 15 | 16 | Args: 17 | rel_path: relative path 18 | create: whether to create folds not exists 19 | 20 | Returns: 21 | str: absolute path 22 | 23 | """ 24 | abs_path = normpath(join(project_root_url, rel_path)) 25 | if create and not exists(dirname(abs_path)): 26 | makedirs(dirname(abs_path)) 27 | return abs_path 28 | 29 | 30 | def date_suffix(file_type=""): 31 | """ return the current date suffix,like '180723.csv' 32 | 33 | Args: 34 | file_type: file type suffix, like '.csv 35 | 36 | Returns: 37 | str: date suffix 38 | 39 | """ 40 | suffix = time.strftime("%y%m%d", time.localtime()) 41 | suffix += file_type 42 | return suffix 43 | 44 | 45 | def main(): 46 | """ for test """ 47 | print(project_root_url) 48 | print(from_project_root('.gitignore')) 49 | print(from_project_root('data/test.py', create=False)) 50 | print(date_suffix('.csv')) 51 | print(date_suffix("")) 52 | pass 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /utils/torch_util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # created by deng on 2019-03-13 3 | 4 | import numpy 5 | import torch 6 | import random 7 | 8 | 9 | def set_random_seed(seed): 10 | """ set random seed for numpy and torch, more information here: 11 | https://pytorch.org/docs/stable/notes/randomness.html 12 | Args: 13 | seed: the random seed to set 14 | """ 15 | torch.manual_seed(seed) 16 | numpy.random.seed(seed) 17 | random.seed(seed) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | def get_device(name='auto'): 23 | """ choose device 24 | 25 | Returns: 26 | the device specified by name, if name is None, proper device will be returned 27 | 28 | """ 29 | if name == 'auto': 30 | return torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | return torch.device(name) 32 | 33 | 34 | def calc_f1(tp, fp, fn, print_result=True): 35 | """ calculating f1 36 | 37 | Args: 38 | tp: true positive 39 | fp: false positive 40 | fn: false negative 41 | print_result: whether to print result 42 | 43 | Returns: 44 | precision, recall, f1 45 | 46 | """ 47 | precision = 0 if tp + fp == 0 else tp / (tp + fp) 48 | recall = 0 if tp + fn == 0 else tp / (tp + fn) 49 | f1 = 0 if precision + recall == 0 else 2 * precision * recall / (precision + recall) 50 | if print_result: 51 | print(" precision = %f, recall = %f, micro_f1 = %f\n" % (precision, recall, f1)) 52 | return precision, recall, f1 53 | 54 | 55 | def main(): 56 | pass 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | --------------------------------------------------------------------------------