├── .gitignore ├── LICENSE ├── README.md ├── censor.py ├── exps ├── censor_cls.ipynb ├── conll2003 BERTBiLSTMAttnCRF base BERT.ipynb ├── conll2003 BERTBiLSTMAttnNCRF base BERT.ipynb ├── conll2003 BERTBiLSTMCRF base BERT.ipynb ├── conll2003 BERTBiLSTMCRF.ipynb ├── fre BERTAttnCRF.ipynb ├── fre BERTBiLSTMAttnCRF-fit_BERT.ipynb ├── fre BERTBiLSTMAttnCRF.ipynb ├── fre BERTBiLSTMAttnNCRF-fit_BERT.ipynb ├── fre BERTBiLSTMAttnNCRF.ipynb ├── fre BERTBiLSTMCRF.ipynb ├── fre BERTBiLSTMNCRF.ipynb ├── fre BERTCRF.ipynb ├── fre BERTNCRF.ipynb └── prc fre.ipynb ├── modules ├── __init__.py ├── analyze_utils │ ├── __init__.py │ ├── main_metrics.py │ ├── plot_metrics.py │ └── utils.py ├── data │ ├── __init__.py │ ├── bert_data.py │ ├── bert_data_clf.py │ ├── conll2003 │ │ ├── __init__.py │ │ └── prc.py │ ├── download_data.py │ └── fre │ │ ├── __init__.py │ │ ├── bilou │ │ ├── __init__.py │ │ ├── from_bilou.py │ │ └── to_bilou.py │ │ ├── entity │ │ ├── __init__.py │ │ ├── document.py │ │ ├── taggedtoken.py │ │ └── token.py │ │ ├── prc.py │ │ ├── reader.py │ │ └── utils.py ├── layers │ ├── __init__.py │ ├── crf.py │ ├── decoders.py │ ├── embedders.py │ ├── layers.py │ └── ncrf.py ├── models │ ├── __init__.py │ ├── bert_models.py │ └── classifiers.py ├── train │ ├── __init__.py │ ├── optimization.py │ ├── train.py │ └── train_clf.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sberbank AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 0. Papers 2 | There are two solutions based on this architecture. 3 | 1. [BSNLP 2019 ACL workshop](http://bsnlp.cs.helsinki.fi/shared_task.html): [solution](https://github.com/king-menin/slavic-ner) and [paper](https://arxiv.org/abs/1906.09978) on multilingual shared task. 4 | 2. The second place [solution](https://github.com/king-menin/AGRR-2019) of [Dialogue AGRR-2019](https://github.com/dialogue-evaluation/AGRR-2019) task and [paper](http://www.dialog-21.ru/media/4679/emelyanov-artemova-gapping_parsing_using_pretrained_embeddings__attention_mechanisn_and_ncrf.pdf). 5 | 6 | ## Description 7 | This repository contains solution of NER task based on PyTorch [reimplementation](https://github.com/huggingface/pytorch-pretrained-BERT) of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 8 | 9 | This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)). 10 | 11 | Old version is in "old" branch. 12 | 13 | ## 2. Usage 14 | ### 2.1 Create data 15 | ``` 16 | from modules.data import bert_data 17 | data = bert_data.LearnData.create( 18 | train_df_path=train_df_path, 19 | valid_df_path=valid_df_path, 20 | idx2labels_path="/path/to/vocab", 21 | clear_cache=True 22 | ) 23 | ``` 24 | 25 | ### 2.2 Create model 26 | ``` 27 | from modules.models.bert_models import BERTBiLSTMAttnCRF 28 | model = BERTBiLSTMAttnCRF.create(len(data.train_ds.idx2label)) 29 | ``` 30 | 31 | ### 2.3 Create Learner 32 | ``` 33 | from modules.train.train import NerLearner 34 | num_epochs = 100 35 | learner = NerLearner( 36 | model, data, "/path/for/save/best/model", t_total=num_epochs * len(data.train_dl)) 37 | ``` 38 | 39 | ### 2.4 Predict 40 | ``` 41 | from modules.data.bert_data import get_data_loader_for_predict 42 | learner.load_model() 43 | dl = get_data_loader_for_predict(data, df_path="/path/to/df/for/predict") 44 | preds = learner.predict(dl) 45 | ``` 46 | 47 | ### 2.5 Evaluate 48 | ``` 49 | from sklearn_crfsuite.metrics import flat_classification_report 50 | from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer 51 | from modules.analyze_utils.plot_metrics import get_bert_span_report 52 | from modules.analyze_utils.main_metrics import precision_recall_f1 53 | 54 | 55 | pred_tokens, pred_labels = bert_labels2tokens(dl, preds) 56 | true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset]) 57 | tokens_report = flat_classification_report(true_labels, pred_labels, digits=4) 58 | print(tokens_report) 59 | 60 | results = precision_recall_f1(true_labels, pred_labels) 61 | ``` 62 | 63 | ## 3. Results 64 | We didn't search best parametres and obtained the following results. 65 | 66 | | Model | Data set | Dev F1 tok | Dev F1 span | Test F1 tok | Test F1 span 67 | |-|-|-|-|-|-| 68 | |**OURS**|||||| 69 | | M-BERTCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8543 | 0.8409 70 | | M-BERTNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8637 | 0.8516 71 | | M-BERTBiLSTMCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8835 | **0.8718** 72 | | M-BERTBiLSTMNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8632 | 0.8510 73 | | M-BERTAttnCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8503 | 0.8346 74 | | M-BERTBiLSTMAttnCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | **0.8839** | 0.8716 75 | | M-BERTBiLSTMAttnNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8807 | 0.8680 76 | | M-BERTBiLSTMAttnCRF-fit_BERT-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8823 | 0.8709 77 | | M-BERTBiLSTMAttnNCRF-fit_BERT-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8583 | 0.8456 78 | |-|-|-|-|-|-| 79 | | BERTBiLSTMCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9629 | - | 0.9221 | - 80 | | B-BERTBiLSTMCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9635 | - | 0.9229 | - 81 | | B-BERTBiLSTMAttnCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9614 | - | 0.9237 | - 82 | | B-BERTBiLSTMAttnNCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9631 | - | **0.9249** | - 83 | |**Current SOTA**|||||| 84 | | DeepPavlov-RuBERT-NER | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | - | **0.8266** 85 | | CSE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | - | - | **0.931** | - 86 | | BERT-LARGE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.966 | - | 0.928 | - 87 | | BERT-BASE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.964 | - | 0.924 | - 88 | -------------------------------------------------------------------------------- /censor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from modules.data import bert_data_clf 4 | from modules.models.classifiers import BERTBiLSTMAttnClassifier 5 | from modules.train.train_clf import NerLearner 6 | 7 | 8 | warnings.filterwarnings("ignore") 9 | sys.path.append("../") 10 | 11 | 12 | def main(): 13 | train_df_path = "/home/ubuntu/censor/train2.csv" 14 | valid_df_path = "/home/ubuntu/censor/dev2.csv" 15 | test_df_path = "/home/ubuntu/censor/test.csv" 16 | num_epochs = 100 17 | 18 | 19 | data = bert_data_clf.LearnDataClass.create( 20 | train_df_path=train_df_path, 21 | valid_df_path=valid_df_path, 22 | idx2cls_path="/home/ubuntu/censor/idx2cls.txt", 23 | clear_cache=False, 24 | batch_size=64 25 | ) 26 | 27 | model = BERTBiLSTMAttnClassifier.create(len(data.train_ds.cls2idx), hidden_dim=768) 28 | learner = NerLearner( 29 | model, data, "/home/ubuntu/censor/cls.cpt4", t_total=num_epochs * len(data.train_dl)) 30 | learner.fit(epochs=num_epochs) 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /exps/conll2003 BERTBiLSTMCRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "\n", 13 | "import sys\n", 14 | "import warnings\n", 15 | "\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append(\"../\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 21, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from modules.data.conll2003.prc import conll2003_preprocess" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 22, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "data_dir = \"/home/eartemov/ae/work/conll2003/\"" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 13, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "application/vnd.jupyter.widget-view+json": { 47 | "model_id": "f14d2a1ce44947ce98b9f430cf82caf1", 48 | "version_major": 2, 49 | "version_minor": 0 50 | }, 51 | "text/plain": [ 52 | "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.train', max=2195…" 53 | ] 54 | }, 55 | "metadata": {}, 56 | "output_type": "display_data" 57 | }, 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "\n" 63 | ] 64 | }, 65 | { 66 | "data": { 67 | "application/vnd.jupyter.widget-view+json": { 68 | "model_id": "fa6fddbba84a4de78a48e7503af8d616", 69 | "version_major": 2, 70 | "version_minor": 0 71 | }, 72 | "text/plain": [ 73 | "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.testa', max=5504…" 74 | ] 75 | }, 76 | "metadata": {}, 77 | "output_type": "display_data" 78 | }, 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "\n" 84 | ] 85 | }, 86 | { 87 | "data": { 88 | "application/vnd.jupyter.widget-view+json": { 89 | "model_id": "e3c3e6e2ebdb4e51ba4051a1499e53cf", 90 | "version_major": 2, 91 | "version_minor": 0 92 | }, 93 | "text/plain": [ 94 | "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.testb', max=5035…" 95 | ] 96 | }, 97 | "metadata": {}, 98 | "output_type": "display_data" 99 | }, 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "conll2003_preprocess(data_dir)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "## IO markup" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "### Train" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 3, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "from modules.data import bert_data" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stderr", 142 | "output_type": "stream", 143 | "text": [ 144 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 145 | ] 146 | }, 147 | { 148 | "data": { 149 | "application/vnd.jupyter.widget-view+json": { 150 | "model_id": "", 151 | "version_major": 2, 152 | "version_minor": 0 153 | }, 154 | "text/plain": [ 155 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=6973, style=ProgressStyle(descri…" 156 | ] 157 | }, 158 | "metadata": {}, 159 | "output_type": "display_data" 160 | }, 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "\r" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "data = bert_data.LearnData.create(\n", 171 | " train_df_path=\"/home/eartemov/ae/work/conll2003/eng.train.train.csv\",\n", 172 | " valid_df_path=\"/home/eartemov/ae/work/conll2003/eng.testa.dev.csv\",\n", 173 | " idx2labels_path=\"/home/eartemov/ae/work/conll2003/idx2labels.txt\",\n", 174 | " clear_cache=True\n", 175 | ")" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 5, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "from modules.models.bert_models import BERTBiLSTMCRF" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "model = BERTBiLSTMCRF.create(\n", 194 | " len(data.train_ds.idx2label),\n", 195 | " lstm_dropout=0., crf_dropout=0.3)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 7, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "from modules.train.train import NerLearner" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 8, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "num_epochs = 100" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 9, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "learner = NerLearner(\n", 223 | " model, data, \"/home/eartemov/ae/work/models/conll2003-BERTBiLSTMCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 10, 229 | "metadata": { 230 | "scrolled": true 231 | }, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "2235023" 237 | ] 238 | }, 239 | "execution_count": 10, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "model.get_n_trainable_params()" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": { 252 | "scrolled": true 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "learner.fit(epochs=num_epochs)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "### Predict" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 12, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "from modules.data.bert_data import get_data_loader_for_predict" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 13, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 14, 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "data": { 291 | "application/vnd.jupyter.widget-view+json": { 292 | "model_id": "", 293 | "version_major": 2, 294 | "version_minor": 0 295 | }, 296 | "text/plain": [ 297 | "HBox(children=(IntProgress(value=0, description='Predicting', max=109, style=ProgressStyle(description_width='…" 298 | ] 299 | }, 300 | "metadata": {}, 301 | "output_type": "display_data" 302 | }, 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "\r" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "preds = learner.predict(dl)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 15, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "from sklearn_crfsuite.metrics import flat_classification_report" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 16, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 331 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 17, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 341 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 18, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "assert pred_tokens == true_tokens\n", 351 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[4:], digits=4)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 20, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | " precision recall f1-score support\n", 364 | "\n", 365 | " I_ORG 0.9514 0.9509 0.9511 2016\n", 366 | " I_O 0.9968 0.9970 0.9969 41702\n", 367 | " I_MISC 0.9353 0.8974 0.9160 1257\n", 368 | " I_PER 0.9849 0.9825 0.9837 2856\n", 369 | " I_LOC 0.9697 0.9637 0.9667 1926\n", 370 | "\n", 371 | " micro avg 0.9917 0.9905 0.9911 49757\n", 372 | " macro avg 0.9676 0.9583 0.9629 49757\n", 373 | "weighted avg 0.9916 0.9905 0.9910 49757\n", 374 | "\n" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "print(tokens_report)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "### Test" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 12, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "from modules.data.bert_data import get_data_loader_for_predict" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 24, 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "dl = get_data_loader_for_predict(data, df_path=\"/home/eartemov/ae/work/conll2003/eng.testb.dev.csv\")" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 25, 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "data": { 414 | "application/vnd.jupyter.widget-view+json": { 415 | "model_id": "", 416 | "version_major": 2, 417 | "version_minor": 0 418 | }, 419 | "text/plain": [ 420 | "HBox(children=(IntProgress(value=0, description='Predicting', max=98, style=ProgressStyle(description_width='i…" 421 | ] 422 | }, 423 | "metadata": {}, 424 | "output_type": "display_data" 425 | }, 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "\r" 431 | ] 432 | } 433 | ], 434 | "source": [ 435 | "preds = learner.predict(dl)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 26, 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "from sklearn_crfsuite.metrics import flat_classification_report" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 27, 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [ 453 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 454 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 28, 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [ 463 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 464 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 29, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "assert pred_tokens == true_tokens\n", 474 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[4:], digits=4)" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 30, 480 | "metadata": {}, 481 | "outputs": [ 482 | { 483 | "name": "stdout", 484 | "output_type": "stream", 485 | "text": [ 486 | " precision recall f1-score support\n", 487 | "\n", 488 | " I_ORG 0.8988 0.9147 0.9067 2368\n", 489 | " I_O 0.9952 0.9917 0.9934 37573\n", 490 | " I_MISC 0.8163 0.8055 0.8108 910\n", 491 | " I_PER 0.9759 0.9770 0.9765 2698\n", 492 | " I_LOC 0.9170 0.9296 0.9233 1819\n", 493 | "\n", 494 | " micro avg 0.9822 0.9806 0.9814 45368\n", 495 | " macro avg 0.9206 0.9237 0.9221 45368\n", 496 | "weighted avg 0.9823 0.9806 0.9814 45368\n", 497 | "\n" 498 | ] 499 | } 500 | ], 501 | "source": [ 502 | "print(tokens_report)" 503 | ] 504 | } 505 | ], 506 | "metadata": { 507 | "kernelspec": { 508 | "display_name": "Python 3", 509 | "language": "python", 510 | "name": "python3" 511 | }, 512 | "language_info": { 513 | "codemirror_mode": { 514 | "name": "ipython", 515 | "version": 3 516 | }, 517 | "file_extension": ".py", 518 | "mimetype": "text/x-python", 519 | "name": "python", 520 | "nbconvert_exporter": "python", 521 | "pygments_lexer": "ipython3", 522 | "version": "3.6.8" 523 | } 524 | }, 525 | "nbformat": 4, 526 | "nbformat_minor": 2 527 | } 528 | -------------------------------------------------------------------------------- /exps/fre BERTAttnCRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "\n", 13 | "import sys\n", 14 | "import warnings\n", 15 | "\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append(\"../\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## IO markup" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Train" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from modules.data import bert_data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stderr", 61 | "output_type": "stream", 62 | "text": [ 63 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 64 | ] 65 | }, 66 | { 67 | "data": { 68 | "application/vnd.jupyter.widget-view+json": { 69 | "model_id": "", 70 | "version_major": 2, 71 | "version_minor": 0 72 | }, 73 | "text/plain": [ 74 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…" 75 | ] 76 | }, 77 | "metadata": {}, 78 | "output_type": "display_data" 79 | }, 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "\r" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "data = bert_data.LearnData.create(\n", 90 | " train_df_path=train_df_path,\n", 91 | " valid_df_path=valid_df_path,\n", 92 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels2.txt\",\n", 93 | " clear_cache=True\n", 94 | ")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "from modules.models.bert_models import BERTAttnCRF" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "model = BERTAttnCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 7, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "from modules.train.train import NerLearner" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 8, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "num_epochs = 100" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 9, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "learner = NerLearner(\n", 140 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTAttnCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 10, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "text/plain": [ 151 | "890617" 152 | ] 153 | }, 154 | "execution_count": 10, 155 | "metadata": {}, 156 | "output_type": "execute_result" 157 | } 158 | ], 159 | "source": [ 160 | "model.get_n_trainable_params()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "scrolled": true 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "learner.fit(epochs=num_epochs)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "### Predict" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 30, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "from modules.data.bert_data import get_data_loader_for_predict" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 31, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 32, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "application/vnd.jupyter.widget-view+json": { 207 | "model_id": "", 208 | "version_major": 2, 209 | "version_minor": 0 210 | }, 211 | "text/plain": [ 212 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…" 213 | ] 214 | }, 215 | "metadata": {}, 216 | "output_type": "display_data" 217 | }, 218 | { 219 | "name": "stdout", 220 | "output_type": "stream", 221 | "text": [ 222 | "\r" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "preds = learner.predict(dl)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 33, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "from sklearn_crfsuite.metrics import flat_classification_report" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 34, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 246 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 35, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 256 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 36, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "assert pred_tokens == true_tokens\n", 266 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 37, 272 | "metadata": {}, 273 | "outputs": [ 274 | { 275 | "name": "stdout", 276 | "output_type": "stream", 277 | "text": [ 278 | " precision recall f1-score support\n", 279 | "\n", 280 | " I_ORG 0.8019 0.7415 0.7705 3865\n", 281 | " I_PER 0.9374 0.9569 0.9470 2112\n", 282 | " I_LOC 0.9007 0.7752 0.8333 1557\n", 283 | "\n", 284 | " micro avg 0.8620 0.8089 0.8346 7534\n", 285 | " macro avg 0.8800 0.8245 0.8503 7534\n", 286 | "weighted avg 0.8603 0.8089 0.8330 7534\n", 287 | "\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "print(tokens_report)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 38, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "from modules.analyze_utils.main_metrics import precision_recall_f1" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 39, 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "processed 56409 tokens with 7534 phrases; found: 7070 phrases; correct: 6094.\n", 314 | "\n", 315 | "precision: 86.20%; recall: 80.89%; FB1: 83.46\n", 316 | "\n", 317 | "\tLOC: precision: 90.07%; recall: 77.52%; F1: 83.33 1340\n", 318 | "\n", 319 | "\tORG: precision: 80.19%; recall: 74.15%; F1: 77.05 3574\n", 320 | "\n", 321 | "\tPER: precision: 93.74%; recall: 95.69%; F1: 94.70 2156\n", 322 | "\n", 323 | "\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "results = precision_recall_f1(true_labels, pred_labels)" 329 | ] 330 | } 331 | ], 332 | "metadata": { 333 | "kernelspec": { 334 | "display_name": "Python 3", 335 | "language": "python", 336 | "name": "python3" 337 | }, 338 | "language_info": { 339 | "codemirror_mode": { 340 | "name": "ipython", 341 | "version": 3 342 | }, 343 | "file_extension": ".py", 344 | "mimetype": "text/x-python", 345 | "name": "python", 346 | "nbconvert_exporter": "python", 347 | "pygments_lexer": "ipython3", 348 | "version": "3.6.8" 349 | } 350 | }, 351 | "nbformat": 4, 352 | "nbformat_minor": 2 353 | } 354 | -------------------------------------------------------------------------------- /exps/fre BERTBiLSTMAttnCRF-fit_BERT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "\n", 13 | "import sys\n", 14 | "import warnings\n", 15 | "\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append(\"../\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## IO markup" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Train" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from modules.data import bert_data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "device = \"cuda:0\"" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stderr", 70 | "output_type": "stream", 71 | "text": [ 72 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 73 | ] 74 | }, 75 | { 76 | "data": { 77 | "application/vnd.jupyter.widget-view+json": { 78 | "model_id": "", 79 | "version_major": 2, 80 | "version_minor": 0 81 | }, 82 | "text/plain": [ 83 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…" 84 | ] 85 | }, 86 | "metadata": {}, 87 | "output_type": "display_data" 88 | }, 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "\r" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "data = bert_data.LearnData.create(\n", 99 | " train_df_path=train_df_path,\n", 100 | " valid_df_path=valid_df_path,\n", 101 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels5.txt\",\n", 102 | " clear_cache=True,\n", 103 | " batch_size=8,\n", 104 | " device=device\n", 105 | ")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "from modules.models.bert_models import BERTBiLSTMAttnCRF" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 7, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "model = BERTBiLSTMAttnCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3, is_freeze=False, device=device)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 8, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "from modules.train.train import NerLearner" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 9, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "num_epochs = 100" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 11, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "learner = NerLearner(\n", 151 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMAttnCRF-fit_BERT-IO.cpt\",\n", 152 | " t_total=num_epochs * len(data.train_dl), lr=0.0001)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 12, 158 | "metadata": { 159 | "scrolled": true 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "180482937" 166 | ] 167 | }, 168 | "execution_count": 12, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | } 172 | ], 173 | "source": [ 174 | "model.get_n_trainable_params()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": { 181 | "scrolled": true 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "learner.fit(epochs=num_epochs)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "### Predict" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 25, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "from modules.data.bert_data import get_data_loader_for_predict" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 26, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 27, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "data": { 220 | "application/vnd.jupyter.widget-view+json": { 221 | "model_id": "", 222 | "version_major": 2, 223 | "version_minor": 0 224 | }, 225 | "text/plain": [ 226 | "HBox(children=(IntProgress(value=0, description='Predicting', max=340, style=ProgressStyle(description_width='…" 227 | ] 228 | }, 229 | "metadata": {}, 230 | "output_type": "display_data" 231 | } 232 | ], 233 | "source": [ 234 | "preds = learner.predict(dl)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 28, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "from sklearn_crfsuite.metrics import flat_classification_report" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 29, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 253 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 30, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 263 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 31, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "assert pred_tokens == true_tokens\n", 273 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 32, 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | " precision recall f1-score support\n", 286 | "\n", 287 | " I_ORG 0.8334 0.8191 0.8262 3865\n", 288 | " I_PER 0.9145 0.9825 0.9473 2112\n", 289 | " I_LOC 0.9342 0.8202 0.8735 1557\n", 290 | "\n", 291 | " micro avg 0.8767 0.8651 0.8709 7534\n", 292 | " macro avg 0.8940 0.8739 0.8823 7534\n", 293 | "weighted avg 0.8769 0.8651 0.8699 7534\n", 294 | "\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "print(tokens_report)" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 33, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "from modules.analyze_utils.main_metrics import precision_recall_f1" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 34, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "name": "stdout", 318 | "output_type": "stream", 319 | "text": [ 320 | "processed 56409 tokens with 7534 phrases; found: 7435 phrases; correct: 6518.\n", 321 | "\n", 322 | "precision: 87.67%; recall: 86.51%; FB1: 87.09\n", 323 | "\n", 324 | "\tLOC: precision: 93.42%; recall: 82.02%; F1: 87.35 1367\n", 325 | "\n", 326 | "\tORG: precision: 83.34%; recall: 81.91%; F1: 82.62 3799\n", 327 | "\n", 328 | "\tPER: precision: 91.45%; recall: 98.25%; F1: 94.73 2269\n", 329 | "\n", 330 | "\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "results = precision_recall_f1(true_labels, pred_labels)" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [] 344 | } 345 | ], 346 | "metadata": { 347 | "kernelspec": { 348 | "display_name": "Python 3", 349 | "language": "python", 350 | "name": "python3" 351 | }, 352 | "language_info": { 353 | "codemirror_mode": { 354 | "name": "ipython", 355 | "version": 3 356 | }, 357 | "file_extension": ".py", 358 | "mimetype": "text/x-python", 359 | "name": "python", 360 | "nbconvert_exporter": "python", 361 | "pygments_lexer": "ipython3", 362 | "version": "3.6.8" 363 | } 364 | }, 365 | "nbformat": 4, 366 | "nbformat_minor": 2 367 | } 368 | -------------------------------------------------------------------------------- /exps/fre BERTBiLSTMAttnCRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "\n", 13 | "import sys\n", 14 | "import warnings\n", 15 | "\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append(\"../\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## IO markup" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Train" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from modules.data import bert_data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stderr", 61 | "output_type": "stream", 62 | "text": [ 63 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 64 | ] 65 | }, 66 | { 67 | "data": { 68 | "application/vnd.jupyter.widget-view+json": { 69 | "model_id": "", 70 | "version_major": 2, 71 | "version_minor": 0 72 | }, 73 | "text/plain": [ 74 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…" 75 | ] 76 | }, 77 | "metadata": {}, 78 | "output_type": "display_data" 79 | }, 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "\r" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "data = bert_data.LearnData.create(\n", 90 | " train_df_path=train_df_path,\n", 91 | " valid_df_path=valid_df_path,\n", 92 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels4.txt\",\n", 93 | " clear_cache=True\n", 94 | ")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "from modules.models.bert_models import BERTBiLSTMAttnCRF" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "model = BERTBiLSTMAttnCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 7, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "from modules.train.train import NerLearner" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 8, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "num_epochs = 100" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 9, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "learner = NerLearner(\n", 140 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMAttnCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 10, 146 | "metadata": { 147 | "scrolled": true 148 | }, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "2629497" 154 | ] 155 | }, 156 | "execution_count": 10, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "model.get_n_trainable_params()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": { 169 | "scrolled": true 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "learner.fit(epochs=num_epochs)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "### Predict" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 12, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "from modules.data.bert_data import get_data_loader_for_predict" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 13, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 23, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "application/vnd.jupyter.widget-view+json": { 209 | "model_id": "", 210 | "version_major": 2, 211 | "version_minor": 0 212 | }, 213 | "text/plain": [ 214 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…" 215 | ] 216 | }, 217 | "metadata": {}, 218 | "output_type": "display_data" 219 | }, 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "\r" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "preds = learner.predict(dl)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 24, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "from sklearn_crfsuite.metrics import flat_classification_report" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 25, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 248 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 26, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 258 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 27, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "assert pred_tokens == true_tokens\n", 268 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 28, 274 | "metadata": { 275 | "scrolled": true 276 | }, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | " precision recall f1-score support\n", 283 | "\n", 284 | " I_ORG 0.8639 0.7803 0.8200 3865\n", 285 | " I_PER 0.9535 0.9706 0.9620 2112\n", 286 | " I_LOC 0.9066 0.8356 0.8697 1557\n", 287 | "\n", 288 | " micro avg 0.8998 0.8451 0.8716 7534\n", 289 | " macro avg 0.9080 0.8622 0.8839 7534\n", 290 | "weighted avg 0.8979 0.8451 0.8701 7534\n", 291 | "\n" 292 | ] 293 | } 294 | ], 295 | "source": [ 296 | "print(tokens_report)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 29, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "from modules.analyze_utils.main_metrics import precision_recall_f1" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 30, 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "name": "stdout", 315 | "output_type": "stream", 316 | "text": [ 317 | "processed 56409 tokens with 7534 phrases; found: 7076 phrases; correct: 6367.\n", 318 | "\n", 319 | "precision: 89.98%; recall: 84.51%; FB1: 87.16\n", 320 | "\n", 321 | "\tLOC: precision: 90.66%; recall: 83.56%; F1: 86.97 1435\n", 322 | "\n", 323 | "\tORG: precision: 86.39%; recall: 78.03%; F1: 82.00 3491\n", 324 | "\n", 325 | "\tPER: precision: 95.35%; recall: 97.06%; F1: 96.20 2150\n", 326 | "\n", 327 | "\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "results = precision_recall_f1(true_labels, pred_labels)" 333 | ] 334 | } 335 | ], 336 | "metadata": { 337 | "kernelspec": { 338 | "display_name": "Python 3", 339 | "language": "python", 340 | "name": "python3" 341 | }, 342 | "language_info": { 343 | "codemirror_mode": { 344 | "name": "ipython", 345 | "version": 3 346 | }, 347 | "file_extension": ".py", 348 | "mimetype": "text/x-python", 349 | "name": "python", 350 | "nbconvert_exporter": "python", 351 | "pygments_lexer": "ipython3", 352 | "version": "3.6.8" 353 | } 354 | }, 355 | "nbformat": 4, 356 | "nbformat_minor": 2 357 | } 358 | -------------------------------------------------------------------------------- /exps/fre BERTBiLSTMAttnNCRF-fit_BERT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "\n", 13 | "import sys\n", 14 | "import warnings\n", 15 | "\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append(\"../\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## IO markup" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Train" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from modules.data import bert_data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "device = \"cuda:1\"" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stderr", 70 | "output_type": "stream", 71 | "text": [ 72 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 73 | ] 74 | }, 75 | { 76 | "data": { 77 | "application/vnd.jupyter.widget-view+json": { 78 | "model_id": "", 79 | "version_major": 2, 80 | "version_minor": 0 81 | }, 82 | "text/plain": [ 83 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…" 84 | ] 85 | }, 86 | "metadata": {}, 87 | "output_type": "display_data" 88 | }, 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "\r" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "data = bert_data.LearnData.create(\n", 99 | " train_df_path=train_df_path,\n", 100 | " valid_df_path=valid_df_path,\n", 101 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels2.txt\",\n", 102 | " clear_cache=True,\n", 103 | " batch_size=8,\n", 104 | " device=device\n", 105 | ")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "from modules.models.bert_models import BERTBiLSTMAttnNCRF" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 8, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "build CRF...\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "model = BERTBiLSTMAttnNCRF.create(\n", 132 | " len(data.train_ds.idx2label), crf_dropout=0.3, nbest=len(data.train_ds.label2idx), is_freeze=False, hidden_dim=256,\n", 133 | " device=device)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 9, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "from modules.train.train import NerLearner" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 10, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "num_epochs = 100" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 14, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "learner = NerLearner(\n", 161 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMAttnNCRF-fit_BERT-IO.cpt\",\n", 162 | " t_total=num_epochs * len(data.train_dl), lr=0.00001)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 12, 168 | "metadata": { 169 | "scrolled": true 170 | }, 171 | "outputs": [ 172 | { 173 | "data": { 174 | "text/plain": [ 175 | "179004667" 176 | ] 177 | }, 178 | "execution_count": 12, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "model.get_n_trainable_params()" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": { 191 | "scrolled": true 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "learner.fit(epochs=num_epochs)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "### Predict" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 27, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "from modules.data.bert_data import get_data_loader_for_predict" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 28, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 29, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "data": { 230 | "application/vnd.jupyter.widget-view+json": { 231 | "model_id": "", 232 | "version_major": 2, 233 | "version_minor": 0 234 | }, 235 | "text/plain": [ 236 | "HBox(children=(IntProgress(value=0, description='Predicting', max=340, style=ProgressStyle(description_width='…" 237 | ] 238 | }, 239 | "metadata": {}, 240 | "output_type": "display_data" 241 | } 242 | ], 243 | "source": [ 244 | "preds = learner.predict(dl)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 30, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "from sklearn_crfsuite.metrics import flat_classification_report" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 31, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 263 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 32, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 273 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 33, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "assert pred_tokens == true_tokens\n", 283 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 34, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | " precision recall f1-score support\n", 296 | "\n", 297 | " I_ORG 0.8761 0.7224 0.7918 3865\n", 298 | " I_PER 0.9207 0.9342 0.9274 2112\n", 299 | " I_LOC 0.8767 0.8356 0.8556 1557\n", 300 | "\n", 301 | " micro avg 0.8902 0.8051 0.8456 7534\n", 302 | " macro avg 0.8911 0.8307 0.8583 7534\n", 303 | "weighted avg 0.8887 0.8051 0.8430 7534\n", 304 | "\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "print(tokens_report)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 35, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "from modules.analyze_utils.main_metrics import precision_recall_f1" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 36, 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "processed 56409 tokens with 7534 phrases; found: 6814 phrases; correct: 6066.\n", 331 | "\n", 332 | "precision: 89.02%; recall: 80.51%; FB1: 84.56\n", 333 | "\n", 334 | "\tLOC: precision: 87.67%; recall: 83.56%; F1: 85.56 1484\n", 335 | "\n", 336 | "\tORG: precision: 87.61%; recall: 72.24%; F1: 79.18 3187\n", 337 | "\n", 338 | "\tPER: precision: 92.07%; recall: 93.42%; F1: 92.74 2143\n", 339 | "\n", 340 | "\n" 341 | ] 342 | } 343 | ], 344 | "source": [ 345 | "results = precision_recall_f1(true_labels, pred_labels)" 346 | ] 347 | } 348 | ], 349 | "metadata": { 350 | "kernelspec": { 351 | "display_name": "Python 3", 352 | "language": "python", 353 | "name": "python3" 354 | }, 355 | "language_info": { 356 | "codemirror_mode": { 357 | "name": "ipython", 358 | "version": 3 359 | }, 360 | "file_extension": ".py", 361 | "mimetype": "text/x-python", 362 | "name": "python", 363 | "nbconvert_exporter": "python", 364 | "pygments_lexer": "ipython3", 365 | "version": "3.6.8" 366 | } 367 | }, 368 | "nbformat": 4, 369 | "nbformat_minor": 2 370 | } 371 | -------------------------------------------------------------------------------- /exps/fre BERTBiLSTMCRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "\n", 13 | "import sys\n", 14 | "import warnings\n", 15 | "\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append(\"../\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## IO markup" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Train" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from modules.data import bert_data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 5, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stderr", 61 | "output_type": "stream", 62 | "text": [ 63 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 64 | ] 65 | }, 66 | { 67 | "data": { 68 | "application/vnd.jupyter.widget-view+json": { 69 | "model_id": "", 70 | "version_major": 2, 71 | "version_minor": 0 72 | }, 73 | "text/plain": [ 74 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…" 75 | ] 76 | }, 77 | "metadata": {}, 78 | "output_type": "display_data" 79 | }, 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "\r" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 90 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"\n", 91 | "data = bert_data.LearnData.create(\n", 92 | " train_df_path=train_df_path,\n", 93 | " valid_df_path=valid_df_path,\n", 94 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels4.txt\",\n", 95 | " clear_cache=True\n", 96 | ")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 6, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "from modules.models.bert_models import BERTBiLSTMCRF" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 7, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "model = BERTBiLSTMCRF.create(len(data.train_ds.idx2label), lstm_dropout=0., crf_dropout=0.3)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 8, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "from modules.train.train import NerLearner" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 9, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "num_epochs = 100" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 12, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "learner = NerLearner(\n", 142 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 13, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "2234745" 154 | ] 155 | }, 156 | "execution_count": 13, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "model.get_n_trainable_params()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": { 169 | "scrolled": true 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "learner.fit(epochs=num_epochs)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "### Predict" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 14, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "from modules.data.bert_data import get_data_loader_for_predict" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 15, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 16, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "application/vnd.jupyter.widget-view+json": { 209 | "model_id": "", 210 | "version_major": 2, 211 | "version_minor": 0 212 | }, 213 | "text/plain": [ 214 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…" 215 | ] 216 | }, 217 | "metadata": {}, 218 | "output_type": "display_data" 219 | }, 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "\r" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "preds = learner.predict(dl)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 17, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "from sklearn_crfsuite.metrics import flat_classification_report" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 18, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 248 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 19, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 258 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 20, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "assert pred_tokens == true_tokens\n", 268 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 21, 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | " precision recall f1-score support\n", 281 | "\n", 282 | " I_ORG 0.8579 0.7917 0.8235 3865\n", 283 | " I_PER 0.9510 0.9659 0.9584 2112\n", 284 | " I_LOC 0.9053 0.8349 0.8687 1557\n", 285 | "\n", 286 | " micro avg 0.8954 0.8495 0.8718 7534\n", 287 | " macro avg 0.9047 0.8642 0.8835 7534\n", 288 | "weighted avg 0.8938 0.8495 0.8706 7534\n", 289 | "\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "print(tokens_report)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 26, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "from modules.analyze_utils.main_metrics import precision_recall_f1" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 27, 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "processed 56409 tokens with 7534 phrases; found: 7148 phrases; correct: 6400.\n", 316 | "\n", 317 | "precision: 89.54%; recall: 84.95%; FB1: 87.18\n", 318 | "\n", 319 | "\tLOC: precision: 90.53%; recall: 83.49%; F1: 86.87 1436\n", 320 | "\n", 321 | "\tORG: precision: 85.79%; recall: 79.17%; F1: 82.35 3567\n", 322 | "\n", 323 | "\tPER: precision: 95.10%; recall: 96.59%; F1: 95.84 2145\n", 324 | "\n", 325 | "\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "results = precision_recall_f1(true_labels, pred_labels)" 331 | ] 332 | } 333 | ], 334 | "metadata": { 335 | "kernelspec": { 336 | "display_name": "Python 3", 337 | "language": "python", 338 | "name": "python3" 339 | }, 340 | "language_info": { 341 | "codemirror_mode": { 342 | "name": "ipython", 343 | "version": 3 344 | }, 345 | "file_extension": ".py", 346 | "mimetype": "text/x-python", 347 | "name": "python", 348 | "nbconvert_exporter": "python", 349 | "pygments_lexer": "ipython3", 350 | "version": "3.6.8" 351 | } 352 | }, 353 | "nbformat": 4, 354 | "nbformat_minor": 2 355 | } 356 | -------------------------------------------------------------------------------- /exps/fre BERTBiLSTMNCRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "\n", 13 | "import sys\n", 14 | "import warnings\n", 15 | "\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append(\"../\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## IO markup" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Train" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from modules.data import bert_data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 6, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "device = \"cuda:2\"" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 7, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stderr", 60 | "output_type": "stream", 61 | "text": [ 62 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 63 | ] 64 | }, 65 | { 66 | "data": { 67 | "application/vnd.jupyter.widget-view+json": { 68 | "model_id": "", 69 | "version_major": 2, 70 | "version_minor": 0 71 | }, 72 | "text/plain": [ 73 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…" 74 | ] 75 | }, 76 | "metadata": {}, 77 | "output_type": "display_data" 78 | }, 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "\r" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "data = bert_data.LearnData.create(\n", 89 | " train_df_path=\"/home/eartemov/ae/work/factRuEval-2016/dev.csv\",\n", 90 | " valid_df_path=\"/home/eartemov/ae/work/factRuEval-2016/test.csv\",\n", 91 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels5.txt\",\n", 92 | " clear_cache=True,\n", 93 | " device=device\n", 94 | ")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 4, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "from modules.models.bert_models import BERTBiLSTMNCRF" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 8, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "build CRF...\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "model = BERTBiLSTMNCRF.create(\n", 121 | " len(data.train_ds.idx2label), lstm_dropout=0., crf_dropout=0.3, nbest=len(data.train_ds.idx2label)-1, device=device)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 9, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "from modules.train.train import NerLearner" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 10, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "num_epochs = 100" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 11, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "learner = NerLearner(\n", 149 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMNCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 12, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "2235259" 161 | ] 162 | }, 163 | "execution_count": 12, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "model.get_n_trainable_params()" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "scrolled": true 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "learner.fit(epochs=num_epochs)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "### Eval" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 22, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "from modules.data.bert_data import get_data_loader_for_predict" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 23, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 24, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "application/vnd.jupyter.widget-view+json": { 216 | "model_id": "", 217 | "version_major": 2, 218 | "version_minor": 0 219 | }, 220 | "text/plain": [ 221 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…" 222 | ] 223 | }, 224 | "metadata": {}, 225 | "output_type": "display_data" 226 | }, 227 | { 228 | "name": "stdout", 229 | "output_type": "stream", 230 | "text": [ 231 | "\r" 232 | ] 233 | } 234 | ], 235 | "source": [ 236 | "preds = learner.predict(dl)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 25, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "from sklearn_crfsuite.metrics import flat_classification_report" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 26, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 255 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 27, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 265 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 28, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "assert pred_tokens == true_tokens\n", 275 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=learner.sup_labels, digits=4)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 29, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | " precision recall f1-score support\n", 288 | "\n", 289 | " I_O 0.9777 0.9886 0.9831 48875\n", 290 | " I_LOC 0.8996 0.7996 0.8467 1557\n", 291 | " I_PER 0.9373 0.9626 0.9498 2112\n", 292 | " I_ORG 0.8709 0.7281 0.7931 3865\n", 293 | "\n", 294 | " micro avg 0.9681 0.9646 0.9663 56409\n", 295 | " macro avg 0.9214 0.8697 0.8932 56409\n", 296 | "weighted avg 0.9667 0.9646 0.9651 56409\n", 297 | "\n" 298 | ] 299 | } 300 | ], 301 | "source": [ 302 | "print(tokens_report)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 30, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "from modules.analyze_utils.main_metrics import precision_recall_f1" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 32, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "processed 56409 tokens with 7534 phrases; found: 6784 phrases; correct: 6092.\n", 324 | "\n", 325 | "precision: 89.80%; recall: 80.86%; FB1: 85.10\n", 326 | "\n", 327 | "\tLOC: precision: 89.96%; recall: 79.96%; F1: 84.67 1384\n", 328 | "\n", 329 | "\tORG: precision: 87.09%; recall: 72.81%; F1: 79.31 3231\n", 330 | "\n", 331 | "\tPER: precision: 93.73%; recall: 96.26%; F1: 94.98 2169\n", 332 | "\n", 333 | "\n" 334 | ] 335 | } 336 | ], 337 | "source": [ 338 | "results = precision_recall_f1(true_labels, pred_labels)" 339 | ] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "Python 3", 345 | "language": "python", 346 | "name": "python3" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.6.8" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 2 363 | } 364 | -------------------------------------------------------------------------------- /exps/fre BERTCRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 12, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "\n", 22 | "import sys\n", 23 | "import warnings\n", 24 | "\n", 25 | "\n", 26 | "warnings.filterwarnings(\"ignore\")\n", 27 | "sys.path.append(\"../\")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## IO markup" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### Train" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from modules.data import bert_data" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 60 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stderr", 70 | "output_type": "stream", 71 | "text": [ 72 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 73 | ] 74 | }, 75 | { 76 | "data": { 77 | "application/vnd.jupyter.widget-view+json": { 78 | "model_id": "", 79 | "version_major": 2, 80 | "version_minor": 0 81 | }, 82 | "text/plain": [ 83 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…" 84 | ] 85 | }, 86 | "metadata": {}, 87 | "output_type": "display_data" 88 | }, 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "\r" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "data = bert_data.LearnData.create(\n", 99 | " train_df_path=train_df_path,\n", 100 | " valid_df_path=valid_df_path,\n", 101 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels1.txt\",\n", 102 | " clear_cache=True\n", 103 | ")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 5, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "from modules.models.bert_models import BERTCRF" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "model = BERTCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 7, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "from modules.train.train import NerLearner" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 8, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "num_epochs = 100" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 9, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "learner = NerLearner(\n", 149 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 10, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "298489" 161 | ] 162 | }, 163 | "execution_count": 10, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "model.get_n_trainable_params()" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "scrolled": true 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "learner.fit(epochs=num_epochs)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "### Predict" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 13, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "from modules.data.bert_data import get_data_loader_for_predict" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 14, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 15, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "application/vnd.jupyter.widget-view+json": { 216 | "model_id": "", 217 | "version_major": 2, 218 | "version_minor": 0 219 | }, 220 | "text/plain": [ 221 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…" 222 | ] 223 | }, 224 | "metadata": {}, 225 | "output_type": "display_data" 226 | }, 227 | { 228 | "name": "stdout", 229 | "output_type": "stream", 230 | "text": [ 231 | "\r" 232 | ] 233 | } 234 | ], 235 | "source": [ 236 | "preds = learner.predict(dl)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 16, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "from sklearn_crfsuite.metrics import flat_classification_report" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 17, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 255 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 18, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 265 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 19, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "assert pred_tokens == true_tokens\n", 275 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[5:], digits=4)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 20, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | " precision recall f1-score support\n", 288 | "\n", 289 | " I_LOC 0.8576 0.7932 0.8242 1557\n", 290 | " I_PER 0.9544 0.9616 0.9580 2112\n", 291 | " I_ORG 0.8150 0.7490 0.7806 3865\n", 292 | "\n", 293 | " micro avg 0.8653 0.8178 0.8409 7534\n", 294 | " macro avg 0.8757 0.8346 0.8543 7534\n", 295 | "weighted avg 0.8629 0.8178 0.8394 7534\n", 296 | "\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "print(tokens_report)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 21, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "from modules.analyze_utils.main_metrics import precision_recall_f1" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 22, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "processed 56409 tokens with 7534 phrases; found: 7120 phrases; correct: 6161.\n", 323 | "\n", 324 | "precision: 86.53%; recall: 81.78%; FB1: 84.09\n", 325 | "\n", 326 | "\tLOC: precision: 85.76%; recall: 79.32%; F1: 82.42 1440\n", 327 | "\n", 328 | "\tORG: precision: 81.50%; recall: 74.90%; F1: 78.06 3552\n", 329 | "\n", 330 | "\tPER: precision: 95.44%; recall: 96.16%; F1: 95.80 2128\n", 331 | "\n", 332 | "\n" 333 | ] 334 | } 335 | ], 336 | "source": [ 337 | "results = precision_recall_f1(true_labels, pred_labels)" 338 | ] 339 | } 340 | ], 341 | "metadata": { 342 | "kernelspec": { 343 | "display_name": "Python 3", 344 | "language": "python", 345 | "name": "python3" 346 | }, 347 | "language_info": { 348 | "codemirror_mode": { 349 | "name": "ipython", 350 | "version": 3 351 | }, 352 | "file_extension": ".py", 353 | "mimetype": "text/x-python", 354 | "name": "python", 355 | "nbconvert_exporter": "python", 356 | "pygments_lexer": "ipython3", 357 | "version": "3.6.8" 358 | } 359 | }, 360 | "nbformat": 4, 361 | "nbformat_minor": 2 362 | } 363 | -------------------------------------------------------------------------------- /exps/fre BERTNCRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "\n", 13 | "import sys\n", 14 | "import warnings\n", 15 | "\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append(\"../\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## IO markup" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Train" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from modules.data import bert_data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stderr", 61 | "output_type": "stream", 62 | "text": [ 63 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n" 64 | ] 65 | }, 66 | { 67 | "data": { 68 | "application/vnd.jupyter.widget-view+json": { 69 | "model_id": "", 70 | "version_major": 2, 71 | "version_minor": 0 72 | }, 73 | "text/plain": [ 74 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…" 75 | ] 76 | }, 77 | "metadata": {}, 78 | "output_type": "display_data" 79 | }, 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "\r" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "data = bert_data.LearnData.create(\n", 90 | " train_df_path=train_df_path,\n", 91 | " valid_df_path=valid_df_path,\n", 92 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels.txt\",\n", 93 | " clear_cache=True\n", 94 | ")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "from modules.models.bert_models import BERTNCRF" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "build CRF...\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "model = BERTNCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3, nbest=len(data.train_ds.idx2label)-1)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 7, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "from modules.train.train import NerLearner" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 8, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "num_epochs = 100" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 9, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "learner = NerLearner(\n", 148 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTNCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 10, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "299259" 160 | ] 161 | }, 162 | "execution_count": 10, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "model.get_n_trainable_params()" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": { 175 | "scrolled": true 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "learner.fit(epochs=num_epochs)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 11, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "learner.load_model()" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "### Predict" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 12, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "from modules.data.bert_data import get_data_loader_for_predict" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 13, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 14, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "application/vnd.jupyter.widget-view+json": { 224 | "model_id": "", 225 | "version_major": 2, 226 | "version_minor": 0 227 | }, 228 | "text/plain": [ 229 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…" 230 | ] 231 | }, 232 | "metadata": {}, 233 | "output_type": "display_data" 234 | }, 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "\r" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "preds = learner.predict(dl)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 15, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "from sklearn_crfsuite.metrics import flat_classification_report" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 16, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n", 263 | "from modules.analyze_utils.plot_metrics import get_bert_span_report" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 17, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n", 273 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 18, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "assert pred_tokens == true_tokens\n", 283 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[5:], digits=4)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 19, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | " precision recall f1-score support\n", 296 | "\n", 297 | " I_LOC 0.8765 0.7887 0.8303 1557\n", 298 | " I_PER 0.9598 0.9598 0.9598 2112\n", 299 | " I_ORG 0.7946 0.8078 0.8011 3865\n", 300 | "\n", 301 | " micro avg 0.8569 0.8464 0.8516 7534\n", 302 | " macro avg 0.8770 0.8521 0.8637 7534\n", 303 | "weighted avg 0.8578 0.8464 0.8516 7534\n", 304 | "\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "print(tokens_report)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 20, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "from modules.analyze_utils.main_metrics import precision_recall_f1" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 21, 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "processed 56409 tokens with 7534 phrases; found: 7442 phrases; correct: 6377.\n", 331 | "\n", 332 | "precision: 85.69%; recall: 84.64%; FB1: 85.16\n", 333 | "\n", 334 | "\tLOC: precision: 87.65%; recall: 78.87%; F1: 83.03 1401\n", 335 | "\n", 336 | "\tORG: precision: 79.46%; recall: 80.78%; F1: 80.11 3929\n", 337 | "\n", 338 | "\tPER: precision: 95.98%; recall: 95.98%; F1: 95.98 2112\n", 339 | "\n", 340 | "\n" 341 | ] 342 | } 343 | ], 344 | "source": [ 345 | "results = precision_recall_f1(true_labels, pred_labels)" 346 | ] 347 | } 348 | ], 349 | "metadata": { 350 | "kernelspec": { 351 | "display_name": "Python 3", 352 | "language": "python", 353 | "name": "python3" 354 | }, 355 | "language_info": { 356 | "codemirror_mode": { 357 | "name": "ipython", 358 | "version": 3 359 | }, 360 | "file_extension": ".py", 361 | "mimetype": "text/x-python", 362 | "name": "python", 363 | "nbconvert_exporter": "python", 364 | "pygments_lexer": "ipython3", 365 | "version": "3.6.8" 366 | } 367 | }, 368 | "nbformat": 4, 369 | "nbformat_minor": 2 370 | } 371 | -------------------------------------------------------------------------------- /exps/prc fre.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### FactRuEval-2016 preprocess\n", 8 | "More info about dataset: https://github.com/dialogue-evaluation/factRuEval-2016" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import sys\n", 18 | "import warnings\n", 19 | "\n", 20 | "\n", 21 | "warnings.filterwarnings(\"ignore\")\n", 22 | "sys.path.append(\"../\")" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from modules.data.fre import fact_ru_eval_preprocess" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "dev_dir = \"/home/eartemov/ae/work/factRuEval-2016/devset/\"\n", 41 | "test_dir = \"/home/eartemov/ae/work/factRuEval-2016/testset/\"\n", 42 | "dev_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n", 43 | "test_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "application/vnd.jupyter.widget-view+json": { 54 | "model_id": "43de7e40d1784421bb55921e2d0058f3", 55 | "version_major": 2, 56 | "version_minor": 0 57 | }, 58 | "text/plain": [ 59 | "HBox(children=(IntProgress(value=0, description='Process FactRuEval2016 dev set.', max=1519, style=ProgressSty…" 60 | ] 61 | }, 62 | "metadata": {}, 63 | "output_type": "display_data" 64 | }, 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "\n" 70 | ] 71 | }, 72 | { 73 | "data": { 74 | "application/vnd.jupyter.widget-view+json": { 75 | "model_id": "27617d0f776d4b37b6f893bcac517f24", 76 | "version_major": 2, 77 | "version_minor": 0 78 | }, 79 | "text/plain": [ 80 | "HBox(children=(IntProgress(value=0, description='Process FactRuEval2016 test set.', max=2715, style=ProgressSt…" 81 | ] 82 | }, 83 | "metadata": {}, 84 | "output_type": "display_data" 85 | }, 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "fact_ru_eval_preprocess(dev_dir, test_dir, dev_df_path, test_df_path)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "import pandas as pd" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 6, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "data": { 114 | "text/html": [ 115 | "
\n", 116 | "\n", 129 | "\n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | "
labelstextcls
0O O B_LOC O O O O O B_PER I_PER O O O O OСегодня в Москве на 40-й день после смерти Его...False
1O B_LOC I_LOC O O O O B_ORG O B_PER I_PER O O ...К Кронштадтскому бульвару , где болельщика « С...False
2O O O O O O O O O OИ тишина ... Все прошло мирно и не столь массовоTrue
3O O O O O O O O O O O O OПравда , были задержания , но , как пояснили в...True
4O O O O O O O O O OОдним словом , очередной « Русский марш » не с...True
\n", 171 | "
" 172 | ], 173 | "text/plain": [ 174 | " labels \\\n", 175 | "0 O O B_LOC O O O O O B_PER I_PER O O O O O \n", 176 | "1 O B_LOC I_LOC O O O O B_ORG O B_PER I_PER O O ... \n", 177 | "2 O O O O O O O O O O \n", 178 | "3 O O O O O O O O O O O O O \n", 179 | "4 O O O O O O O O O O \n", 180 | "\n", 181 | " text cls \n", 182 | "0 Сегодня в Москве на 40-й день после смерти Его... False \n", 183 | "1 К Кронштадтскому бульвару , где болельщика « С... False \n", 184 | "2 И тишина ... Все прошло мирно и не столь массово True \n", 185 | "3 Правда , были задержания , но , как пояснили в... True \n", 186 | "4 Одним словом , очередной « Русский марш » не с... True " 187 | ] 188 | }, 189 | "execution_count": 6, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "pd.read_csv(dev_df_path, sep=\"\\t\").head()" 196 | ] 197 | } 198 | ], 199 | "metadata": { 200 | "kernelspec": { 201 | "display_name": "Python 3", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.6.8" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 2 220 | } 221 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import get_tqdm 2 | 3 | 4 | tqdm = get_tqdm() 5 | 6 | 7 | __all__ = ["tqdm"] 8 | -------------------------------------------------------------------------------- /modules/analyze_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | 3 | 4 | __all__ = ["read_json", "save_json"] 5 | -------------------------------------------------------------------------------- /modules/analyze_utils/main_metrics.py: -------------------------------------------------------------------------------- 1 | # This code is reused from https://github.com/deepmipt/DeepPavlov/blob/master/deeppavlov/metrics/fmeasure.py 2 | import itertools 3 | from collections import OrderedDict 4 | 5 | 6 | def chunk_finder(current_token, previous_token, tag): 7 | current_tag = current_token.split('_', 1)[-1] 8 | previous_tag = previous_token.split('_', 1)[-1] 9 | if previous_tag != tag: 10 | previous_tag = 'O' 11 | if current_tag != tag: 12 | current_tag = 'O' 13 | if (previous_tag == 'O' and current_token == 'B_' + tag) or \ 14 | (previous_token == 'I_' + tag and current_token == 'B_' + tag) or \ 15 | (previous_token == 'B_' + tag and current_token == 'B_' + tag) or \ 16 | (previous_tag == 'O' and current_token == 'I_' + tag): 17 | create_chunk = True 18 | else: 19 | create_chunk = False 20 | 21 | if (previous_token == 'I_' + tag and current_token == 'B_' + tag) or \ 22 | (previous_token == 'B_' + tag and current_token == 'B_' + tag) or \ 23 | (current_tag == 'O' and previous_token == 'I_' + tag) or \ 24 | (current_tag == 'O' and previous_token == 'B_' + tag): 25 | pop_out = True 26 | else: 27 | pop_out = False 28 | return create_chunk, pop_out 29 | 30 | 31 | def _global_stats_f1(results): 32 | total_true_entities = 0 33 | total_predicted_entities = 0 34 | total_precision = 0 35 | total_recall = 0 36 | total_f1 = 0 37 | total_correct = 0 38 | for tag in results: 39 | if tag == '__total__': 40 | continue 41 | 42 | n_pred = results[tag]['n_pred'] 43 | n_true = results[tag]['n_true'] 44 | total_correct += results[tag]['tp'] 45 | total_true_entities += n_true 46 | total_predicted_entities += n_pred 47 | total_precision += results[tag]['precision'] * n_pred 48 | total_recall += results[tag]['recall'] * n_true 49 | total_f1 += results[tag]['f1'] * n_true 50 | if total_true_entities > 0: 51 | accuracy = total_correct / total_true_entities * 100 52 | total_recall = total_recall / total_true_entities 53 | else: 54 | accuracy = 0 55 | total_recall = 0 56 | if total_predicted_entities > 0: 57 | total_precision = total_precision / total_predicted_entities 58 | else: 59 | total_precision = 0 60 | 61 | if total_precision + total_recall > 0: 62 | total_f1 = 2 * total_precision * total_recall / (total_precision + total_recall) 63 | else: 64 | total_f1 = 0 65 | 66 | total_res = {'n_predicted_entities': total_predicted_entities, 67 | 'n_true_entities': total_true_entities, 68 | 'precision': total_precision, 69 | 'recall': total_recall, 70 | 'f1': total_f1} 71 | return total_res, accuracy, total_true_entities, total_predicted_entities, total_correct 72 | 73 | 74 | def precision_recall_f1(y_true, y_pred, print_results=True, short_report=False, entity_of_interest=None): 75 | y_true = list(itertools.chain(*y_true)) 76 | y_pred = list(itertools.chain(*y_pred)) 77 | # Find all tags 78 | tags = set() 79 | for tag in itertools.chain(y_true, y_pred): 80 | if tag not in ["O", "I_O", "B_O"]: 81 | current_tag = tag[2:] 82 | tags.add(current_tag) 83 | tags = sorted(list(tags)) 84 | 85 | results = OrderedDict() 86 | for tag in tags: 87 | results[tag] = OrderedDict() 88 | results['__total__'] = OrderedDict() 89 | n_tokens = len(y_true) 90 | # Firstly we find all chunks in the ground truth and prediction 91 | # For each chunk we write starting and ending indices 92 | 93 | for tag in tags: 94 | count = 0 95 | true_chunk = [] 96 | pred_chunk = [] 97 | y_true = [str(y) for y in y_true] 98 | y_pred = [str(y) for y in y_pred] 99 | prev_tag_true = 'O' 100 | prev_tag_pred = 'O' 101 | while count < n_tokens: 102 | yt = y_true[count] 103 | yp = y_pred[count] 104 | 105 | create_chunk_true, pop_out_true = chunk_finder(yt, prev_tag_true, tag) 106 | if pop_out_true: 107 | true_chunk[-1] = (true_chunk[-1], count - 1) 108 | if create_chunk_true: 109 | true_chunk.append(count) 110 | 111 | create_chunk_pred, pop_out_pred = chunk_finder(yp, prev_tag_pred, tag) 112 | if pop_out_pred: 113 | pred_chunk[-1] = (pred_chunk[-1], count - 1) 114 | if create_chunk_pred: 115 | pred_chunk.append(count) 116 | prev_tag_true = yt 117 | prev_tag_pred = yp 118 | count += 1 119 | 120 | if len(true_chunk) > 0 and not isinstance(true_chunk[-1], tuple): 121 | true_chunk[-1] = (true_chunk[-1], count - 1) 122 | if len(pred_chunk) > 0 and not isinstance(pred_chunk[-1], tuple): 123 | pred_chunk[-1] = (pred_chunk[-1], count - 1) 124 | 125 | # Then we find all correctly classified intervals 126 | # True positive results 127 | tp = len(set(pred_chunk).intersection(set(true_chunk))) 128 | # And then just calculate errors of the first and second kind 129 | # False negative 130 | fn = len(true_chunk) - tp 131 | # False positive 132 | fp = len(pred_chunk) - tp 133 | if tp + fp > 0: 134 | precision = tp / (tp + fp) * 100 135 | else: 136 | precision = 0 137 | if tp + fn > 0: 138 | recall = tp / (tp + fn) * 100 139 | else: 140 | recall = 0 141 | if precision + recall > 0: 142 | f1 = 2 * precision * recall / (precision + recall) 143 | else: 144 | f1 = 0 145 | results[tag]['precision'] = precision 146 | results[tag]['recall'] = recall 147 | results[tag]['f1'] = f1 148 | results[tag]['n_pred'] = len(pred_chunk) 149 | results[tag]['n_true'] = len(true_chunk) 150 | results[tag]['tp'] = tp 151 | results[tag]['fn'] = fn 152 | results[tag]['fp'] = fp 153 | 154 | results['__total__'], accuracy, total_true_entities, total_predicted_entities, total_correct = _global_stats_f1(results) 155 | results['__total__']['n_pred'] = total_predicted_entities 156 | results['__total__']['n_true'] = total_true_entities 157 | results['__total__']["n_tokens"] = n_tokens 158 | if print_results: 159 | _print_conll_report(results, short_report, entity_of_interest) 160 | return results 161 | 162 | 163 | def _print_conll_report(results, short_report=False, entity_of_interest=None): 164 | _, accuracy, total_true_entities, total_predicted_entities, total_correct = _global_stats_f1(results) 165 | n_tokens = results['__total__']["n_tokens"] 166 | tags = list(results.keys()) 167 | 168 | s = 'processed {len} tokens ' \ 169 | 'with {tot_true} phrases; ' \ 170 | 'found: {tot_pred} phrases;' \ 171 | ' correct: {tot_cor}.\n\n'.format(len=n_tokens, 172 | tot_true=total_true_entities, 173 | tot_pred=total_predicted_entities, 174 | tot_cor=total_correct) 175 | 176 | s += 'precision: {tot_prec:.2f}%; ' \ 177 | 'recall: {tot_recall:.2f}%; ' \ 178 | 'FB1: {tot_f1:.2f}\n\n'.format(acc=accuracy, 179 | tot_prec=results['__total__']['precision'], 180 | tot_recall=results['__total__']['recall'], 181 | tot_f1=results['__total__']['f1']) 182 | 183 | if not short_report: 184 | for tag in tags: 185 | if entity_of_interest is not None: 186 | if entity_of_interest in tag: 187 | s += '\t' + tag + ': precision: {tot_prec:.2f}%; ' \ 188 | 'recall: {tot_recall:.2f}%; ' \ 189 | 'F1: {tot_f1:.2f} ' \ 190 | '{tot_predicted}\n\n'.format(tot_prec=results[tag]['precision'], 191 | tot_recall=results[tag]['recall'], 192 | tot_f1=results[tag]['f1'], 193 | tot_predicted=results[tag]['n_pred']) 194 | elif tag != '__total__': 195 | s += '\t' + tag + ': precision: {tot_prec:.2f}%; ' \ 196 | 'recall: {tot_recall:.2f}%; ' \ 197 | 'F1: {tot_f1:.2f} ' \ 198 | '{tot_predicted}\n\n'.format(tot_prec=results[tag]['precision'], 199 | tot_recall=results[tag]['recall'], 200 | tot_f1=results[tag]['f1'], 201 | tot_predicted=results[tag]['n_pred']) 202 | elif entity_of_interest is not None: 203 | s += '\t' + entity_of_interest + ': precision: {tot_prec:.2f}%; ' \ 204 | 'recall: {tot_recall:.2f}%; ' \ 205 | 'F1: {tot_f1:.2f} ' \ 206 | '{tot_predicted}\n\n'.format(tot_prec=results[entity_of_interest]['precision'], 207 | tot_recall=results[entity_of_interest]['recall'], 208 | tot_f1=results[entity_of_interest]['f1'], 209 | tot_predicted=results[entity_of_interest]['n_pred']) 210 | print(s) 211 | -------------------------------------------------------------------------------- /modules/analyze_utils/plot_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | from matplotlib import pyplot as plt 4 | from .utils import tokens2spans, bert_labels2tokens, voting_choicer, first_choicer 5 | from sklearn_crfsuite.metrics import flat_classification_report 6 | from sklearn.metrics import f1_score 7 | 8 | 9 | def plot_by_class_curve(history, metric_, sup_labels): 10 | by_class = get_by_class_metric(history, metric_, sup_labels) 11 | vals = list(by_class.values()) 12 | x = np.arange(len(vals[0])) 13 | args = [] 14 | for val in vals: 15 | args.append(x) 16 | args.append(val) 17 | plt.figure(figsize=(15, 10)) 18 | plt.grid(True) 19 | plt.plot(*args) 20 | plt.legend(list(by_class.keys())) 21 | _, _ = plt.yticks(np.arange(0, 1, step=0.1)) 22 | plt.show() 23 | 24 | 25 | def get_metrics_by_class(text_res, sup_labels): 26 | # text_res = flat_classification_report(y_true, y_pred, labels=labels, digits=3) 27 | res = {} 28 | for line in text_res.split("\n"): 29 | line = line.split() 30 | if len(line) and line[0] in sup_labels: 31 | res[line[0]] = {key: val for key, val in zip(["prec", "rec", "f1"], line[1:-1])} 32 | return res 33 | 34 | 35 | def get_by_class_metric(history, metric_, sup_labels): 36 | res = defaultdict(list) 37 | for h in history: 38 | h = get_metrics_by_class(h, sup_labels) 39 | for class_, metrics_ in h.items(): 40 | res[class_].append(float(metrics_[metric_])) 41 | return res 42 | 43 | 44 | def get_max_metric(history, metric_, sup_labels, return_idx=False): 45 | by_class = get_by_class_metric(history, metric_, sup_labels) 46 | by_class_arr = np.array(list(by_class.values())) 47 | idx = np.array(by_class_arr.sum(0)).argmax() 48 | if return_idx: 49 | return list(zip(by_class.keys(), by_class_arr[:, idx])), idx 50 | return list(zip(by_class.keys(), by_class_arr[:, idx])) 51 | 52 | 53 | def get_mean_max_metric(history, metric_="f1", return_idx=False): 54 | m_idx = 0 55 | if metric_ == "f1": 56 | m_idx = 2 57 | elif m_idx == "rec": 58 | m_idx = 1 59 | metrics = [float(h.split("\n")[-3].split()[2 + m_idx]) for h in history] 60 | idx = np.argmax(metrics) 61 | res = metrics[idx] 62 | if return_idx: 63 | return idx, res 64 | return res 65 | 66 | 67 | def get_bert_span_report(dl, preds, labels=None, fn=voting_choicer): 68 | pred_tokens, pred_labels = bert_labels2tokens(dl, preds) 69 | true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset]) 70 | spans_pred = tokens2spans(pred_tokens, pred_labels) 71 | spans_true = tokens2spans(true_tokens, true_labels) 72 | res_t = [] 73 | res_p = [] 74 | for pred_span, true_span in zip(spans_pred, spans_true): 75 | text2span = {t: l for t, l in pred_span} 76 | for (pt, pl), (tt, tl) in zip(pred_span, true_span): 77 | res_t.append(tl) 78 | if tt in text2span: 79 | res_p.append(pl) 80 | else: 81 | res_p.append("O") 82 | return flat_classification_report([res_t], [res_p], labels=labels, digits=4) 83 | 84 | 85 | def analyze_bert_errors(dl, labels, fn=voting_choicer): 86 | errors = [] 87 | res_tokens = [] 88 | res_labels = [] 89 | r_labels = [x.labels for x in dl.dataset] 90 | for f, l_, rl in zip(dl.dataset, labels, r_labels): 91 | label = fn(f.tok_map, l_) 92 | label_r = fn(f.tok_map, rl) 93 | prev_idx = 0 94 | errors_ = [] 95 | # if len(label_r) > 1: 96 | # assert len(label_r) == len(f.tokens) - 1 97 | for idx, (lbl, rl, t) in enumerate(zip(label, label_r, f.tokens)): 98 | if lbl != rl: 99 | errors_.append( 100 | {"token: ": t, 101 | "real_label": rl, 102 | "pred_label": lbl, 103 | "bert_token": f.bert_tokens[prev_idx:f.tok_map[idx]], 104 | "real_bert_label": f.labels[prev_idx:f.tok_map[idx]], 105 | "pred_bert_label": l_[prev_idx:f.tok_map[idx]], 106 | "text_example": " ".join(f.tokens[1:-1]), 107 | "labels": " ".join(label_r[1:])}) 108 | prev_idx = f.tok_map[idx] 109 | errors.append(errors_) 110 | res_tokens.append(f.tokens[1:-1]) 111 | res_labels.append(label[1:]) 112 | return res_tokens, res_labels, errors 113 | 114 | 115 | def get_f1_score(y_true, y_pred, labels): 116 | res_t = [] 117 | res_p = [] 118 | for yts, yps in zip(y_true, y_pred): 119 | for yt, yp in zip(yts, yps): 120 | res_t.append(yt) 121 | res_p.append(yp) 122 | return f1_score(res_t, res_p, average="macro", labels=labels) 123 | -------------------------------------------------------------------------------- /modules/analyze_utils/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import numpy as np 3 | import json 4 | import numpy 5 | 6 | 7 | def voting_choicer(tok_map, labels): 8 | label = [] 9 | prev_idx = 0 10 | for origin_idx in tok_map: 11 | votes = [] 12 | for l in labels[prev_idx:origin_idx]: 13 | if l != "X": 14 | votes.append(l) 15 | vote_labels = Counter(votes) 16 | if not len(vote_labels): 17 | vote_labels = {"B_O": 1} 18 | # vote_labels = Counter(c) 19 | lb = sorted(list(vote_labels), key=lambda x: vote_labels[x]) 20 | if len(lb): 21 | label.append(lb[-1]) 22 | prev_idx = origin_idx 23 | if origin_idx < 0: 24 | break 25 | 26 | return label 27 | 28 | 29 | def first_choicer(tok_map, labels): 30 | label = [] 31 | prev_idx = 0 32 | for origin_idx in tok_map: 33 | l = labels[prev_idx] 34 | if l in ["X"]: 35 | l = "B_O" 36 | if l == "B_O": 37 | for ll in labels[prev_idx + 1:origin_idx]: 38 | if ll not in ["B_O", "I_O", "X"]: 39 | l = ll 40 | break 41 | label.append(l) 42 | prev_idx = origin_idx 43 | if origin_idx < 0: 44 | break 45 | # assert "[SEP]" not in label 46 | return label 47 | 48 | 49 | def bert_labels2tokens(dl, labels, fn=voting_choicer): 50 | res_tokens = [] 51 | res_labels = [] 52 | for f, l in zip(dl.dataset, labels): 53 | label = fn(f.tok_map, l[1:]) 54 | 55 | res_tokens.append(f.tokens[1:-1]) 56 | res_labels.append(label[1:]) 57 | return res_tokens, res_labels 58 | 59 | 60 | def tokens2spans_(tokens_, labels_): 61 | res = [] 62 | idx_ = 0 63 | while idx_ < len(labels_): 64 | label = labels_[idx_] 65 | if label in ["I_O", "B_O", "O"]: 66 | res.append((tokens_[idx_], "O")) 67 | idx_ += 1 68 | elif label == "": 69 | break 70 | elif label == "[CLS]" or label == "": 71 | res.append((tokens_[idx_], label)) 72 | idx_ += 1 73 | else: 74 | span = [tokens_[idx_]] 75 | try: 76 | span_label = labels_[idx_].split("_")[1] 77 | except IndexError: 78 | print(label, labels_[idx_].split("_")) 79 | span_label = None 80 | idx_ += 1 81 | while idx_ < len(labels_) and labels_[idx_] not in ["I_O", "B_O", "O"] \ 82 | and labels_[idx_].split("_")[0] == "I": 83 | if span_label == labels_[idx_].split("_")[1]: 84 | span.append(tokens_[idx_]) 85 | idx_ += 1 86 | else: 87 | break 88 | res.append((" ".join(span), span_label)) 89 | return res 90 | 91 | 92 | def tokens2spans(tokens, labels): 93 | assert len(tokens) == len(labels) 94 | 95 | return list(map(lambda x: tokens2spans_(*x), zip(tokens, labels))) 96 | 97 | 98 | def encode_position(pos, emb_dim=10): 99 | """The sinusoid position encoding""" 100 | 101 | # keep dim 0 for padding token position encoding zero vector 102 | if pos == 0: 103 | return np.zeros(emb_dim) 104 | position_enc = np.array( 105 | [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)]) 106 | 107 | # apply sin on 0th,2nd,4th...emb_dim 108 | position_enc[0::2] = np.sin(position_enc[0::2]) 109 | # apply cos on 1st,3rd,5th...emb_dim 110 | position_enc[1::2] = np.cos(position_enc[1::2]) 111 | return list(position_enc.reshape(-1)) 112 | 113 | 114 | class JsonEncoder(json.JSONEncoder): 115 | def default(self, obj): 116 | if isinstance(obj, numpy.integer): 117 | return int(obj) 118 | elif isinstance(obj, numpy.floating): 119 | return float(obj) 120 | elif isinstance(obj, numpy.ndarray): 121 | return obj.tolist() 122 | else: 123 | return super(JsonEncoder, self).default(obj) 124 | 125 | 126 | def jsonify(data): 127 | return json.dumps(data, cls=JsonEncoder) 128 | 129 | 130 | def read_json(config): 131 | if isinstance(config, str): 132 | with open(config, "r") as f: 133 | config = json.load(f) 134 | return config 135 | 136 | 137 | def save_json(config, path): 138 | with open(path, "w") as file: 139 | json.dump(config, file, cls=JsonEncoder) 140 | -------------------------------------------------------------------------------- /modules/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/ner-bert/b75a903c35acbd36ff5f26c525e3596294f36815/modules/data/__init__.py -------------------------------------------------------------------------------- /modules/data/bert_data_clf.py: -------------------------------------------------------------------------------- 1 | from .bert_data import TextDataLoader 2 | from pytorch_pretrained_bert import BertTokenizer 3 | from modules.utils import read_config, if_none 4 | from modules import tqdm 5 | import pandas as pd 6 | from copy import deepcopy 7 | 8 | 9 | class InputFeature(object): 10 | """A single set of features of data.""" 11 | 12 | def __init__( 13 | self, 14 | # Bert data 15 | bert_tokens, input_ids, input_mask, input_type_ids, 16 | # Origin data 17 | tokens, tok_map, 18 | # Cls data 19 | cls=None, id_cls=None): 20 | """ 21 | Data has the following structure. 22 | data[0]: list, tokens ids 23 | data[1]: list, tokens mask 24 | data[2]: list, tokens type ids (for bert) 25 | """ 26 | self.data = [] 27 | # Bert data 28 | self.bert_tokens = bert_tokens 29 | self.input_ids = input_ids 30 | self.data.append(input_ids) 31 | self.input_mask = input_mask 32 | self.data.append(input_mask) 33 | self.input_type_ids = input_type_ids 34 | self.data.append(input_type_ids) 35 | # Classification data 36 | self.cls = cls 37 | self.id_cls = id_cls 38 | if cls is not None: 39 | self.data.append(id_cls) 40 | # Origin data 41 | self.tokens = tokens 42 | self.tok_map = tok_map 43 | 44 | def __iter__(self): 45 | return iter(self.data) 46 | 47 | 48 | class TextDataSet(object): 49 | 50 | @classmethod 51 | def from_config(cls, config, clear_cache=False, df=None): 52 | return cls.create(**read_config(config), clear_cache=clear_cache, df=df) 53 | 54 | @classmethod 55 | def create(cls, 56 | df_path=None, 57 | idx2cls=None, 58 | idx2cls_path=None, 59 | min_char_len=1, 60 | model_name="bert-base-multilingual-cased", 61 | max_sequence_length=424, 62 | pad_idx=0, 63 | clear_cache=False, 64 | df=None, tokenizer=None): 65 | if tokenizer is None: 66 | tokenizer = BertTokenizer.from_pretrained(model_name) 67 | config = { 68 | "min_char_len": min_char_len, 69 | "model_name": model_name, 70 | "max_sequence_length": max_sequence_length, 71 | "clear_cache": clear_cache, 72 | "df_path": df_path, 73 | "pad_idx": pad_idx, 74 | "idx2cls_path": idx2cls_path 75 | } 76 | if df is None and df_path is not None: 77 | df = pd.read_csv(df_path, sep='\t', engine='python') 78 | elif df is None: 79 | df = pd.DataFrame(columns=["text", "clf"]) 80 | if clear_cache: 81 | _, idx2cls = cls.create_vocabs(df, idx2cls_path, idx2cls) 82 | self = cls(tokenizer, df=df, config=config, idx2cls=idx2cls) 83 | self.load(df=df) 84 | return self 85 | 86 | @staticmethod 87 | def create_vocabs( 88 | df, idx2cls_path, idx2cls=None): 89 | idx2cls = idx2cls 90 | cls2idx = {} 91 | if idx2cls is not None: 92 | cls2idx = {label: idx for idx, label in enumerate(idx2cls)} 93 | else: 94 | idx2cls = [] 95 | for _, row in tqdm(df.iterrows(), total=len(df), leave=False, desc="Creating labels vocabs"): 96 | if row.cls not in cls2idx: 97 | cls2idx[row.cls] = len(cls2idx) 98 | idx2cls.append(row.cls) 99 | 100 | with open(idx2cls_path, "w", encoding="utf-8") as f: 101 | for label in idx2cls: 102 | f.write("{}\n".format(label)) 103 | 104 | return cls2idx, idx2cls 105 | 106 | def load(self, df_path=None, df=None): 107 | df_path = if_none(df_path, self.config["df_path"]) 108 | if df is None: 109 | self.df = pd.read_csv(df_path, sep='\t') 110 | 111 | self.idx2cls = [] 112 | self.cls2idx = {} 113 | with open(self.config["idx2cls_path"], "r", encoding="utf-8") as f: 114 | for idx, label in enumerate(f.readlines()): 115 | label = label.strip() 116 | self.cls2idx[label] = idx 117 | self.idx2cls.append(label) 118 | 119 | def create_feature(self, row): 120 | bert_tokens = [] 121 | orig_tokens = row.text.split() 122 | tok_map = [] 123 | for orig_token in orig_tokens: 124 | cur_tokens = self.tokenizer.tokenize(orig_token) 125 | if self.config["max_sequence_length"] - 2 < len(bert_tokens) + len(cur_tokens): 126 | break 127 | cur_tokens = self.tokenizer.tokenize(orig_token) 128 | tok_map.append(len(bert_tokens)) 129 | bert_tokens.extend(cur_tokens) 130 | 131 | orig_tokens = ["[CLS]"] + orig_tokens + ["[SEP]"] 132 | 133 | input_ids = self.tokenizer.convert_tokens_to_ids(['[CLS]'] + bert_tokens + ['[SEP]']) 134 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 135 | # tokens are attended to. 136 | input_mask = [1] * len(input_ids) 137 | # Zero-pad up to the sequence length. 138 | while len(input_ids) < self.config["max_sequence_length"]: 139 | input_ids.append(self.config["pad_idx"]) 140 | input_mask.append(0) 141 | tok_map.append(-1) 142 | input_type_ids = [0] * len(input_ids) 143 | cls = str(row.cls) 144 | id_cls = self.cls2idx[cls] 145 | return InputFeature( 146 | # Bert data 147 | bert_tokens=bert_tokens, 148 | input_ids=input_ids, 149 | input_mask=input_mask, 150 | input_type_ids=input_type_ids, 151 | # Origin data 152 | tokens=orig_tokens, 153 | tok_map=tok_map, 154 | # Cls 155 | cls=cls, id_cls=id_cls 156 | ) 157 | 158 | def __getitem__(self, item): 159 | if self.config["df_path"] is None and self.df is None: 160 | raise ValueError("Should setup df_path or df.") 161 | if self.df is None: 162 | self.load() 163 | 164 | return self.create_feature(self.df.iloc[item]) 165 | 166 | def __len__(self): 167 | return len(self.df) if self.df is not None else 0 168 | 169 | def save(self, df_path=None): 170 | df_path = if_none(df_path, self.config["df_path"]) 171 | self.df.to_csv(df_path, sep='\t', index=False) 172 | 173 | def __init__( 174 | self, tokenizer, 175 | df=None, 176 | config=None, 177 | idx2cls=None): 178 | self.df = df 179 | self.tokenizer = tokenizer 180 | self.config = config 181 | self.label2idx = None 182 | 183 | self.idx2cls = idx2cls 184 | if idx2cls is not None: 185 | self.cls2idx = {label: idx for idx, label in enumerate(idx2cls)} 186 | 187 | 188 | class LearnDataClass(object): 189 | def __init__(self, train_ds=None, train_dl=None, valid_ds=None, valid_dl=None): 190 | self.train_ds = train_ds 191 | self.train_dl = train_dl 192 | self.valid_ds = valid_ds 193 | self.valid_dl = valid_dl 194 | 195 | @classmethod 196 | def create(cls, 197 | # DataSet params 198 | train_df_path, 199 | valid_df_path, 200 | idx2cls=None, 201 | idx2cls_path=None, 202 | min_char_len=1, 203 | model_name="bert-base-multilingual-cased", 204 | max_sequence_length=424, 205 | pad_idx=0, 206 | clear_cache=False, 207 | train_df=None, 208 | valid_df=None, 209 | # DataLoader params 210 | device="cuda", batch_size=16): 211 | train_ds = None 212 | train_dl = None 213 | valid_ds = None 214 | valid_dl = None 215 | if idx2cls_path is not None: 216 | train_ds = TextDataSet.create( 217 | train_df_path, 218 | idx2cls=idx2cls, 219 | idx2cls_path=idx2cls_path, 220 | min_char_len=min_char_len, 221 | model_name=model_name, 222 | max_sequence_length=max_sequence_length, 223 | pad_idx=pad_idx, 224 | clear_cache=clear_cache, 225 | df=train_df) 226 | if len(train_ds): 227 | train_dl = TextDataLoader(train_ds, device=device, shuffle=True, batch_size=batch_size) 228 | if valid_df_path is not None: 229 | valid_ds = TextDataSet.create( 230 | valid_df_path, 231 | idx2cls=train_ds.idx2cls, 232 | idx2cls_path=idx2cls_path, 233 | min_char_len=min_char_len, 234 | model_name=model_name, 235 | max_sequence_length=max_sequence_length, 236 | pad_idx=pad_idx, 237 | clear_cache=False, 238 | df=valid_df, tokenizer=train_ds.tokenizer) 239 | valid_dl = TextDataLoader(valid_ds, device=device, batch_size=batch_size) 240 | 241 | self = cls(train_ds, train_dl, valid_ds, valid_dl) 242 | self.device = device 243 | self.batch_size = batch_size 244 | return self 245 | 246 | def load(self): 247 | if self.train_ds is not None: 248 | self.train_ds.load() 249 | if self.valid_ds is not None: 250 | self.valid_ds.load() 251 | 252 | def save(self): 253 | if self.train_ds is not None: 254 | self.train_ds.save() 255 | if self.valid_ds is not None: 256 | self.valid_ds.save() 257 | 258 | 259 | def get_data_loader_for_predict(data, df_path=None, df=None): 260 | config = deepcopy(data.train_ds.config) 261 | config["df_path"] = df_path 262 | config["clear_cache"] = False 263 | ds = TextDataSet.create( 264 | idx2cls=data.train_ds.idx2cls, 265 | df=df, tokenizer=data.train_ds.tokenizer, **config) 266 | return TextDataLoader( 267 | ds, device=data.device, batch_size=data.batch_size, shuffle=False), ds 268 | -------------------------------------------------------------------------------- /modules/data/conll2003/__init__.py: -------------------------------------------------------------------------------- 1 | from .prc import conll2003_preprocess 2 | 3 | 4 | __all__ = ["conll2003_preprocess"] 5 | -------------------------------------------------------------------------------- /modules/data/conll2003/prc.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from modules import tqdm 3 | import argparse 4 | import codecs 5 | import os 6 | 7 | 8 | def conll2003_preprocess( 9 | data_dir, train_name="eng.train", dev_name="eng.testa", test_name="eng.testb"): 10 | train_f = read_data(os.path.join(data_dir, train_name)) 11 | dev_f = read_data(os.path.join(data_dir, dev_name)) 12 | test_f = read_data(os.path.join(data_dir, test_name)) 13 | 14 | train = pd.DataFrame({"labels": [x[0] for x in train_f], "text": [x[1] for x in train_f]}) 15 | train["cls"] = train["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()])) 16 | train.to_csv(os.path.join(data_dir, "{}.train.csv".format(train_name)), index=False, sep="\t") 17 | 18 | dev = pd.DataFrame({"labels": [x[0] for x in dev_f], "text": [x[1] for x in dev_f]}) 19 | dev["cls"] = dev["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()])) 20 | dev.to_csv(os.path.join(data_dir, "{}.dev.csv".format(dev_name)), index=False, sep="\t") 21 | 22 | test_ = pd.DataFrame({"labels": [x[0] for x in test_f], "text": [x[1] for x in test_f]}) 23 | test_["cls"] = test_["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()])) 24 | test_.to_csv(os.path.join(data_dir, "{}.dev.csv".format(test_name)), index=False, sep="\t") 25 | 26 | 27 | def read_data(input_file): 28 | """Reads a BIO data.""" 29 | with codecs.open(input_file, "r", encoding="utf-8") as f: 30 | lines = [] 31 | words = [] 32 | labels = [] 33 | f_lines = f.readlines() 34 | for line in tqdm(f_lines, total=len(f_lines), desc="Process {}".format(input_file)): 35 | contends = line.strip() 36 | word = line.strip().split(' ')[0] 37 | label = line.strip().split(' ')[-1] 38 | if contends.startswith("-DOCSTART-"): 39 | words.append('') 40 | continue 41 | 42 | if len(contends) == 0 and not len(words): 43 | words.append("") 44 | 45 | if len(contends) == 0 and words[-1] == '.': 46 | lbl = ' '.join([label for label in labels if len(label) > 0]) 47 | w = ' '.join([word for word in words if len(word) > 0]) 48 | lines.append([lbl, w]) 49 | words = [] 50 | labels = [] 51 | continue 52 | words.append(word) 53 | labels.append(label.replace("-", "_")) 54 | return lines 55 | 56 | 57 | def parse_args(): 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--data_dir', type=str) 60 | parser.add_argument('--train_name', type=str, default="eng.train") 61 | parser.add_argument('--dev_name', type=str, default="eng.testa") 62 | parser.add_argument('--test_name', type=str, default="eng.testb") 63 | return vars(parser.parse_args()) 64 | 65 | 66 | if __name__ == "__main__": 67 | conll2003_preprocess(**parse_args()) 68 | -------------------------------------------------------------------------------- /modules/data/download_data.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import sys 3 | import os 4 | 5 | 6 | tasks_urls = { 7 | "conll2003": [ 8 | ["eng.testa", "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa"], 9 | ["eng.testb", "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb"], 10 | ["eng.train", "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train"] 11 | ]} 12 | 13 | 14 | def download_data(task_name, data_dir): 15 | req = urllib 16 | if sys.version_info >= (3, 0): 17 | req = urllib.request 18 | for data_file, url in tasks_urls[task_name]: 19 | if not os.path.exists(data_dir): 20 | os.mkdir(data_dir) 21 | _ = req.urlretrieve(url, os.path.join(data_dir, data_file)) 22 | -------------------------------------------------------------------------------- /modules/data/fre/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .reader import Reader as FREReader 3 | from .prc import fact_ru_eval_preprocess 4 | 5 | __all__ = ["FREReader", "fact_ru_eval_preprocess"] 6 | -------------------------------------------------------------------------------- /modules/data/fre/bilou/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /modules/data/fre/bilou/from_bilou.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def untag(list_of_tags, list_of_tokens): 5 | """ 6 | :param list_of_tags: 7 | :param list_of_tokens: 8 | :return: 9 | """ 10 | if len(list_of_tags) == len(list_of_tokens): 11 | dict_of_final_ne = {} 12 | ne_words = [] 13 | ne_tag = None 14 | 15 | for index in range(len(list_of_tokens)): 16 | if not ((ne_tag is not None) ^ (ne_words != [])): 17 | current_tag = list_of_tags[index] 18 | current_token = list_of_tokens[index] 19 | 20 | if current_tag.startswith('B') or current_tag.startswith('I'): 21 | dict_of_final_ne, ne_words, ne_tag = __check_bi( 22 | dict_of_final_ne, ne_words, ne_tag, current_tag, current_token) 23 | elif current_tag.startswith('L'): 24 | dict_of_final_ne, ne_words, ne_tag = __check_l( 25 | dict_of_final_ne, ne_words, ne_tag, current_tag, current_token) 26 | elif current_tag.startswith('O'): 27 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag) 28 | 29 | elif current_tag.startswith('U'): 30 | dict_of_final_ne, ne_words, ne_tag = __check_u(dict_of_final_ne, ne_words, ne_tag, current_tag, 31 | current_token) 32 | else: 33 | raise ValueError("tag contains no BILOU tags") 34 | else: 35 | if ne_tag is None: 36 | raise Exception('Somehow ne_tag is None and ne_words is not None') 37 | else: 38 | raise Exception('Somehow ne_words is None and ne_tag is not None') 39 | 40 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag) 41 | return __to_output_format(dict_of_final_ne) 42 | else: 43 | raise ValueError('lengths are not equal') 44 | 45 | 46 | def __check_bi(dict_of_final_ne, ne_words, ne_tag, current_tag, current_token): 47 | if ne_tag is None and ne_words == []: 48 | ne_tag = current_tag[1:] 49 | ne_words = [current_token] 50 | else: 51 | if current_tag.startswith('I') and ne_tag == current_tag[1:]: 52 | ne_words.append(current_token) 53 | else: 54 | dict_of_final_ne, ne_words, ne_tag = __replace_by_new(dict_of_final_ne, ne_words, ne_tag, current_tag, 55 | current_token) 56 | return dict_of_final_ne, ne_words, ne_tag 57 | 58 | 59 | def __check_l(dict_of_final_ne, ne_words, ne_tag, current_tag, current_token): 60 | if ne_tag == current_tag[1:]: 61 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words+[current_token], ne_tag) 62 | else: 63 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag) 64 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, [current_token], current_tag[1:]) 65 | return dict_of_final_ne, ne_words, ne_tag 66 | 67 | 68 | def __check_u(dict_of_final_ne, ne_words, ne_tag, current_tag, current_token): 69 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag) 70 | return __finish_ne_if_required(dict_of_final_ne, [current_token], current_tag[1:]) 71 | 72 | 73 | def __replace_by_new(dict_of_final_ne, ne_words, ne_tag, current_tag, current_token): 74 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag) 75 | ne_tag = current_tag[1:] 76 | ne_words = [current_token] 77 | return dict_of_final_ne, ne_words, ne_tag 78 | 79 | 80 | def __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag): 81 | if ne_tag is not None and ne_words != []: 82 | dict_of_final_ne[tuple(ne_words)] = ne_tag 83 | ne_tag = None 84 | ne_words = [] 85 | return dict_of_final_ne, ne_words, ne_tag 86 | 87 | 88 | def __to_output_format(dict_nes): 89 | """ 90 | :param dict_nes: 91 | :return: 92 | """ 93 | list_of_results_for_output = [] 94 | 95 | for tokens_tuple, tag in dict_nes.items(): 96 | position = int(tokens_tuple[0].get_position()) 97 | length = int(tokens_tuple[-1].get_position()) + int(tokens_tuple[-1].get_length()) - position 98 | list_of_results_for_output.append([tag, position, length]) 99 | 100 | return list_of_results_for_output 101 | -------------------------------------------------------------------------------- /modules/data/fre/bilou/to_bilou.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from ..entity.taggedtoken import TaggedToken 3 | 4 | 5 | def get_tagged_tokens_from(dict_of_nes, token_list): 6 | list_of_tagged_tokens = [TaggedToken('O', token_list[i]) for i in range(len(token_list))] 7 | dict_of_tokens_with_indexes = {token_list[i].id: i for i in range(len(token_list))} 8 | 9 | for ne in dict_of_nes.values(): 10 | for tokenid in ne['tokens_list']: 11 | try: 12 | tag = format_tag(tokenid, ne) 13 | except ValueError: 14 | tag = "O" 15 | id_in_token_tuple = dict_of_tokens_with_indexes[tokenid] 16 | token = token_list[id_in_token_tuple] 17 | list_of_tagged_tokens[id_in_token_tuple] = TaggedToken(tag, token) 18 | return list_of_tagged_tokens 19 | 20 | 21 | def format_tag(tokenid, ne): 22 | bilou = __choose_bilou_tag_for(tokenid, ne['tokens_list']) 23 | formatted_tag = __tag_to_fact_ru_eval_format(ne['tag']) 24 | return "{}_{}".format(bilou, formatted_tag) 25 | 26 | 27 | def __choose_bilou_tag_for(token_id, token_list): 28 | if len(token_list) == 1: 29 | return 'B' 30 | elif len(token_list) > 1: 31 | if token_list.index(token_id) == 0: 32 | return 'B' 33 | else: 34 | return 'I' 35 | 36 | 37 | def __tag_to_fact_ru_eval_format(tag): 38 | if tag == 'Person': 39 | return 'PER' 40 | elif tag == 'Org': 41 | return 'ORG' 42 | elif tag == 'Location': 43 | return 'LOC' 44 | elif tag == 'LocOrg': 45 | return 'LOC' 46 | elif tag == 'Project': 47 | return 'ORG' 48 | else: 49 | raise ValueError('tag ' + tag + " is not the right tag") 50 | -------------------------------------------------------------------------------- /modules/data/fre/entity/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modules/data/fre/entity/document.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from .token import Token 3 | from .taggedtoken import TaggedToken 4 | from collections import defaultdict 5 | from ..bilou import to_bilou 6 | 7 | 8 | class Document(object): 9 | def __init__(self, path, tagged=True, encoding="utf-8"): 10 | self.path = path 11 | self.tagged = tagged 12 | self.encoding = encoding 13 | self.tokens = [] 14 | self.tagged_tokens = [] 15 | self.load() 16 | 17 | def to_text_tokens(self): 18 | return [token.text for token in self.tokens] 19 | 20 | def get_tags(self): 21 | return [token.get_tag() for token in self.tagged_tokens] 22 | 23 | def load(self): 24 | self.tokens = self.__get_tokens_from_file() 25 | if self.tagged: 26 | self.tagged_tokens = self.__get_tagged_tokens_from() 27 | else: 28 | self.tagged_tokens = [TaggedToken(None, token) for token in self.tokens] 29 | return self 30 | 31 | def parse_file(self, path): 32 | with codecs.open(path, 'r', encoding=self.encoding, errors="ignore") as file: 33 | rows = file.read().split('\n') 34 | return [row.split(' # ')[0].split() for row in rows if len(row) != 0] 35 | 36 | def __get_tokens_from_file(self): 37 | rows = self.parse_file(self.path + '.tokens') 38 | tokens = [] 39 | for token_str in rows: 40 | tokens.append(Token().from_sting(token_str)) 41 | return tokens 42 | 43 | def __get_tagged_tokens_from(self): 44 | span_dict = self.__span_id2token_ids(self.path + '.spans', [token.id for token in self.tokens]) 45 | object_dict = self.__to_dict_of_objects(self.path + '.objects') 46 | dict_of_nes = self.__merge(object_dict, span_dict, self.tokens) 47 | return to_bilou.get_tagged_tokens_from(dict_of_nes, self.tokens) 48 | 49 | def __span_id2token_ids(self, span_file, token_ids): 50 | span_list = self.parse_file(span_file) 51 | dict_of_spans = {} 52 | for span in span_list: 53 | span_id = span[0] 54 | span_start = span[4] 55 | span_length_in_tokens = int(span[5]) 56 | list_of_token_of_spans = self.__find_tokens_for(span_start, span_length_in_tokens, token_ids) 57 | dict_of_spans[span_id] = list_of_token_of_spans 58 | return dict_of_spans 59 | 60 | @staticmethod 61 | def __find_tokens_for(start, length, token_ids): 62 | list_of_tokens = [] 63 | index = token_ids.index(start) 64 | for i in range(length): 65 | list_of_tokens.append(token_ids[index + i]) 66 | return list_of_tokens 67 | 68 | def __to_dict_of_objects(self, object_file): 69 | object_list = self.parse_file(object_file) 70 | dict_of_objects = {} 71 | for obj in object_list: 72 | object_id = obj[0] 73 | object_tag = obj[1] 74 | object_spans = obj[2:] 75 | dict_of_objects[object_id] = {'tag': object_tag, 'spans': object_spans} 76 | return dict_of_objects 77 | 78 | def __merge(self, object_dict, span_dict, tokens): 79 | ne_dict = self.__get_dict_of_nes(object_dict, span_dict) 80 | return self.__clean(ne_dict, tokens) 81 | 82 | @staticmethod 83 | def __get_dict_of_nes(object_dict, span_dict): 84 | ne_dict = defaultdict(set) 85 | for obj_id, obj_values in object_dict.items(): 86 | for span in obj_values['spans']: 87 | ne_dict[(obj_id, obj_values['tag'])].update(span_dict[span]) 88 | for ne in ne_dict: 89 | ne_dict[ne] = sorted(list(set([int(i) for i in ne_dict[ne]]))) 90 | return ne_dict 91 | 92 | def __clean(self, ne_dict, tokens): 93 | sorted_nes = sorted(ne_dict.items(), key=self.__sort_by_tokens) 94 | dict_of_tokens_by_id = {} 95 | for i in range(len(tokens)): 96 | dict_of_tokens_by_id[tokens[i].id] = i 97 | result_nes = {} 98 | if len(sorted_nes) != 0: 99 | start_ne = sorted_nes[0] 100 | for ne in sorted_nes: 101 | if self.__not_intersect(start_ne[1], ne[1]): 102 | result_nes[start_ne[0][0]] = { 103 | 'tokens_list': self.__check_order(start_ne[1], dict_of_tokens_by_id, tokens), 104 | 'tag': start_ne[0][1]} 105 | start_ne = ne 106 | else: 107 | result_tokens_list = self.__check_normal_form(start_ne[1], ne[1]) 108 | start_ne = (start_ne[0], result_tokens_list) 109 | result_nes[start_ne[0][0]] = { 110 | 'tokens_list': self.__check_order(start_ne[1], dict_of_tokens_by_id, tokens), 111 | 'tag': start_ne[0][1]} 112 | return result_nes 113 | 114 | @staticmethod 115 | def __sort_by_tokens(tokens): 116 | ids_as_int = [int(token_id) for token_id in tokens[1]] 117 | return min(ids_as_int), -max(ids_as_int) 118 | 119 | @staticmethod 120 | def __not_intersect(start_ne, current_ne): 121 | intersection = set.intersection(set(start_ne), set(current_ne)) 122 | return intersection == set() 123 | 124 | def __check_normal_form(self, start_ne, ne): 125 | all_tokens = set.union(set(start_ne), set(ne)) 126 | return self.__find_all_range_of_tokens(all_tokens) 127 | 128 | @staticmethod 129 | def __find_all_range_of_tokens(tokens): 130 | tokens = sorted(tokens) 131 | if (tokens[-1] - tokens[0] - len(tokens)) < 5: 132 | return list(range(tokens[0], tokens[-1] + 1)) 133 | else: 134 | return tokens 135 | 136 | def __check_order(self, list_of_tokens, dict_of_tokens_by_id, tokens): 137 | list_of_tokens = [str(i) for i in self.__find_all_range_of_tokens(list_of_tokens)] 138 | result = [] 139 | for token in list_of_tokens: 140 | if token in dict_of_tokens_by_id: 141 | result.append((token, dict_of_tokens_by_id[token])) 142 | result = sorted(result, key=self.__sort_by_position) 143 | result = self.__add_quotation_marks(result, tokens) 144 | return [r[0] for r in result] 145 | 146 | @staticmethod 147 | def __sort_by_position(result_tuple): 148 | return result_tuple[1] 149 | 150 | @staticmethod 151 | def __add_quotation_marks(result, tokens): 152 | result_tokens_texts = [tokens[token[1]].text for token in result] 153 | prev_pos = result[0][1] - 1 154 | next_pos = result[-1][1] + 1 155 | 156 | if prev_pos >= 0 and tokens[prev_pos].text == '«' \ 157 | and '»' in result_tokens_texts and '«' not in result_tokens_texts: 158 | result = [(tokens[prev_pos].id, prev_pos)] + result 159 | 160 | if next_pos < len(tokens) and tokens[next_pos].text == '»' \ 161 | and '«' in result_tokens_texts and '»' not in result_tokens_texts: 162 | result = result + [(tokens[next_pos].id, next_pos)] 163 | 164 | return result 165 | -------------------------------------------------------------------------------- /modules/data/fre/entity/taggedtoken.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | class TaggedToken(object): 5 | 6 | @property 7 | def text(self): 8 | return self.__token.text 9 | 10 | def __init__(self, tag, token): 11 | self.__tag = tag 12 | self.__token = token 13 | 14 | def get_token(self): 15 | return self.__token 16 | 17 | def get_tag(self): 18 | return self.__tag 19 | 20 | def __repr__(self): 21 | if self.__tag: 22 | return "<" + self.__tag + "_" + str(self.__token) + ">" 23 | else: 24 | return "" 25 | -------------------------------------------------------------------------------- /modules/data/fre/entity/token.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | class Token(object): 5 | __token_id__ = 0 6 | 7 | @property 8 | def length(self): 9 | return self.__length 10 | 11 | @property 12 | def position(self): 13 | return self.__position 14 | 15 | @property 16 | def id(self): 17 | return self.__id 18 | 19 | @property 20 | def text(self): 21 | return self.__text 22 | 23 | @property 24 | def all(self): 25 | return self.__id, self.__position, self.__length, self.__text 26 | 27 | @property 28 | def tag(self): 29 | return self.tag 30 | 31 | def __init__(self, token_id=None, position=None, length=None, text=None): 32 | self.__id = token_id 33 | if token_id is None: 34 | self.__id = Token.__token_id__ 35 | Token.__token_id__ += 1 36 | self.__position = position 37 | self.__length = length 38 | self.__text = text 39 | self.__tag = None 40 | 41 | def from_sting(self, string): 42 | self.__id, self.__position, self.__length, self.__text = string 43 | return self 44 | 45 | def __len__(self): 46 | return self.__length 47 | 48 | def __str__(self): 49 | return self.__text 50 | 51 | def __repr__(self): 52 | return "<<" + self.__id + "_" + self.__text + ">>" 53 | -------------------------------------------------------------------------------- /modules/data/fre/prc.py: -------------------------------------------------------------------------------- 1 | from modules.data.fre.reader import Reader 2 | import pandas as pd 3 | from modules import tqdm 4 | import argparse 5 | 6 | 7 | def fact_ru_eval_preprocess(dev_dir, test_dir, dev_df_path, test_df_path): 8 | dev_reader = Reader(dev_dir) 9 | dev_reader.read_dir() 10 | dev_texts, dev_tags = dev_reader.split() 11 | res_tags = [] 12 | res_tokens = [] 13 | for tag, tokens in tqdm(zip(dev_tags, dev_texts), total=len(dev_tags), desc="Process FactRuEval2016 dev set."): 14 | if len(tag): 15 | res_tags.append(tag) 16 | res_tokens.append(tokens) 17 | dev = pd.DataFrame({"labels": list(map(" ".join, res_tags)), "text": list(map(" ".join, res_tokens))}) 18 | dev["clf"] = dev["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()])) 19 | dev.to_csv(dev_df_path, index=False, sep="\t") 20 | 21 | test_reader = Reader(test_dir) 22 | test_reader.read_dir() 23 | test_texts, test_tags = test_reader.split() 24 | res_tags = [] 25 | res_tokens = [] 26 | for tag, tokens in tqdm(zip(test_tags, test_texts), total=len(test_tags), desc="Process FactRuEval2016 test set."): 27 | if len(tag): 28 | res_tags.append(tag) 29 | res_tokens.append(tokens) 30 | valid = pd.DataFrame({"labels": list(map(" ".join, res_tags)), "text": list(map(" ".join, res_tokens))}) 31 | valid["clf"] = valid["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()])) 32 | valid.to_csv(test_df_path, index=False, sep="\t") 33 | 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('-dd', '--dev_dir', type=str) 38 | parser.add_argument('-td', '--test_dir', type=str) 39 | parser.add_argument('-ddp', '--dev_df_path', type=str) 40 | parser.add_argument('-tdp', '--test_df_path', type=str) 41 | return vars(parser.parse_args()) 42 | 43 | 44 | if __name__ == "__main__": 45 | fact_ru_eval_preprocess(**parse_args()) 46 | -------------------------------------------------------------------------------- /modules/data/fre/reader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pandas as pd 3 | from .utils import get_file_names 4 | from .entity.document import Document 5 | 6 | 7 | class Reader(object): 8 | 9 | def __init__(self, 10 | dir_path, 11 | document_creator=Document, 12 | get_file_names_=get_file_names, 13 | tagged=True): 14 | self.path = dir_path 15 | self.tagged = tagged 16 | self.documents = [] 17 | self.document_creator = document_creator 18 | self.get_file_names = get_file_names_ 19 | 20 | def split(self, use_morph=False): 21 | res_texts = [] 22 | res_tags = [] 23 | for doc in self.documents: 24 | sent_tokens = [] 25 | sent_tags = [] 26 | for token in doc.tagged_tokens: 27 | if token.get_tag() == "O" and token.text == ".": 28 | res_texts.append(tuple(sent_tokens)) 29 | res_tags.append(tuple(sent_tags)) 30 | sent_tokens = [] 31 | sent_tags = [] 32 | else: 33 | text = token.text 34 | sent_tokens.append(text) 35 | sent_tags.append(token.get_tag()) 36 | if use_morph: 37 | return res_texts, res_tags 38 | return res_texts, res_tags 39 | 40 | def to_data_frame(self, split=False): 41 | if split: 42 | docs = self.split() 43 | else: 44 | docs = [] 45 | for doc in self.documents: 46 | docs.append([(token.text, token.get_tag()) for token in doc.tagged_tokens]) 47 | 48 | texts = [] 49 | tags = [] 50 | for sent in docs: 51 | sample_text = [] 52 | sample_tag = [] 53 | for text, tag in sent: 54 | sample_text.append(text) 55 | sample_tag.append(tag) 56 | texts.append(" ".join(sample_text)) 57 | tags.append(" ".join(sample_tag)) 58 | return pd.DataFrame({"texts": texts, "tags": tags}, columns=["texts", "tags"]) 59 | 60 | def read_dir(self): 61 | for path in self.get_file_names(self.path): 62 | self.documents.append(self.document_creator(path, self.tagged)) 63 | 64 | def get_text_tokens(self): 65 | return [doc.to_text_tokens() for doc in self.documents] 66 | 67 | def get_text_tags(self): 68 | return [doc.get_tags() for doc in self.documents] 69 | -------------------------------------------------------------------------------- /modules/data/fre/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_file_names(path): 5 | res = [] 6 | for root, dirs, files in os.walk(path): 7 | for file in files: 8 | if file.endswith('.tokens'): 9 | res.append(os.path.join(root, os.path.splitext(file)[0])) 10 | return res 11 | -------------------------------------------------------------------------------- /modules/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/ner-bert/b75a903c35acbd36ff5f26c525e3596294f36815/modules/layers/__init__.py -------------------------------------------------------------------------------- /modules/layers/crf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | # TODO: move to utils 6 | def log_sum_exp(tensor, dim=0): 7 | """LogSumExp operation.""" 8 | m, _ = torch.max(tensor, dim) 9 | m_exp = m.unsqueeze(-1).expand_as(tensor) 10 | return m + torch.log(torch.sum(torch.exp(tensor - m_exp), dim)) 11 | 12 | 13 | def sequence_mask(lens, max_len=None): 14 | batch_size = lens.size(0) 15 | 16 | if max_len is None: 17 | max_len = lens.max().item() 18 | 19 | ranges = torch.arange(0, max_len).long() 20 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len) 21 | 22 | if lens.data.is_cuda: 23 | ranges = ranges.cuda() 24 | 25 | lens_exp = lens.unsqueeze(1).expand_as(ranges) 26 | mask = ranges < lens_exp 27 | 28 | return mask 29 | 30 | 31 | class CRF(nn.Module): 32 | def forward(self, *input_): 33 | return self.viterbi_decode(*input_) 34 | 35 | def __init__(self, label_size): 36 | super(CRF, self).__init__() 37 | 38 | self.label_size = label_size 39 | self.start = self.label_size - 2 40 | self.end = self.label_size - 1 41 | transition = torch.randn(self.label_size, self.label_size) 42 | self.transition = nn.Parameter(transition) 43 | self.initialize() 44 | 45 | def initialize(self): 46 | self.transition.data[:, self.end] = -100.0 47 | self.transition.data[self.start, :] = -100.0 48 | 49 | @staticmethod 50 | def pad_logits(logits): 51 | # lens = lens.data 52 | batch_size, seq_len, label_num = logits.size() 53 | # pads = Variable(logits.data.new(batch_size, seq_len, 2).fill_(-1000.0), 54 | # requires_grad=False) 55 | pads = logits.new_full((batch_size, seq_len, 2), -1000.0, 56 | requires_grad=False) 57 | logits = torch.cat([logits, pads], dim=2) 58 | return logits 59 | 60 | def calc_binary_score(self, labels, lens): 61 | batch_size, seq_len = labels.size() 62 | 63 | # labels_ext = Variable(labels.data.new(batch_size, seq_len + 2)) 64 | labels_ext = labels.new_empty((batch_size, seq_len + 2)) 65 | labels_ext[:, 0] = self.start 66 | labels_ext[:, 1:-1] = labels 67 | mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long() 68 | # pad_stop = Variable(labels.data.new(1).fill_(self.end)) 69 | pad_stop = labels.new_full((1,), self.end, requires_grad=False) 70 | pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2) 71 | labels_ext = (1 - mask) * pad_stop + mask * labels_ext 72 | labels = labels_ext 73 | 74 | trn = self.transition 75 | trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size()) 76 | lbl_r = labels[:, 1:] 77 | lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0)) 78 | trn_row = torch.gather(trn_exp, 1, lbl_rexp) 79 | 80 | lbl_lexp = labels[:, :-1].unsqueeze(-1) 81 | trn_scr = torch.gather(trn_row, 2, lbl_lexp) 82 | trn_scr = trn_scr.squeeze(-1) 83 | 84 | mask = sequence_mask(lens + 1).float() 85 | trn_scr = trn_scr * mask 86 | score = trn_scr 87 | 88 | return score 89 | 90 | @staticmethod 91 | def calc_unary_score(logits, labels, lens): 92 | labels_exp = labels.unsqueeze(-1) 93 | scores = torch.gather(logits, 2, labels_exp).squeeze(-1) 94 | mask = sequence_mask(lens).float() 95 | scores = scores * mask 96 | return scores 97 | 98 | def calc_gold_score(self, logits, labels, lens): 99 | unary_score = self.calc_unary_score(logits, labels, lens).sum( 100 | 1).squeeze(-1) 101 | binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1) 102 | return unary_score + binary_score 103 | 104 | def calc_norm_score(self, logits, lens): 105 | batch_size, seq_len, feat_dim = logits.size() 106 | # alpha = logits.data.new(batch_size, self.label_size).fill_(-10000.0) 107 | alpha = logits.new_full((batch_size, self.label_size), -100.0) 108 | alpha[:, self.start] = 0 109 | # alpha = Variable(alpha) 110 | lens_ = lens.clone() 111 | 112 | logits_t = logits.transpose(1, 0) 113 | for logit in logits_t: 114 | logit_exp = logit.unsqueeze(-1).expand(batch_size, 115 | *self.transition.size()) 116 | alpha_exp = alpha.unsqueeze(1).expand(batch_size, 117 | *self.transition.size()) 118 | trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp) 119 | mat = logit_exp + alpha_exp + trans_exp 120 | alpha_nxt = log_sum_exp(mat, 2).squeeze(-1) 121 | 122 | mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha) 123 | alpha = mask * alpha_nxt + (1 - mask) * alpha 124 | lens_ = lens_ - 1 125 | 126 | alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha) 127 | norm = log_sum_exp(alpha, 1).squeeze(-1) 128 | 129 | return norm 130 | 131 | def viterbi_decode(self, logits, lens): 132 | """Borrowed from pytorch tutorial 133 | Arguments: 134 | logits: [batch_size, seq_len, n_labels] FloatTensor 135 | lens: [batch_size] LongTensor 136 | """ 137 | batch_size, seq_len, n_labels = logits.size() 138 | # vit = logits.data.new(batch_size, self.label_size).fill_(-10000) 139 | vit = logits.new_full((batch_size, self.label_size), -100.0) 140 | vit[:, self.start] = 0 141 | # vit = Variable(vit) 142 | c_lens = lens.clone() 143 | 144 | logits_t = logits.transpose(1, 0) 145 | pointers = [] 146 | for logit in logits_t: 147 | vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels) 148 | trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp) 149 | vit_trn_sum = vit_exp + trn_exp 150 | vt_max, vt_argmax = vit_trn_sum.max(2) 151 | 152 | vt_max = vt_max.squeeze(-1) 153 | vit_nxt = vt_max + logit 154 | pointers.append(vt_argmax.squeeze(-1).unsqueeze(0)) 155 | 156 | mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt) 157 | vit = mask * vit_nxt + (1 - mask) * vit 158 | 159 | mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt) 160 | vit += mask * self.transition[self.end].unsqueeze( 161 | 0).expand_as(vit_nxt) 162 | 163 | c_lens = c_lens - 1 164 | 165 | pointers = torch.cat(pointers) 166 | scores, idx = vit.max(1) 167 | # idx = idx.squeeze(-1) 168 | paths = [idx.unsqueeze(1)] 169 | for argmax in reversed(pointers): 170 | idx_exp = idx.unsqueeze(-1) 171 | idx = torch.gather(argmax, 1, idx_exp) 172 | idx = idx.squeeze(-1) 173 | 174 | paths.insert(0, idx.unsqueeze(1)) 175 | 176 | paths = torch.cat(paths[1:], 1) 177 | scores = scores.squeeze(-1) 178 | 179 | return scores, paths 180 | -------------------------------------------------------------------------------- /modules/layers/embedders.py: -------------------------------------------------------------------------------- 1 | from pytorch_pretrained_bert import BertModel 2 | import torch 3 | 4 | 5 | class BERTEmbedder(torch.nn.Module): 6 | def __init__(self, model, config): 7 | super(BERTEmbedder, self).__init__() 8 | self.config = config 9 | self.model = model 10 | if self.config["mode"] == "weighted": 11 | self.bert_weights = torch.nn.Parameter(torch.FloatTensor(12, 1)) 12 | self.bert_gamma = torch.nn.Parameter(torch.FloatTensor(1, 1)) 13 | self.init_weights() 14 | 15 | def init_weights(self): 16 | if self.config["mode"] == "weighted": 17 | torch.nn.init.xavier_normal(self.bert_gamma) 18 | torch.nn.init.xavier_normal(self.bert_weights) 19 | 20 | @classmethod 21 | def create( 22 | cls, model_name='bert-base-multilingual-cased', 23 | device="cuda", mode="weighted", 24 | is_freeze=True): 25 | config = { 26 | "model_name": model_name, 27 | "device": device, 28 | "mode": mode, 29 | "is_freeze": is_freeze 30 | } 31 | model = BertModel.from_pretrained(model_name) 32 | model.to(device) 33 | model.train() 34 | self = cls(model, config) 35 | if is_freeze: 36 | self.freeze() 37 | return self 38 | 39 | @classmethod 40 | def from_config(cls, config): 41 | return cls.create(**config) 42 | 43 | def forward(self, batch): 44 | """ 45 | batch has the following structure: 46 | data[0]: list, tokens ids 47 | data[1]: list, tokens mask 48 | data[2]: list, tokens type ids (for bert) 49 | data[3]: list, bert labels ids 50 | """ 51 | encoded_layers, _ = self.model( 52 | input_ids=batch[0], 53 | token_type_ids=batch[2], 54 | attention_mask=batch[1], 55 | output_all_encoded_layers=self.config["mode"] == "weighted") 56 | if self.config["mode"] == "weighted": 57 | encoded_layers = torch.stack([a * b for a, b in zip(encoded_layers, self.bert_weights)]) 58 | return self.bert_gamma * torch.sum(encoded_layers, dim=0) 59 | return encoded_layers 60 | 61 | def freeze(self): 62 | for param in self.model.parameters(): 63 | param.requires_grad = False 64 | -------------------------------------------------------------------------------- /modules/layers/layers.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import init 6 | from torch.nn.utils import rnn as rnn_utils 7 | import math 8 | 9 | 10 | class BiLSTM(nn.Module): 11 | 12 | def __init__(self, embedding_size=768, hidden_dim=512, rnn_layers=1, dropout=0.5): 13 | super(BiLSTM, self).__init__() 14 | self.embedding_size = embedding_size 15 | self.hidden_dim = hidden_dim 16 | self.rnn_layers = rnn_layers 17 | self.dropout = nn.Dropout(dropout) 18 | self.lstm = nn.LSTM( 19 | embedding_size, 20 | hidden_dim // 2, 21 | rnn_layers, batch_first=True, bidirectional=True) 22 | 23 | def forward(self, input_, input_mask): 24 | length = input_mask.sum(-1) 25 | sorted_lengths, sorted_idx = torch.sort(length, descending=True) 26 | input_ = input_[sorted_idx] 27 | packed_input = rnn_utils.pack_padded_sequence(input_, sorted_lengths.data.tolist(), batch_first=True) 28 | output, (hidden, _) = self.lstm(packed_input) 29 | padded_outputs = rnn_utils.pad_packed_sequence(output, batch_first=True)[0] 30 | _, reversed_idx = torch.sort(sorted_idx) 31 | return padded_outputs[reversed_idx], hidden[:, reversed_idx] 32 | 33 | @classmethod 34 | def create(cls, *args, **kwargs): 35 | return cls(*args, **kwargs) 36 | 37 | 38 | class Linear(nn.Linear): 39 | def __init__(self, 40 | in_features: int, 41 | out_features: int, 42 | bias: bool = True): 43 | super(Linear, self).__init__(in_features, out_features, bias=bias) 44 | init.orthogonal_(self.weight) 45 | 46 | 47 | class Linears(nn.Module): 48 | def __init__(self, 49 | in_features, 50 | out_features, 51 | hiddens, 52 | bias=True, 53 | activation='tanh'): 54 | super(Linears, self).__init__() 55 | assert len(hiddens) > 0 56 | 57 | self.in_features = in_features 58 | self.out_features = self.output_size = out_features 59 | 60 | in_dims = [in_features] + hiddens[:-1] 61 | self.linears = nn.ModuleList([Linear(in_dim, out_dim, bias=bias) 62 | for in_dim, out_dim 63 | in zip(in_dims, hiddens)]) 64 | self.output_linear = Linear(hiddens[-1], out_features, bias=bias) 65 | self.activation = getattr(functional, activation) 66 | 67 | def forward(self, inputs): 68 | linear_outputs = inputs 69 | for linear in self.linears: 70 | linear_outputs = linear.forward(linear_outputs) 71 | linear_outputs = self.activation(linear_outputs) 72 | return self.output_linear.forward(linear_outputs) 73 | 74 | 75 | # Reused from https://github.com/JayParks/transformer/ 76 | class ScaledDotProductAttention(nn.Module): 77 | def __init__(self, d_k, dropout=.1): 78 | super(ScaledDotProductAttention, self).__init__() 79 | self.scale_factor = np.sqrt(d_k) 80 | self.softmax = nn.Softmax(dim=-1) 81 | self.dropout = nn.Dropout(dropout) 82 | 83 | def forward(self, q, k, v, attn_mask=None): 84 | # q: [b_size x len_q x d_k] 85 | # k: [b_size x len_k x d_k] 86 | # v: [b_size x len_v x d_v] note: (len_k == len_v) 87 | attn = torch.bmm(q, k.transpose(1, 2)) / self.scale_factor # attn: [b_size x len_q x len_k] 88 | if attn_mask is not None: 89 | print(attn_mask.size(), attn.size()) 90 | assert attn_mask.size() == attn.size() 91 | attn.data.masked_fill_(attn_mask, -float('inf')) 92 | 93 | attn = self.softmax(attn) 94 | attn = self.dropout(attn) 95 | outputs = torch.bmm(attn, v) # outputs: [b_size x len_q x d_v] 96 | 97 | return outputs, attn 98 | 99 | 100 | class LayerNormalization(nn.Module): 101 | def __init__(self, d_hid, eps=1e-3): 102 | super(LayerNormalization, self).__init__() 103 | self.gamma = nn.Parameter(torch.ones(d_hid), requires_grad=True) 104 | self.beta = nn.Parameter(torch.zeros(d_hid), requires_grad=True) 105 | self.eps = eps 106 | 107 | def forward(self, z): 108 | mean = z.mean(dim=-1, keepdim=True,) 109 | std = z.std(dim=-1, keepdim=True,) 110 | ln_out = (z - mean.expand_as(z)) / (std.expand_as(z) + self.eps) 111 | ln_out = self.gamma.expand_as(ln_out) * ln_out + self.beta.expand_as(ln_out) 112 | 113 | return ln_out 114 | 115 | 116 | class _MultiHeadAttention(nn.Module): 117 | def __init__(self, d_k, d_v, d_model, n_heads, dropout): 118 | super(_MultiHeadAttention, self).__init__() 119 | self.d_k = d_k 120 | self.d_v = d_v 121 | self.d_model = d_model 122 | self.n_heads = n_heads 123 | self.w_q = nn.Parameter(torch.FloatTensor(n_heads, d_model, d_k)) 124 | self.w_k = nn.Parameter(torch.FloatTensor(n_heads, d_model, d_k)) 125 | self.w_v = nn.Parameter(torch.FloatTensor(n_heads, d_model, d_v)) 126 | 127 | self.attention = ScaledDotProductAttention(d_k, dropout) 128 | 129 | init.xavier_normal(self.w_q) 130 | init.xavier_normal(self.w_k) 131 | init.xavier_normal(self.w_v) 132 | 133 | def forward(self, q, k, v, attn_mask=None): 134 | (d_k, d_v, d_model, n_heads) = (self.d_k, self.d_v, self.d_model, self.n_heads) 135 | b_size = k.size(0) 136 | 137 | q_s = q.repeat(n_heads, 1, 1).view(n_heads, -1, d_model) # [n_heads x b_size * len_q x d_model] 138 | k_s = k.repeat(n_heads, 1, 1).view(n_heads, -1, d_model) # [n_heads x b_size * len_k x d_model] 139 | v_s = v.repeat(n_heads, 1, 1).view(n_heads, -1, d_model) # [n_heads x b_size * len_v x d_model] 140 | 141 | q_s = torch.bmm(q_s, self.w_q).view(b_size * n_heads, -1, d_k) # [b_size * n_heads x len_q x d_k] 142 | k_s = torch.bmm(k_s, self.w_k).view(b_size * n_heads, -1, d_k) # [b_size * n_heads x len_k x d_k] 143 | v_s = torch.bmm(v_s, self.w_v).view(b_size * n_heads, -1, d_v) # [b_size * n_heads x len_v x d_v] 144 | 145 | # perform attention, result_size = [b_size * n_heads x len_q x d_v] 146 | if attn_mask is not None: 147 | attn_mask = attn_mask.repeat(n_heads, 1, 1) 148 | outputs, attn = self.attention(q_s, k_s, v_s, attn_mask=attn_mask) 149 | 150 | # return a list of tensors of shape [b_size x len_q x d_v] (length: n_heads) 151 | return torch.split(outputs, b_size, dim=0), attn 152 | 153 | 154 | class MultiHeadAttention(nn.Module): 155 | def __init__(self, d_k, d_v, d_model, n_heads, dropout): 156 | super(MultiHeadAttention, self).__init__() 157 | self.attention = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout) 158 | self.proj = Linear(n_heads * d_v, d_model) 159 | self.dropout = nn.Dropout(dropout) 160 | self.layer_norm = LayerNormalization(d_model) 161 | 162 | def forward(self, q, k, v, attn_mask): 163 | # q: [b_size x len_q x d_model] 164 | # k: [b_size x len_k x d_model] 165 | # v: [b_size x len_v x d_model] note (len_k == len_v) 166 | residual = q 167 | # outputs: a list of tensors of shape [b_size x len_q x d_v] (length: n_heads) 168 | outputs, attn = self.attention(q, k, v, attn_mask=attn_mask) 169 | # concatenate 'n_heads' multi-head attentions 170 | outputs = torch.cat(outputs, dim=-1) 171 | # project back to residual size, result_size = [b_size x len_q x d_model] 172 | outputs = self.proj(outputs) 173 | outputs = self.dropout(outputs) 174 | 175 | return self.layer_norm(residual + outputs), attn 176 | 177 | 178 | class _BahdanauAttention(nn.Module): 179 | def __init__(self, method, hidden_size): 180 | super(_BahdanauAttention, self).__init__() 181 | self.method = method 182 | self.hidden_size = hidden_size 183 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size) 184 | self.v = nn.Parameter(torch.rand(hidden_size)) 185 | stdv = 1. / math.sqrt(self.v.size(0)) 186 | self.v.data.normal_(mean=0, std=stdv) 187 | 188 | def forward(self, hidden, encoder_outputs, mask=None): 189 | """ 190 | :param hidden: 191 | previous hidden state of the decoder, in shape (layers*directions,B,H) 192 | :param encoder_outputs: 193 | encoder outputs from Encoder, in shape (T,B,H) 194 | :param mask: 195 | used for masking. NoneType or tensor in shape (B) indicating sequence length 196 | :return 197 | attention energies in shape (B,T) 198 | """ 199 | max_len = encoder_outputs.size(0) 200 | # this_batch_size = encoder_outputs.size(1) 201 | H = hidden.repeat(max_len, 1, 1).transpose(0, 1) 202 | # [B*T*H] 203 | encoder_outputs = encoder_outputs.transpose(0, 1) 204 | # compute attention score 205 | attn_energies = self.score(H, encoder_outputs) 206 | if mask is not None: 207 | attn_energies = attn_energies.masked_fill(mask, -1e18) 208 | # normalize with softmax 209 | return functional.softmax(attn_energies).unsqueeze(1) 210 | 211 | def score(self, hidden, encoder_outputs): 212 | # [B*T*2H]->[B*T*H] 213 | energy = functional.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2))) 214 | # [B*H*T] 215 | energy = energy.transpose(2, 1) 216 | # [B*1*H] 217 | v = self.v.repeat(encoder_outputs.data.shape[0], 1).unsqueeze(1) 218 | # [B*1*T] 219 | energy = torch.bmm(v, energy) 220 | # [B*T] 221 | return energy.squeeze(1) 222 | 223 | 224 | class BahdanauAttention(nn.Module): 225 | """Reused from https://github.com/chrisbangun/pytorch-seq2seq_with_attention/""" 226 | 227 | def __init__(self, hidden_dim=128, query_dim=128, memory_dim=128): 228 | super(BahdanauAttention, self).__init__() 229 | 230 | self.hidden_dim = hidden_dim 231 | self.query_dim = query_dim 232 | self.memory_dim = memory_dim 233 | self.sofmax = nn.Softmax() 234 | 235 | self.query_layer = nn.Linear(query_dim, hidden_dim, bias=False) 236 | self.memory_layer = nn.Linear(memory_dim, hidden_dim, bias=False) 237 | self.alignment_layer = nn.Linear(hidden_dim, 1, bias=False) 238 | 239 | def alignment_score(self, query, keys): 240 | query = self.query_layer(query) 241 | keys = self.memory_layer(keys) 242 | 243 | extendded_query = query.unsqueeze(1) 244 | alignment = self.alignment_layer(functional.tanh(extendded_query + keys)) 245 | return alignment.squeeze(2) 246 | 247 | def forward(self, query, keys): 248 | alignment_score = self.alignment_score(query, keys) 249 | weight = functional.softmax(alignment_score) 250 | context = weight.unsqueeze(2) * keys 251 | total_context = context.sum(1) 252 | return total_context, alignment_score 253 | -------------------------------------------------------------------------------- /modules/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/ner-bert/b75a903c35acbd36ff5f26c525e3596294f36815/modules/models/__init__.py -------------------------------------------------------------------------------- /modules/models/classifiers.py: -------------------------------------------------------------------------------- 1 | from .bert_models import BERTNerModel 2 | from modules.layers.decoders import * 3 | from modules.layers.embedders import * 4 | from modules.layers.layers import * 5 | 6 | 7 | class BERTLinearsClassifier(BERTNerModel): 8 | 9 | def __init__(self, embeddings, linear, dropout, activation, device="cuda"): 10 | super(BERTLinearsClassifier, self).__init__() 11 | self.embeddings = embeddings 12 | self.linear = linear 13 | self.dropout = dropout 14 | self.activation = activation 15 | self.intent_loss = nn.CrossEntropyLoss() 16 | self.to(device) 17 | 18 | @staticmethod 19 | def pool(x, bs, is_max): 20 | """Pool the tensor along the seq_len dimension.""" 21 | f = functional.adaptive_max_pool1d if is_max else functional.adaptive_avg_pool1d 22 | return f(x.permute(1, 2, 0), (1,)).view(bs, -1) 23 | 24 | def forward(self, batch): 25 | input_embeddings = self.embeddings(batch) 26 | output = self.dropout(input_embeddings).transpose(0, 1) 27 | sl, bs, _ = output.size() 28 | output = self.pool(output, bs, True) 29 | output = self.linear(output) 30 | return self.activation(output).argmax(-1) 31 | 32 | def score(self, batch): 33 | input_embeddings = self.embeddings(batch) 34 | output = self.dropout(input_embeddings).transpose(0, 1) 35 | sl, bs, _ = output.size() 36 | output = self.pool(output, bs, True) 37 | output = self.linear(output) 38 | return self.intent_loss(self.activation(output), batch[-1]) 39 | 40 | @classmethod 41 | def create(cls, 42 | intent_size, 43 | # BertEmbedder params 44 | model_name='bert-base-multilingual-cased', mode="weighted", is_freeze=True, 45 | # Decoder params 46 | embedding_size=768, clf_dropout=0.3, num_hiddens=2, 47 | activation="tanh", 48 | # Global params 49 | device="cuda"): 50 | embeddings = BERTEmbedder.create(model_name=model_name, device=device, mode=mode, is_freeze=is_freeze) 51 | linear = Linears(embedding_size, intent_size, [embedding_size // 2**idx for idx in range(num_hiddens)]) 52 | dropout = nn.Dropout(clf_dropout) 53 | activation = getattr(functional, activation) 54 | return cls(embeddings, linear, dropout, activation, device) 55 | 56 | 57 | class BERTLinearClassifier(BERTNerModel): 58 | 59 | def __init__(self, embeddings, linear, dropout, activation, device="cuda"): 60 | super(BERTLinearClassifier, self).__init__() 61 | self.embeddings = embeddings 62 | self.linear = linear 63 | self.dropout = dropout 64 | self.activation = activation 65 | self.intent_loss = nn.CrossEntropyLoss() 66 | self.to(device) 67 | 68 | @staticmethod 69 | def pool(x, bs, is_max): 70 | """Pool the tensor along the seq_len dimension.""" 71 | f = functional.adaptive_max_pool1d if is_max else functional.adaptive_avg_pool1d 72 | return f(x.permute(1, 2, 0), (1,)).view(bs, -1) 73 | 74 | def forward(self, batch): 75 | input_embeddings = self.embeddings(batch) 76 | output = self.dropout(input_embeddings).transpose(0, 1) 77 | sl, bs, _ = output.size() 78 | output = self.pool(output, bs, True) 79 | output = self.linear(output) 80 | return self.activation(output).argmax(-1) 81 | 82 | def score(self, batch): 83 | input_embeddings = self.embeddings(batch) 84 | output = self.dropout(input_embeddings).transpose(0, 1) 85 | sl, bs, _ = output.size() 86 | output = self.pool(output, bs, True) 87 | output = self.linear(output) 88 | return self.intent_loss(self.activation(output), batch[-1]) 89 | 90 | @classmethod 91 | def create(cls, 92 | intent_size, 93 | # BertEmbedder params 94 | model_name='bert-base-multilingual-cased', mode="weighted", is_freeze=True, 95 | # Decoder params 96 | embedding_size=768, clf_dropout=0.3, 97 | activation="sigmoid", 98 | # Global params 99 | device="cuda"): 100 | embeddings = BERTEmbedder.create(model_name=model_name, device=device, mode=mode, is_freeze=is_freeze) 101 | linear = Linear(embedding_size, intent_size) 102 | dropout = nn.Dropout(clf_dropout) 103 | activation = getattr(functional, activation) 104 | return cls(embeddings, linear, dropout, activation, device) 105 | 106 | 107 | class BERTBaseClassifier(BERTNerModel): 108 | 109 | def __init__(self, embeddings, clf, device="cuda"): 110 | super(BERTBaseClassifier, self).__init__() 111 | self.embeddings = embeddings 112 | self.clf = clf 113 | self.to(device) 114 | 115 | def forward(self, batch): 116 | input_embeddings = self.embeddings(batch) 117 | return self.clf(input_embeddings) 118 | 119 | def score(self, batch): 120 | input_, labels_mask, input_type_ids, cls_ids = batch 121 | input_embeddings = self.embeddings(batch) 122 | return self.clf.score(input_embeddings, cls_ids) 123 | 124 | @classmethod 125 | def create(cls, 126 | intent_size, 127 | # BertEmbedder params 128 | model_name='bert-base-multilingual-cased', mode="weighted", is_freeze=True, 129 | # Decoder params 130 | embedding_size=768, clf_dropout=0.3, 131 | # Global params 132 | device="cuda"): 133 | embeddings = BERTEmbedder.create(model_name=model_name, device=device, mode=mode, is_freeze=is_freeze) 134 | clf = ClassDecoder(intent_size, embedding_size, clf_dropout) 135 | return cls(embeddings, clf, device) 136 | 137 | 138 | class BERTBiLSTMAttnClassifier(BERTNerModel): 139 | 140 | def __init__(self, embeddings, lstm, attn, clf, device="cuda"): 141 | super(BERTBiLSTMAttnClassifier, self).__init__() 142 | self.embeddings = embeddings 143 | self.lstm = lstm 144 | self.attn = attn 145 | self.clf = clf 146 | self.to(device) 147 | 148 | def forward(self, batch): 149 | input_, labels_mask, input_type_ids = batch[:3] 150 | input_embeddings = self.embeddings(batch) 151 | output, _ = self.lstm.forward(input_embeddings, labels_mask) 152 | output, _ = self.attn(output, output, output, None) 153 | return self.clf(output) 154 | 155 | def score(self, batch): 156 | input_, labels_mask, input_type_ids = batch[:3] 157 | input_embeddings = self.embeddings(batch) 158 | output, _ = self.lstm.forward(input_embeddings, labels_mask) 159 | output, _ = self.attn(output, output, output, None) 160 | return self.clf.score(output, batch[-1]) 161 | 162 | @classmethod 163 | def create(cls, 164 | intent_size, 165 | # BertEmbedder params 166 | model_name='bert-base-multilingual-cased', mode="weighted", is_freeze=True, 167 | # Decoder params 168 | clf_dropout=0.3, 169 | # BiLSTM 170 | hidden_dim=512, rnn_layers=1, lstm_dropout=0.3, 171 | # Attn params 172 | embedding_size=768, key_dim=64, val_dim=64, num_heads=3, attn_dropout=0.3, 173 | # Global params 174 | device="cuda"): 175 | embeddings = BERTEmbedder.create(model_name=model_name, device=device, mode=mode, is_freeze=is_freeze) 176 | lstm = BiLSTM.create( 177 | embedding_size=embedding_size, hidden_dim=hidden_dim, rnn_layers=rnn_layers, dropout=lstm_dropout) 178 | attn = MultiHeadAttention(key_dim, val_dim, hidden_dim, num_heads, attn_dropout) 179 | clf = ClassDecoder(intent_size, hidden_dim, clf_dropout) 180 | return cls(embeddings, lstm, attn, clf, device) 181 | -------------------------------------------------------------------------------- /modules/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/ner-bert/b75a903c35acbd36ff5f26c525e3596294f36815/modules/train/__init__.py -------------------------------------------------------------------------------- /modules/train/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/optimization.py 3 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch optimization for BERT model.""" 17 | 18 | import math 19 | import torch 20 | from torch.optim import Optimizer 21 | from torch.optim.optimizer import required 22 | from torch.nn.utils import clip_grad_norm_ 23 | 24 | 25 | def warmup_cosine(x, warmup=0.002): 26 | if x < warmup: 27 | return x/warmup 28 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 29 | 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | if x < warmup: 33 | return x/warmup 34 | return 1.0 35 | 36 | 37 | def warmup_linear(x, warmup=0.002): 38 | if x < warmup: 39 | return x/warmup 40 | return 1.0 - x 41 | 42 | 43 | SCHEDULES = { 44 | 'warmup_cosine': warmup_cosine, 45 | 'warmup_constant': warmup_constant, 46 | 'warmup_linear': warmup_linear, 47 | } 48 | 49 | 50 | class BertAdam(Optimizer): 51 | """Implements BERT version of Adam algorithm with weight decay fix. 52 | Params: 53 | lr: learning rate 54 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 55 | t_total: total number of training steps for the learning 56 | rate schedule, -1 means constant learning rate. Default: -1 57 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 58 | b1: Adams b1. Default: 0.9 59 | b2: Adams b2. Default: 0.999 60 | e: Adams epsilon. Default: 1e-6 61 | weight_decay: Weight decay. Default: 0.01 62 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 63 | """ 64 | def __init__(self, model, lr=required, warmup=0.1, t_total=-1, schedule='warmup_linear', 65 | b1=0.8, b2=0.999, e=1e-6, weight_decay=0.01, 66 | max_grad_norm=1.0): 67 | if lr is not required and lr < 0.0: 68 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 69 | if schedule not in SCHEDULES: 70 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 71 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 72 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 73 | if not 0.0 <= b1 < 1.0: 74 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 75 | if not 0.0 <= b2 < 1.0: 76 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 77 | if not e >= 0.0: 78 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 79 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 80 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 81 | max_grad_norm=max_grad_norm) 82 | # Prepare optimizer 83 | param_optimizer = list(model.named_parameters()) 84 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 85 | optimizer_grouped_parameters = [ 86 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 87 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 88 | ] 89 | super(BertAdam, self).__init__(optimizer_grouped_parameters, defaults) 90 | self.global_step = 1 91 | self.t_total = t_total 92 | 93 | def get_lr(self): 94 | lr = [] 95 | for group in self.param_groups: 96 | for p in group['params']: 97 | state = self.state[p] 98 | if len(state) == 0: 99 | return [0] 100 | if group['t_total'] != -1: 101 | schedule_fct = SCHEDULES[group['schedule']] 102 | lr_scheduled = group['lr'] * schedule_fct(state['step'] / group['t_total'], group['warmup']) 103 | else: 104 | lr_scheduled = group['lr'] 105 | lr.append(lr_scheduled) 106 | return lr 107 | 108 | def update_lr(self): 109 | if 0 < self.t_total: 110 | lr_this_step = self.defaults["lr"] * warmup_linear(self.global_step / self.t_total, self.defaults["warmup"]) 111 | for param_group in self.param_groups: 112 | param_group['lr'] = lr_this_step 113 | 114 | def step(self, closure=None): 115 | """Performs a single optimization step. 116 | Arguments: 117 | closure (callable, optional): A closure that reevaluates the model 118 | and returns the loss. 119 | """ 120 | self.update_lr() 121 | loss = None 122 | if closure is not None: 123 | loss = closure() 124 | 125 | for group in self.param_groups: 126 | for p in group['params']: 127 | if p.grad is None: 128 | continue 129 | grad = p.grad.data 130 | if grad.is_sparse: 131 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 132 | 133 | state = self.state[p] 134 | 135 | # State initialization 136 | if len(state) == 0: 137 | state['step'] = 0 138 | # Exponential moving average of gradient values 139 | state['next_m'] = torch.zeros_like(p.data) 140 | # Exponential moving average of squared gradient values 141 | state['next_v'] = torch.zeros_like(p.data) 142 | 143 | next_m, next_v = state['next_m'], state['next_v'] 144 | beta1, beta2 = group['b1'], group['b2'] 145 | 146 | # Add grad clipping 147 | if group['max_grad_norm'] > 0: 148 | clip_grad_norm_(p, group['max_grad_norm']) 149 | 150 | # Decay the first and second moment running average coefficient 151 | # In-place operations to update the averages at the same time 152 | next_m.mul_(beta1).add_(1 - beta1, grad) 153 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 154 | update = next_m / (next_v.sqrt() + group['e']) 155 | 156 | # Just adding the square of the weights to the loss function is *not* 157 | # the correct way of using L2 regularization/weight decay with Adam, 158 | # since that will interact with the m and v parameters in strange ways. 159 | # 160 | # Instead we want to decay the weights in a manner that doesn't interact 161 | # with the m/v parameters. This is equivalent to adding the square 162 | # of the weights to the loss with plain (non-momentum) SGD. 163 | if group['weight_decay'] > 0.0: 164 | update += group['weight_decay'] * p.data 165 | 166 | if group['t_total'] != -1: 167 | schedule_fct = SCHEDULES[group['schedule']] 168 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 169 | else: 170 | lr_scheduled = group['lr'] 171 | 172 | update_with_lr = lr_scheduled * update 173 | p.data.add_(-update_with_lr) 174 | 175 | state['step'] += 1 176 | 177 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 178 | # No bias correction 179 | # bias_correction1 = 1 - beta1 ** state['step'] 180 | # bias_correction2 = 1 - beta2 ** state['step'] 181 | self.global_step += 1 182 | return loss 183 | -------------------------------------------------------------------------------- /modules/train/train.py: -------------------------------------------------------------------------------- 1 | from modules import tqdm 2 | from sklearn_crfsuite.metrics import flat_classification_report 3 | import logging 4 | import torch 5 | from .optimization import BertAdam 6 | from modules.analyze_utils.plot_metrics import get_mean_max_metric 7 | from modules.data.bert_data import get_data_loader_for_predict 8 | 9 | 10 | def train_step(dl, model, optimizer, num_epoch=1): 11 | model.train() 12 | epoch_loss = 0 13 | idx = 0 14 | pr = tqdm(dl, total=len(dl), leave=False) 15 | for batch in pr: 16 | idx += 1 17 | model.zero_grad() 18 | loss = model.score(batch) 19 | loss.backward() 20 | optimizer.step() 21 | optimizer.zero_grad() 22 | loss = loss.data.cpu().tolist() 23 | epoch_loss += loss 24 | pr.set_description("train loss: {}".format(epoch_loss / idx)) 25 | torch.cuda.empty_cache() 26 | logging.info("\nepoch {}, average train epoch loss={:.5}\n".format( 27 | num_epoch, epoch_loss / idx)) 28 | 29 | 30 | def transformed_result(preds, mask, id2label, target_all=None, pad_idx=0): 31 | preds_cpu = [] 32 | targets_cpu = [] 33 | lc = len(id2label) 34 | if target_all is not None: 35 | for batch_p, batch_t, batch_m in zip(preds, target_all, mask): 36 | for pred, true_, bm in zip(batch_p, batch_t, batch_m): 37 | sent = [] 38 | sent_t = [] 39 | bm = bm.sum().cpu().data.tolist() 40 | for p, t in zip(pred[:bm], true_[:bm]): 41 | p = p.cpu().data.tolist() 42 | p = p if p < lc else pad_idx 43 | sent.append(p) 44 | sent_t.append(t.cpu().data.tolist()) 45 | preds_cpu.append([id2label[w] for w in sent]) 46 | targets_cpu.append([id2label[w] for w in sent_t]) 47 | else: 48 | for batch_p, batch_m in zip(preds, mask): 49 | 50 | for pred, bm in zip(batch_p, batch_m): 51 | assert len(pred) == len(bm) 52 | bm = bm.sum().cpu().data.tolist() 53 | sent = pred[:bm].cpu().data.tolist() 54 | preds_cpu.append([id2label[w] for w in sent]) 55 | if target_all is not None: 56 | return preds_cpu, targets_cpu 57 | else: 58 | return preds_cpu 59 | 60 | 61 | def transformed_result_cls(preds, target_all, cls2label, return_target=True): 62 | preds_cpu = [] 63 | targets_cpu = [] 64 | for batch_p, batch_t in zip(preds, target_all): 65 | for pred, true_ in zip(batch_p, batch_t): 66 | preds_cpu.append(cls2label[pred.cpu().data.tolist()]) 67 | if return_target: 68 | targets_cpu.append(cls2label[true_.cpu().data.tolist()]) 69 | if return_target: 70 | return preds_cpu, targets_cpu 71 | return preds_cpu 72 | 73 | 74 | def validate_step(dl, model, id2label, sup_labels, id2cls=None): 75 | model.eval() 76 | idx = 0 77 | preds_cpu, targets_cpu = [], [] 78 | preds_cpu_cls, targets_cpu_cls = [], [] 79 | for batch in tqdm(dl, total=len(dl), leave=False): 80 | idx += 1 81 | labels_mask, labels_ids = batch[1], batch[3] 82 | preds = model.forward(batch) 83 | if id2cls is not None: 84 | preds, preds_cls = preds 85 | preds_cpu_, targets_cpu_ = transformed_result_cls([preds_cls], [batch[-1]], id2cls) 86 | preds_cpu_cls.extend(preds_cpu_) 87 | targets_cpu_cls.extend(targets_cpu_) 88 | preds_cpu_, targets_cpu_ = transformed_result([preds], [labels_mask], id2label, [labels_ids]) 89 | preds_cpu.extend(preds_cpu_) 90 | targets_cpu.extend(targets_cpu_) 91 | clf_report = flat_classification_report(targets_cpu, preds_cpu, labels=sup_labels, digits=3) 92 | if id2cls is not None: 93 | clf_report_cls = flat_classification_report([targets_cpu_cls], [preds_cpu_cls], digits=3) 94 | return clf_report, clf_report_cls 95 | return clf_report 96 | 97 | 98 | def predict(dl, model, id2label, id2cls=None): 99 | model.eval() 100 | idx = 0 101 | preds_cpu = [] 102 | preds_cpu_cls = [] 103 | for batch in tqdm(dl, total=len(dl), leave=False, desc="Predicting"): 104 | idx += 1 105 | labels_mask, labels_ids = batch[1], batch[3] 106 | preds = model.forward(batch) 107 | if id2cls is not None: 108 | preds, preds_cls = preds 109 | preds_cpu_ = transformed_result_cls([preds_cls], [preds_cls], id2cls, False) 110 | preds_cpu_cls.extend(preds_cpu_) 111 | 112 | preds_cpu_ = transformed_result([preds], [labels_mask], id2label) 113 | preds_cpu.extend(preds_cpu_) 114 | if id2cls is not None: 115 | return preds_cpu, preds_cpu_cls 116 | return preds_cpu 117 | 118 | 119 | class NerLearner(object): 120 | 121 | def __init__(self, model, data, best_model_path, lr=0.001, betas=[0.8, 0.9], clip=1.0, 122 | verbose=True, sup_labels=None, t_total=-1, warmup=0.1, weight_decay=0.01, 123 | validate_every=1, schedule="warmup_linear", e=1e-6): 124 | logging.basicConfig(level=logging.INFO) 125 | self.model = model 126 | self.optimizer = BertAdam(model, lr, t_total=t_total, b1=betas[0], b2=betas[1], max_grad_norm=clip) 127 | self.optimizer_defaults = dict( 128 | model=model, lr=lr, warmup=warmup, t_total=t_total, schedule=schedule, 129 | b1=betas[0], b2=betas[1], e=e, weight_decay=weight_decay, 130 | max_grad_norm=clip) 131 | 132 | self.lr = lr 133 | self.betas = betas 134 | self.clip = clip 135 | self.sup_labels = sup_labels 136 | self.t_total = t_total 137 | self.warmup = warmup 138 | self.weight_decay = weight_decay 139 | self.validate_every = validate_every 140 | self.schedule = schedule 141 | self.data = data 142 | self.e = e 143 | if sup_labels is None: 144 | sup_labels = data.train_ds.idx2label[4:] 145 | self.sup_labels = sup_labels 146 | self.best_model_path = best_model_path 147 | self.verbose = verbose 148 | self.history = [] 149 | self.cls_history = [] 150 | self.epoch = 0 151 | self.best_target_metric = 0. 152 | 153 | def fit(self, epochs=100, resume_history=True, target_metric="f1"): 154 | if not resume_history: 155 | self.optimizer_defaults["t_total"] = epochs * len(self.data.train_dl) 156 | self.optimizer = BertAdam(**self.optimizer_defaults) 157 | self.history = [] 158 | self.cls_history = [] 159 | self.epoch = 0 160 | self.best_target_metric = 0. 161 | elif self.verbose: 162 | logging.info("Resuming train... Current epoch {}.".format(self.epoch)) 163 | try: 164 | for _ in range(epochs): 165 | self.epoch += 1 166 | self.fit_one_cycle(self.epoch, target_metric) 167 | except KeyboardInterrupt: 168 | pass 169 | 170 | def fit_one_cycle(self, epoch, target_metric="f1"): 171 | train_step(self.data.train_dl, self.model, self.optimizer, epoch) 172 | if epoch % self.validate_every == 0: 173 | if self.data.train_ds.is_cls: 174 | rep, rep_cls = validate_step( 175 | self.data.valid_dl, self.model, self.data.train_ds.idx2label, self.sup_labels, 176 | self.data.train_ds.idx2cls) 177 | self.cls_history.append(rep_cls) 178 | else: 179 | rep = validate_step( 180 | self.data.valid_dl, self.model, self.data.train_ds.idx2label, self.sup_labels) 181 | self.history.append(rep) 182 | idx, metric = get_mean_max_metric(self.history, target_metric, True) 183 | if self.verbose: 184 | logging.info("on epoch {} by max_{}: {}".format(idx, target_metric, metric)) 185 | print(self.history[-1]) 186 | if self.data.train_ds.is_cls: 187 | logging.info("on epoch {} classification report:") 188 | print(self.cls_history[-1]) 189 | # Store best model 190 | if self.best_target_metric < metric: 191 | self.best_target_metric = metric 192 | if self.verbose: 193 | logging.info("Saving new best model...") 194 | self.save_model() 195 | 196 | def predict(self, dl=None, df_path=None, df=None): 197 | if dl is None: 198 | dl = get_data_loader_for_predict(self.data, df_path, df) 199 | if self.data.train_ds.is_cls: 200 | return predict(dl, self.model, self.data.train_ds.idx2label, self.data.train_ds.idx2cls) 201 | return predict(dl, self.model, self.data.train_ds.idx2label) 202 | 203 | def save_model(self, path=None): 204 | path = path if path else self.best_model_path 205 | torch.save(self.model.state_dict(), path) 206 | 207 | def load_model(self, path=None): 208 | path = path if path else self.best_model_path 209 | self.model.load_state_dict(torch.load(path)) 210 | -------------------------------------------------------------------------------- /modules/train/train_clf.py: -------------------------------------------------------------------------------- 1 | from modules import tqdm 2 | from sklearn_crfsuite.metrics import flat_classification_report 3 | import logging 4 | import torch 5 | from .optimization import BertAdam 6 | from modules.analyze_utils.plot_metrics import get_mean_max_metric 7 | from modules.data.bert_data_clf import get_data_loader_for_predict 8 | 9 | 10 | def train_step(dl, model, optimizer, num_epoch=1): 11 | model.train() 12 | epoch_loss = 0 13 | idx = 0 14 | pr = tqdm(dl, total=len(dl), leave=False) 15 | for batch in pr: 16 | idx += 1 17 | model.zero_grad() 18 | loss = model.score(batch) 19 | loss.backward() 20 | optimizer.step() 21 | optimizer.zero_grad() 22 | loss = loss.data.cpu().tolist() 23 | epoch_loss += loss 24 | pr.set_description("train loss: {}".format(epoch_loss / idx)) 25 | torch.cuda.empty_cache() 26 | logging.info("\nepoch {}, average train epoch loss={:.5}\n".format( 27 | num_epoch, epoch_loss / idx)) 28 | 29 | 30 | def transformed_result_cls(preds, target_all, cls2label, return_target=True): 31 | preds_cpu = [] 32 | targets_cpu = [] 33 | for batch_p, batch_t in zip(preds, target_all): 34 | for pred, true_ in zip(batch_p, batch_t): 35 | preds_cpu.append(cls2label[pred.cpu().data.tolist()]) 36 | if return_target: 37 | targets_cpu.append(cls2label[true_.cpu().data.tolist()]) 38 | if return_target: 39 | return preds_cpu, targets_cpu 40 | return preds_cpu 41 | 42 | 43 | def validate_step(dl, model, id2cls): 44 | model.eval() 45 | idx = 0 46 | preds_cpu_cls, targets_cpu_cls = [], [] 47 | for batch in tqdm(dl, total=len(dl), leave=False, desc="Validation"): 48 | idx += 1 49 | preds_cls = model.forward(batch) 50 | preds_cpu_, targets_cpu_ = transformed_result_cls([preds_cls], [batch[-1]], id2cls) 51 | preds_cpu_cls.extend(preds_cpu_) 52 | targets_cpu_cls.extend(targets_cpu_) 53 | clf_report_cls = flat_classification_report([targets_cpu_cls], [preds_cpu_cls], digits=4) 54 | return clf_report_cls 55 | 56 | 57 | def predict(dl, model, id2cls): 58 | model.eval() 59 | idx = 0 60 | preds_cpu_cls = [] 61 | for batch in tqdm(dl, total=len(dl), leave=False, desc="Predicting"): 62 | idx += 1 63 | preds_cls = model.forward(batch) 64 | preds_cpu_ = transformed_result_cls([preds_cls], [preds_cls], id2cls, False) 65 | preds_cpu_cls.extend(preds_cpu_) 66 | 67 | return preds_cpu_cls 68 | 69 | 70 | class NerLearner(object): 71 | 72 | def __init__(self, model, data, best_model_path, lr=0.001, betas=[0.8, 0.9], clip=1.0, 73 | verbose=True, t_total=-1, warmup=0.1, weight_decay=0.01, 74 | validate_every=1, schedule="warmup_linear", e=1e-6): 75 | logging.basicConfig(level=logging.INFO) 76 | self.model = model 77 | self.optimizer = BertAdam(model, lr, t_total=t_total, b1=betas[0], b2=betas[1], max_grad_norm=clip) 78 | self.optimizer_defaults = dict( 79 | model=model, lr=lr, warmup=warmup, t_total=t_total, schedule=schedule, 80 | b1=betas[0], b2=betas[1], e=e, weight_decay=weight_decay, 81 | max_grad_norm=clip) 82 | 83 | self.lr = lr 84 | self.betas = betas 85 | self.clip = clip 86 | self.t_total = t_total 87 | self.warmup = warmup 88 | self.weight_decay = weight_decay 89 | self.validate_every = validate_every 90 | self.schedule = schedule 91 | self.data = data 92 | self.e = e 93 | self.best_model_path = best_model_path 94 | self.verbose = verbose 95 | self.cls_history = [] 96 | self.epoch = 0 97 | self.best_target_metric = 0. 98 | 99 | def fit(self, epochs=100, resume_history=True, target_metric="f1"): 100 | if not resume_history: 101 | self.optimizer_defaults["t_total"] = epochs * len(self.data.train_dl) 102 | self.optimizer = BertAdam(**self.optimizer_defaults) 103 | self.cls_history = [] 104 | self.epoch = 0 105 | self.best_target_metric = 0. 106 | elif self.verbose: 107 | logging.info("Resuming train... Current epoch {}.".format(self.epoch)) 108 | try: 109 | for _ in range(epochs): 110 | self.epoch += 1 111 | self.fit_one_cycle(self.epoch, target_metric) 112 | except KeyboardInterrupt: 113 | pass 114 | 115 | def fit_one_cycle(self, epoch, target_metric="f1"): 116 | train_step(self.data.train_dl, self.model, self.optimizer, epoch) 117 | if epoch % self.validate_every == 0: 118 | rep_cls = validate_step(self.data.valid_dl, self.model, self.data.train_ds.idx2cls) 119 | self.cls_history.append(rep_cls) 120 | idx, metric = get_mean_max_metric(self.cls_history, target_metric, True) 121 | if self.verbose: 122 | logging.info("on epoch {} by max_{}: {}".format(idx, target_metric, metric)) 123 | print(self.cls_history[-1]) 124 | 125 | # Store best model 126 | if self.best_target_metric < metric: 127 | self.best_target_metric = metric 128 | if self.verbose: 129 | logging.info("Saving new best model...") 130 | self.save_model() 131 | 132 | def predict(self, dl=None, df_path=None, df=None): 133 | if dl is None: 134 | dl, ds = get_data_loader_for_predict(self.data, df_path, df) 135 | return predict(dl, self.model, self.data.train_ds.idx2cls) 136 | 137 | def save_model(self, path=None): 138 | path = path if path else self.best_model_path 139 | torch.save(self.model.state_dict(), path) 140 | 141 | def load_model(self, path=None): 142 | path = path if path else self.best_model_path 143 | self.model.load_state_dict(torch.load(path)) 144 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy 4 | import bson 5 | import sys 6 | 7 | 8 | def ipython_info(): 9 | ip = False 10 | if 'ipykernel' in sys.modules: 11 | ip = 'notebook' 12 | elif 'IPython' in sys.modules: 13 | ip = 'terminal' 14 | return ip 15 | 16 | 17 | def get_tqdm(): 18 | ip = ipython_info() 19 | if ip == "terminal" or not ip: 20 | from tqdm import tqdm 21 | return tqdm 22 | else: 23 | try: 24 | from tqdm import tqdm_notebook 25 | return tqdm_notebook 26 | except: 27 | from tqdm import tqdm 28 | return tqdm 29 | 30 | 31 | class JsonEncoder(json.JSONEncoder): 32 | def default(self, obj): 33 | if isinstance(obj, numpy.integer): 34 | return int(obj) 35 | elif isinstance(obj, numpy.floating): 36 | return float(obj) 37 | elif isinstance(obj, numpy.ndarray): 38 | return obj.tolist() 39 | elif isinstance(obj, bson.ObjectId): 40 | return str(obj) 41 | else: 42 | return super(JsonEncoder, self).default(obj) 43 | 44 | 45 | def jsonify(data): 46 | return json.dumps(data, cls=JsonEncoder) 47 | 48 | 49 | def read_config(config): 50 | if isinstance(config, str): 51 | with open(config, "r", encoding="utf-8") as f: 52 | config = json.load(f) 53 | return config 54 | 55 | 56 | def save_config(config, path): 57 | with open(path, "w") as file: 58 | json.dump(config, file, cls=JsonEncoder) 59 | 60 | 61 | def if_none(origin, other): 62 | return other if origin is None else origin 63 | 64 | 65 | def get_files_path_from_dir(path): 66 | f = [] 67 | for dir_path, dir_names, filenames in os.walk(path): 68 | for f_name in filenames: 69 | f.append(dir_path + "/" + f_name) 70 | return f 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bson 2 | pandas 3 | scikit-learn 4 | sklearn-crfsuite 5 | tqdm 6 | rusenttokenize 7 | numpy 8 | nltk 9 | torch 10 | matplotlib --------------------------------------------------------------------------------