├── .gitignore
├── LICENSE
├── README.md
├── censor.py
├── exps
├── censor_cls.ipynb
├── conll2003 BERTBiLSTMAttnCRF base BERT.ipynb
├── conll2003 BERTBiLSTMAttnNCRF base BERT.ipynb
├── conll2003 BERTBiLSTMCRF base BERT.ipynb
├── conll2003 BERTBiLSTMCRF.ipynb
├── fre BERTAttnCRF.ipynb
├── fre BERTBiLSTMAttnCRF-fit_BERT.ipynb
├── fre BERTBiLSTMAttnCRF.ipynb
├── fre BERTBiLSTMAttnNCRF-fit_BERT.ipynb
├── fre BERTBiLSTMAttnNCRF.ipynb
├── fre BERTBiLSTMCRF.ipynb
├── fre BERTBiLSTMNCRF.ipynb
├── fre BERTCRF.ipynb
├── fre BERTNCRF.ipynb
└── prc fre.ipynb
├── modules
├── __init__.py
├── analyze_utils
│ ├── __init__.py
│ ├── main_metrics.py
│ ├── plot_metrics.py
│ └── utils.py
├── data
│ ├── __init__.py
│ ├── bert_data.py
│ ├── bert_data_clf.py
│ ├── conll2003
│ │ ├── __init__.py
│ │ └── prc.py
│ ├── download_data.py
│ └── fre
│ │ ├── __init__.py
│ │ ├── bilou
│ │ ├── __init__.py
│ │ ├── from_bilou.py
│ │ └── to_bilou.py
│ │ ├── entity
│ │ ├── __init__.py
│ │ ├── document.py
│ │ ├── taggedtoken.py
│ │ └── token.py
│ │ ├── prc.py
│ │ ├── reader.py
│ │ └── utils.py
├── layers
│ ├── __init__.py
│ ├── crf.py
│ ├── decoders.py
│ ├── embedders.py
│ ├── layers.py
│ └── ncrf.py
├── models
│ ├── __init__.py
│ ├── bert_models.py
│ └── classifiers.py
├── train
│ ├── __init__.py
│ ├── optimization.py
│ ├── train.py
│ └── train_clf.py
└── utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea/
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Sberbank AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## 0. Papers
2 | There are two solutions based on this architecture.
3 | 1. [BSNLP 2019 ACL workshop](http://bsnlp.cs.helsinki.fi/shared_task.html): [solution](https://github.com/king-menin/slavic-ner) and [paper](https://arxiv.org/abs/1906.09978) on multilingual shared task.
4 | 2. The second place [solution](https://github.com/king-menin/AGRR-2019) of [Dialogue AGRR-2019](https://github.com/dialogue-evaluation/AGRR-2019) task and [paper](http://www.dialog-21.ru/media/4679/emelyanov-artemova-gapping_parsing_using_pretrained_embeddings__attention_mechanisn_and_ncrf.pdf).
5 |
6 | ## Description
7 | This repository contains solution of NER task based on PyTorch [reimplementation](https://github.com/huggingface/pytorch-pretrained-BERT) of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
8 |
9 | This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)).
10 |
11 | Old version is in "old" branch.
12 |
13 | ## 2. Usage
14 | ### 2.1 Create data
15 | ```
16 | from modules.data import bert_data
17 | data = bert_data.LearnData.create(
18 | train_df_path=train_df_path,
19 | valid_df_path=valid_df_path,
20 | idx2labels_path="/path/to/vocab",
21 | clear_cache=True
22 | )
23 | ```
24 |
25 | ### 2.2 Create model
26 | ```
27 | from modules.models.bert_models import BERTBiLSTMAttnCRF
28 | model = BERTBiLSTMAttnCRF.create(len(data.train_ds.idx2label))
29 | ```
30 |
31 | ### 2.3 Create Learner
32 | ```
33 | from modules.train.train import NerLearner
34 | num_epochs = 100
35 | learner = NerLearner(
36 | model, data, "/path/for/save/best/model", t_total=num_epochs * len(data.train_dl))
37 | ```
38 |
39 | ### 2.4 Predict
40 | ```
41 | from modules.data.bert_data import get_data_loader_for_predict
42 | learner.load_model()
43 | dl = get_data_loader_for_predict(data, df_path="/path/to/df/for/predict")
44 | preds = learner.predict(dl)
45 | ```
46 |
47 | ### 2.5 Evaluate
48 | ```
49 | from sklearn_crfsuite.metrics import flat_classification_report
50 | from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer
51 | from modules.analyze_utils.plot_metrics import get_bert_span_report
52 | from modules.analyze_utils.main_metrics import precision_recall_f1
53 |
54 |
55 | pred_tokens, pred_labels = bert_labels2tokens(dl, preds)
56 | true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])
57 | tokens_report = flat_classification_report(true_labels, pred_labels, digits=4)
58 | print(tokens_report)
59 |
60 | results = precision_recall_f1(true_labels, pred_labels)
61 | ```
62 |
63 | ## 3. Results
64 | We didn't search best parametres and obtained the following results.
65 |
66 | | Model | Data set | Dev F1 tok | Dev F1 span | Test F1 tok | Test F1 span
67 | |-|-|-|-|-|-|
68 | |**OURS**||||||
69 | | M-BERTCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8543 | 0.8409
70 | | M-BERTNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8637 | 0.8516
71 | | M-BERTBiLSTMCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8835 | **0.8718**
72 | | M-BERTBiLSTMNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8632 | 0.8510
73 | | M-BERTAttnCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8503 | 0.8346
74 | | M-BERTBiLSTMAttnCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | **0.8839** | 0.8716
75 | | M-BERTBiLSTMAttnNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8807 | 0.8680
76 | | M-BERTBiLSTMAttnCRF-fit_BERT-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8823 | 0.8709
77 | | M-BERTBiLSTMAttnNCRF-fit_BERT-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8583 | 0.8456
78 | |-|-|-|-|-|-|
79 | | BERTBiLSTMCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9629 | - | 0.9221 | -
80 | | B-BERTBiLSTMCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9635 | - | 0.9229 | -
81 | | B-BERTBiLSTMAttnCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9614 | - | 0.9237 | -
82 | | B-BERTBiLSTMAttnNCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9631 | - | **0.9249** | -
83 | |**Current SOTA**||||||
84 | | DeepPavlov-RuBERT-NER | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | - | **0.8266**
85 | | CSE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | - | - | **0.931** | -
86 | | BERT-LARGE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.966 | - | 0.928 | -
87 | | BERT-BASE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.964 | - | 0.924 | -
88 |
--------------------------------------------------------------------------------
/censor.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import warnings
3 | from modules.data import bert_data_clf
4 | from modules.models.classifiers import BERTBiLSTMAttnClassifier
5 | from modules.train.train_clf import NerLearner
6 |
7 |
8 | warnings.filterwarnings("ignore")
9 | sys.path.append("../")
10 |
11 |
12 | def main():
13 | train_df_path = "/home/ubuntu/censor/train2.csv"
14 | valid_df_path = "/home/ubuntu/censor/dev2.csv"
15 | test_df_path = "/home/ubuntu/censor/test.csv"
16 | num_epochs = 100
17 |
18 |
19 | data = bert_data_clf.LearnDataClass.create(
20 | train_df_path=train_df_path,
21 | valid_df_path=valid_df_path,
22 | idx2cls_path="/home/ubuntu/censor/idx2cls.txt",
23 | clear_cache=False,
24 | batch_size=64
25 | )
26 |
27 | model = BERTBiLSTMAttnClassifier.create(len(data.train_ds.cls2idx), hidden_dim=768)
28 | learner = NerLearner(
29 | model, data, "/home/ubuntu/censor/cls.cpt4", t_total=num_epochs * len(data.train_dl))
30 | learner.fit(epochs=num_epochs)
31 |
32 |
33 | if __name__ == "__main__":
34 | main()
35 |
--------------------------------------------------------------------------------
/exps/conll2003 BERTBiLSTMCRF.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "\n",
13 | "import sys\n",
14 | "import warnings\n",
15 | "\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append(\"../\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 21,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "from modules.data.conll2003.prc import conll2003_preprocess"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 22,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "data_dir = \"/home/eartemov/ae/work/conll2003/\""
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 13,
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "data": {
46 | "application/vnd.jupyter.widget-view+json": {
47 | "model_id": "f14d2a1ce44947ce98b9f430cf82caf1",
48 | "version_major": 2,
49 | "version_minor": 0
50 | },
51 | "text/plain": [
52 | "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.train', max=2195…"
53 | ]
54 | },
55 | "metadata": {},
56 | "output_type": "display_data"
57 | },
58 | {
59 | "name": "stdout",
60 | "output_type": "stream",
61 | "text": [
62 | "\n"
63 | ]
64 | },
65 | {
66 | "data": {
67 | "application/vnd.jupyter.widget-view+json": {
68 | "model_id": "fa6fddbba84a4de78a48e7503af8d616",
69 | "version_major": 2,
70 | "version_minor": 0
71 | },
72 | "text/plain": [
73 | "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.testa', max=5504…"
74 | ]
75 | },
76 | "metadata": {},
77 | "output_type": "display_data"
78 | },
79 | {
80 | "name": "stdout",
81 | "output_type": "stream",
82 | "text": [
83 | "\n"
84 | ]
85 | },
86 | {
87 | "data": {
88 | "application/vnd.jupyter.widget-view+json": {
89 | "model_id": "e3c3e6e2ebdb4e51ba4051a1499e53cf",
90 | "version_major": 2,
91 | "version_minor": 0
92 | },
93 | "text/plain": [
94 | "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.testb', max=5035…"
95 | ]
96 | },
97 | "metadata": {},
98 | "output_type": "display_data"
99 | },
100 | {
101 | "name": "stdout",
102 | "output_type": "stream",
103 | "text": [
104 | "\n"
105 | ]
106 | }
107 | ],
108 | "source": [
109 | "conll2003_preprocess(data_dir)"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "## IO markup"
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "### Train"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 3,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "from modules.data import bert_data"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 4,
138 | "metadata": {},
139 | "outputs": [
140 | {
141 | "name": "stderr",
142 | "output_type": "stream",
143 | "text": [
144 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
145 | ]
146 | },
147 | {
148 | "data": {
149 | "application/vnd.jupyter.widget-view+json": {
150 | "model_id": "",
151 | "version_major": 2,
152 | "version_minor": 0
153 | },
154 | "text/plain": [
155 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=6973, style=ProgressStyle(descri…"
156 | ]
157 | },
158 | "metadata": {},
159 | "output_type": "display_data"
160 | },
161 | {
162 | "name": "stdout",
163 | "output_type": "stream",
164 | "text": [
165 | "\r"
166 | ]
167 | }
168 | ],
169 | "source": [
170 | "data = bert_data.LearnData.create(\n",
171 | " train_df_path=\"/home/eartemov/ae/work/conll2003/eng.train.train.csv\",\n",
172 | " valid_df_path=\"/home/eartemov/ae/work/conll2003/eng.testa.dev.csv\",\n",
173 | " idx2labels_path=\"/home/eartemov/ae/work/conll2003/idx2labels.txt\",\n",
174 | " clear_cache=True\n",
175 | ")"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 5,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "from modules.models.bert_models import BERTBiLSTMCRF"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 6,
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "model = BERTBiLSTMCRF.create(\n",
194 | " len(data.train_ds.idx2label),\n",
195 | " lstm_dropout=0., crf_dropout=0.3)"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 7,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "from modules.train.train import NerLearner"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 8,
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "num_epochs = 100"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 9,
219 | "metadata": {},
220 | "outputs": [],
221 | "source": [
222 | "learner = NerLearner(\n",
223 | " model, data, \"/home/eartemov/ae/work/models/conll2003-BERTBiLSTMCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 10,
229 | "metadata": {
230 | "scrolled": true
231 | },
232 | "outputs": [
233 | {
234 | "data": {
235 | "text/plain": [
236 | "2235023"
237 | ]
238 | },
239 | "execution_count": 10,
240 | "metadata": {},
241 | "output_type": "execute_result"
242 | }
243 | ],
244 | "source": [
245 | "model.get_n_trainable_params()"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "metadata": {
252 | "scrolled": true
253 | },
254 | "outputs": [],
255 | "source": [
256 | "learner.fit(epochs=num_epochs)"
257 | ]
258 | },
259 | {
260 | "cell_type": "markdown",
261 | "metadata": {},
262 | "source": [
263 | "### Predict"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 12,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": [
272 | "from modules.data.bert_data import get_data_loader_for_predict"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 13,
278 | "metadata": {},
279 | "outputs": [],
280 | "source": [
281 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": 14,
287 | "metadata": {},
288 | "outputs": [
289 | {
290 | "data": {
291 | "application/vnd.jupyter.widget-view+json": {
292 | "model_id": "",
293 | "version_major": 2,
294 | "version_minor": 0
295 | },
296 | "text/plain": [
297 | "HBox(children=(IntProgress(value=0, description='Predicting', max=109, style=ProgressStyle(description_width='…"
298 | ]
299 | },
300 | "metadata": {},
301 | "output_type": "display_data"
302 | },
303 | {
304 | "name": "stdout",
305 | "output_type": "stream",
306 | "text": [
307 | "\r"
308 | ]
309 | }
310 | ],
311 | "source": [
312 | "preds = learner.predict(dl)"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": 15,
318 | "metadata": {},
319 | "outputs": [],
320 | "source": [
321 | "from sklearn_crfsuite.metrics import flat_classification_report"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": 16,
327 | "metadata": {},
328 | "outputs": [],
329 | "source": [
330 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
331 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 17,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
341 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": 18,
347 | "metadata": {},
348 | "outputs": [],
349 | "source": [
350 | "assert pred_tokens == true_tokens\n",
351 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[4:], digits=4)"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": 20,
357 | "metadata": {},
358 | "outputs": [
359 | {
360 | "name": "stdout",
361 | "output_type": "stream",
362 | "text": [
363 | " precision recall f1-score support\n",
364 | "\n",
365 | " I_ORG 0.9514 0.9509 0.9511 2016\n",
366 | " I_O 0.9968 0.9970 0.9969 41702\n",
367 | " I_MISC 0.9353 0.8974 0.9160 1257\n",
368 | " I_PER 0.9849 0.9825 0.9837 2856\n",
369 | " I_LOC 0.9697 0.9637 0.9667 1926\n",
370 | "\n",
371 | " micro avg 0.9917 0.9905 0.9911 49757\n",
372 | " macro avg 0.9676 0.9583 0.9629 49757\n",
373 | "weighted avg 0.9916 0.9905 0.9910 49757\n",
374 | "\n"
375 | ]
376 | }
377 | ],
378 | "source": [
379 | "print(tokens_report)"
380 | ]
381 | },
382 | {
383 | "cell_type": "markdown",
384 | "metadata": {},
385 | "source": [
386 | "### Test"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": 12,
392 | "metadata": {},
393 | "outputs": [],
394 | "source": [
395 | "from modules.data.bert_data import get_data_loader_for_predict"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 24,
401 | "metadata": {},
402 | "outputs": [],
403 | "source": [
404 | "dl = get_data_loader_for_predict(data, df_path=\"/home/eartemov/ae/work/conll2003/eng.testb.dev.csv\")"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 25,
410 | "metadata": {},
411 | "outputs": [
412 | {
413 | "data": {
414 | "application/vnd.jupyter.widget-view+json": {
415 | "model_id": "",
416 | "version_major": 2,
417 | "version_minor": 0
418 | },
419 | "text/plain": [
420 | "HBox(children=(IntProgress(value=0, description='Predicting', max=98, style=ProgressStyle(description_width='i…"
421 | ]
422 | },
423 | "metadata": {},
424 | "output_type": "display_data"
425 | },
426 | {
427 | "name": "stdout",
428 | "output_type": "stream",
429 | "text": [
430 | "\r"
431 | ]
432 | }
433 | ],
434 | "source": [
435 | "preds = learner.predict(dl)"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": 26,
441 | "metadata": {},
442 | "outputs": [],
443 | "source": [
444 | "from sklearn_crfsuite.metrics import flat_classification_report"
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "execution_count": 27,
450 | "metadata": {},
451 | "outputs": [],
452 | "source": [
453 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
454 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": 28,
460 | "metadata": {},
461 | "outputs": [],
462 | "source": [
463 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
464 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": 29,
470 | "metadata": {},
471 | "outputs": [],
472 | "source": [
473 | "assert pred_tokens == true_tokens\n",
474 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[4:], digits=4)"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": 30,
480 | "metadata": {},
481 | "outputs": [
482 | {
483 | "name": "stdout",
484 | "output_type": "stream",
485 | "text": [
486 | " precision recall f1-score support\n",
487 | "\n",
488 | " I_ORG 0.8988 0.9147 0.9067 2368\n",
489 | " I_O 0.9952 0.9917 0.9934 37573\n",
490 | " I_MISC 0.8163 0.8055 0.8108 910\n",
491 | " I_PER 0.9759 0.9770 0.9765 2698\n",
492 | " I_LOC 0.9170 0.9296 0.9233 1819\n",
493 | "\n",
494 | " micro avg 0.9822 0.9806 0.9814 45368\n",
495 | " macro avg 0.9206 0.9237 0.9221 45368\n",
496 | "weighted avg 0.9823 0.9806 0.9814 45368\n",
497 | "\n"
498 | ]
499 | }
500 | ],
501 | "source": [
502 | "print(tokens_report)"
503 | ]
504 | }
505 | ],
506 | "metadata": {
507 | "kernelspec": {
508 | "display_name": "Python 3",
509 | "language": "python",
510 | "name": "python3"
511 | },
512 | "language_info": {
513 | "codemirror_mode": {
514 | "name": "ipython",
515 | "version": 3
516 | },
517 | "file_extension": ".py",
518 | "mimetype": "text/x-python",
519 | "name": "python",
520 | "nbconvert_exporter": "python",
521 | "pygments_lexer": "ipython3",
522 | "version": "3.6.8"
523 | }
524 | },
525 | "nbformat": 4,
526 | "nbformat_minor": 2
527 | }
528 |
--------------------------------------------------------------------------------
/exps/fre BERTAttnCRF.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "\n",
13 | "import sys\n",
14 | "import warnings\n",
15 | "\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append(\"../\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## IO markup"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### Train"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "from modules.data import bert_data"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\""
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stderr",
61 | "output_type": "stream",
62 | "text": [
63 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
64 | ]
65 | },
66 | {
67 | "data": {
68 | "application/vnd.jupyter.widget-view+json": {
69 | "model_id": "",
70 | "version_major": 2,
71 | "version_minor": 0
72 | },
73 | "text/plain": [
74 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…"
75 | ]
76 | },
77 | "metadata": {},
78 | "output_type": "display_data"
79 | },
80 | {
81 | "name": "stdout",
82 | "output_type": "stream",
83 | "text": [
84 | "\r"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "data = bert_data.LearnData.create(\n",
90 | " train_df_path=train_df_path,\n",
91 | " valid_df_path=valid_df_path,\n",
92 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels2.txt\",\n",
93 | " clear_cache=True\n",
94 | ")"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 5,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "from modules.models.bert_models import BERTAttnCRF"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 6,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "model = BERTAttnCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3)"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 7,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "from modules.train.train import NerLearner"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 8,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "num_epochs = 100"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 9,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "learner = NerLearner(\n",
140 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTAttnCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 10,
146 | "metadata": {},
147 | "outputs": [
148 | {
149 | "data": {
150 | "text/plain": [
151 | "890617"
152 | ]
153 | },
154 | "execution_count": 10,
155 | "metadata": {},
156 | "output_type": "execute_result"
157 | }
158 | ],
159 | "source": [
160 | "model.get_n_trainable_params()"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {
167 | "scrolled": true
168 | },
169 | "outputs": [],
170 | "source": [
171 | "learner.fit(epochs=num_epochs)"
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "metadata": {},
177 | "source": [
178 | "### Predict"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 30,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "from modules.data.bert_data import get_data_loader_for_predict"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 31,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 32,
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "data": {
206 | "application/vnd.jupyter.widget-view+json": {
207 | "model_id": "",
208 | "version_major": 2,
209 | "version_minor": 0
210 | },
211 | "text/plain": [
212 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…"
213 | ]
214 | },
215 | "metadata": {},
216 | "output_type": "display_data"
217 | },
218 | {
219 | "name": "stdout",
220 | "output_type": "stream",
221 | "text": [
222 | "\r"
223 | ]
224 | }
225 | ],
226 | "source": [
227 | "preds = learner.predict(dl)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 33,
233 | "metadata": {},
234 | "outputs": [],
235 | "source": [
236 | "from sklearn_crfsuite.metrics import flat_classification_report"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 34,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
246 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 35,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
256 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 36,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "assert pred_tokens == true_tokens\n",
266 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": 37,
272 | "metadata": {},
273 | "outputs": [
274 | {
275 | "name": "stdout",
276 | "output_type": "stream",
277 | "text": [
278 | " precision recall f1-score support\n",
279 | "\n",
280 | " I_ORG 0.8019 0.7415 0.7705 3865\n",
281 | " I_PER 0.9374 0.9569 0.9470 2112\n",
282 | " I_LOC 0.9007 0.7752 0.8333 1557\n",
283 | "\n",
284 | " micro avg 0.8620 0.8089 0.8346 7534\n",
285 | " macro avg 0.8800 0.8245 0.8503 7534\n",
286 | "weighted avg 0.8603 0.8089 0.8330 7534\n",
287 | "\n"
288 | ]
289 | }
290 | ],
291 | "source": [
292 | "print(tokens_report)"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": 38,
298 | "metadata": {},
299 | "outputs": [],
300 | "source": [
301 | "from modules.analyze_utils.main_metrics import precision_recall_f1"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": 39,
307 | "metadata": {},
308 | "outputs": [
309 | {
310 | "name": "stdout",
311 | "output_type": "stream",
312 | "text": [
313 | "processed 56409 tokens with 7534 phrases; found: 7070 phrases; correct: 6094.\n",
314 | "\n",
315 | "precision: 86.20%; recall: 80.89%; FB1: 83.46\n",
316 | "\n",
317 | "\tLOC: precision: 90.07%; recall: 77.52%; F1: 83.33 1340\n",
318 | "\n",
319 | "\tORG: precision: 80.19%; recall: 74.15%; F1: 77.05 3574\n",
320 | "\n",
321 | "\tPER: precision: 93.74%; recall: 95.69%; F1: 94.70 2156\n",
322 | "\n",
323 | "\n"
324 | ]
325 | }
326 | ],
327 | "source": [
328 | "results = precision_recall_f1(true_labels, pred_labels)"
329 | ]
330 | }
331 | ],
332 | "metadata": {
333 | "kernelspec": {
334 | "display_name": "Python 3",
335 | "language": "python",
336 | "name": "python3"
337 | },
338 | "language_info": {
339 | "codemirror_mode": {
340 | "name": "ipython",
341 | "version": 3
342 | },
343 | "file_extension": ".py",
344 | "mimetype": "text/x-python",
345 | "name": "python",
346 | "nbconvert_exporter": "python",
347 | "pygments_lexer": "ipython3",
348 | "version": "3.6.8"
349 | }
350 | },
351 | "nbformat": 4,
352 | "nbformat_minor": 2
353 | }
354 |
--------------------------------------------------------------------------------
/exps/fre BERTBiLSTMAttnCRF-fit_BERT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "\n",
13 | "import sys\n",
14 | "import warnings\n",
15 | "\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append(\"../\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## IO markup"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### Train"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "from modules.data import bert_data"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\""
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "device = \"cuda:0\""
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 5,
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "name": "stderr",
70 | "output_type": "stream",
71 | "text": [
72 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
73 | ]
74 | },
75 | {
76 | "data": {
77 | "application/vnd.jupyter.widget-view+json": {
78 | "model_id": "",
79 | "version_major": 2,
80 | "version_minor": 0
81 | },
82 | "text/plain": [
83 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…"
84 | ]
85 | },
86 | "metadata": {},
87 | "output_type": "display_data"
88 | },
89 | {
90 | "name": "stdout",
91 | "output_type": "stream",
92 | "text": [
93 | "\r"
94 | ]
95 | }
96 | ],
97 | "source": [
98 | "data = bert_data.LearnData.create(\n",
99 | " train_df_path=train_df_path,\n",
100 | " valid_df_path=valid_df_path,\n",
101 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels5.txt\",\n",
102 | " clear_cache=True,\n",
103 | " batch_size=8,\n",
104 | " device=device\n",
105 | ")"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 6,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "from modules.models.bert_models import BERTBiLSTMAttnCRF"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 7,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "model = BERTBiLSTMAttnCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3, is_freeze=False, device=device)"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 8,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "from modules.train.train import NerLearner"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 9,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "num_epochs = 100"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 11,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "learner = NerLearner(\n",
151 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMAttnCRF-fit_BERT-IO.cpt\",\n",
152 | " t_total=num_epochs * len(data.train_dl), lr=0.0001)"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 12,
158 | "metadata": {
159 | "scrolled": true
160 | },
161 | "outputs": [
162 | {
163 | "data": {
164 | "text/plain": [
165 | "180482937"
166 | ]
167 | },
168 | "execution_count": 12,
169 | "metadata": {},
170 | "output_type": "execute_result"
171 | }
172 | ],
173 | "source": [
174 | "model.get_n_trainable_params()"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "metadata": {
181 | "scrolled": true
182 | },
183 | "outputs": [],
184 | "source": [
185 | "learner.fit(epochs=num_epochs)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "### Predict"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 25,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "from modules.data.bert_data import get_data_loader_for_predict"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 26,
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 27,
216 | "metadata": {},
217 | "outputs": [
218 | {
219 | "data": {
220 | "application/vnd.jupyter.widget-view+json": {
221 | "model_id": "",
222 | "version_major": 2,
223 | "version_minor": 0
224 | },
225 | "text/plain": [
226 | "HBox(children=(IntProgress(value=0, description='Predicting', max=340, style=ProgressStyle(description_width='…"
227 | ]
228 | },
229 | "metadata": {},
230 | "output_type": "display_data"
231 | }
232 | ],
233 | "source": [
234 | "preds = learner.predict(dl)"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 28,
240 | "metadata": {},
241 | "outputs": [],
242 | "source": [
243 | "from sklearn_crfsuite.metrics import flat_classification_report"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 29,
249 | "metadata": {},
250 | "outputs": [],
251 | "source": [
252 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
253 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 30,
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
263 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 31,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": [
272 | "assert pred_tokens == true_tokens\n",
273 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 32,
279 | "metadata": {},
280 | "outputs": [
281 | {
282 | "name": "stdout",
283 | "output_type": "stream",
284 | "text": [
285 | " precision recall f1-score support\n",
286 | "\n",
287 | " I_ORG 0.8334 0.8191 0.8262 3865\n",
288 | " I_PER 0.9145 0.9825 0.9473 2112\n",
289 | " I_LOC 0.9342 0.8202 0.8735 1557\n",
290 | "\n",
291 | " micro avg 0.8767 0.8651 0.8709 7534\n",
292 | " macro avg 0.8940 0.8739 0.8823 7534\n",
293 | "weighted avg 0.8769 0.8651 0.8699 7534\n",
294 | "\n"
295 | ]
296 | }
297 | ],
298 | "source": [
299 | "print(tokens_report)"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": 33,
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "from modules.analyze_utils.main_metrics import precision_recall_f1"
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": 34,
314 | "metadata": {},
315 | "outputs": [
316 | {
317 | "name": "stdout",
318 | "output_type": "stream",
319 | "text": [
320 | "processed 56409 tokens with 7534 phrases; found: 7435 phrases; correct: 6518.\n",
321 | "\n",
322 | "precision: 87.67%; recall: 86.51%; FB1: 87.09\n",
323 | "\n",
324 | "\tLOC: precision: 93.42%; recall: 82.02%; F1: 87.35 1367\n",
325 | "\n",
326 | "\tORG: precision: 83.34%; recall: 81.91%; F1: 82.62 3799\n",
327 | "\n",
328 | "\tPER: precision: 91.45%; recall: 98.25%; F1: 94.73 2269\n",
329 | "\n",
330 | "\n"
331 | ]
332 | }
333 | ],
334 | "source": [
335 | "results = precision_recall_f1(true_labels, pred_labels)"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "metadata": {},
342 | "outputs": [],
343 | "source": []
344 | }
345 | ],
346 | "metadata": {
347 | "kernelspec": {
348 | "display_name": "Python 3",
349 | "language": "python",
350 | "name": "python3"
351 | },
352 | "language_info": {
353 | "codemirror_mode": {
354 | "name": "ipython",
355 | "version": 3
356 | },
357 | "file_extension": ".py",
358 | "mimetype": "text/x-python",
359 | "name": "python",
360 | "nbconvert_exporter": "python",
361 | "pygments_lexer": "ipython3",
362 | "version": "3.6.8"
363 | }
364 | },
365 | "nbformat": 4,
366 | "nbformat_minor": 2
367 | }
368 |
--------------------------------------------------------------------------------
/exps/fre BERTBiLSTMAttnCRF.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "\n",
13 | "import sys\n",
14 | "import warnings\n",
15 | "\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append(\"../\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## IO markup"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### Train"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "from modules.data import bert_data"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\""
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stderr",
61 | "output_type": "stream",
62 | "text": [
63 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
64 | ]
65 | },
66 | {
67 | "data": {
68 | "application/vnd.jupyter.widget-view+json": {
69 | "model_id": "",
70 | "version_major": 2,
71 | "version_minor": 0
72 | },
73 | "text/plain": [
74 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…"
75 | ]
76 | },
77 | "metadata": {},
78 | "output_type": "display_data"
79 | },
80 | {
81 | "name": "stdout",
82 | "output_type": "stream",
83 | "text": [
84 | "\r"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "data = bert_data.LearnData.create(\n",
90 | " train_df_path=train_df_path,\n",
91 | " valid_df_path=valid_df_path,\n",
92 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels4.txt\",\n",
93 | " clear_cache=True\n",
94 | ")"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 5,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "from modules.models.bert_models import BERTBiLSTMAttnCRF"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 6,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "model = BERTBiLSTMAttnCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3)"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 7,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "from modules.train.train import NerLearner"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 8,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "num_epochs = 100"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 9,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "learner = NerLearner(\n",
140 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMAttnCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 10,
146 | "metadata": {
147 | "scrolled": true
148 | },
149 | "outputs": [
150 | {
151 | "data": {
152 | "text/plain": [
153 | "2629497"
154 | ]
155 | },
156 | "execution_count": 10,
157 | "metadata": {},
158 | "output_type": "execute_result"
159 | }
160 | ],
161 | "source": [
162 | "model.get_n_trainable_params()"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "metadata": {
169 | "scrolled": true
170 | },
171 | "outputs": [],
172 | "source": [
173 | "learner.fit(epochs=num_epochs)"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {},
179 | "source": [
180 | "### Predict"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 12,
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "from modules.data.bert_data import get_data_loader_for_predict"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 13,
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 23,
204 | "metadata": {},
205 | "outputs": [
206 | {
207 | "data": {
208 | "application/vnd.jupyter.widget-view+json": {
209 | "model_id": "",
210 | "version_major": 2,
211 | "version_minor": 0
212 | },
213 | "text/plain": [
214 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…"
215 | ]
216 | },
217 | "metadata": {},
218 | "output_type": "display_data"
219 | },
220 | {
221 | "name": "stdout",
222 | "output_type": "stream",
223 | "text": [
224 | "\r"
225 | ]
226 | }
227 | ],
228 | "source": [
229 | "preds = learner.predict(dl)"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": 24,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "from sklearn_crfsuite.metrics import flat_classification_report"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 25,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
248 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 26,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
258 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 27,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "assert pred_tokens == true_tokens\n",
268 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 28,
274 | "metadata": {
275 | "scrolled": true
276 | },
277 | "outputs": [
278 | {
279 | "name": "stdout",
280 | "output_type": "stream",
281 | "text": [
282 | " precision recall f1-score support\n",
283 | "\n",
284 | " I_ORG 0.8639 0.7803 0.8200 3865\n",
285 | " I_PER 0.9535 0.9706 0.9620 2112\n",
286 | " I_LOC 0.9066 0.8356 0.8697 1557\n",
287 | "\n",
288 | " micro avg 0.8998 0.8451 0.8716 7534\n",
289 | " macro avg 0.9080 0.8622 0.8839 7534\n",
290 | "weighted avg 0.8979 0.8451 0.8701 7534\n",
291 | "\n"
292 | ]
293 | }
294 | ],
295 | "source": [
296 | "print(tokens_report)"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 29,
302 | "metadata": {},
303 | "outputs": [],
304 | "source": [
305 | "from modules.analyze_utils.main_metrics import precision_recall_f1"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": 30,
311 | "metadata": {},
312 | "outputs": [
313 | {
314 | "name": "stdout",
315 | "output_type": "stream",
316 | "text": [
317 | "processed 56409 tokens with 7534 phrases; found: 7076 phrases; correct: 6367.\n",
318 | "\n",
319 | "precision: 89.98%; recall: 84.51%; FB1: 87.16\n",
320 | "\n",
321 | "\tLOC: precision: 90.66%; recall: 83.56%; F1: 86.97 1435\n",
322 | "\n",
323 | "\tORG: precision: 86.39%; recall: 78.03%; F1: 82.00 3491\n",
324 | "\n",
325 | "\tPER: precision: 95.35%; recall: 97.06%; F1: 96.20 2150\n",
326 | "\n",
327 | "\n"
328 | ]
329 | }
330 | ],
331 | "source": [
332 | "results = precision_recall_f1(true_labels, pred_labels)"
333 | ]
334 | }
335 | ],
336 | "metadata": {
337 | "kernelspec": {
338 | "display_name": "Python 3",
339 | "language": "python",
340 | "name": "python3"
341 | },
342 | "language_info": {
343 | "codemirror_mode": {
344 | "name": "ipython",
345 | "version": 3
346 | },
347 | "file_extension": ".py",
348 | "mimetype": "text/x-python",
349 | "name": "python",
350 | "nbconvert_exporter": "python",
351 | "pygments_lexer": "ipython3",
352 | "version": "3.6.8"
353 | }
354 | },
355 | "nbformat": 4,
356 | "nbformat_minor": 2
357 | }
358 |
--------------------------------------------------------------------------------
/exps/fre BERTBiLSTMAttnNCRF-fit_BERT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "\n",
13 | "import sys\n",
14 | "import warnings\n",
15 | "\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append(\"../\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## IO markup"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### Train"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "from modules.data import bert_data"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\""
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "device = \"cuda:1\""
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 5,
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "name": "stderr",
70 | "output_type": "stream",
71 | "text": [
72 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
73 | ]
74 | },
75 | {
76 | "data": {
77 | "application/vnd.jupyter.widget-view+json": {
78 | "model_id": "",
79 | "version_major": 2,
80 | "version_minor": 0
81 | },
82 | "text/plain": [
83 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…"
84 | ]
85 | },
86 | "metadata": {},
87 | "output_type": "display_data"
88 | },
89 | {
90 | "name": "stdout",
91 | "output_type": "stream",
92 | "text": [
93 | "\r"
94 | ]
95 | }
96 | ],
97 | "source": [
98 | "data = bert_data.LearnData.create(\n",
99 | " train_df_path=train_df_path,\n",
100 | " valid_df_path=valid_df_path,\n",
101 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels2.txt\",\n",
102 | " clear_cache=True,\n",
103 | " batch_size=8,\n",
104 | " device=device\n",
105 | ")"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 6,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "from modules.models.bert_models import BERTBiLSTMAttnNCRF"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 8,
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "build CRF...\n"
127 | ]
128 | }
129 | ],
130 | "source": [
131 | "model = BERTBiLSTMAttnNCRF.create(\n",
132 | " len(data.train_ds.idx2label), crf_dropout=0.3, nbest=len(data.train_ds.label2idx), is_freeze=False, hidden_dim=256,\n",
133 | " device=device)"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 9,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "from modules.train.train import NerLearner"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 10,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "num_epochs = 100"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 14,
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "learner = NerLearner(\n",
161 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMAttnNCRF-fit_BERT-IO.cpt\",\n",
162 | " t_total=num_epochs * len(data.train_dl), lr=0.00001)"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 12,
168 | "metadata": {
169 | "scrolled": true
170 | },
171 | "outputs": [
172 | {
173 | "data": {
174 | "text/plain": [
175 | "179004667"
176 | ]
177 | },
178 | "execution_count": 12,
179 | "metadata": {},
180 | "output_type": "execute_result"
181 | }
182 | ],
183 | "source": [
184 | "model.get_n_trainable_params()"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "metadata": {
191 | "scrolled": true
192 | },
193 | "outputs": [],
194 | "source": [
195 | "learner.fit(epochs=num_epochs)"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | "### Predict"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 27,
208 | "metadata": {},
209 | "outputs": [],
210 | "source": [
211 | "from modules.data.bert_data import get_data_loader_for_predict"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 28,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 29,
226 | "metadata": {},
227 | "outputs": [
228 | {
229 | "data": {
230 | "application/vnd.jupyter.widget-view+json": {
231 | "model_id": "",
232 | "version_major": 2,
233 | "version_minor": 0
234 | },
235 | "text/plain": [
236 | "HBox(children=(IntProgress(value=0, description='Predicting', max=340, style=ProgressStyle(description_width='…"
237 | ]
238 | },
239 | "metadata": {},
240 | "output_type": "display_data"
241 | }
242 | ],
243 | "source": [
244 | "preds = learner.predict(dl)"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": 30,
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "from sklearn_crfsuite.metrics import flat_classification_report"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 31,
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
263 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 32,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": [
272 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
273 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 33,
279 | "metadata": {},
280 | "outputs": [],
281 | "source": [
282 | "assert pred_tokens == true_tokens\n",
283 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": 34,
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | " precision recall f1-score support\n",
296 | "\n",
297 | " I_ORG 0.8761 0.7224 0.7918 3865\n",
298 | " I_PER 0.9207 0.9342 0.9274 2112\n",
299 | " I_LOC 0.8767 0.8356 0.8556 1557\n",
300 | "\n",
301 | " micro avg 0.8902 0.8051 0.8456 7534\n",
302 | " macro avg 0.8911 0.8307 0.8583 7534\n",
303 | "weighted avg 0.8887 0.8051 0.8430 7534\n",
304 | "\n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "print(tokens_report)"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 35,
315 | "metadata": {},
316 | "outputs": [],
317 | "source": [
318 | "from modules.analyze_utils.main_metrics import precision_recall_f1"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": 36,
324 | "metadata": {},
325 | "outputs": [
326 | {
327 | "name": "stdout",
328 | "output_type": "stream",
329 | "text": [
330 | "processed 56409 tokens with 7534 phrases; found: 6814 phrases; correct: 6066.\n",
331 | "\n",
332 | "precision: 89.02%; recall: 80.51%; FB1: 84.56\n",
333 | "\n",
334 | "\tLOC: precision: 87.67%; recall: 83.56%; F1: 85.56 1484\n",
335 | "\n",
336 | "\tORG: precision: 87.61%; recall: 72.24%; F1: 79.18 3187\n",
337 | "\n",
338 | "\tPER: precision: 92.07%; recall: 93.42%; F1: 92.74 2143\n",
339 | "\n",
340 | "\n"
341 | ]
342 | }
343 | ],
344 | "source": [
345 | "results = precision_recall_f1(true_labels, pred_labels)"
346 | ]
347 | }
348 | ],
349 | "metadata": {
350 | "kernelspec": {
351 | "display_name": "Python 3",
352 | "language": "python",
353 | "name": "python3"
354 | },
355 | "language_info": {
356 | "codemirror_mode": {
357 | "name": "ipython",
358 | "version": 3
359 | },
360 | "file_extension": ".py",
361 | "mimetype": "text/x-python",
362 | "name": "python",
363 | "nbconvert_exporter": "python",
364 | "pygments_lexer": "ipython3",
365 | "version": "3.6.8"
366 | }
367 | },
368 | "nbformat": 4,
369 | "nbformat_minor": 2
370 | }
371 |
--------------------------------------------------------------------------------
/exps/fre BERTBiLSTMCRF.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "\n",
13 | "import sys\n",
14 | "import warnings\n",
15 | "\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append(\"../\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## IO markup"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### Train"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "from modules.data import bert_data"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\""
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 5,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stderr",
61 | "output_type": "stream",
62 | "text": [
63 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
64 | ]
65 | },
66 | {
67 | "data": {
68 | "application/vnd.jupyter.widget-view+json": {
69 | "model_id": "",
70 | "version_major": 2,
71 | "version_minor": 0
72 | },
73 | "text/plain": [
74 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…"
75 | ]
76 | },
77 | "metadata": {},
78 | "output_type": "display_data"
79 | },
80 | {
81 | "name": "stdout",
82 | "output_type": "stream",
83 | "text": [
84 | "\r"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
90 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\"\n",
91 | "data = bert_data.LearnData.create(\n",
92 | " train_df_path=train_df_path,\n",
93 | " valid_df_path=valid_df_path,\n",
94 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels4.txt\",\n",
95 | " clear_cache=True\n",
96 | ")"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 6,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "from modules.models.bert_models import BERTBiLSTMCRF"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 7,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "model = BERTBiLSTMCRF.create(len(data.train_ds.idx2label), lstm_dropout=0., crf_dropout=0.3)"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 8,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "from modules.train.train import NerLearner"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 9,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "num_epochs = 100"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 12,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "learner = NerLearner(\n",
142 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 13,
148 | "metadata": {},
149 | "outputs": [
150 | {
151 | "data": {
152 | "text/plain": [
153 | "2234745"
154 | ]
155 | },
156 | "execution_count": 13,
157 | "metadata": {},
158 | "output_type": "execute_result"
159 | }
160 | ],
161 | "source": [
162 | "model.get_n_trainable_params()"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "metadata": {
169 | "scrolled": true
170 | },
171 | "outputs": [],
172 | "source": [
173 | "learner.fit(epochs=num_epochs)"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {},
179 | "source": [
180 | "### Predict"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 14,
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "from modules.data.bert_data import get_data_loader_for_predict"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 15,
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 16,
204 | "metadata": {},
205 | "outputs": [
206 | {
207 | "data": {
208 | "application/vnd.jupyter.widget-view+json": {
209 | "model_id": "",
210 | "version_major": 2,
211 | "version_minor": 0
212 | },
213 | "text/plain": [
214 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…"
215 | ]
216 | },
217 | "metadata": {},
218 | "output_type": "display_data"
219 | },
220 | {
221 | "name": "stdout",
222 | "output_type": "stream",
223 | "text": [
224 | "\r"
225 | ]
226 | }
227 | ],
228 | "source": [
229 | "preds = learner.predict(dl)"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": 17,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "from sklearn_crfsuite.metrics import flat_classification_report"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 18,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
248 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 19,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
258 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 20,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "assert pred_tokens == true_tokens\n",
268 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=[\"I_ORG\", \"I_PER\", \"I_LOC\"], digits=4)"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 21,
274 | "metadata": {},
275 | "outputs": [
276 | {
277 | "name": "stdout",
278 | "output_type": "stream",
279 | "text": [
280 | " precision recall f1-score support\n",
281 | "\n",
282 | " I_ORG 0.8579 0.7917 0.8235 3865\n",
283 | " I_PER 0.9510 0.9659 0.9584 2112\n",
284 | " I_LOC 0.9053 0.8349 0.8687 1557\n",
285 | "\n",
286 | " micro avg 0.8954 0.8495 0.8718 7534\n",
287 | " macro avg 0.9047 0.8642 0.8835 7534\n",
288 | "weighted avg 0.8938 0.8495 0.8706 7534\n",
289 | "\n"
290 | ]
291 | }
292 | ],
293 | "source": [
294 | "print(tokens_report)"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": 26,
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "from modules.analyze_utils.main_metrics import precision_recall_f1"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": 27,
309 | "metadata": {},
310 | "outputs": [
311 | {
312 | "name": "stdout",
313 | "output_type": "stream",
314 | "text": [
315 | "processed 56409 tokens with 7534 phrases; found: 7148 phrases; correct: 6400.\n",
316 | "\n",
317 | "precision: 89.54%; recall: 84.95%; FB1: 87.18\n",
318 | "\n",
319 | "\tLOC: precision: 90.53%; recall: 83.49%; F1: 86.87 1436\n",
320 | "\n",
321 | "\tORG: precision: 85.79%; recall: 79.17%; F1: 82.35 3567\n",
322 | "\n",
323 | "\tPER: precision: 95.10%; recall: 96.59%; F1: 95.84 2145\n",
324 | "\n",
325 | "\n"
326 | ]
327 | }
328 | ],
329 | "source": [
330 | "results = precision_recall_f1(true_labels, pred_labels)"
331 | ]
332 | }
333 | ],
334 | "metadata": {
335 | "kernelspec": {
336 | "display_name": "Python 3",
337 | "language": "python",
338 | "name": "python3"
339 | },
340 | "language_info": {
341 | "codemirror_mode": {
342 | "name": "ipython",
343 | "version": 3
344 | },
345 | "file_extension": ".py",
346 | "mimetype": "text/x-python",
347 | "name": "python",
348 | "nbconvert_exporter": "python",
349 | "pygments_lexer": "ipython3",
350 | "version": "3.6.8"
351 | }
352 | },
353 | "nbformat": 4,
354 | "nbformat_minor": 2
355 | }
356 |
--------------------------------------------------------------------------------
/exps/fre BERTBiLSTMNCRF.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "\n",
13 | "import sys\n",
14 | "import warnings\n",
15 | "\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append(\"../\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## IO markup"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### Train"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "from modules.data import bert_data"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 6,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "device = \"cuda:2\""
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 7,
56 | "metadata": {},
57 | "outputs": [
58 | {
59 | "name": "stderr",
60 | "output_type": "stream",
61 | "text": [
62 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
63 | ]
64 | },
65 | {
66 | "data": {
67 | "application/vnd.jupyter.widget-view+json": {
68 | "model_id": "",
69 | "version_major": 2,
70 | "version_minor": 0
71 | },
72 | "text/plain": [
73 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…"
74 | ]
75 | },
76 | "metadata": {},
77 | "output_type": "display_data"
78 | },
79 | {
80 | "name": "stdout",
81 | "output_type": "stream",
82 | "text": [
83 | "\r"
84 | ]
85 | }
86 | ],
87 | "source": [
88 | "data = bert_data.LearnData.create(\n",
89 | " train_df_path=\"/home/eartemov/ae/work/factRuEval-2016/dev.csv\",\n",
90 | " valid_df_path=\"/home/eartemov/ae/work/factRuEval-2016/test.csv\",\n",
91 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels5.txt\",\n",
92 | " clear_cache=True,\n",
93 | " device=device\n",
94 | ")"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 4,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "from modules.models.bert_models import BERTBiLSTMNCRF"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 8,
109 | "metadata": {},
110 | "outputs": [
111 | {
112 | "name": "stdout",
113 | "output_type": "stream",
114 | "text": [
115 | "build CRF...\n"
116 | ]
117 | }
118 | ],
119 | "source": [
120 | "model = BERTBiLSTMNCRF.create(\n",
121 | " len(data.train_ds.idx2label), lstm_dropout=0., crf_dropout=0.3, nbest=len(data.train_ds.idx2label)-1, device=device)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 9,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "from modules.train.train import NerLearner"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 10,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "num_epochs = 100"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 11,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "learner = NerLearner(\n",
149 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTBiLSTMNCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 12,
155 | "metadata": {},
156 | "outputs": [
157 | {
158 | "data": {
159 | "text/plain": [
160 | "2235259"
161 | ]
162 | },
163 | "execution_count": 12,
164 | "metadata": {},
165 | "output_type": "execute_result"
166 | }
167 | ],
168 | "source": [
169 | "model.get_n_trainable_params()"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "metadata": {
176 | "scrolled": true
177 | },
178 | "outputs": [],
179 | "source": [
180 | "learner.fit(epochs=num_epochs)"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "### Eval"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 22,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "from modules.data.bert_data import get_data_loader_for_predict"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 23,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 24,
211 | "metadata": {},
212 | "outputs": [
213 | {
214 | "data": {
215 | "application/vnd.jupyter.widget-view+json": {
216 | "model_id": "",
217 | "version_major": 2,
218 | "version_minor": 0
219 | },
220 | "text/plain": [
221 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…"
222 | ]
223 | },
224 | "metadata": {},
225 | "output_type": "display_data"
226 | },
227 | {
228 | "name": "stdout",
229 | "output_type": "stream",
230 | "text": [
231 | "\r"
232 | ]
233 | }
234 | ],
235 | "source": [
236 | "preds = learner.predict(dl)"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 25,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "from sklearn_crfsuite.metrics import flat_classification_report"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 26,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
255 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 27,
261 | "metadata": {},
262 | "outputs": [],
263 | "source": [
264 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
265 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": 28,
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "assert pred_tokens == true_tokens\n",
275 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=learner.sup_labels, digits=4)"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 29,
281 | "metadata": {},
282 | "outputs": [
283 | {
284 | "name": "stdout",
285 | "output_type": "stream",
286 | "text": [
287 | " precision recall f1-score support\n",
288 | "\n",
289 | " I_O 0.9777 0.9886 0.9831 48875\n",
290 | " I_LOC 0.8996 0.7996 0.8467 1557\n",
291 | " I_PER 0.9373 0.9626 0.9498 2112\n",
292 | " I_ORG 0.8709 0.7281 0.7931 3865\n",
293 | "\n",
294 | " micro avg 0.9681 0.9646 0.9663 56409\n",
295 | " macro avg 0.9214 0.8697 0.8932 56409\n",
296 | "weighted avg 0.9667 0.9646 0.9651 56409\n",
297 | "\n"
298 | ]
299 | }
300 | ],
301 | "source": [
302 | "print(tokens_report)"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 30,
308 | "metadata": {},
309 | "outputs": [],
310 | "source": [
311 | "from modules.analyze_utils.main_metrics import precision_recall_f1"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 32,
317 | "metadata": {},
318 | "outputs": [
319 | {
320 | "name": "stdout",
321 | "output_type": "stream",
322 | "text": [
323 | "processed 56409 tokens with 7534 phrases; found: 6784 phrases; correct: 6092.\n",
324 | "\n",
325 | "precision: 89.80%; recall: 80.86%; FB1: 85.10\n",
326 | "\n",
327 | "\tLOC: precision: 89.96%; recall: 79.96%; F1: 84.67 1384\n",
328 | "\n",
329 | "\tORG: precision: 87.09%; recall: 72.81%; F1: 79.31 3231\n",
330 | "\n",
331 | "\tPER: precision: 93.73%; recall: 96.26%; F1: 94.98 2169\n",
332 | "\n",
333 | "\n"
334 | ]
335 | }
336 | ],
337 | "source": [
338 | "results = precision_recall_f1(true_labels, pred_labels)"
339 | ]
340 | }
341 | ],
342 | "metadata": {
343 | "kernelspec": {
344 | "display_name": "Python 3",
345 | "language": "python",
346 | "name": "python3"
347 | },
348 | "language_info": {
349 | "codemirror_mode": {
350 | "name": "ipython",
351 | "version": 3
352 | },
353 | "file_extension": ".py",
354 | "mimetype": "text/x-python",
355 | "name": "python",
356 | "nbconvert_exporter": "python",
357 | "pygments_lexer": "ipython3",
358 | "version": "3.6.8"
359 | }
360 | },
361 | "nbformat": 4,
362 | "nbformat_minor": 2
363 | }
364 |
--------------------------------------------------------------------------------
/exps/fre BERTCRF.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 12,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "The autoreload extension is already loaded. To reload it, use:\n",
13 | " %reload_ext autoreload\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "%load_ext autoreload\n",
19 | "%autoreload 2\n",
20 | "\n",
21 | "\n",
22 | "import sys\n",
23 | "import warnings\n",
24 | "\n",
25 | "\n",
26 | "warnings.filterwarnings(\"ignore\")\n",
27 | "sys.path.append(\"../\")"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "## IO markup"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "### Train"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 2,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "from modules.data import bert_data"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 3,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
60 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\""
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 4,
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "name": "stderr",
70 | "output_type": "stream",
71 | "text": [
72 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
73 | ]
74 | },
75 | {
76 | "data": {
77 | "application/vnd.jupyter.widget-view+json": {
78 | "model_id": "",
79 | "version_major": 2,
80 | "version_minor": 0
81 | },
82 | "text/plain": [
83 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…"
84 | ]
85 | },
86 | "metadata": {},
87 | "output_type": "display_data"
88 | },
89 | {
90 | "name": "stdout",
91 | "output_type": "stream",
92 | "text": [
93 | "\r"
94 | ]
95 | }
96 | ],
97 | "source": [
98 | "data = bert_data.LearnData.create(\n",
99 | " train_df_path=train_df_path,\n",
100 | " valid_df_path=valid_df_path,\n",
101 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels1.txt\",\n",
102 | " clear_cache=True\n",
103 | ")"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 5,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "from modules.models.bert_models import BERTCRF"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 6,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "model = BERTCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 7,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "from modules.train.train import NerLearner"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 8,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "num_epochs = 100"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 9,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "learner = NerLearner(\n",
149 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 10,
155 | "metadata": {},
156 | "outputs": [
157 | {
158 | "data": {
159 | "text/plain": [
160 | "298489"
161 | ]
162 | },
163 | "execution_count": 10,
164 | "metadata": {},
165 | "output_type": "execute_result"
166 | }
167 | ],
168 | "source": [
169 | "model.get_n_trainable_params()"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "metadata": {
176 | "scrolled": true
177 | },
178 | "outputs": [],
179 | "source": [
180 | "learner.fit(epochs=num_epochs)"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "### Predict"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 13,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "from modules.data.bert_data import get_data_loader_for_predict"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 14,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 15,
211 | "metadata": {},
212 | "outputs": [
213 | {
214 | "data": {
215 | "application/vnd.jupyter.widget-view+json": {
216 | "model_id": "",
217 | "version_major": 2,
218 | "version_minor": 0
219 | },
220 | "text/plain": [
221 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…"
222 | ]
223 | },
224 | "metadata": {},
225 | "output_type": "display_data"
226 | },
227 | {
228 | "name": "stdout",
229 | "output_type": "stream",
230 | "text": [
231 | "\r"
232 | ]
233 | }
234 | ],
235 | "source": [
236 | "preds = learner.predict(dl)"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 16,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "from sklearn_crfsuite.metrics import flat_classification_report"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 17,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
255 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 18,
261 | "metadata": {},
262 | "outputs": [],
263 | "source": [
264 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
265 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": 19,
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "assert pred_tokens == true_tokens\n",
275 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[5:], digits=4)"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 20,
281 | "metadata": {},
282 | "outputs": [
283 | {
284 | "name": "stdout",
285 | "output_type": "stream",
286 | "text": [
287 | " precision recall f1-score support\n",
288 | "\n",
289 | " I_LOC 0.8576 0.7932 0.8242 1557\n",
290 | " I_PER 0.9544 0.9616 0.9580 2112\n",
291 | " I_ORG 0.8150 0.7490 0.7806 3865\n",
292 | "\n",
293 | " micro avg 0.8653 0.8178 0.8409 7534\n",
294 | " macro avg 0.8757 0.8346 0.8543 7534\n",
295 | "weighted avg 0.8629 0.8178 0.8394 7534\n",
296 | "\n"
297 | ]
298 | }
299 | ],
300 | "source": [
301 | "print(tokens_report)"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": 21,
307 | "metadata": {},
308 | "outputs": [],
309 | "source": [
310 | "from modules.analyze_utils.main_metrics import precision_recall_f1"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 22,
316 | "metadata": {},
317 | "outputs": [
318 | {
319 | "name": "stdout",
320 | "output_type": "stream",
321 | "text": [
322 | "processed 56409 tokens with 7534 phrases; found: 7120 phrases; correct: 6161.\n",
323 | "\n",
324 | "precision: 86.53%; recall: 81.78%; FB1: 84.09\n",
325 | "\n",
326 | "\tLOC: precision: 85.76%; recall: 79.32%; F1: 82.42 1440\n",
327 | "\n",
328 | "\tORG: precision: 81.50%; recall: 74.90%; F1: 78.06 3552\n",
329 | "\n",
330 | "\tPER: precision: 95.44%; recall: 96.16%; F1: 95.80 2128\n",
331 | "\n",
332 | "\n"
333 | ]
334 | }
335 | ],
336 | "source": [
337 | "results = precision_recall_f1(true_labels, pred_labels)"
338 | ]
339 | }
340 | ],
341 | "metadata": {
342 | "kernelspec": {
343 | "display_name": "Python 3",
344 | "language": "python",
345 | "name": "python3"
346 | },
347 | "language_info": {
348 | "codemirror_mode": {
349 | "name": "ipython",
350 | "version": 3
351 | },
352 | "file_extension": ".py",
353 | "mimetype": "text/x-python",
354 | "name": "python",
355 | "nbconvert_exporter": "python",
356 | "pygments_lexer": "ipython3",
357 | "version": "3.6.8"
358 | }
359 | },
360 | "nbformat": 4,
361 | "nbformat_minor": 2
362 | }
363 |
--------------------------------------------------------------------------------
/exps/fre BERTNCRF.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "\n",
13 | "import sys\n",
14 | "import warnings\n",
15 | "\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append(\"../\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## IO markup"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### Train"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "from modules.data import bert_data"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "train_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
51 | "valid_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\""
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stderr",
61 | "output_type": "stream",
62 | "text": [
63 | "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
64 | ]
65 | },
66 | {
67 | "data": {
68 | "application/vnd.jupyter.widget-view+json": {
69 | "model_id": "",
70 | "version_major": 2,
71 | "version_minor": 0
72 | },
73 | "text/plain": [
74 | "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=1519, style=ProgressStyle(descri…"
75 | ]
76 | },
77 | "metadata": {},
78 | "output_type": "display_data"
79 | },
80 | {
81 | "name": "stdout",
82 | "output_type": "stream",
83 | "text": [
84 | "\r"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "data = bert_data.LearnData.create(\n",
90 | " train_df_path=train_df_path,\n",
91 | " valid_df_path=valid_df_path,\n",
92 | " idx2labels_path=\"/home/eartemov/ae/work/factRuEval-2016/idx2labels.txt\",\n",
93 | " clear_cache=True\n",
94 | ")"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 5,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "from modules.models.bert_models import BERTNCRF"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 6,
109 | "metadata": {},
110 | "outputs": [
111 | {
112 | "name": "stdout",
113 | "output_type": "stream",
114 | "text": [
115 | "build CRF...\n"
116 | ]
117 | }
118 | ],
119 | "source": [
120 | "model = BERTNCRF.create(len(data.train_ds.idx2label), crf_dropout=0.3, nbest=len(data.train_ds.idx2label)-1)"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 7,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "from modules.train.train import NerLearner"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 8,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "num_epochs = 100"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 9,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "learner = NerLearner(\n",
148 | " model, data, \"/home/eartemov/ae/work/models/fre-BERTNCRF-IO.cpt\", t_total=num_epochs * len(data.train_dl))"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 10,
154 | "metadata": {},
155 | "outputs": [
156 | {
157 | "data": {
158 | "text/plain": [
159 | "299259"
160 | ]
161 | },
162 | "execution_count": 10,
163 | "metadata": {},
164 | "output_type": "execute_result"
165 | }
166 | ],
167 | "source": [
168 | "model.get_n_trainable_params()"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "metadata": {
175 | "scrolled": true
176 | },
177 | "outputs": [],
178 | "source": [
179 | "learner.fit(epochs=num_epochs)"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": 11,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "learner.load_model()"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {},
194 | "source": [
195 | "### Predict"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 12,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "from modules.data.bert_data import get_data_loader_for_predict"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 13,
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 14,
219 | "metadata": {},
220 | "outputs": [
221 | {
222 | "data": {
223 | "application/vnd.jupyter.widget-view+json": {
224 | "model_id": "",
225 | "version_major": 2,
226 | "version_minor": 0
227 | },
228 | "text/plain": [
229 | "HBox(children=(IntProgress(value=0, description='Predicting', max=170, style=ProgressStyle(description_width='…"
230 | ]
231 | },
232 | "metadata": {},
233 | "output_type": "display_data"
234 | },
235 | {
236 | "name": "stdout",
237 | "output_type": "stream",
238 | "text": [
239 | "\r"
240 | ]
241 | }
242 | ],
243 | "source": [
244 | "preds = learner.predict(dl)"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": 15,
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "from sklearn_crfsuite.metrics import flat_classification_report"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 16,
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
263 | "from modules.analyze_utils.plot_metrics import get_bert_span_report"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 17,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": [
272 | "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
273 | "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 18,
279 | "metadata": {},
280 | "outputs": [],
281 | "source": [
282 | "assert pred_tokens == true_tokens\n",
283 | "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[5:], digits=4)"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": 19,
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | " precision recall f1-score support\n",
296 | "\n",
297 | " I_LOC 0.8765 0.7887 0.8303 1557\n",
298 | " I_PER 0.9598 0.9598 0.9598 2112\n",
299 | " I_ORG 0.7946 0.8078 0.8011 3865\n",
300 | "\n",
301 | " micro avg 0.8569 0.8464 0.8516 7534\n",
302 | " macro avg 0.8770 0.8521 0.8637 7534\n",
303 | "weighted avg 0.8578 0.8464 0.8516 7534\n",
304 | "\n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "print(tokens_report)"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 20,
315 | "metadata": {},
316 | "outputs": [],
317 | "source": [
318 | "from modules.analyze_utils.main_metrics import precision_recall_f1"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": 21,
324 | "metadata": {},
325 | "outputs": [
326 | {
327 | "name": "stdout",
328 | "output_type": "stream",
329 | "text": [
330 | "processed 56409 tokens with 7534 phrases; found: 7442 phrases; correct: 6377.\n",
331 | "\n",
332 | "precision: 85.69%; recall: 84.64%; FB1: 85.16\n",
333 | "\n",
334 | "\tLOC: precision: 87.65%; recall: 78.87%; F1: 83.03 1401\n",
335 | "\n",
336 | "\tORG: precision: 79.46%; recall: 80.78%; F1: 80.11 3929\n",
337 | "\n",
338 | "\tPER: precision: 95.98%; recall: 95.98%; F1: 95.98 2112\n",
339 | "\n",
340 | "\n"
341 | ]
342 | }
343 | ],
344 | "source": [
345 | "results = precision_recall_f1(true_labels, pred_labels)"
346 | ]
347 | }
348 | ],
349 | "metadata": {
350 | "kernelspec": {
351 | "display_name": "Python 3",
352 | "language": "python",
353 | "name": "python3"
354 | },
355 | "language_info": {
356 | "codemirror_mode": {
357 | "name": "ipython",
358 | "version": 3
359 | },
360 | "file_extension": ".py",
361 | "mimetype": "text/x-python",
362 | "name": "python",
363 | "nbconvert_exporter": "python",
364 | "pygments_lexer": "ipython3",
365 | "version": "3.6.8"
366 | }
367 | },
368 | "nbformat": 4,
369 | "nbformat_minor": 2
370 | }
371 |
--------------------------------------------------------------------------------
/exps/prc fre.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### FactRuEval-2016 preprocess\n",
8 | "More info about dataset: https://github.com/dialogue-evaluation/factRuEval-2016"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "import sys\n",
18 | "import warnings\n",
19 | "\n",
20 | "\n",
21 | "warnings.filterwarnings(\"ignore\")\n",
22 | "sys.path.append(\"../\")"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from modules.data.fre import fact_ru_eval_preprocess"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 3,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "dev_dir = \"/home/eartemov/ae/work/factRuEval-2016/devset/\"\n",
41 | "test_dir = \"/home/eartemov/ae/work/factRuEval-2016/testset/\"\n",
42 | "dev_df_path = \"/home/eartemov/ae/work/factRuEval-2016/dev.csv\"\n",
43 | "test_df_path = \"/home/eartemov/ae/work/factRuEval-2016/test.csv\""
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 4,
49 | "metadata": {},
50 | "outputs": [
51 | {
52 | "data": {
53 | "application/vnd.jupyter.widget-view+json": {
54 | "model_id": "43de7e40d1784421bb55921e2d0058f3",
55 | "version_major": 2,
56 | "version_minor": 0
57 | },
58 | "text/plain": [
59 | "HBox(children=(IntProgress(value=0, description='Process FactRuEval2016 dev set.', max=1519, style=ProgressSty…"
60 | ]
61 | },
62 | "metadata": {},
63 | "output_type": "display_data"
64 | },
65 | {
66 | "name": "stdout",
67 | "output_type": "stream",
68 | "text": [
69 | "\n"
70 | ]
71 | },
72 | {
73 | "data": {
74 | "application/vnd.jupyter.widget-view+json": {
75 | "model_id": "27617d0f776d4b37b6f893bcac517f24",
76 | "version_major": 2,
77 | "version_minor": 0
78 | },
79 | "text/plain": [
80 | "HBox(children=(IntProgress(value=0, description='Process FactRuEval2016 test set.', max=2715, style=ProgressSt…"
81 | ]
82 | },
83 | "metadata": {},
84 | "output_type": "display_data"
85 | },
86 | {
87 | "name": "stdout",
88 | "output_type": "stream",
89 | "text": [
90 | "\n"
91 | ]
92 | }
93 | ],
94 | "source": [
95 | "fact_ru_eval_preprocess(dev_dir, test_dir, dev_df_path, test_df_path)"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 5,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "import pandas as pd"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 6,
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "data": {
114 | "text/html": [
115 | "
\n",
116 | "\n",
129 | "
\n",
130 | " \n",
131 | " \n",
132 | " | \n",
133 | " labels | \n",
134 | " text | \n",
135 | " cls | \n",
136 | "
\n",
137 | " \n",
138 | " \n",
139 | " \n",
140 | " 0 | \n",
141 | " O O B_LOC O O O O O B_PER I_PER O O O O O | \n",
142 | " Сегодня в Москве на 40-й день после смерти Его... | \n",
143 | " False | \n",
144 | "
\n",
145 | " \n",
146 | " 1 | \n",
147 | " O B_LOC I_LOC O O O O B_ORG O B_PER I_PER O O ... | \n",
148 | " К Кронштадтскому бульвару , где болельщика « С... | \n",
149 | " False | \n",
150 | "
\n",
151 | " \n",
152 | " 2 | \n",
153 | " O O O O O O O O O O | \n",
154 | " И тишина ... Все прошло мирно и не столь массово | \n",
155 | " True | \n",
156 | "
\n",
157 | " \n",
158 | " 3 | \n",
159 | " O O O O O O O O O O O O O | \n",
160 | " Правда , были задержания , но , как пояснили в... | \n",
161 | " True | \n",
162 | "
\n",
163 | " \n",
164 | " 4 | \n",
165 | " O O O O O O O O O O | \n",
166 | " Одним словом , очередной « Русский марш » не с... | \n",
167 | " True | \n",
168 | "
\n",
169 | " \n",
170 | "
\n",
171 | "
"
172 | ],
173 | "text/plain": [
174 | " labels \\\n",
175 | "0 O O B_LOC O O O O O B_PER I_PER O O O O O \n",
176 | "1 O B_LOC I_LOC O O O O B_ORG O B_PER I_PER O O ... \n",
177 | "2 O O O O O O O O O O \n",
178 | "3 O O O O O O O O O O O O O \n",
179 | "4 O O O O O O O O O O \n",
180 | "\n",
181 | " text cls \n",
182 | "0 Сегодня в Москве на 40-й день после смерти Его... False \n",
183 | "1 К Кронштадтскому бульвару , где болельщика « С... False \n",
184 | "2 И тишина ... Все прошло мирно и не столь массово True \n",
185 | "3 Правда , были задержания , но , как пояснили в... True \n",
186 | "4 Одним словом , очередной « Русский марш » не с... True "
187 | ]
188 | },
189 | "execution_count": 6,
190 | "metadata": {},
191 | "output_type": "execute_result"
192 | }
193 | ],
194 | "source": [
195 | "pd.read_csv(dev_df_path, sep=\"\\t\").head()"
196 | ]
197 | }
198 | ],
199 | "metadata": {
200 | "kernelspec": {
201 | "display_name": "Python 3",
202 | "language": "python",
203 | "name": "python3"
204 | },
205 | "language_info": {
206 | "codemirror_mode": {
207 | "name": "ipython",
208 | "version": 3
209 | },
210 | "file_extension": ".py",
211 | "mimetype": "text/x-python",
212 | "name": "python",
213 | "nbconvert_exporter": "python",
214 | "pygments_lexer": "ipython3",
215 | "version": "3.6.8"
216 | }
217 | },
218 | "nbformat": 4,
219 | "nbformat_minor": 2
220 | }
221 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import get_tqdm
2 |
3 |
4 | tqdm = get_tqdm()
5 |
6 |
7 | __all__ = ["tqdm"]
8 |
--------------------------------------------------------------------------------
/modules/analyze_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 |
3 |
4 | __all__ = ["read_json", "save_json"]
5 |
--------------------------------------------------------------------------------
/modules/analyze_utils/main_metrics.py:
--------------------------------------------------------------------------------
1 | # This code is reused from https://github.com/deepmipt/DeepPavlov/blob/master/deeppavlov/metrics/fmeasure.py
2 | import itertools
3 | from collections import OrderedDict
4 |
5 |
6 | def chunk_finder(current_token, previous_token, tag):
7 | current_tag = current_token.split('_', 1)[-1]
8 | previous_tag = previous_token.split('_', 1)[-1]
9 | if previous_tag != tag:
10 | previous_tag = 'O'
11 | if current_tag != tag:
12 | current_tag = 'O'
13 | if (previous_tag == 'O' and current_token == 'B_' + tag) or \
14 | (previous_token == 'I_' + tag and current_token == 'B_' + tag) or \
15 | (previous_token == 'B_' + tag and current_token == 'B_' + tag) or \
16 | (previous_tag == 'O' and current_token == 'I_' + tag):
17 | create_chunk = True
18 | else:
19 | create_chunk = False
20 |
21 | if (previous_token == 'I_' + tag and current_token == 'B_' + tag) or \
22 | (previous_token == 'B_' + tag and current_token == 'B_' + tag) or \
23 | (current_tag == 'O' and previous_token == 'I_' + tag) or \
24 | (current_tag == 'O' and previous_token == 'B_' + tag):
25 | pop_out = True
26 | else:
27 | pop_out = False
28 | return create_chunk, pop_out
29 |
30 |
31 | def _global_stats_f1(results):
32 | total_true_entities = 0
33 | total_predicted_entities = 0
34 | total_precision = 0
35 | total_recall = 0
36 | total_f1 = 0
37 | total_correct = 0
38 | for tag in results:
39 | if tag == '__total__':
40 | continue
41 |
42 | n_pred = results[tag]['n_pred']
43 | n_true = results[tag]['n_true']
44 | total_correct += results[tag]['tp']
45 | total_true_entities += n_true
46 | total_predicted_entities += n_pred
47 | total_precision += results[tag]['precision'] * n_pred
48 | total_recall += results[tag]['recall'] * n_true
49 | total_f1 += results[tag]['f1'] * n_true
50 | if total_true_entities > 0:
51 | accuracy = total_correct / total_true_entities * 100
52 | total_recall = total_recall / total_true_entities
53 | else:
54 | accuracy = 0
55 | total_recall = 0
56 | if total_predicted_entities > 0:
57 | total_precision = total_precision / total_predicted_entities
58 | else:
59 | total_precision = 0
60 |
61 | if total_precision + total_recall > 0:
62 | total_f1 = 2 * total_precision * total_recall / (total_precision + total_recall)
63 | else:
64 | total_f1 = 0
65 |
66 | total_res = {'n_predicted_entities': total_predicted_entities,
67 | 'n_true_entities': total_true_entities,
68 | 'precision': total_precision,
69 | 'recall': total_recall,
70 | 'f1': total_f1}
71 | return total_res, accuracy, total_true_entities, total_predicted_entities, total_correct
72 |
73 |
74 | def precision_recall_f1(y_true, y_pred, print_results=True, short_report=False, entity_of_interest=None):
75 | y_true = list(itertools.chain(*y_true))
76 | y_pred = list(itertools.chain(*y_pred))
77 | # Find all tags
78 | tags = set()
79 | for tag in itertools.chain(y_true, y_pred):
80 | if tag not in ["O", "I_O", "B_O"]:
81 | current_tag = tag[2:]
82 | tags.add(current_tag)
83 | tags = sorted(list(tags))
84 |
85 | results = OrderedDict()
86 | for tag in tags:
87 | results[tag] = OrderedDict()
88 | results['__total__'] = OrderedDict()
89 | n_tokens = len(y_true)
90 | # Firstly we find all chunks in the ground truth and prediction
91 | # For each chunk we write starting and ending indices
92 |
93 | for tag in tags:
94 | count = 0
95 | true_chunk = []
96 | pred_chunk = []
97 | y_true = [str(y) for y in y_true]
98 | y_pred = [str(y) for y in y_pred]
99 | prev_tag_true = 'O'
100 | prev_tag_pred = 'O'
101 | while count < n_tokens:
102 | yt = y_true[count]
103 | yp = y_pred[count]
104 |
105 | create_chunk_true, pop_out_true = chunk_finder(yt, prev_tag_true, tag)
106 | if pop_out_true:
107 | true_chunk[-1] = (true_chunk[-1], count - 1)
108 | if create_chunk_true:
109 | true_chunk.append(count)
110 |
111 | create_chunk_pred, pop_out_pred = chunk_finder(yp, prev_tag_pred, tag)
112 | if pop_out_pred:
113 | pred_chunk[-1] = (pred_chunk[-1], count - 1)
114 | if create_chunk_pred:
115 | pred_chunk.append(count)
116 | prev_tag_true = yt
117 | prev_tag_pred = yp
118 | count += 1
119 |
120 | if len(true_chunk) > 0 and not isinstance(true_chunk[-1], tuple):
121 | true_chunk[-1] = (true_chunk[-1], count - 1)
122 | if len(pred_chunk) > 0 and not isinstance(pred_chunk[-1], tuple):
123 | pred_chunk[-1] = (pred_chunk[-1], count - 1)
124 |
125 | # Then we find all correctly classified intervals
126 | # True positive results
127 | tp = len(set(pred_chunk).intersection(set(true_chunk)))
128 | # And then just calculate errors of the first and second kind
129 | # False negative
130 | fn = len(true_chunk) - tp
131 | # False positive
132 | fp = len(pred_chunk) - tp
133 | if tp + fp > 0:
134 | precision = tp / (tp + fp) * 100
135 | else:
136 | precision = 0
137 | if tp + fn > 0:
138 | recall = tp / (tp + fn) * 100
139 | else:
140 | recall = 0
141 | if precision + recall > 0:
142 | f1 = 2 * precision * recall / (precision + recall)
143 | else:
144 | f1 = 0
145 | results[tag]['precision'] = precision
146 | results[tag]['recall'] = recall
147 | results[tag]['f1'] = f1
148 | results[tag]['n_pred'] = len(pred_chunk)
149 | results[tag]['n_true'] = len(true_chunk)
150 | results[tag]['tp'] = tp
151 | results[tag]['fn'] = fn
152 | results[tag]['fp'] = fp
153 |
154 | results['__total__'], accuracy, total_true_entities, total_predicted_entities, total_correct = _global_stats_f1(results)
155 | results['__total__']['n_pred'] = total_predicted_entities
156 | results['__total__']['n_true'] = total_true_entities
157 | results['__total__']["n_tokens"] = n_tokens
158 | if print_results:
159 | _print_conll_report(results, short_report, entity_of_interest)
160 | return results
161 |
162 |
163 | def _print_conll_report(results, short_report=False, entity_of_interest=None):
164 | _, accuracy, total_true_entities, total_predicted_entities, total_correct = _global_stats_f1(results)
165 | n_tokens = results['__total__']["n_tokens"]
166 | tags = list(results.keys())
167 |
168 | s = 'processed {len} tokens ' \
169 | 'with {tot_true} phrases; ' \
170 | 'found: {tot_pred} phrases;' \
171 | ' correct: {tot_cor}.\n\n'.format(len=n_tokens,
172 | tot_true=total_true_entities,
173 | tot_pred=total_predicted_entities,
174 | tot_cor=total_correct)
175 |
176 | s += 'precision: {tot_prec:.2f}%; ' \
177 | 'recall: {tot_recall:.2f}%; ' \
178 | 'FB1: {tot_f1:.2f}\n\n'.format(acc=accuracy,
179 | tot_prec=results['__total__']['precision'],
180 | tot_recall=results['__total__']['recall'],
181 | tot_f1=results['__total__']['f1'])
182 |
183 | if not short_report:
184 | for tag in tags:
185 | if entity_of_interest is not None:
186 | if entity_of_interest in tag:
187 | s += '\t' + tag + ': precision: {tot_prec:.2f}%; ' \
188 | 'recall: {tot_recall:.2f}%; ' \
189 | 'F1: {tot_f1:.2f} ' \
190 | '{tot_predicted}\n\n'.format(tot_prec=results[tag]['precision'],
191 | tot_recall=results[tag]['recall'],
192 | tot_f1=results[tag]['f1'],
193 | tot_predicted=results[tag]['n_pred'])
194 | elif tag != '__total__':
195 | s += '\t' + tag + ': precision: {tot_prec:.2f}%; ' \
196 | 'recall: {tot_recall:.2f}%; ' \
197 | 'F1: {tot_f1:.2f} ' \
198 | '{tot_predicted}\n\n'.format(tot_prec=results[tag]['precision'],
199 | tot_recall=results[tag]['recall'],
200 | tot_f1=results[tag]['f1'],
201 | tot_predicted=results[tag]['n_pred'])
202 | elif entity_of_interest is not None:
203 | s += '\t' + entity_of_interest + ': precision: {tot_prec:.2f}%; ' \
204 | 'recall: {tot_recall:.2f}%; ' \
205 | 'F1: {tot_f1:.2f} ' \
206 | '{tot_predicted}\n\n'.format(tot_prec=results[entity_of_interest]['precision'],
207 | tot_recall=results[entity_of_interest]['recall'],
208 | tot_f1=results[entity_of_interest]['f1'],
209 | tot_predicted=results[entity_of_interest]['n_pred'])
210 | print(s)
211 |
--------------------------------------------------------------------------------
/modules/analyze_utils/plot_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import defaultdict
3 | from matplotlib import pyplot as plt
4 | from .utils import tokens2spans, bert_labels2tokens, voting_choicer, first_choicer
5 | from sklearn_crfsuite.metrics import flat_classification_report
6 | from sklearn.metrics import f1_score
7 |
8 |
9 | def plot_by_class_curve(history, metric_, sup_labels):
10 | by_class = get_by_class_metric(history, metric_, sup_labels)
11 | vals = list(by_class.values())
12 | x = np.arange(len(vals[0]))
13 | args = []
14 | for val in vals:
15 | args.append(x)
16 | args.append(val)
17 | plt.figure(figsize=(15, 10))
18 | plt.grid(True)
19 | plt.plot(*args)
20 | plt.legend(list(by_class.keys()))
21 | _, _ = plt.yticks(np.arange(0, 1, step=0.1))
22 | plt.show()
23 |
24 |
25 | def get_metrics_by_class(text_res, sup_labels):
26 | # text_res = flat_classification_report(y_true, y_pred, labels=labels, digits=3)
27 | res = {}
28 | for line in text_res.split("\n"):
29 | line = line.split()
30 | if len(line) and line[0] in sup_labels:
31 | res[line[0]] = {key: val for key, val in zip(["prec", "rec", "f1"], line[1:-1])}
32 | return res
33 |
34 |
35 | def get_by_class_metric(history, metric_, sup_labels):
36 | res = defaultdict(list)
37 | for h in history:
38 | h = get_metrics_by_class(h, sup_labels)
39 | for class_, metrics_ in h.items():
40 | res[class_].append(float(metrics_[metric_]))
41 | return res
42 |
43 |
44 | def get_max_metric(history, metric_, sup_labels, return_idx=False):
45 | by_class = get_by_class_metric(history, metric_, sup_labels)
46 | by_class_arr = np.array(list(by_class.values()))
47 | idx = np.array(by_class_arr.sum(0)).argmax()
48 | if return_idx:
49 | return list(zip(by_class.keys(), by_class_arr[:, idx])), idx
50 | return list(zip(by_class.keys(), by_class_arr[:, idx]))
51 |
52 |
53 | def get_mean_max_metric(history, metric_="f1", return_idx=False):
54 | m_idx = 0
55 | if metric_ == "f1":
56 | m_idx = 2
57 | elif m_idx == "rec":
58 | m_idx = 1
59 | metrics = [float(h.split("\n")[-3].split()[2 + m_idx]) for h in history]
60 | idx = np.argmax(metrics)
61 | res = metrics[idx]
62 | if return_idx:
63 | return idx, res
64 | return res
65 |
66 |
67 | def get_bert_span_report(dl, preds, labels=None, fn=voting_choicer):
68 | pred_tokens, pred_labels = bert_labels2tokens(dl, preds)
69 | true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])
70 | spans_pred = tokens2spans(pred_tokens, pred_labels)
71 | spans_true = tokens2spans(true_tokens, true_labels)
72 | res_t = []
73 | res_p = []
74 | for pred_span, true_span in zip(spans_pred, spans_true):
75 | text2span = {t: l for t, l in pred_span}
76 | for (pt, pl), (tt, tl) in zip(pred_span, true_span):
77 | res_t.append(tl)
78 | if tt in text2span:
79 | res_p.append(pl)
80 | else:
81 | res_p.append("O")
82 | return flat_classification_report([res_t], [res_p], labels=labels, digits=4)
83 |
84 |
85 | def analyze_bert_errors(dl, labels, fn=voting_choicer):
86 | errors = []
87 | res_tokens = []
88 | res_labels = []
89 | r_labels = [x.labels for x in dl.dataset]
90 | for f, l_, rl in zip(dl.dataset, labels, r_labels):
91 | label = fn(f.tok_map, l_)
92 | label_r = fn(f.tok_map, rl)
93 | prev_idx = 0
94 | errors_ = []
95 | # if len(label_r) > 1:
96 | # assert len(label_r) == len(f.tokens) - 1
97 | for idx, (lbl, rl, t) in enumerate(zip(label, label_r, f.tokens)):
98 | if lbl != rl:
99 | errors_.append(
100 | {"token: ": t,
101 | "real_label": rl,
102 | "pred_label": lbl,
103 | "bert_token": f.bert_tokens[prev_idx:f.tok_map[idx]],
104 | "real_bert_label": f.labels[prev_idx:f.tok_map[idx]],
105 | "pred_bert_label": l_[prev_idx:f.tok_map[idx]],
106 | "text_example": " ".join(f.tokens[1:-1]),
107 | "labels": " ".join(label_r[1:])})
108 | prev_idx = f.tok_map[idx]
109 | errors.append(errors_)
110 | res_tokens.append(f.tokens[1:-1])
111 | res_labels.append(label[1:])
112 | return res_tokens, res_labels, errors
113 |
114 |
115 | def get_f1_score(y_true, y_pred, labels):
116 | res_t = []
117 | res_p = []
118 | for yts, yps in zip(y_true, y_pred):
119 | for yt, yp in zip(yts, yps):
120 | res_t.append(yt)
121 | res_p.append(yp)
122 | return f1_score(res_t, res_p, average="macro", labels=labels)
123 |
--------------------------------------------------------------------------------
/modules/analyze_utils/utils.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import numpy as np
3 | import json
4 | import numpy
5 |
6 |
7 | def voting_choicer(tok_map, labels):
8 | label = []
9 | prev_idx = 0
10 | for origin_idx in tok_map:
11 | votes = []
12 | for l in labels[prev_idx:origin_idx]:
13 | if l != "X":
14 | votes.append(l)
15 | vote_labels = Counter(votes)
16 | if not len(vote_labels):
17 | vote_labels = {"B_O": 1}
18 | # vote_labels = Counter(c)
19 | lb = sorted(list(vote_labels), key=lambda x: vote_labels[x])
20 | if len(lb):
21 | label.append(lb[-1])
22 | prev_idx = origin_idx
23 | if origin_idx < 0:
24 | break
25 |
26 | return label
27 |
28 |
29 | def first_choicer(tok_map, labels):
30 | label = []
31 | prev_idx = 0
32 | for origin_idx in tok_map:
33 | l = labels[prev_idx]
34 | if l in ["X"]:
35 | l = "B_O"
36 | if l == "B_O":
37 | for ll in labels[prev_idx + 1:origin_idx]:
38 | if ll not in ["B_O", "I_O", "X"]:
39 | l = ll
40 | break
41 | label.append(l)
42 | prev_idx = origin_idx
43 | if origin_idx < 0:
44 | break
45 | # assert "[SEP]" not in label
46 | return label
47 |
48 |
49 | def bert_labels2tokens(dl, labels, fn=voting_choicer):
50 | res_tokens = []
51 | res_labels = []
52 | for f, l in zip(dl.dataset, labels):
53 | label = fn(f.tok_map, l[1:])
54 |
55 | res_tokens.append(f.tokens[1:-1])
56 | res_labels.append(label[1:])
57 | return res_tokens, res_labels
58 |
59 |
60 | def tokens2spans_(tokens_, labels_):
61 | res = []
62 | idx_ = 0
63 | while idx_ < len(labels_):
64 | label = labels_[idx_]
65 | if label in ["I_O", "B_O", "O"]:
66 | res.append((tokens_[idx_], "O"))
67 | idx_ += 1
68 | elif label == "":
69 | break
70 | elif label == "[CLS]" or label == "":
71 | res.append((tokens_[idx_], label))
72 | idx_ += 1
73 | else:
74 | span = [tokens_[idx_]]
75 | try:
76 | span_label = labels_[idx_].split("_")[1]
77 | except IndexError:
78 | print(label, labels_[idx_].split("_"))
79 | span_label = None
80 | idx_ += 1
81 | while idx_ < len(labels_) and labels_[idx_] not in ["I_O", "B_O", "O"] \
82 | and labels_[idx_].split("_")[0] == "I":
83 | if span_label == labels_[idx_].split("_")[1]:
84 | span.append(tokens_[idx_])
85 | idx_ += 1
86 | else:
87 | break
88 | res.append((" ".join(span), span_label))
89 | return res
90 |
91 |
92 | def tokens2spans(tokens, labels):
93 | assert len(tokens) == len(labels)
94 |
95 | return list(map(lambda x: tokens2spans_(*x), zip(tokens, labels)))
96 |
97 |
98 | def encode_position(pos, emb_dim=10):
99 | """The sinusoid position encoding"""
100 |
101 | # keep dim 0 for padding token position encoding zero vector
102 | if pos == 0:
103 | return np.zeros(emb_dim)
104 | position_enc = np.array(
105 | [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)])
106 |
107 | # apply sin on 0th,2nd,4th...emb_dim
108 | position_enc[0::2] = np.sin(position_enc[0::2])
109 | # apply cos on 1st,3rd,5th...emb_dim
110 | position_enc[1::2] = np.cos(position_enc[1::2])
111 | return list(position_enc.reshape(-1))
112 |
113 |
114 | class JsonEncoder(json.JSONEncoder):
115 | def default(self, obj):
116 | if isinstance(obj, numpy.integer):
117 | return int(obj)
118 | elif isinstance(obj, numpy.floating):
119 | return float(obj)
120 | elif isinstance(obj, numpy.ndarray):
121 | return obj.tolist()
122 | else:
123 | return super(JsonEncoder, self).default(obj)
124 |
125 |
126 | def jsonify(data):
127 | return json.dumps(data, cls=JsonEncoder)
128 |
129 |
130 | def read_json(config):
131 | if isinstance(config, str):
132 | with open(config, "r") as f:
133 | config = json.load(f)
134 | return config
135 |
136 |
137 | def save_json(config, path):
138 | with open(path, "w") as file:
139 | json.dump(config, file, cls=JsonEncoder)
140 |
--------------------------------------------------------------------------------
/modules/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai-forever/ner-bert/b75a903c35acbd36ff5f26c525e3596294f36815/modules/data/__init__.py
--------------------------------------------------------------------------------
/modules/data/bert_data_clf.py:
--------------------------------------------------------------------------------
1 | from .bert_data import TextDataLoader
2 | from pytorch_pretrained_bert import BertTokenizer
3 | from modules.utils import read_config, if_none
4 | from modules import tqdm
5 | import pandas as pd
6 | from copy import deepcopy
7 |
8 |
9 | class InputFeature(object):
10 | """A single set of features of data."""
11 |
12 | def __init__(
13 | self,
14 | # Bert data
15 | bert_tokens, input_ids, input_mask, input_type_ids,
16 | # Origin data
17 | tokens, tok_map,
18 | # Cls data
19 | cls=None, id_cls=None):
20 | """
21 | Data has the following structure.
22 | data[0]: list, tokens ids
23 | data[1]: list, tokens mask
24 | data[2]: list, tokens type ids (for bert)
25 | """
26 | self.data = []
27 | # Bert data
28 | self.bert_tokens = bert_tokens
29 | self.input_ids = input_ids
30 | self.data.append(input_ids)
31 | self.input_mask = input_mask
32 | self.data.append(input_mask)
33 | self.input_type_ids = input_type_ids
34 | self.data.append(input_type_ids)
35 | # Classification data
36 | self.cls = cls
37 | self.id_cls = id_cls
38 | if cls is not None:
39 | self.data.append(id_cls)
40 | # Origin data
41 | self.tokens = tokens
42 | self.tok_map = tok_map
43 |
44 | def __iter__(self):
45 | return iter(self.data)
46 |
47 |
48 | class TextDataSet(object):
49 |
50 | @classmethod
51 | def from_config(cls, config, clear_cache=False, df=None):
52 | return cls.create(**read_config(config), clear_cache=clear_cache, df=df)
53 |
54 | @classmethod
55 | def create(cls,
56 | df_path=None,
57 | idx2cls=None,
58 | idx2cls_path=None,
59 | min_char_len=1,
60 | model_name="bert-base-multilingual-cased",
61 | max_sequence_length=424,
62 | pad_idx=0,
63 | clear_cache=False,
64 | df=None, tokenizer=None):
65 | if tokenizer is None:
66 | tokenizer = BertTokenizer.from_pretrained(model_name)
67 | config = {
68 | "min_char_len": min_char_len,
69 | "model_name": model_name,
70 | "max_sequence_length": max_sequence_length,
71 | "clear_cache": clear_cache,
72 | "df_path": df_path,
73 | "pad_idx": pad_idx,
74 | "idx2cls_path": idx2cls_path
75 | }
76 | if df is None and df_path is not None:
77 | df = pd.read_csv(df_path, sep='\t', engine='python')
78 | elif df is None:
79 | df = pd.DataFrame(columns=["text", "clf"])
80 | if clear_cache:
81 | _, idx2cls = cls.create_vocabs(df, idx2cls_path, idx2cls)
82 | self = cls(tokenizer, df=df, config=config, idx2cls=idx2cls)
83 | self.load(df=df)
84 | return self
85 |
86 | @staticmethod
87 | def create_vocabs(
88 | df, idx2cls_path, idx2cls=None):
89 | idx2cls = idx2cls
90 | cls2idx = {}
91 | if idx2cls is not None:
92 | cls2idx = {label: idx for idx, label in enumerate(idx2cls)}
93 | else:
94 | idx2cls = []
95 | for _, row in tqdm(df.iterrows(), total=len(df), leave=False, desc="Creating labels vocabs"):
96 | if row.cls not in cls2idx:
97 | cls2idx[row.cls] = len(cls2idx)
98 | idx2cls.append(row.cls)
99 |
100 | with open(idx2cls_path, "w", encoding="utf-8") as f:
101 | for label in idx2cls:
102 | f.write("{}\n".format(label))
103 |
104 | return cls2idx, idx2cls
105 |
106 | def load(self, df_path=None, df=None):
107 | df_path = if_none(df_path, self.config["df_path"])
108 | if df is None:
109 | self.df = pd.read_csv(df_path, sep='\t')
110 |
111 | self.idx2cls = []
112 | self.cls2idx = {}
113 | with open(self.config["idx2cls_path"], "r", encoding="utf-8") as f:
114 | for idx, label in enumerate(f.readlines()):
115 | label = label.strip()
116 | self.cls2idx[label] = idx
117 | self.idx2cls.append(label)
118 |
119 | def create_feature(self, row):
120 | bert_tokens = []
121 | orig_tokens = row.text.split()
122 | tok_map = []
123 | for orig_token in orig_tokens:
124 | cur_tokens = self.tokenizer.tokenize(orig_token)
125 | if self.config["max_sequence_length"] - 2 < len(bert_tokens) + len(cur_tokens):
126 | break
127 | cur_tokens = self.tokenizer.tokenize(orig_token)
128 | tok_map.append(len(bert_tokens))
129 | bert_tokens.extend(cur_tokens)
130 |
131 | orig_tokens = ["[CLS]"] + orig_tokens + ["[SEP]"]
132 |
133 | input_ids = self.tokenizer.convert_tokens_to_ids(['[CLS]'] + bert_tokens + ['[SEP]'])
134 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
135 | # tokens are attended to.
136 | input_mask = [1] * len(input_ids)
137 | # Zero-pad up to the sequence length.
138 | while len(input_ids) < self.config["max_sequence_length"]:
139 | input_ids.append(self.config["pad_idx"])
140 | input_mask.append(0)
141 | tok_map.append(-1)
142 | input_type_ids = [0] * len(input_ids)
143 | cls = str(row.cls)
144 | id_cls = self.cls2idx[cls]
145 | return InputFeature(
146 | # Bert data
147 | bert_tokens=bert_tokens,
148 | input_ids=input_ids,
149 | input_mask=input_mask,
150 | input_type_ids=input_type_ids,
151 | # Origin data
152 | tokens=orig_tokens,
153 | tok_map=tok_map,
154 | # Cls
155 | cls=cls, id_cls=id_cls
156 | )
157 |
158 | def __getitem__(self, item):
159 | if self.config["df_path"] is None and self.df is None:
160 | raise ValueError("Should setup df_path or df.")
161 | if self.df is None:
162 | self.load()
163 |
164 | return self.create_feature(self.df.iloc[item])
165 |
166 | def __len__(self):
167 | return len(self.df) if self.df is not None else 0
168 |
169 | def save(self, df_path=None):
170 | df_path = if_none(df_path, self.config["df_path"])
171 | self.df.to_csv(df_path, sep='\t', index=False)
172 |
173 | def __init__(
174 | self, tokenizer,
175 | df=None,
176 | config=None,
177 | idx2cls=None):
178 | self.df = df
179 | self.tokenizer = tokenizer
180 | self.config = config
181 | self.label2idx = None
182 |
183 | self.idx2cls = idx2cls
184 | if idx2cls is not None:
185 | self.cls2idx = {label: idx for idx, label in enumerate(idx2cls)}
186 |
187 |
188 | class LearnDataClass(object):
189 | def __init__(self, train_ds=None, train_dl=None, valid_ds=None, valid_dl=None):
190 | self.train_ds = train_ds
191 | self.train_dl = train_dl
192 | self.valid_ds = valid_ds
193 | self.valid_dl = valid_dl
194 |
195 | @classmethod
196 | def create(cls,
197 | # DataSet params
198 | train_df_path,
199 | valid_df_path,
200 | idx2cls=None,
201 | idx2cls_path=None,
202 | min_char_len=1,
203 | model_name="bert-base-multilingual-cased",
204 | max_sequence_length=424,
205 | pad_idx=0,
206 | clear_cache=False,
207 | train_df=None,
208 | valid_df=None,
209 | # DataLoader params
210 | device="cuda", batch_size=16):
211 | train_ds = None
212 | train_dl = None
213 | valid_ds = None
214 | valid_dl = None
215 | if idx2cls_path is not None:
216 | train_ds = TextDataSet.create(
217 | train_df_path,
218 | idx2cls=idx2cls,
219 | idx2cls_path=idx2cls_path,
220 | min_char_len=min_char_len,
221 | model_name=model_name,
222 | max_sequence_length=max_sequence_length,
223 | pad_idx=pad_idx,
224 | clear_cache=clear_cache,
225 | df=train_df)
226 | if len(train_ds):
227 | train_dl = TextDataLoader(train_ds, device=device, shuffle=True, batch_size=batch_size)
228 | if valid_df_path is not None:
229 | valid_ds = TextDataSet.create(
230 | valid_df_path,
231 | idx2cls=train_ds.idx2cls,
232 | idx2cls_path=idx2cls_path,
233 | min_char_len=min_char_len,
234 | model_name=model_name,
235 | max_sequence_length=max_sequence_length,
236 | pad_idx=pad_idx,
237 | clear_cache=False,
238 | df=valid_df, tokenizer=train_ds.tokenizer)
239 | valid_dl = TextDataLoader(valid_ds, device=device, batch_size=batch_size)
240 |
241 | self = cls(train_ds, train_dl, valid_ds, valid_dl)
242 | self.device = device
243 | self.batch_size = batch_size
244 | return self
245 |
246 | def load(self):
247 | if self.train_ds is not None:
248 | self.train_ds.load()
249 | if self.valid_ds is not None:
250 | self.valid_ds.load()
251 |
252 | def save(self):
253 | if self.train_ds is not None:
254 | self.train_ds.save()
255 | if self.valid_ds is not None:
256 | self.valid_ds.save()
257 |
258 |
259 | def get_data_loader_for_predict(data, df_path=None, df=None):
260 | config = deepcopy(data.train_ds.config)
261 | config["df_path"] = df_path
262 | config["clear_cache"] = False
263 | ds = TextDataSet.create(
264 | idx2cls=data.train_ds.idx2cls,
265 | df=df, tokenizer=data.train_ds.tokenizer, **config)
266 | return TextDataLoader(
267 | ds, device=data.device, batch_size=data.batch_size, shuffle=False), ds
268 |
--------------------------------------------------------------------------------
/modules/data/conll2003/__init__.py:
--------------------------------------------------------------------------------
1 | from .prc import conll2003_preprocess
2 |
3 |
4 | __all__ = ["conll2003_preprocess"]
5 |
--------------------------------------------------------------------------------
/modules/data/conll2003/prc.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from modules import tqdm
3 | import argparse
4 | import codecs
5 | import os
6 |
7 |
8 | def conll2003_preprocess(
9 | data_dir, train_name="eng.train", dev_name="eng.testa", test_name="eng.testb"):
10 | train_f = read_data(os.path.join(data_dir, train_name))
11 | dev_f = read_data(os.path.join(data_dir, dev_name))
12 | test_f = read_data(os.path.join(data_dir, test_name))
13 |
14 | train = pd.DataFrame({"labels": [x[0] for x in train_f], "text": [x[1] for x in train_f]})
15 | train["cls"] = train["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()]))
16 | train.to_csv(os.path.join(data_dir, "{}.train.csv".format(train_name)), index=False, sep="\t")
17 |
18 | dev = pd.DataFrame({"labels": [x[0] for x in dev_f], "text": [x[1] for x in dev_f]})
19 | dev["cls"] = dev["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()]))
20 | dev.to_csv(os.path.join(data_dir, "{}.dev.csv".format(dev_name)), index=False, sep="\t")
21 |
22 | test_ = pd.DataFrame({"labels": [x[0] for x in test_f], "text": [x[1] for x in test_f]})
23 | test_["cls"] = test_["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()]))
24 | test_.to_csv(os.path.join(data_dir, "{}.dev.csv".format(test_name)), index=False, sep="\t")
25 |
26 |
27 | def read_data(input_file):
28 | """Reads a BIO data."""
29 | with codecs.open(input_file, "r", encoding="utf-8") as f:
30 | lines = []
31 | words = []
32 | labels = []
33 | f_lines = f.readlines()
34 | for line in tqdm(f_lines, total=len(f_lines), desc="Process {}".format(input_file)):
35 | contends = line.strip()
36 | word = line.strip().split(' ')[0]
37 | label = line.strip().split(' ')[-1]
38 | if contends.startswith("-DOCSTART-"):
39 | words.append('')
40 | continue
41 |
42 | if len(contends) == 0 and not len(words):
43 | words.append("")
44 |
45 | if len(contends) == 0 and words[-1] == '.':
46 | lbl = ' '.join([label for label in labels if len(label) > 0])
47 | w = ' '.join([word for word in words if len(word) > 0])
48 | lines.append([lbl, w])
49 | words = []
50 | labels = []
51 | continue
52 | words.append(word)
53 | labels.append(label.replace("-", "_"))
54 | return lines
55 |
56 |
57 | def parse_args():
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument('--data_dir', type=str)
60 | parser.add_argument('--train_name', type=str, default="eng.train")
61 | parser.add_argument('--dev_name', type=str, default="eng.testa")
62 | parser.add_argument('--test_name', type=str, default="eng.testb")
63 | return vars(parser.parse_args())
64 |
65 |
66 | if __name__ == "__main__":
67 | conll2003_preprocess(**parse_args())
68 |
--------------------------------------------------------------------------------
/modules/data/download_data.py:
--------------------------------------------------------------------------------
1 | import urllib
2 | import sys
3 | import os
4 |
5 |
6 | tasks_urls = {
7 | "conll2003": [
8 | ["eng.testa", "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa"],
9 | ["eng.testb", "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb"],
10 | ["eng.train", "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train"]
11 | ]}
12 |
13 |
14 | def download_data(task_name, data_dir):
15 | req = urllib
16 | if sys.version_info >= (3, 0):
17 | req = urllib.request
18 | for data_file, url in tasks_urls[task_name]:
19 | if not os.path.exists(data_dir):
20 | os.mkdir(data_dir)
21 | _ = req.urlretrieve(url, os.path.join(data_dir, data_file))
22 |
--------------------------------------------------------------------------------
/modules/data/fre/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from .reader import Reader as FREReader
3 | from .prc import fact_ru_eval_preprocess
4 |
5 | __all__ = ["FREReader", "fact_ru_eval_preprocess"]
6 |
--------------------------------------------------------------------------------
/modules/data/fre/bilou/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/modules/data/fre/bilou/from_bilou.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | def untag(list_of_tags, list_of_tokens):
5 | """
6 | :param list_of_tags:
7 | :param list_of_tokens:
8 | :return:
9 | """
10 | if len(list_of_tags) == len(list_of_tokens):
11 | dict_of_final_ne = {}
12 | ne_words = []
13 | ne_tag = None
14 |
15 | for index in range(len(list_of_tokens)):
16 | if not ((ne_tag is not None) ^ (ne_words != [])):
17 | current_tag = list_of_tags[index]
18 | current_token = list_of_tokens[index]
19 |
20 | if current_tag.startswith('B') or current_tag.startswith('I'):
21 | dict_of_final_ne, ne_words, ne_tag = __check_bi(
22 | dict_of_final_ne, ne_words, ne_tag, current_tag, current_token)
23 | elif current_tag.startswith('L'):
24 | dict_of_final_ne, ne_words, ne_tag = __check_l(
25 | dict_of_final_ne, ne_words, ne_tag, current_tag, current_token)
26 | elif current_tag.startswith('O'):
27 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag)
28 |
29 | elif current_tag.startswith('U'):
30 | dict_of_final_ne, ne_words, ne_tag = __check_u(dict_of_final_ne, ne_words, ne_tag, current_tag,
31 | current_token)
32 | else:
33 | raise ValueError("tag contains no BILOU tags")
34 | else:
35 | if ne_tag is None:
36 | raise Exception('Somehow ne_tag is None and ne_words is not None')
37 | else:
38 | raise Exception('Somehow ne_words is None and ne_tag is not None')
39 |
40 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag)
41 | return __to_output_format(dict_of_final_ne)
42 | else:
43 | raise ValueError('lengths are not equal')
44 |
45 |
46 | def __check_bi(dict_of_final_ne, ne_words, ne_tag, current_tag, current_token):
47 | if ne_tag is None and ne_words == []:
48 | ne_tag = current_tag[1:]
49 | ne_words = [current_token]
50 | else:
51 | if current_tag.startswith('I') and ne_tag == current_tag[1:]:
52 | ne_words.append(current_token)
53 | else:
54 | dict_of_final_ne, ne_words, ne_tag = __replace_by_new(dict_of_final_ne, ne_words, ne_tag, current_tag,
55 | current_token)
56 | return dict_of_final_ne, ne_words, ne_tag
57 |
58 |
59 | def __check_l(dict_of_final_ne, ne_words, ne_tag, current_tag, current_token):
60 | if ne_tag == current_tag[1:]:
61 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words+[current_token], ne_tag)
62 | else:
63 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag)
64 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, [current_token], current_tag[1:])
65 | return dict_of_final_ne, ne_words, ne_tag
66 |
67 |
68 | def __check_u(dict_of_final_ne, ne_words, ne_tag, current_tag, current_token):
69 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag)
70 | return __finish_ne_if_required(dict_of_final_ne, [current_token], current_tag[1:])
71 |
72 |
73 | def __replace_by_new(dict_of_final_ne, ne_words, ne_tag, current_tag, current_token):
74 | dict_of_final_ne, ne_words, ne_tag = __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag)
75 | ne_tag = current_tag[1:]
76 | ne_words = [current_token]
77 | return dict_of_final_ne, ne_words, ne_tag
78 |
79 |
80 | def __finish_ne_if_required(dict_of_final_ne, ne_words, ne_tag):
81 | if ne_tag is not None and ne_words != []:
82 | dict_of_final_ne[tuple(ne_words)] = ne_tag
83 | ne_tag = None
84 | ne_words = []
85 | return dict_of_final_ne, ne_words, ne_tag
86 |
87 |
88 | def __to_output_format(dict_nes):
89 | """
90 | :param dict_nes:
91 | :return:
92 | """
93 | list_of_results_for_output = []
94 |
95 | for tokens_tuple, tag in dict_nes.items():
96 | position = int(tokens_tuple[0].get_position())
97 | length = int(tokens_tuple[-1].get_position()) + int(tokens_tuple[-1].get_length()) - position
98 | list_of_results_for_output.append([tag, position, length])
99 |
100 | return list_of_results_for_output
101 |
--------------------------------------------------------------------------------
/modules/data/fre/bilou/to_bilou.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from ..entity.taggedtoken import TaggedToken
3 |
4 |
5 | def get_tagged_tokens_from(dict_of_nes, token_list):
6 | list_of_tagged_tokens = [TaggedToken('O', token_list[i]) for i in range(len(token_list))]
7 | dict_of_tokens_with_indexes = {token_list[i].id: i for i in range(len(token_list))}
8 |
9 | for ne in dict_of_nes.values():
10 | for tokenid in ne['tokens_list']:
11 | try:
12 | tag = format_tag(tokenid, ne)
13 | except ValueError:
14 | tag = "O"
15 | id_in_token_tuple = dict_of_tokens_with_indexes[tokenid]
16 | token = token_list[id_in_token_tuple]
17 | list_of_tagged_tokens[id_in_token_tuple] = TaggedToken(tag, token)
18 | return list_of_tagged_tokens
19 |
20 |
21 | def format_tag(tokenid, ne):
22 | bilou = __choose_bilou_tag_for(tokenid, ne['tokens_list'])
23 | formatted_tag = __tag_to_fact_ru_eval_format(ne['tag'])
24 | return "{}_{}".format(bilou, formatted_tag)
25 |
26 |
27 | def __choose_bilou_tag_for(token_id, token_list):
28 | if len(token_list) == 1:
29 | return 'B'
30 | elif len(token_list) > 1:
31 | if token_list.index(token_id) == 0:
32 | return 'B'
33 | else:
34 | return 'I'
35 |
36 |
37 | def __tag_to_fact_ru_eval_format(tag):
38 | if tag == 'Person':
39 | return 'PER'
40 | elif tag == 'Org':
41 | return 'ORG'
42 | elif tag == 'Location':
43 | return 'LOC'
44 | elif tag == 'LocOrg':
45 | return 'LOC'
46 | elif tag == 'Project':
47 | return 'ORG'
48 | else:
49 | raise ValueError('tag ' + tag + " is not the right tag")
50 |
--------------------------------------------------------------------------------
/modules/data/fre/entity/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/modules/data/fre/entity/document.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | from .token import Token
3 | from .taggedtoken import TaggedToken
4 | from collections import defaultdict
5 | from ..bilou import to_bilou
6 |
7 |
8 | class Document(object):
9 | def __init__(self, path, tagged=True, encoding="utf-8"):
10 | self.path = path
11 | self.tagged = tagged
12 | self.encoding = encoding
13 | self.tokens = []
14 | self.tagged_tokens = []
15 | self.load()
16 |
17 | def to_text_tokens(self):
18 | return [token.text for token in self.tokens]
19 |
20 | def get_tags(self):
21 | return [token.get_tag() for token in self.tagged_tokens]
22 |
23 | def load(self):
24 | self.tokens = self.__get_tokens_from_file()
25 | if self.tagged:
26 | self.tagged_tokens = self.__get_tagged_tokens_from()
27 | else:
28 | self.tagged_tokens = [TaggedToken(None, token) for token in self.tokens]
29 | return self
30 |
31 | def parse_file(self, path):
32 | with codecs.open(path, 'r', encoding=self.encoding, errors="ignore") as file:
33 | rows = file.read().split('\n')
34 | return [row.split(' # ')[0].split() for row in rows if len(row) != 0]
35 |
36 | def __get_tokens_from_file(self):
37 | rows = self.parse_file(self.path + '.tokens')
38 | tokens = []
39 | for token_str in rows:
40 | tokens.append(Token().from_sting(token_str))
41 | return tokens
42 |
43 | def __get_tagged_tokens_from(self):
44 | span_dict = self.__span_id2token_ids(self.path + '.spans', [token.id for token in self.tokens])
45 | object_dict = self.__to_dict_of_objects(self.path + '.objects')
46 | dict_of_nes = self.__merge(object_dict, span_dict, self.tokens)
47 | return to_bilou.get_tagged_tokens_from(dict_of_nes, self.tokens)
48 |
49 | def __span_id2token_ids(self, span_file, token_ids):
50 | span_list = self.parse_file(span_file)
51 | dict_of_spans = {}
52 | for span in span_list:
53 | span_id = span[0]
54 | span_start = span[4]
55 | span_length_in_tokens = int(span[5])
56 | list_of_token_of_spans = self.__find_tokens_for(span_start, span_length_in_tokens, token_ids)
57 | dict_of_spans[span_id] = list_of_token_of_spans
58 | return dict_of_spans
59 |
60 | @staticmethod
61 | def __find_tokens_for(start, length, token_ids):
62 | list_of_tokens = []
63 | index = token_ids.index(start)
64 | for i in range(length):
65 | list_of_tokens.append(token_ids[index + i])
66 | return list_of_tokens
67 |
68 | def __to_dict_of_objects(self, object_file):
69 | object_list = self.parse_file(object_file)
70 | dict_of_objects = {}
71 | for obj in object_list:
72 | object_id = obj[0]
73 | object_tag = obj[1]
74 | object_spans = obj[2:]
75 | dict_of_objects[object_id] = {'tag': object_tag, 'spans': object_spans}
76 | return dict_of_objects
77 |
78 | def __merge(self, object_dict, span_dict, tokens):
79 | ne_dict = self.__get_dict_of_nes(object_dict, span_dict)
80 | return self.__clean(ne_dict, tokens)
81 |
82 | @staticmethod
83 | def __get_dict_of_nes(object_dict, span_dict):
84 | ne_dict = defaultdict(set)
85 | for obj_id, obj_values in object_dict.items():
86 | for span in obj_values['spans']:
87 | ne_dict[(obj_id, obj_values['tag'])].update(span_dict[span])
88 | for ne in ne_dict:
89 | ne_dict[ne] = sorted(list(set([int(i) for i in ne_dict[ne]])))
90 | return ne_dict
91 |
92 | def __clean(self, ne_dict, tokens):
93 | sorted_nes = sorted(ne_dict.items(), key=self.__sort_by_tokens)
94 | dict_of_tokens_by_id = {}
95 | for i in range(len(tokens)):
96 | dict_of_tokens_by_id[tokens[i].id] = i
97 | result_nes = {}
98 | if len(sorted_nes) != 0:
99 | start_ne = sorted_nes[0]
100 | for ne in sorted_nes:
101 | if self.__not_intersect(start_ne[1], ne[1]):
102 | result_nes[start_ne[0][0]] = {
103 | 'tokens_list': self.__check_order(start_ne[1], dict_of_tokens_by_id, tokens),
104 | 'tag': start_ne[0][1]}
105 | start_ne = ne
106 | else:
107 | result_tokens_list = self.__check_normal_form(start_ne[1], ne[1])
108 | start_ne = (start_ne[0], result_tokens_list)
109 | result_nes[start_ne[0][0]] = {
110 | 'tokens_list': self.__check_order(start_ne[1], dict_of_tokens_by_id, tokens),
111 | 'tag': start_ne[0][1]}
112 | return result_nes
113 |
114 | @staticmethod
115 | def __sort_by_tokens(tokens):
116 | ids_as_int = [int(token_id) for token_id in tokens[1]]
117 | return min(ids_as_int), -max(ids_as_int)
118 |
119 | @staticmethod
120 | def __not_intersect(start_ne, current_ne):
121 | intersection = set.intersection(set(start_ne), set(current_ne))
122 | return intersection == set()
123 |
124 | def __check_normal_form(self, start_ne, ne):
125 | all_tokens = set.union(set(start_ne), set(ne))
126 | return self.__find_all_range_of_tokens(all_tokens)
127 |
128 | @staticmethod
129 | def __find_all_range_of_tokens(tokens):
130 | tokens = sorted(tokens)
131 | if (tokens[-1] - tokens[0] - len(tokens)) < 5:
132 | return list(range(tokens[0], tokens[-1] + 1))
133 | else:
134 | return tokens
135 |
136 | def __check_order(self, list_of_tokens, dict_of_tokens_by_id, tokens):
137 | list_of_tokens = [str(i) for i in self.__find_all_range_of_tokens(list_of_tokens)]
138 | result = []
139 | for token in list_of_tokens:
140 | if token in dict_of_tokens_by_id:
141 | result.append((token, dict_of_tokens_by_id[token]))
142 | result = sorted(result, key=self.__sort_by_position)
143 | result = self.__add_quotation_marks(result, tokens)
144 | return [r[0] for r in result]
145 |
146 | @staticmethod
147 | def __sort_by_position(result_tuple):
148 | return result_tuple[1]
149 |
150 | @staticmethod
151 | def __add_quotation_marks(result, tokens):
152 | result_tokens_texts = [tokens[token[1]].text for token in result]
153 | prev_pos = result[0][1] - 1
154 | next_pos = result[-1][1] + 1
155 |
156 | if prev_pos >= 0 and tokens[prev_pos].text == '«' \
157 | and '»' in result_tokens_texts and '«' not in result_tokens_texts:
158 | result = [(tokens[prev_pos].id, prev_pos)] + result
159 |
160 | if next_pos < len(tokens) and tokens[next_pos].text == '»' \
161 | and '«' in result_tokens_texts and '»' not in result_tokens_texts:
162 | result = result + [(tokens[next_pos].id, next_pos)]
163 |
164 | return result
165 |
--------------------------------------------------------------------------------
/modules/data/fre/entity/taggedtoken.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | class TaggedToken(object):
5 |
6 | @property
7 | def text(self):
8 | return self.__token.text
9 |
10 | def __init__(self, tag, token):
11 | self.__tag = tag
12 | self.__token = token
13 |
14 | def get_token(self):
15 | return self.__token
16 |
17 | def get_tag(self):
18 | return self.__tag
19 |
20 | def __repr__(self):
21 | if self.__tag:
22 | return "<" + self.__tag + "_" + str(self.__token) + ">"
23 | else:
24 | return ""
25 |
--------------------------------------------------------------------------------
/modules/data/fre/entity/token.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | class Token(object):
5 | __token_id__ = 0
6 |
7 | @property
8 | def length(self):
9 | return self.__length
10 |
11 | @property
12 | def position(self):
13 | return self.__position
14 |
15 | @property
16 | def id(self):
17 | return self.__id
18 |
19 | @property
20 | def text(self):
21 | return self.__text
22 |
23 | @property
24 | def all(self):
25 | return self.__id, self.__position, self.__length, self.__text
26 |
27 | @property
28 | def tag(self):
29 | return self.tag
30 |
31 | def __init__(self, token_id=None, position=None, length=None, text=None):
32 | self.__id = token_id
33 | if token_id is None:
34 | self.__id = Token.__token_id__
35 | Token.__token_id__ += 1
36 | self.__position = position
37 | self.__length = length
38 | self.__text = text
39 | self.__tag = None
40 |
41 | def from_sting(self, string):
42 | self.__id, self.__position, self.__length, self.__text = string
43 | return self
44 |
45 | def __len__(self):
46 | return self.__length
47 |
48 | def __str__(self):
49 | return self.__text
50 |
51 | def __repr__(self):
52 | return "<<" + self.__id + "_" + self.__text + ">>"
53 |
--------------------------------------------------------------------------------
/modules/data/fre/prc.py:
--------------------------------------------------------------------------------
1 | from modules.data.fre.reader import Reader
2 | import pandas as pd
3 | from modules import tqdm
4 | import argparse
5 |
6 |
7 | def fact_ru_eval_preprocess(dev_dir, test_dir, dev_df_path, test_df_path):
8 | dev_reader = Reader(dev_dir)
9 | dev_reader.read_dir()
10 | dev_texts, dev_tags = dev_reader.split()
11 | res_tags = []
12 | res_tokens = []
13 | for tag, tokens in tqdm(zip(dev_tags, dev_texts), total=len(dev_tags), desc="Process FactRuEval2016 dev set."):
14 | if len(tag):
15 | res_tags.append(tag)
16 | res_tokens.append(tokens)
17 | dev = pd.DataFrame({"labels": list(map(" ".join, res_tags)), "text": list(map(" ".join, res_tokens))})
18 | dev["clf"] = dev["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()]))
19 | dev.to_csv(dev_df_path, index=False, sep="\t")
20 |
21 | test_reader = Reader(test_dir)
22 | test_reader.read_dir()
23 | test_texts, test_tags = test_reader.split()
24 | res_tags = []
25 | res_tokens = []
26 | for tag, tokens in tqdm(zip(test_tags, test_texts), total=len(test_tags), desc="Process FactRuEval2016 test set."):
27 | if len(tag):
28 | res_tags.append(tag)
29 | res_tokens.append(tokens)
30 | valid = pd.DataFrame({"labels": list(map(" ".join, res_tags)), "text": list(map(" ".join, res_tokens))})
31 | valid["clf"] = valid["labels"].apply(lambda x: all([y.split("_")[0] == "O" for y in x.split()]))
32 | valid.to_csv(test_df_path, index=False, sep="\t")
33 |
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('-dd', '--dev_dir', type=str)
38 | parser.add_argument('-td', '--test_dir', type=str)
39 | parser.add_argument('-ddp', '--dev_df_path', type=str)
40 | parser.add_argument('-tdp', '--test_df_path', type=str)
41 | return vars(parser.parse_args())
42 |
43 |
44 | if __name__ == "__main__":
45 | fact_ru_eval_preprocess(**parse_args())
46 |
--------------------------------------------------------------------------------
/modules/data/fre/reader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pandas as pd
3 | from .utils import get_file_names
4 | from .entity.document import Document
5 |
6 |
7 | class Reader(object):
8 |
9 | def __init__(self,
10 | dir_path,
11 | document_creator=Document,
12 | get_file_names_=get_file_names,
13 | tagged=True):
14 | self.path = dir_path
15 | self.tagged = tagged
16 | self.documents = []
17 | self.document_creator = document_creator
18 | self.get_file_names = get_file_names_
19 |
20 | def split(self, use_morph=False):
21 | res_texts = []
22 | res_tags = []
23 | for doc in self.documents:
24 | sent_tokens = []
25 | sent_tags = []
26 | for token in doc.tagged_tokens:
27 | if token.get_tag() == "O" and token.text == ".":
28 | res_texts.append(tuple(sent_tokens))
29 | res_tags.append(tuple(sent_tags))
30 | sent_tokens = []
31 | sent_tags = []
32 | else:
33 | text = token.text
34 | sent_tokens.append(text)
35 | sent_tags.append(token.get_tag())
36 | if use_morph:
37 | return res_texts, res_tags
38 | return res_texts, res_tags
39 |
40 | def to_data_frame(self, split=False):
41 | if split:
42 | docs = self.split()
43 | else:
44 | docs = []
45 | for doc in self.documents:
46 | docs.append([(token.text, token.get_tag()) for token in doc.tagged_tokens])
47 |
48 | texts = []
49 | tags = []
50 | for sent in docs:
51 | sample_text = []
52 | sample_tag = []
53 | for text, tag in sent:
54 | sample_text.append(text)
55 | sample_tag.append(tag)
56 | texts.append(" ".join(sample_text))
57 | tags.append(" ".join(sample_tag))
58 | return pd.DataFrame({"texts": texts, "tags": tags}, columns=["texts", "tags"])
59 |
60 | def read_dir(self):
61 | for path in self.get_file_names(self.path):
62 | self.documents.append(self.document_creator(path, self.tagged))
63 |
64 | def get_text_tokens(self):
65 | return [doc.to_text_tokens() for doc in self.documents]
66 |
67 | def get_text_tags(self):
68 | return [doc.get_tags() for doc in self.documents]
69 |
--------------------------------------------------------------------------------
/modules/data/fre/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def get_file_names(path):
5 | res = []
6 | for root, dirs, files in os.walk(path):
7 | for file in files:
8 | if file.endswith('.tokens'):
9 | res.append(os.path.join(root, os.path.splitext(file)[0]))
10 | return res
11 |
--------------------------------------------------------------------------------
/modules/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai-forever/ner-bert/b75a903c35acbd36ff5f26c525e3596294f36815/modules/layers/__init__.py
--------------------------------------------------------------------------------
/modules/layers/crf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | # TODO: move to utils
6 | def log_sum_exp(tensor, dim=0):
7 | """LogSumExp operation."""
8 | m, _ = torch.max(tensor, dim)
9 | m_exp = m.unsqueeze(-1).expand_as(tensor)
10 | return m + torch.log(torch.sum(torch.exp(tensor - m_exp), dim))
11 |
12 |
13 | def sequence_mask(lens, max_len=None):
14 | batch_size = lens.size(0)
15 |
16 | if max_len is None:
17 | max_len = lens.max().item()
18 |
19 | ranges = torch.arange(0, max_len).long()
20 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len)
21 |
22 | if lens.data.is_cuda:
23 | ranges = ranges.cuda()
24 |
25 | lens_exp = lens.unsqueeze(1).expand_as(ranges)
26 | mask = ranges < lens_exp
27 |
28 | return mask
29 |
30 |
31 | class CRF(nn.Module):
32 | def forward(self, *input_):
33 | return self.viterbi_decode(*input_)
34 |
35 | def __init__(self, label_size):
36 | super(CRF, self).__init__()
37 |
38 | self.label_size = label_size
39 | self.start = self.label_size - 2
40 | self.end = self.label_size - 1
41 | transition = torch.randn(self.label_size, self.label_size)
42 | self.transition = nn.Parameter(transition)
43 | self.initialize()
44 |
45 | def initialize(self):
46 | self.transition.data[:, self.end] = -100.0
47 | self.transition.data[self.start, :] = -100.0
48 |
49 | @staticmethod
50 | def pad_logits(logits):
51 | # lens = lens.data
52 | batch_size, seq_len, label_num = logits.size()
53 | # pads = Variable(logits.data.new(batch_size, seq_len, 2).fill_(-1000.0),
54 | # requires_grad=False)
55 | pads = logits.new_full((batch_size, seq_len, 2), -1000.0,
56 | requires_grad=False)
57 | logits = torch.cat([logits, pads], dim=2)
58 | return logits
59 |
60 | def calc_binary_score(self, labels, lens):
61 | batch_size, seq_len = labels.size()
62 |
63 | # labels_ext = Variable(labels.data.new(batch_size, seq_len + 2))
64 | labels_ext = labels.new_empty((batch_size, seq_len + 2))
65 | labels_ext[:, 0] = self.start
66 | labels_ext[:, 1:-1] = labels
67 | mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long()
68 | # pad_stop = Variable(labels.data.new(1).fill_(self.end))
69 | pad_stop = labels.new_full((1,), self.end, requires_grad=False)
70 | pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
71 | labels_ext = (1 - mask) * pad_stop + mask * labels_ext
72 | labels = labels_ext
73 |
74 | trn = self.transition
75 | trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size())
76 | lbl_r = labels[:, 1:]
77 | lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0))
78 | trn_row = torch.gather(trn_exp, 1, lbl_rexp)
79 |
80 | lbl_lexp = labels[:, :-1].unsqueeze(-1)
81 | trn_scr = torch.gather(trn_row, 2, lbl_lexp)
82 | trn_scr = trn_scr.squeeze(-1)
83 |
84 | mask = sequence_mask(lens + 1).float()
85 | trn_scr = trn_scr * mask
86 | score = trn_scr
87 |
88 | return score
89 |
90 | @staticmethod
91 | def calc_unary_score(logits, labels, lens):
92 | labels_exp = labels.unsqueeze(-1)
93 | scores = torch.gather(logits, 2, labels_exp).squeeze(-1)
94 | mask = sequence_mask(lens).float()
95 | scores = scores * mask
96 | return scores
97 |
98 | def calc_gold_score(self, logits, labels, lens):
99 | unary_score = self.calc_unary_score(logits, labels, lens).sum(
100 | 1).squeeze(-1)
101 | binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1)
102 | return unary_score + binary_score
103 |
104 | def calc_norm_score(self, logits, lens):
105 | batch_size, seq_len, feat_dim = logits.size()
106 | # alpha = logits.data.new(batch_size, self.label_size).fill_(-10000.0)
107 | alpha = logits.new_full((batch_size, self.label_size), -100.0)
108 | alpha[:, self.start] = 0
109 | # alpha = Variable(alpha)
110 | lens_ = lens.clone()
111 |
112 | logits_t = logits.transpose(1, 0)
113 | for logit in logits_t:
114 | logit_exp = logit.unsqueeze(-1).expand(batch_size,
115 | *self.transition.size())
116 | alpha_exp = alpha.unsqueeze(1).expand(batch_size,
117 | *self.transition.size())
118 | trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp)
119 | mat = logit_exp + alpha_exp + trans_exp
120 | alpha_nxt = log_sum_exp(mat, 2).squeeze(-1)
121 |
122 | mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha)
123 | alpha = mask * alpha_nxt + (1 - mask) * alpha
124 | lens_ = lens_ - 1
125 |
126 | alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha)
127 | norm = log_sum_exp(alpha, 1).squeeze(-1)
128 |
129 | return norm
130 |
131 | def viterbi_decode(self, logits, lens):
132 | """Borrowed from pytorch tutorial
133 | Arguments:
134 | logits: [batch_size, seq_len, n_labels] FloatTensor
135 | lens: [batch_size] LongTensor
136 | """
137 | batch_size, seq_len, n_labels = logits.size()
138 | # vit = logits.data.new(batch_size, self.label_size).fill_(-10000)
139 | vit = logits.new_full((batch_size, self.label_size), -100.0)
140 | vit[:, self.start] = 0
141 | # vit = Variable(vit)
142 | c_lens = lens.clone()
143 |
144 | logits_t = logits.transpose(1, 0)
145 | pointers = []
146 | for logit in logits_t:
147 | vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels)
148 | trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp)
149 | vit_trn_sum = vit_exp + trn_exp
150 | vt_max, vt_argmax = vit_trn_sum.max(2)
151 |
152 | vt_max = vt_max.squeeze(-1)
153 | vit_nxt = vt_max + logit
154 | pointers.append(vt_argmax.squeeze(-1).unsqueeze(0))
155 |
156 | mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt)
157 | vit = mask * vit_nxt + (1 - mask) * vit
158 |
159 | mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt)
160 | vit += mask * self.transition[self.end].unsqueeze(
161 | 0).expand_as(vit_nxt)
162 |
163 | c_lens = c_lens - 1
164 |
165 | pointers = torch.cat(pointers)
166 | scores, idx = vit.max(1)
167 | # idx = idx.squeeze(-1)
168 | paths = [idx.unsqueeze(1)]
169 | for argmax in reversed(pointers):
170 | idx_exp = idx.unsqueeze(-1)
171 | idx = torch.gather(argmax, 1, idx_exp)
172 | idx = idx.squeeze(-1)
173 |
174 | paths.insert(0, idx.unsqueeze(1))
175 |
176 | paths = torch.cat(paths[1:], 1)
177 | scores = scores.squeeze(-1)
178 |
179 | return scores, paths
180 |
--------------------------------------------------------------------------------
/modules/layers/embedders.py:
--------------------------------------------------------------------------------
1 | from pytorch_pretrained_bert import BertModel
2 | import torch
3 |
4 |
5 | class BERTEmbedder(torch.nn.Module):
6 | def __init__(self, model, config):
7 | super(BERTEmbedder, self).__init__()
8 | self.config = config
9 | self.model = model
10 | if self.config["mode"] == "weighted":
11 | self.bert_weights = torch.nn.Parameter(torch.FloatTensor(12, 1))
12 | self.bert_gamma = torch.nn.Parameter(torch.FloatTensor(1, 1))
13 | self.init_weights()
14 |
15 | def init_weights(self):
16 | if self.config["mode"] == "weighted":
17 | torch.nn.init.xavier_normal(self.bert_gamma)
18 | torch.nn.init.xavier_normal(self.bert_weights)
19 |
20 | @classmethod
21 | def create(
22 | cls, model_name='bert-base-multilingual-cased',
23 | device="cuda", mode="weighted",
24 | is_freeze=True):
25 | config = {
26 | "model_name": model_name,
27 | "device": device,
28 | "mode": mode,
29 | "is_freeze": is_freeze
30 | }
31 | model = BertModel.from_pretrained(model_name)
32 | model.to(device)
33 | model.train()
34 | self = cls(model, config)
35 | if is_freeze:
36 | self.freeze()
37 | return self
38 |
39 | @classmethod
40 | def from_config(cls, config):
41 | return cls.create(**config)
42 |
43 | def forward(self, batch):
44 | """
45 | batch has the following structure:
46 | data[0]: list, tokens ids
47 | data[1]: list, tokens mask
48 | data[2]: list, tokens type ids (for bert)
49 | data[3]: list, bert labels ids
50 | """
51 | encoded_layers, _ = self.model(
52 | input_ids=batch[0],
53 | token_type_ids=batch[2],
54 | attention_mask=batch[1],
55 | output_all_encoded_layers=self.config["mode"] == "weighted")
56 | if self.config["mode"] == "weighted":
57 | encoded_layers = torch.stack([a * b for a, b in zip(encoded_layers, self.bert_weights)])
58 | return self.bert_gamma * torch.sum(encoded_layers, dim=0)
59 | return encoded_layers
60 |
61 | def freeze(self):
62 | for param in self.model.parameters():
63 | param.requires_grad = False
64 |
--------------------------------------------------------------------------------
/modules/layers/layers.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import init
6 | from torch.nn.utils import rnn as rnn_utils
7 | import math
8 |
9 |
10 | class BiLSTM(nn.Module):
11 |
12 | def __init__(self, embedding_size=768, hidden_dim=512, rnn_layers=1, dropout=0.5):
13 | super(BiLSTM, self).__init__()
14 | self.embedding_size = embedding_size
15 | self.hidden_dim = hidden_dim
16 | self.rnn_layers = rnn_layers
17 | self.dropout = nn.Dropout(dropout)
18 | self.lstm = nn.LSTM(
19 | embedding_size,
20 | hidden_dim // 2,
21 | rnn_layers, batch_first=True, bidirectional=True)
22 |
23 | def forward(self, input_, input_mask):
24 | length = input_mask.sum(-1)
25 | sorted_lengths, sorted_idx = torch.sort(length, descending=True)
26 | input_ = input_[sorted_idx]
27 | packed_input = rnn_utils.pack_padded_sequence(input_, sorted_lengths.data.tolist(), batch_first=True)
28 | output, (hidden, _) = self.lstm(packed_input)
29 | padded_outputs = rnn_utils.pad_packed_sequence(output, batch_first=True)[0]
30 | _, reversed_idx = torch.sort(sorted_idx)
31 | return padded_outputs[reversed_idx], hidden[:, reversed_idx]
32 |
33 | @classmethod
34 | def create(cls, *args, **kwargs):
35 | return cls(*args, **kwargs)
36 |
37 |
38 | class Linear(nn.Linear):
39 | def __init__(self,
40 | in_features: int,
41 | out_features: int,
42 | bias: bool = True):
43 | super(Linear, self).__init__(in_features, out_features, bias=bias)
44 | init.orthogonal_(self.weight)
45 |
46 |
47 | class Linears(nn.Module):
48 | def __init__(self,
49 | in_features,
50 | out_features,
51 | hiddens,
52 | bias=True,
53 | activation='tanh'):
54 | super(Linears, self).__init__()
55 | assert len(hiddens) > 0
56 |
57 | self.in_features = in_features
58 | self.out_features = self.output_size = out_features
59 |
60 | in_dims = [in_features] + hiddens[:-1]
61 | self.linears = nn.ModuleList([Linear(in_dim, out_dim, bias=bias)
62 | for in_dim, out_dim
63 | in zip(in_dims, hiddens)])
64 | self.output_linear = Linear(hiddens[-1], out_features, bias=bias)
65 | self.activation = getattr(functional, activation)
66 |
67 | def forward(self, inputs):
68 | linear_outputs = inputs
69 | for linear in self.linears:
70 | linear_outputs = linear.forward(linear_outputs)
71 | linear_outputs = self.activation(linear_outputs)
72 | return self.output_linear.forward(linear_outputs)
73 |
74 |
75 | # Reused from https://github.com/JayParks/transformer/
76 | class ScaledDotProductAttention(nn.Module):
77 | def __init__(self, d_k, dropout=.1):
78 | super(ScaledDotProductAttention, self).__init__()
79 | self.scale_factor = np.sqrt(d_k)
80 | self.softmax = nn.Softmax(dim=-1)
81 | self.dropout = nn.Dropout(dropout)
82 |
83 | def forward(self, q, k, v, attn_mask=None):
84 | # q: [b_size x len_q x d_k]
85 | # k: [b_size x len_k x d_k]
86 | # v: [b_size x len_v x d_v] note: (len_k == len_v)
87 | attn = torch.bmm(q, k.transpose(1, 2)) / self.scale_factor # attn: [b_size x len_q x len_k]
88 | if attn_mask is not None:
89 | print(attn_mask.size(), attn.size())
90 | assert attn_mask.size() == attn.size()
91 | attn.data.masked_fill_(attn_mask, -float('inf'))
92 |
93 | attn = self.softmax(attn)
94 | attn = self.dropout(attn)
95 | outputs = torch.bmm(attn, v) # outputs: [b_size x len_q x d_v]
96 |
97 | return outputs, attn
98 |
99 |
100 | class LayerNormalization(nn.Module):
101 | def __init__(self, d_hid, eps=1e-3):
102 | super(LayerNormalization, self).__init__()
103 | self.gamma = nn.Parameter(torch.ones(d_hid), requires_grad=True)
104 | self.beta = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
105 | self.eps = eps
106 |
107 | def forward(self, z):
108 | mean = z.mean(dim=-1, keepdim=True,)
109 | std = z.std(dim=-1, keepdim=True,)
110 | ln_out = (z - mean.expand_as(z)) / (std.expand_as(z) + self.eps)
111 | ln_out = self.gamma.expand_as(ln_out) * ln_out + self.beta.expand_as(ln_out)
112 |
113 | return ln_out
114 |
115 |
116 | class _MultiHeadAttention(nn.Module):
117 | def __init__(self, d_k, d_v, d_model, n_heads, dropout):
118 | super(_MultiHeadAttention, self).__init__()
119 | self.d_k = d_k
120 | self.d_v = d_v
121 | self.d_model = d_model
122 | self.n_heads = n_heads
123 | self.w_q = nn.Parameter(torch.FloatTensor(n_heads, d_model, d_k))
124 | self.w_k = nn.Parameter(torch.FloatTensor(n_heads, d_model, d_k))
125 | self.w_v = nn.Parameter(torch.FloatTensor(n_heads, d_model, d_v))
126 |
127 | self.attention = ScaledDotProductAttention(d_k, dropout)
128 |
129 | init.xavier_normal(self.w_q)
130 | init.xavier_normal(self.w_k)
131 | init.xavier_normal(self.w_v)
132 |
133 | def forward(self, q, k, v, attn_mask=None):
134 | (d_k, d_v, d_model, n_heads) = (self.d_k, self.d_v, self.d_model, self.n_heads)
135 | b_size = k.size(0)
136 |
137 | q_s = q.repeat(n_heads, 1, 1).view(n_heads, -1, d_model) # [n_heads x b_size * len_q x d_model]
138 | k_s = k.repeat(n_heads, 1, 1).view(n_heads, -1, d_model) # [n_heads x b_size * len_k x d_model]
139 | v_s = v.repeat(n_heads, 1, 1).view(n_heads, -1, d_model) # [n_heads x b_size * len_v x d_model]
140 |
141 | q_s = torch.bmm(q_s, self.w_q).view(b_size * n_heads, -1, d_k) # [b_size * n_heads x len_q x d_k]
142 | k_s = torch.bmm(k_s, self.w_k).view(b_size * n_heads, -1, d_k) # [b_size * n_heads x len_k x d_k]
143 | v_s = torch.bmm(v_s, self.w_v).view(b_size * n_heads, -1, d_v) # [b_size * n_heads x len_v x d_v]
144 |
145 | # perform attention, result_size = [b_size * n_heads x len_q x d_v]
146 | if attn_mask is not None:
147 | attn_mask = attn_mask.repeat(n_heads, 1, 1)
148 | outputs, attn = self.attention(q_s, k_s, v_s, attn_mask=attn_mask)
149 |
150 | # return a list of tensors of shape [b_size x len_q x d_v] (length: n_heads)
151 | return torch.split(outputs, b_size, dim=0), attn
152 |
153 |
154 | class MultiHeadAttention(nn.Module):
155 | def __init__(self, d_k, d_v, d_model, n_heads, dropout):
156 | super(MultiHeadAttention, self).__init__()
157 | self.attention = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout)
158 | self.proj = Linear(n_heads * d_v, d_model)
159 | self.dropout = nn.Dropout(dropout)
160 | self.layer_norm = LayerNormalization(d_model)
161 |
162 | def forward(self, q, k, v, attn_mask):
163 | # q: [b_size x len_q x d_model]
164 | # k: [b_size x len_k x d_model]
165 | # v: [b_size x len_v x d_model] note (len_k == len_v)
166 | residual = q
167 | # outputs: a list of tensors of shape [b_size x len_q x d_v] (length: n_heads)
168 | outputs, attn = self.attention(q, k, v, attn_mask=attn_mask)
169 | # concatenate 'n_heads' multi-head attentions
170 | outputs = torch.cat(outputs, dim=-1)
171 | # project back to residual size, result_size = [b_size x len_q x d_model]
172 | outputs = self.proj(outputs)
173 | outputs = self.dropout(outputs)
174 |
175 | return self.layer_norm(residual + outputs), attn
176 |
177 |
178 | class _BahdanauAttention(nn.Module):
179 | def __init__(self, method, hidden_size):
180 | super(_BahdanauAttention, self).__init__()
181 | self.method = method
182 | self.hidden_size = hidden_size
183 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
184 | self.v = nn.Parameter(torch.rand(hidden_size))
185 | stdv = 1. / math.sqrt(self.v.size(0))
186 | self.v.data.normal_(mean=0, std=stdv)
187 |
188 | def forward(self, hidden, encoder_outputs, mask=None):
189 | """
190 | :param hidden:
191 | previous hidden state of the decoder, in shape (layers*directions,B,H)
192 | :param encoder_outputs:
193 | encoder outputs from Encoder, in shape (T,B,H)
194 | :param mask:
195 | used for masking. NoneType or tensor in shape (B) indicating sequence length
196 | :return
197 | attention energies in shape (B,T)
198 | """
199 | max_len = encoder_outputs.size(0)
200 | # this_batch_size = encoder_outputs.size(1)
201 | H = hidden.repeat(max_len, 1, 1).transpose(0, 1)
202 | # [B*T*H]
203 | encoder_outputs = encoder_outputs.transpose(0, 1)
204 | # compute attention score
205 | attn_energies = self.score(H, encoder_outputs)
206 | if mask is not None:
207 | attn_energies = attn_energies.masked_fill(mask, -1e18)
208 | # normalize with softmax
209 | return functional.softmax(attn_energies).unsqueeze(1)
210 |
211 | def score(self, hidden, encoder_outputs):
212 | # [B*T*2H]->[B*T*H]
213 | energy = functional.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2)))
214 | # [B*H*T]
215 | energy = energy.transpose(2, 1)
216 | # [B*1*H]
217 | v = self.v.repeat(encoder_outputs.data.shape[0], 1).unsqueeze(1)
218 | # [B*1*T]
219 | energy = torch.bmm(v, energy)
220 | # [B*T]
221 | return energy.squeeze(1)
222 |
223 |
224 | class BahdanauAttention(nn.Module):
225 | """Reused from https://github.com/chrisbangun/pytorch-seq2seq_with_attention/"""
226 |
227 | def __init__(self, hidden_dim=128, query_dim=128, memory_dim=128):
228 | super(BahdanauAttention, self).__init__()
229 |
230 | self.hidden_dim = hidden_dim
231 | self.query_dim = query_dim
232 | self.memory_dim = memory_dim
233 | self.sofmax = nn.Softmax()
234 |
235 | self.query_layer = nn.Linear(query_dim, hidden_dim, bias=False)
236 | self.memory_layer = nn.Linear(memory_dim, hidden_dim, bias=False)
237 | self.alignment_layer = nn.Linear(hidden_dim, 1, bias=False)
238 |
239 | def alignment_score(self, query, keys):
240 | query = self.query_layer(query)
241 | keys = self.memory_layer(keys)
242 |
243 | extendded_query = query.unsqueeze(1)
244 | alignment = self.alignment_layer(functional.tanh(extendded_query + keys))
245 | return alignment.squeeze(2)
246 |
247 | def forward(self, query, keys):
248 | alignment_score = self.alignment_score(query, keys)
249 | weight = functional.softmax(alignment_score)
250 | context = weight.unsqueeze(2) * keys
251 | total_context = context.sum(1)
252 | return total_context, alignment_score
253 |
--------------------------------------------------------------------------------
/modules/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai-forever/ner-bert/b75a903c35acbd36ff5f26c525e3596294f36815/modules/models/__init__.py
--------------------------------------------------------------------------------
/modules/models/classifiers.py:
--------------------------------------------------------------------------------
1 | from .bert_models import BERTNerModel
2 | from modules.layers.decoders import *
3 | from modules.layers.embedders import *
4 | from modules.layers.layers import *
5 |
6 |
7 | class BERTLinearsClassifier(BERTNerModel):
8 |
9 | def __init__(self, embeddings, linear, dropout, activation, device="cuda"):
10 | super(BERTLinearsClassifier, self).__init__()
11 | self.embeddings = embeddings
12 | self.linear = linear
13 | self.dropout = dropout
14 | self.activation = activation
15 | self.intent_loss = nn.CrossEntropyLoss()
16 | self.to(device)
17 |
18 | @staticmethod
19 | def pool(x, bs, is_max):
20 | """Pool the tensor along the seq_len dimension."""
21 | f = functional.adaptive_max_pool1d if is_max else functional.adaptive_avg_pool1d
22 | return f(x.permute(1, 2, 0), (1,)).view(bs, -1)
23 |
24 | def forward(self, batch):
25 | input_embeddings = self.embeddings(batch)
26 | output = self.dropout(input_embeddings).transpose(0, 1)
27 | sl, bs, _ = output.size()
28 | output = self.pool(output, bs, True)
29 | output = self.linear(output)
30 | return self.activation(output).argmax(-1)
31 |
32 | def score(self, batch):
33 | input_embeddings = self.embeddings(batch)
34 | output = self.dropout(input_embeddings).transpose(0, 1)
35 | sl, bs, _ = output.size()
36 | output = self.pool(output, bs, True)
37 | output = self.linear(output)
38 | return self.intent_loss(self.activation(output), batch[-1])
39 |
40 | @classmethod
41 | def create(cls,
42 | intent_size,
43 | # BertEmbedder params
44 | model_name='bert-base-multilingual-cased', mode="weighted", is_freeze=True,
45 | # Decoder params
46 | embedding_size=768, clf_dropout=0.3, num_hiddens=2,
47 | activation="tanh",
48 | # Global params
49 | device="cuda"):
50 | embeddings = BERTEmbedder.create(model_name=model_name, device=device, mode=mode, is_freeze=is_freeze)
51 | linear = Linears(embedding_size, intent_size, [embedding_size // 2**idx for idx in range(num_hiddens)])
52 | dropout = nn.Dropout(clf_dropout)
53 | activation = getattr(functional, activation)
54 | return cls(embeddings, linear, dropout, activation, device)
55 |
56 |
57 | class BERTLinearClassifier(BERTNerModel):
58 |
59 | def __init__(self, embeddings, linear, dropout, activation, device="cuda"):
60 | super(BERTLinearClassifier, self).__init__()
61 | self.embeddings = embeddings
62 | self.linear = linear
63 | self.dropout = dropout
64 | self.activation = activation
65 | self.intent_loss = nn.CrossEntropyLoss()
66 | self.to(device)
67 |
68 | @staticmethod
69 | def pool(x, bs, is_max):
70 | """Pool the tensor along the seq_len dimension."""
71 | f = functional.adaptive_max_pool1d if is_max else functional.adaptive_avg_pool1d
72 | return f(x.permute(1, 2, 0), (1,)).view(bs, -1)
73 |
74 | def forward(self, batch):
75 | input_embeddings = self.embeddings(batch)
76 | output = self.dropout(input_embeddings).transpose(0, 1)
77 | sl, bs, _ = output.size()
78 | output = self.pool(output, bs, True)
79 | output = self.linear(output)
80 | return self.activation(output).argmax(-1)
81 |
82 | def score(self, batch):
83 | input_embeddings = self.embeddings(batch)
84 | output = self.dropout(input_embeddings).transpose(0, 1)
85 | sl, bs, _ = output.size()
86 | output = self.pool(output, bs, True)
87 | output = self.linear(output)
88 | return self.intent_loss(self.activation(output), batch[-1])
89 |
90 | @classmethod
91 | def create(cls,
92 | intent_size,
93 | # BertEmbedder params
94 | model_name='bert-base-multilingual-cased', mode="weighted", is_freeze=True,
95 | # Decoder params
96 | embedding_size=768, clf_dropout=0.3,
97 | activation="sigmoid",
98 | # Global params
99 | device="cuda"):
100 | embeddings = BERTEmbedder.create(model_name=model_name, device=device, mode=mode, is_freeze=is_freeze)
101 | linear = Linear(embedding_size, intent_size)
102 | dropout = nn.Dropout(clf_dropout)
103 | activation = getattr(functional, activation)
104 | return cls(embeddings, linear, dropout, activation, device)
105 |
106 |
107 | class BERTBaseClassifier(BERTNerModel):
108 |
109 | def __init__(self, embeddings, clf, device="cuda"):
110 | super(BERTBaseClassifier, self).__init__()
111 | self.embeddings = embeddings
112 | self.clf = clf
113 | self.to(device)
114 |
115 | def forward(self, batch):
116 | input_embeddings = self.embeddings(batch)
117 | return self.clf(input_embeddings)
118 |
119 | def score(self, batch):
120 | input_, labels_mask, input_type_ids, cls_ids = batch
121 | input_embeddings = self.embeddings(batch)
122 | return self.clf.score(input_embeddings, cls_ids)
123 |
124 | @classmethod
125 | def create(cls,
126 | intent_size,
127 | # BertEmbedder params
128 | model_name='bert-base-multilingual-cased', mode="weighted", is_freeze=True,
129 | # Decoder params
130 | embedding_size=768, clf_dropout=0.3,
131 | # Global params
132 | device="cuda"):
133 | embeddings = BERTEmbedder.create(model_name=model_name, device=device, mode=mode, is_freeze=is_freeze)
134 | clf = ClassDecoder(intent_size, embedding_size, clf_dropout)
135 | return cls(embeddings, clf, device)
136 |
137 |
138 | class BERTBiLSTMAttnClassifier(BERTNerModel):
139 |
140 | def __init__(self, embeddings, lstm, attn, clf, device="cuda"):
141 | super(BERTBiLSTMAttnClassifier, self).__init__()
142 | self.embeddings = embeddings
143 | self.lstm = lstm
144 | self.attn = attn
145 | self.clf = clf
146 | self.to(device)
147 |
148 | def forward(self, batch):
149 | input_, labels_mask, input_type_ids = batch[:3]
150 | input_embeddings = self.embeddings(batch)
151 | output, _ = self.lstm.forward(input_embeddings, labels_mask)
152 | output, _ = self.attn(output, output, output, None)
153 | return self.clf(output)
154 |
155 | def score(self, batch):
156 | input_, labels_mask, input_type_ids = batch[:3]
157 | input_embeddings = self.embeddings(batch)
158 | output, _ = self.lstm.forward(input_embeddings, labels_mask)
159 | output, _ = self.attn(output, output, output, None)
160 | return self.clf.score(output, batch[-1])
161 |
162 | @classmethod
163 | def create(cls,
164 | intent_size,
165 | # BertEmbedder params
166 | model_name='bert-base-multilingual-cased', mode="weighted", is_freeze=True,
167 | # Decoder params
168 | clf_dropout=0.3,
169 | # BiLSTM
170 | hidden_dim=512, rnn_layers=1, lstm_dropout=0.3,
171 | # Attn params
172 | embedding_size=768, key_dim=64, val_dim=64, num_heads=3, attn_dropout=0.3,
173 | # Global params
174 | device="cuda"):
175 | embeddings = BERTEmbedder.create(model_name=model_name, device=device, mode=mode, is_freeze=is_freeze)
176 | lstm = BiLSTM.create(
177 | embedding_size=embedding_size, hidden_dim=hidden_dim, rnn_layers=rnn_layers, dropout=lstm_dropout)
178 | attn = MultiHeadAttention(key_dim, val_dim, hidden_dim, num_heads, attn_dropout)
179 | clf = ClassDecoder(intent_size, hidden_dim, clf_dropout)
180 | return cls(embeddings, lstm, attn, clf, device)
181 |
--------------------------------------------------------------------------------
/modules/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai-forever/ner-bert/b75a903c35acbd36ff5f26c525e3596294f36815/modules/train/__init__.py
--------------------------------------------------------------------------------
/modules/train/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/optimization.py
3 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch optimization for BERT model."""
17 |
18 | import math
19 | import torch
20 | from torch.optim import Optimizer
21 | from torch.optim.optimizer import required
22 | from torch.nn.utils import clip_grad_norm_
23 |
24 |
25 | def warmup_cosine(x, warmup=0.002):
26 | if x < warmup:
27 | return x/warmup
28 | return 0.5 * (1.0 + torch.cos(math.pi * x))
29 |
30 |
31 | def warmup_constant(x, warmup=0.002):
32 | if x < warmup:
33 | return x/warmup
34 | return 1.0
35 |
36 |
37 | def warmup_linear(x, warmup=0.002):
38 | if x < warmup:
39 | return x/warmup
40 | return 1.0 - x
41 |
42 |
43 | SCHEDULES = {
44 | 'warmup_cosine': warmup_cosine,
45 | 'warmup_constant': warmup_constant,
46 | 'warmup_linear': warmup_linear,
47 | }
48 |
49 |
50 | class BertAdam(Optimizer):
51 | """Implements BERT version of Adam algorithm with weight decay fix.
52 | Params:
53 | lr: learning rate
54 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
55 | t_total: total number of training steps for the learning
56 | rate schedule, -1 means constant learning rate. Default: -1
57 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
58 | b1: Adams b1. Default: 0.9
59 | b2: Adams b2. Default: 0.999
60 | e: Adams epsilon. Default: 1e-6
61 | weight_decay: Weight decay. Default: 0.01
62 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
63 | """
64 | def __init__(self, model, lr=required, warmup=0.1, t_total=-1, schedule='warmup_linear',
65 | b1=0.8, b2=0.999, e=1e-6, weight_decay=0.01,
66 | max_grad_norm=1.0):
67 | if lr is not required and lr < 0.0:
68 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
69 | if schedule not in SCHEDULES:
70 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
71 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
72 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
73 | if not 0.0 <= b1 < 1.0:
74 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
75 | if not 0.0 <= b2 < 1.0:
76 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
77 | if not e >= 0.0:
78 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
79 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
80 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
81 | max_grad_norm=max_grad_norm)
82 | # Prepare optimizer
83 | param_optimizer = list(model.named_parameters())
84 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
85 | optimizer_grouped_parameters = [
86 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
87 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
88 | ]
89 | super(BertAdam, self).__init__(optimizer_grouped_parameters, defaults)
90 | self.global_step = 1
91 | self.t_total = t_total
92 |
93 | def get_lr(self):
94 | lr = []
95 | for group in self.param_groups:
96 | for p in group['params']:
97 | state = self.state[p]
98 | if len(state) == 0:
99 | return [0]
100 | if group['t_total'] != -1:
101 | schedule_fct = SCHEDULES[group['schedule']]
102 | lr_scheduled = group['lr'] * schedule_fct(state['step'] / group['t_total'], group['warmup'])
103 | else:
104 | lr_scheduled = group['lr']
105 | lr.append(lr_scheduled)
106 | return lr
107 |
108 | def update_lr(self):
109 | if 0 < self.t_total:
110 | lr_this_step = self.defaults["lr"] * warmup_linear(self.global_step / self.t_total, self.defaults["warmup"])
111 | for param_group in self.param_groups:
112 | param_group['lr'] = lr_this_step
113 |
114 | def step(self, closure=None):
115 | """Performs a single optimization step.
116 | Arguments:
117 | closure (callable, optional): A closure that reevaluates the model
118 | and returns the loss.
119 | """
120 | self.update_lr()
121 | loss = None
122 | if closure is not None:
123 | loss = closure()
124 |
125 | for group in self.param_groups:
126 | for p in group['params']:
127 | if p.grad is None:
128 | continue
129 | grad = p.grad.data
130 | if grad.is_sparse:
131 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
132 |
133 | state = self.state[p]
134 |
135 | # State initialization
136 | if len(state) == 0:
137 | state['step'] = 0
138 | # Exponential moving average of gradient values
139 | state['next_m'] = torch.zeros_like(p.data)
140 | # Exponential moving average of squared gradient values
141 | state['next_v'] = torch.zeros_like(p.data)
142 |
143 | next_m, next_v = state['next_m'], state['next_v']
144 | beta1, beta2 = group['b1'], group['b2']
145 |
146 | # Add grad clipping
147 | if group['max_grad_norm'] > 0:
148 | clip_grad_norm_(p, group['max_grad_norm'])
149 |
150 | # Decay the first and second moment running average coefficient
151 | # In-place operations to update the averages at the same time
152 | next_m.mul_(beta1).add_(1 - beta1, grad)
153 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
154 | update = next_m / (next_v.sqrt() + group['e'])
155 |
156 | # Just adding the square of the weights to the loss function is *not*
157 | # the correct way of using L2 regularization/weight decay with Adam,
158 | # since that will interact with the m and v parameters in strange ways.
159 | #
160 | # Instead we want to decay the weights in a manner that doesn't interact
161 | # with the m/v parameters. This is equivalent to adding the square
162 | # of the weights to the loss with plain (non-momentum) SGD.
163 | if group['weight_decay'] > 0.0:
164 | update += group['weight_decay'] * p.data
165 |
166 | if group['t_total'] != -1:
167 | schedule_fct = SCHEDULES[group['schedule']]
168 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
169 | else:
170 | lr_scheduled = group['lr']
171 |
172 | update_with_lr = lr_scheduled * update
173 | p.data.add_(-update_with_lr)
174 |
175 | state['step'] += 1
176 |
177 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
178 | # No bias correction
179 | # bias_correction1 = 1 - beta1 ** state['step']
180 | # bias_correction2 = 1 - beta2 ** state['step']
181 | self.global_step += 1
182 | return loss
183 |
--------------------------------------------------------------------------------
/modules/train/train.py:
--------------------------------------------------------------------------------
1 | from modules import tqdm
2 | from sklearn_crfsuite.metrics import flat_classification_report
3 | import logging
4 | import torch
5 | from .optimization import BertAdam
6 | from modules.analyze_utils.plot_metrics import get_mean_max_metric
7 | from modules.data.bert_data import get_data_loader_for_predict
8 |
9 |
10 | def train_step(dl, model, optimizer, num_epoch=1):
11 | model.train()
12 | epoch_loss = 0
13 | idx = 0
14 | pr = tqdm(dl, total=len(dl), leave=False)
15 | for batch in pr:
16 | idx += 1
17 | model.zero_grad()
18 | loss = model.score(batch)
19 | loss.backward()
20 | optimizer.step()
21 | optimizer.zero_grad()
22 | loss = loss.data.cpu().tolist()
23 | epoch_loss += loss
24 | pr.set_description("train loss: {}".format(epoch_loss / idx))
25 | torch.cuda.empty_cache()
26 | logging.info("\nepoch {}, average train epoch loss={:.5}\n".format(
27 | num_epoch, epoch_loss / idx))
28 |
29 |
30 | def transformed_result(preds, mask, id2label, target_all=None, pad_idx=0):
31 | preds_cpu = []
32 | targets_cpu = []
33 | lc = len(id2label)
34 | if target_all is not None:
35 | for batch_p, batch_t, batch_m in zip(preds, target_all, mask):
36 | for pred, true_, bm in zip(batch_p, batch_t, batch_m):
37 | sent = []
38 | sent_t = []
39 | bm = bm.sum().cpu().data.tolist()
40 | for p, t in zip(pred[:bm], true_[:bm]):
41 | p = p.cpu().data.tolist()
42 | p = p if p < lc else pad_idx
43 | sent.append(p)
44 | sent_t.append(t.cpu().data.tolist())
45 | preds_cpu.append([id2label[w] for w in sent])
46 | targets_cpu.append([id2label[w] for w in sent_t])
47 | else:
48 | for batch_p, batch_m in zip(preds, mask):
49 |
50 | for pred, bm in zip(batch_p, batch_m):
51 | assert len(pred) == len(bm)
52 | bm = bm.sum().cpu().data.tolist()
53 | sent = pred[:bm].cpu().data.tolist()
54 | preds_cpu.append([id2label[w] for w in sent])
55 | if target_all is not None:
56 | return preds_cpu, targets_cpu
57 | else:
58 | return preds_cpu
59 |
60 |
61 | def transformed_result_cls(preds, target_all, cls2label, return_target=True):
62 | preds_cpu = []
63 | targets_cpu = []
64 | for batch_p, batch_t in zip(preds, target_all):
65 | for pred, true_ in zip(batch_p, batch_t):
66 | preds_cpu.append(cls2label[pred.cpu().data.tolist()])
67 | if return_target:
68 | targets_cpu.append(cls2label[true_.cpu().data.tolist()])
69 | if return_target:
70 | return preds_cpu, targets_cpu
71 | return preds_cpu
72 |
73 |
74 | def validate_step(dl, model, id2label, sup_labels, id2cls=None):
75 | model.eval()
76 | idx = 0
77 | preds_cpu, targets_cpu = [], []
78 | preds_cpu_cls, targets_cpu_cls = [], []
79 | for batch in tqdm(dl, total=len(dl), leave=False):
80 | idx += 1
81 | labels_mask, labels_ids = batch[1], batch[3]
82 | preds = model.forward(batch)
83 | if id2cls is not None:
84 | preds, preds_cls = preds
85 | preds_cpu_, targets_cpu_ = transformed_result_cls([preds_cls], [batch[-1]], id2cls)
86 | preds_cpu_cls.extend(preds_cpu_)
87 | targets_cpu_cls.extend(targets_cpu_)
88 | preds_cpu_, targets_cpu_ = transformed_result([preds], [labels_mask], id2label, [labels_ids])
89 | preds_cpu.extend(preds_cpu_)
90 | targets_cpu.extend(targets_cpu_)
91 | clf_report = flat_classification_report(targets_cpu, preds_cpu, labels=sup_labels, digits=3)
92 | if id2cls is not None:
93 | clf_report_cls = flat_classification_report([targets_cpu_cls], [preds_cpu_cls], digits=3)
94 | return clf_report, clf_report_cls
95 | return clf_report
96 |
97 |
98 | def predict(dl, model, id2label, id2cls=None):
99 | model.eval()
100 | idx = 0
101 | preds_cpu = []
102 | preds_cpu_cls = []
103 | for batch in tqdm(dl, total=len(dl), leave=False, desc="Predicting"):
104 | idx += 1
105 | labels_mask, labels_ids = batch[1], batch[3]
106 | preds = model.forward(batch)
107 | if id2cls is not None:
108 | preds, preds_cls = preds
109 | preds_cpu_ = transformed_result_cls([preds_cls], [preds_cls], id2cls, False)
110 | preds_cpu_cls.extend(preds_cpu_)
111 |
112 | preds_cpu_ = transformed_result([preds], [labels_mask], id2label)
113 | preds_cpu.extend(preds_cpu_)
114 | if id2cls is not None:
115 | return preds_cpu, preds_cpu_cls
116 | return preds_cpu
117 |
118 |
119 | class NerLearner(object):
120 |
121 | def __init__(self, model, data, best_model_path, lr=0.001, betas=[0.8, 0.9], clip=1.0,
122 | verbose=True, sup_labels=None, t_total=-1, warmup=0.1, weight_decay=0.01,
123 | validate_every=1, schedule="warmup_linear", e=1e-6):
124 | logging.basicConfig(level=logging.INFO)
125 | self.model = model
126 | self.optimizer = BertAdam(model, lr, t_total=t_total, b1=betas[0], b2=betas[1], max_grad_norm=clip)
127 | self.optimizer_defaults = dict(
128 | model=model, lr=lr, warmup=warmup, t_total=t_total, schedule=schedule,
129 | b1=betas[0], b2=betas[1], e=e, weight_decay=weight_decay,
130 | max_grad_norm=clip)
131 |
132 | self.lr = lr
133 | self.betas = betas
134 | self.clip = clip
135 | self.sup_labels = sup_labels
136 | self.t_total = t_total
137 | self.warmup = warmup
138 | self.weight_decay = weight_decay
139 | self.validate_every = validate_every
140 | self.schedule = schedule
141 | self.data = data
142 | self.e = e
143 | if sup_labels is None:
144 | sup_labels = data.train_ds.idx2label[4:]
145 | self.sup_labels = sup_labels
146 | self.best_model_path = best_model_path
147 | self.verbose = verbose
148 | self.history = []
149 | self.cls_history = []
150 | self.epoch = 0
151 | self.best_target_metric = 0.
152 |
153 | def fit(self, epochs=100, resume_history=True, target_metric="f1"):
154 | if not resume_history:
155 | self.optimizer_defaults["t_total"] = epochs * len(self.data.train_dl)
156 | self.optimizer = BertAdam(**self.optimizer_defaults)
157 | self.history = []
158 | self.cls_history = []
159 | self.epoch = 0
160 | self.best_target_metric = 0.
161 | elif self.verbose:
162 | logging.info("Resuming train... Current epoch {}.".format(self.epoch))
163 | try:
164 | for _ in range(epochs):
165 | self.epoch += 1
166 | self.fit_one_cycle(self.epoch, target_metric)
167 | except KeyboardInterrupt:
168 | pass
169 |
170 | def fit_one_cycle(self, epoch, target_metric="f1"):
171 | train_step(self.data.train_dl, self.model, self.optimizer, epoch)
172 | if epoch % self.validate_every == 0:
173 | if self.data.train_ds.is_cls:
174 | rep, rep_cls = validate_step(
175 | self.data.valid_dl, self.model, self.data.train_ds.idx2label, self.sup_labels,
176 | self.data.train_ds.idx2cls)
177 | self.cls_history.append(rep_cls)
178 | else:
179 | rep = validate_step(
180 | self.data.valid_dl, self.model, self.data.train_ds.idx2label, self.sup_labels)
181 | self.history.append(rep)
182 | idx, metric = get_mean_max_metric(self.history, target_metric, True)
183 | if self.verbose:
184 | logging.info("on epoch {} by max_{}: {}".format(idx, target_metric, metric))
185 | print(self.history[-1])
186 | if self.data.train_ds.is_cls:
187 | logging.info("on epoch {} classification report:")
188 | print(self.cls_history[-1])
189 | # Store best model
190 | if self.best_target_metric < metric:
191 | self.best_target_metric = metric
192 | if self.verbose:
193 | logging.info("Saving new best model...")
194 | self.save_model()
195 |
196 | def predict(self, dl=None, df_path=None, df=None):
197 | if dl is None:
198 | dl = get_data_loader_for_predict(self.data, df_path, df)
199 | if self.data.train_ds.is_cls:
200 | return predict(dl, self.model, self.data.train_ds.idx2label, self.data.train_ds.idx2cls)
201 | return predict(dl, self.model, self.data.train_ds.idx2label)
202 |
203 | def save_model(self, path=None):
204 | path = path if path else self.best_model_path
205 | torch.save(self.model.state_dict(), path)
206 |
207 | def load_model(self, path=None):
208 | path = path if path else self.best_model_path
209 | self.model.load_state_dict(torch.load(path))
210 |
--------------------------------------------------------------------------------
/modules/train/train_clf.py:
--------------------------------------------------------------------------------
1 | from modules import tqdm
2 | from sklearn_crfsuite.metrics import flat_classification_report
3 | import logging
4 | import torch
5 | from .optimization import BertAdam
6 | from modules.analyze_utils.plot_metrics import get_mean_max_metric
7 | from modules.data.bert_data_clf import get_data_loader_for_predict
8 |
9 |
10 | def train_step(dl, model, optimizer, num_epoch=1):
11 | model.train()
12 | epoch_loss = 0
13 | idx = 0
14 | pr = tqdm(dl, total=len(dl), leave=False)
15 | for batch in pr:
16 | idx += 1
17 | model.zero_grad()
18 | loss = model.score(batch)
19 | loss.backward()
20 | optimizer.step()
21 | optimizer.zero_grad()
22 | loss = loss.data.cpu().tolist()
23 | epoch_loss += loss
24 | pr.set_description("train loss: {}".format(epoch_loss / idx))
25 | torch.cuda.empty_cache()
26 | logging.info("\nepoch {}, average train epoch loss={:.5}\n".format(
27 | num_epoch, epoch_loss / idx))
28 |
29 |
30 | def transformed_result_cls(preds, target_all, cls2label, return_target=True):
31 | preds_cpu = []
32 | targets_cpu = []
33 | for batch_p, batch_t in zip(preds, target_all):
34 | for pred, true_ in zip(batch_p, batch_t):
35 | preds_cpu.append(cls2label[pred.cpu().data.tolist()])
36 | if return_target:
37 | targets_cpu.append(cls2label[true_.cpu().data.tolist()])
38 | if return_target:
39 | return preds_cpu, targets_cpu
40 | return preds_cpu
41 |
42 |
43 | def validate_step(dl, model, id2cls):
44 | model.eval()
45 | idx = 0
46 | preds_cpu_cls, targets_cpu_cls = [], []
47 | for batch in tqdm(dl, total=len(dl), leave=False, desc="Validation"):
48 | idx += 1
49 | preds_cls = model.forward(batch)
50 | preds_cpu_, targets_cpu_ = transformed_result_cls([preds_cls], [batch[-1]], id2cls)
51 | preds_cpu_cls.extend(preds_cpu_)
52 | targets_cpu_cls.extend(targets_cpu_)
53 | clf_report_cls = flat_classification_report([targets_cpu_cls], [preds_cpu_cls], digits=4)
54 | return clf_report_cls
55 |
56 |
57 | def predict(dl, model, id2cls):
58 | model.eval()
59 | idx = 0
60 | preds_cpu_cls = []
61 | for batch in tqdm(dl, total=len(dl), leave=False, desc="Predicting"):
62 | idx += 1
63 | preds_cls = model.forward(batch)
64 | preds_cpu_ = transformed_result_cls([preds_cls], [preds_cls], id2cls, False)
65 | preds_cpu_cls.extend(preds_cpu_)
66 |
67 | return preds_cpu_cls
68 |
69 |
70 | class NerLearner(object):
71 |
72 | def __init__(self, model, data, best_model_path, lr=0.001, betas=[0.8, 0.9], clip=1.0,
73 | verbose=True, t_total=-1, warmup=0.1, weight_decay=0.01,
74 | validate_every=1, schedule="warmup_linear", e=1e-6):
75 | logging.basicConfig(level=logging.INFO)
76 | self.model = model
77 | self.optimizer = BertAdam(model, lr, t_total=t_total, b1=betas[0], b2=betas[1], max_grad_norm=clip)
78 | self.optimizer_defaults = dict(
79 | model=model, lr=lr, warmup=warmup, t_total=t_total, schedule=schedule,
80 | b1=betas[0], b2=betas[1], e=e, weight_decay=weight_decay,
81 | max_grad_norm=clip)
82 |
83 | self.lr = lr
84 | self.betas = betas
85 | self.clip = clip
86 | self.t_total = t_total
87 | self.warmup = warmup
88 | self.weight_decay = weight_decay
89 | self.validate_every = validate_every
90 | self.schedule = schedule
91 | self.data = data
92 | self.e = e
93 | self.best_model_path = best_model_path
94 | self.verbose = verbose
95 | self.cls_history = []
96 | self.epoch = 0
97 | self.best_target_metric = 0.
98 |
99 | def fit(self, epochs=100, resume_history=True, target_metric="f1"):
100 | if not resume_history:
101 | self.optimizer_defaults["t_total"] = epochs * len(self.data.train_dl)
102 | self.optimizer = BertAdam(**self.optimizer_defaults)
103 | self.cls_history = []
104 | self.epoch = 0
105 | self.best_target_metric = 0.
106 | elif self.verbose:
107 | logging.info("Resuming train... Current epoch {}.".format(self.epoch))
108 | try:
109 | for _ in range(epochs):
110 | self.epoch += 1
111 | self.fit_one_cycle(self.epoch, target_metric)
112 | except KeyboardInterrupt:
113 | pass
114 |
115 | def fit_one_cycle(self, epoch, target_metric="f1"):
116 | train_step(self.data.train_dl, self.model, self.optimizer, epoch)
117 | if epoch % self.validate_every == 0:
118 | rep_cls = validate_step(self.data.valid_dl, self.model, self.data.train_ds.idx2cls)
119 | self.cls_history.append(rep_cls)
120 | idx, metric = get_mean_max_metric(self.cls_history, target_metric, True)
121 | if self.verbose:
122 | logging.info("on epoch {} by max_{}: {}".format(idx, target_metric, metric))
123 | print(self.cls_history[-1])
124 |
125 | # Store best model
126 | if self.best_target_metric < metric:
127 | self.best_target_metric = metric
128 | if self.verbose:
129 | logging.info("Saving new best model...")
130 | self.save_model()
131 |
132 | def predict(self, dl=None, df_path=None, df=None):
133 | if dl is None:
134 | dl, ds = get_data_loader_for_predict(self.data, df_path, df)
135 | return predict(dl, self.model, self.data.train_ds.idx2cls)
136 |
137 | def save_model(self, path=None):
138 | path = path if path else self.best_model_path
139 | torch.save(self.model.state_dict(), path)
140 |
141 | def load_model(self, path=None):
142 | path = path if path else self.best_model_path
143 | self.model.load_state_dict(torch.load(path))
144 |
--------------------------------------------------------------------------------
/modules/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy
4 | import bson
5 | import sys
6 |
7 |
8 | def ipython_info():
9 | ip = False
10 | if 'ipykernel' in sys.modules:
11 | ip = 'notebook'
12 | elif 'IPython' in sys.modules:
13 | ip = 'terminal'
14 | return ip
15 |
16 |
17 | def get_tqdm():
18 | ip = ipython_info()
19 | if ip == "terminal" or not ip:
20 | from tqdm import tqdm
21 | return tqdm
22 | else:
23 | try:
24 | from tqdm import tqdm_notebook
25 | return tqdm_notebook
26 | except:
27 | from tqdm import tqdm
28 | return tqdm
29 |
30 |
31 | class JsonEncoder(json.JSONEncoder):
32 | def default(self, obj):
33 | if isinstance(obj, numpy.integer):
34 | return int(obj)
35 | elif isinstance(obj, numpy.floating):
36 | return float(obj)
37 | elif isinstance(obj, numpy.ndarray):
38 | return obj.tolist()
39 | elif isinstance(obj, bson.ObjectId):
40 | return str(obj)
41 | else:
42 | return super(JsonEncoder, self).default(obj)
43 |
44 |
45 | def jsonify(data):
46 | return json.dumps(data, cls=JsonEncoder)
47 |
48 |
49 | def read_config(config):
50 | if isinstance(config, str):
51 | with open(config, "r", encoding="utf-8") as f:
52 | config = json.load(f)
53 | return config
54 |
55 |
56 | def save_config(config, path):
57 | with open(path, "w") as file:
58 | json.dump(config, file, cls=JsonEncoder)
59 |
60 |
61 | def if_none(origin, other):
62 | return other if origin is None else origin
63 |
64 |
65 | def get_files_path_from_dir(path):
66 | f = []
67 | for dir_path, dir_names, filenames in os.walk(path):
68 | for f_name in filenames:
69 | f.append(dir_path + "/" + f_name)
70 | return f
71 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bson
2 | pandas
3 | scikit-learn
4 | sklearn-crfsuite
5 | tqdm
6 | rusenttokenize
7 | numpy
8 | nltk
9 | torch
10 | matplotlib
--------------------------------------------------------------------------------