├── .gitignore
├── LICENSE
├── README.md
├── colab_quickstart.ipynb
├── data_download.sh
├── data_prep.ipynb
├── run_model.ipynb
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repository is now deprecated. Please use [Simple Transformers](https://github.com/ThilinaRajapakse/simpletransformers) instead.
2 |
3 | # Update Notice
4 |
5 | The underlying [Pytorch-Transformers](https://github.com/huggingface/pytorch-transformers) library by HuggingFace has been updated substantially since this repo was created. As such, this repo might not be compatible with the current version of the Hugging Face Transformers library. This repo will not be updated further.
6 |
7 | **I recommend using [Simple Transformers](https://github.com/ThilinaRajapakse/simpletransformers) (based on the updated Hugging Face library) as it is regularly maintained, feature rich, as well as (much) easier to use.**
8 |
9 |
10 | # Pytorch-Transformers-Classification
11 |
12 |
13 | This repository is based on the [Pytorch-Transformers](https://github.com/huggingface/pytorch-transformers) library by HuggingFace. It is intended as a starting point for anyone who wishes to use Transformer models in text classification tasks.
14 |
15 | Please refer to this [Medium article](https://medium.com/p/https-medium-com-chaturangarajapakshe-text-classification-with-transformer-models-d370944b50ca?source=email-6b1e2355088e--writer.postDistributed&sk=f21ffeb66c03a9804572d7063f57c04e) for further information on how this project works.
16 |
17 | Check out the new library [simpletransformers](https://github.com/ThilinaRajapakse/simpletransformers) for one line training and evaluating!
18 |
19 | Table of contents
20 | =================
21 |
22 |
23 | * [Setup](#Setup)
24 | * [Simple Transformers](#simple-transformers---ready-to-use-library)
25 | * [Quickstart using Colab](#quickstart-using-colab)
26 | * [With Conda](#with-conda)
27 | * [Usage](#usage)
28 | * [Yelp Demo](#yelp-demo)
29 | * [Custom Datasets](#custom-datasets)
30 | * [Current Pretrained Models](#current-pretrained-models)
31 | * [Evaluation Metrics](#evaluation-metrics)
32 | * [Acknowledgements](#acknowledgements)
33 |
34 |
35 | ## Setup
36 |
37 | ### Simple Transformers - Ready to use library
38 |
39 | If you want to go directly to training, evaluating, and predicting with Transformer models, take a look at the [Simple Transformers](https://github.com/ThilinaRajapakse/simpletransformers) library. It's the easiest way to use Transformers for text classification with only 3 lines of code required. It's based on this repo but is designed to enable the use of Transformers without having to worry about the low level details. However, ease of usage comes at the cost of less control (and visibility) over how everything works.
40 |
41 | ### Quickstart using Colab
42 |
43 | Try this [Google Colab Notebook](colab_quickstart.ipynb) for a quick preview. You can run all cells without any modifications to see how everything works. However, due to the 12 hour time limit on Colab instances, the dataset has been undersampled from 500 000 samples to about 5000 samples. For such a tiny sample size, everything should complete in about 10 minutes.
44 |
45 | ### With Conda
46 |
47 | 1. Install Anaconda or Miniconda Package Manager from [here](https://www.anaconda.com/distribution/)
48 | 2. Create a new virtual environment and install packages.
49 | `conda create -n transformers python pandas tqdm jupyter`
50 | `conda activate transformers`
51 | If using cuda:
52 | `conda install pytorch cudatoolkit=10.0 -c pytorch`
53 | else:
54 | `conda install pytorch cpuonly -c pytorch`
55 | `conda install -c anaconda scipy`
56 | `conda install -c anaconda scikit-learn`
57 | `pip install pytorch-transformers`
58 | `pip install tensorboardX`
59 | 3. Clone repo.
60 | `git clone https://github.com/ThilinaRajapakse/pytorch-transformers-classification.git`
61 |
62 | ## Usage
63 |
64 | ### Yelp Demo
65 |
66 | This demonstration uses the Yelp Reviews dataset.
67 |
68 | Linux users can execute [data_download.sh](data_download.sh) to download and set up the data files.
69 |
70 | If you are doing it manually;
71 |
72 | 1. Download [Yelp Reviews Dataset](https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz).
73 | 2. Extract `train.csv` and `test.csv` and place them in the directory `data/`.
74 |
75 | Once the download is complete, you can run the [data_prep.ipynb](data_prep.ipynb) notebook to get the data ready for training.
76 |
77 | Finally, you can run the [run_model.ipynb](run_model.ipynb) notebook to fine-tune a Transformer model on the Yelp Dataset and evaluate the results.
78 |
79 | ### Current Pretrained Models
80 |
81 | The table below shows the currently available model types and their models. You can use any of these by setting the `model_type` and `model_name` in the `args` dictionary. For more information about pretrained models, see [HuggingFace docs](https://huggingface.co/pytorch-transformers/pretrained_models.html).
82 |
83 | | Architecture | Model Type | Model Name | Details |
84 | | :------------- |:----------| :-------------| :-----------------------------|
85 | | BERT | bert | bert-base-uncased | 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on lower-cased English text. |
86 | | BERT | bert | bert-large-uncased | 24-layer, 1024-hidden, 16-heads, 340M parameters.
Trained on lower-cased English text. |
87 | | BERT | bert | bert-base-cased | 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on cased English text. |
88 | | BERT | bert | bert-large-cased | 24-layer, 1024-hidden, 16-heads, 340M parameters.
Trained on cased English text. |
89 | | BERT | bert | bert-base-multilingual-uncased | (Original, not recommended) 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on lower-cased text in the top 102 languages with the largest Wikipedias |
90 | | BERT | bert | bert-base-multilingual-cased | (New, recommended) 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on cased text in the top 104 languages with the largest Wikipedias |
91 | | BERT | bert | bert-base-chinese | 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on cased Chinese Simplified and Traditional text. |
92 | | BERT | bert | bert-base-german-cased | 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on cased German text by Deepset.ai |
93 | | BERT | bert | bert-large-uncased-whole-word-masking | 24-layer, 1024-hidden, 16-heads, 340M parameters.
Trained on lower-cased English text using Whole-Word-Masking |
94 | | BERT | bert | bert-large-cased-whole-word-masking | 24-layer, 1024-hidden, 16-heads, 340M parameters.
Trained on cased English text using Whole-Word-Masking |
95 | | BERT | bert | bert-large-uncased-whole-word-masking-finetuned-squad | 24-layer, 1024-hidden, 16-heads, 340M parameters.
The bert-large-uncased-whole-word-masking model fine-tuned on SQuAD |
96 | | BERT | bert | bert-large-cased-whole-word-masking-finetuned-squad | 24-layer, 1024-hidden, 16-heads, 340M parameters
The bert-large-cased-whole-word-masking model fine-tuned on SQuAD |
97 | | BERT | bert | bert-base-cased-finetuned-mrpc | 12-layer, 768-hidden, 12-heads, 110M parameters.
The bert-base-cased model fine-tuned on MRPC |
98 | | XLNet | xlnet | xlnet-base-cased | 12-layer, 768-hidden, 12-heads, 110M parameters.
XLNet English model |
99 | | XLNet | xlnet | xlnet-large-cased | 24-layer, 1024-hidden, 16-heads, 340M parameters.
XLNet Large English model |
100 | | XLM | xlm | xlm-mlm-en-2048 | 12-layer, 2048-hidden, 16-heads
XLM English model |
101 | | XLM | xlm | xlm-mlm-ende-1024 | 6-layer, 1024-hidden, 8-heads
XLM English-German Multi-language model |
102 | | XLM | xlm | xlm-mlm-enfr-1024 | 6-layer, 1024-hidden, 8-heads
XLM English-French Multi-language model |
103 | | XLM | xlm | xlm-mlm-enro-1024 | 6-layer, 1024-hidden, 8-heads
XLM English-Romanian Multi-language model |
104 | | XLM | xlm | xlm-mlm-xnli15-1024 | 12-layer, 1024-hidden, 8-heads
XLM Model pre-trained with MLM on the 15 XNLI languages |
105 | | XLM | xlm | xlm-mlm-tlm-xnli15-1024 | 12-layer, 1024-hidden, 8-heads
XLM Model pre-trained with MLM + TLM on the 15 XNLI languages |
106 | | XLM | xlm | xlm-clm-enfr-1024 | 12-layer, 1024-hidden, 8-heads
XLM English model trained with CLM (Causal Language Modeling) |
107 | | XLM | xlm | xlm-clm-ende-1024 | 6-layer, 1024-hidden, 8-heads
XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
108 | | RoBERTa | roberta | roberta-base | 125M parameters
RoBERTa using the BERT-base architecture |
109 | | RoBERTa | roberta | roberta-large | 24-layer, 1024-hidden, 16-heads, 355M parameters
RoBERTa using the BERT-large architecture |
110 | | RoBERTa | roberta | roberta-large-mnli | 24-layer, 1024-hidden, 16-heads, 355M parameters
roberta-large fine-tuned on MNLI. |
111 |
112 | ### Custom Datasets
113 |
114 | When working with your own datasets, you can create a script/notebook similar to [data_prep.ipynb](data_prep.ipynb) that will convert the dataset to a Pytorch-Transformer ready format.
115 |
116 | The data needs to be in `tsv` format, with four columns, and no header.
117 |
118 | This is the required structure.
119 |
120 | - `guid`: An ID for the row.
121 | - `label`: The label for the row (should be an int).
122 | - `alpha`: A column of the same letter for all rows. Not used in classification but still expected by the `DataProcessor`.
123 | - `text`: The sentence or sequence of text.
124 |
125 | ### Evaluation Metrics
126 |
127 | The evaluation process in the [run_model.ipynb](run_model.ipynb) notebook outputs the confusion matrix, and the Matthews correlation coefficient. If you wish to add any more evaluation metrics, simply edit the `get_eval_reports()` function in the notebook. This function takes the predictions and the ground truth labels as parameters, therefore you can add any custom metrics calculations to the function as required.
128 |
129 | ## Acknowledgements
130 |
131 | None of this would have been possible without the hard work by the HuggingFace team in developing the [Pytorch-Transformers](https://github.com/huggingface/pytorch-transformers) library.
132 |
--------------------------------------------------------------------------------
/colab_quickstart.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {
17 | "colab": {},
18 | "colab_type": "code",
19 | "id": "MrIwJ5QDz-e9"
20 | },
21 | "outputs": [],
22 | "source": [
23 | "!pip install pytorch-transformers"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {
30 | "colab": {},
31 | "colab_type": "code",
32 | "id": "0YLoS0hWz-ch"
33 | },
34 | "outputs": [],
35 | "source": [
36 | "%%writefile utils.py \n",
37 | "\n",
38 | "# coding=utf-8\n",
39 | "# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n",
40 | "# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.\n",
41 | "#\n",
42 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
43 | "# you may not use this file except in compliance with the License.\n",
44 | "# You may obtain a copy of the License at\n",
45 | "#\n",
46 | "# http://www.apache.org/licenses/LICENSE-2.0\n",
47 | "#\n",
48 | "# Unless required by applicable law or agreed to in writing, software\n",
49 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
50 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
51 | "# See the License for the specific language governing permissions and\n",
52 | "# limitations under the License.\n",
53 | "\"\"\" BERT classification fine-tuning: utilities to work with GLUE tasks \"\"\"\n",
54 | "\n",
55 | "from __future__ import absolute_import, division, print_function\n",
56 | "\n",
57 | "import csv\n",
58 | "import logging\n",
59 | "import os\n",
60 | "import sys\n",
61 | "from io import open\n",
62 | "\n",
63 | "from scipy.stats import pearsonr, spearmanr\n",
64 | "from sklearn.metrics import matthews_corrcoef, f1_score\n",
65 | "\n",
66 | "from multiprocessing import Pool, cpu_count\n",
67 | "from tqdm import tqdm\n",
68 | "\n",
69 | "logger = logging.getLogger(__name__)\n",
70 | "csv.field_size_limit(2147483647)\n",
71 | "\n",
72 | "class InputExample(object):\n",
73 | " \"\"\"A single training/test example for simple sequence classification.\"\"\"\n",
74 | "\n",
75 | " def __init__(self, guid, text_a, text_b=None, label=None):\n",
76 | " \"\"\"Constructs a InputExample.\n",
77 | "\n",
78 | " Args:\n",
79 | " guid: Unique id for the example.\n",
80 | " text_a: string. The untokenized text of the first sequence. For single\n",
81 | " sequence tasks, only this sequence must be specified.\n",
82 | " text_b: (Optional) string. The untokenized text of the second sequence.\n",
83 | " Only must be specified for sequence pair tasks.\n",
84 | " label: (Optional) string. The label of the example. This should be\n",
85 | " specified for train and dev examples, but not for test examples.\n",
86 | " \"\"\"\n",
87 | " self.guid = guid\n",
88 | " self.text_a = text_a\n",
89 | " self.text_b = text_b\n",
90 | " self.label = label\n",
91 | "\n",
92 | "\n",
93 | "class InputFeatures(object):\n",
94 | " \"\"\"A single set of features of data.\"\"\"\n",
95 | "\n",
96 | " def __init__(self, input_ids, input_mask, segment_ids, label_id):\n",
97 | " self.input_ids = input_ids\n",
98 | " self.input_mask = input_mask\n",
99 | " self.segment_ids = segment_ids\n",
100 | " self.label_id = label_id\n",
101 | "\n",
102 | "\n",
103 | "class DataProcessor(object):\n",
104 | " \"\"\"Base class for data converters for sequence classification data sets.\"\"\"\n",
105 | "\n",
106 | " def get_train_examples(self, data_dir):\n",
107 | " \"\"\"Gets a collection of `InputExample`s for the train set.\"\"\"\n",
108 | " raise NotImplementedError()\n",
109 | "\n",
110 | " def get_dev_examples(self, data_dir):\n",
111 | " \"\"\"Gets a collection of `InputExample`s for the dev set.\"\"\"\n",
112 | " raise NotImplementedError()\n",
113 | "\n",
114 | " def get_labels(self):\n",
115 | " \"\"\"Gets the list of labels for this data set.\"\"\"\n",
116 | " raise NotImplementedError()\n",
117 | "\n",
118 | " @classmethod\n",
119 | " def _read_tsv(cls, input_file, quotechar=None):\n",
120 | " \"\"\"Reads a tab separated value file.\"\"\"\n",
121 | " with open(input_file, \"r\", encoding=\"utf-8-sig\") as f:\n",
122 | " reader = csv.reader(f, delimiter=\"\\t\", quotechar=quotechar)\n",
123 | " lines = []\n",
124 | " for line in reader:\n",
125 | " if sys.version_info[0] == 2:\n",
126 | " line = list(unicode(cell, 'utf-8') for cell in line)\n",
127 | " lines.append(line)\n",
128 | " return lines\n",
129 | "\n",
130 | "\n",
131 | "class BinaryProcessor(DataProcessor):\n",
132 | " \"\"\"Processor for the binary data sets\"\"\"\n",
133 | "\n",
134 | " def get_train_examples(self, data_dir):\n",
135 | " \"\"\"See base class.\"\"\"\n",
136 | " return self._create_examples(\n",
137 | " self._read_tsv(os.path.join(data_dir, \"train.tsv\")), \"train\")\n",
138 | "\n",
139 | " def get_dev_examples(self, data_dir):\n",
140 | " \"\"\"See base class.\"\"\"\n",
141 | " return self._create_examples(\n",
142 | " self._read_tsv(os.path.join(data_dir, \"dev.tsv\")), \"dev\")\n",
143 | "\n",
144 | " def get_labels(self):\n",
145 | " \"\"\"See base class.\"\"\"\n",
146 | " return [\"0\", \"1\"]\n",
147 | "\n",
148 | " def _create_examples(self, lines, set_type):\n",
149 | " \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
150 | " examples = []\n",
151 | " for (i, line) in enumerate(lines):\n",
152 | " guid = \"%s-%s\" % (set_type, i)\n",
153 | " text_a = line[3]\n",
154 | " label = line[1]\n",
155 | " examples.append(\n",
156 | " InputExample(guid=guid, text_a=text_a, text_b=None, label=label))\n",
157 | " return examples\n",
158 | "\n",
159 | "\n",
160 | "def convert_example_to_feature(example_row, pad_token=0,\n",
161 | "sequence_a_segment_id=0, sequence_b_segment_id=1,\n",
162 | "cls_token_segment_id=1, pad_token_segment_id=0,\n",
163 | "mask_padding_with_zero=True):\n",
164 | " example, label_map, max_seq_length, tokenizer, output_mode, cls_token_at_end, cls_token, sep_token, cls_token_segment_id, pad_on_left, pad_token_segment_id = example_row\n",
165 | "\n",
166 | " tokens_a = tokenizer.tokenize(example.text_a)\n",
167 | "\n",
168 | " tokens_b = None\n",
169 | " if example.text_b:\n",
170 | " tokens_b = tokenizer.tokenize(example.text_b)\n",
171 | " # Modifies `tokens_a` and `tokens_b` in place so that the total\n",
172 | " # length is less than the specified length.\n",
173 | " # Account for [CLS], [SEP], [SEP] with \"- 3\"\n",
174 | " _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)\n",
175 | " else:\n",
176 | " # Account for [CLS] and [SEP] with \"- 2\"\n",
177 | " if len(tokens_a) > max_seq_length - 2:\n",
178 | " tokens_a = tokens_a[:(max_seq_length - 2)]\n",
179 | "\n",
180 | " # The convention in BERT is:\n",
181 | " # (a) For sequence pairs:\n",
182 | " # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]\n",
183 | " # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1\n",
184 | " # (b) For single sequences:\n",
185 | " # tokens: [CLS] the dog is hairy . [SEP]\n",
186 | " # type_ids: 0 0 0 0 0 0 0\n",
187 | " #\n",
188 | " # Where \"type_ids\" are used to indicate whether this is the first\n",
189 | " # sequence or the second sequence. The embedding vectors for `type=0` and\n",
190 | " # `type=1` were learned during pre-training and are added to the wordpiece\n",
191 | " # embedding vector (and position vector). This is not *strictly* necessary\n",
192 | " # since the [SEP] token unambiguously separates the sequences, but it makes\n",
193 | " # it easier for the model to learn the concept of sequences.\n",
194 | " #\n",
195 | " # For classification tasks, the first vector (corresponding to [CLS]) is\n",
196 | " # used as as the \"sentence vector\". Note that this only makes sense because\n",
197 | " # the entire model is fine-tuned.\n",
198 | " tokens = tokens_a + [sep_token]\n",
199 | " segment_ids = [sequence_a_segment_id] * len(tokens)\n",
200 | "\n",
201 | " if tokens_b:\n",
202 | " tokens += tokens_b + [sep_token]\n",
203 | " segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)\n",
204 | "\n",
205 | " if cls_token_at_end:\n",
206 | " tokens = tokens + [cls_token]\n",
207 | " segment_ids = segment_ids + [cls_token_segment_id]\n",
208 | " else:\n",
209 | " tokens = [cls_token] + tokens\n",
210 | " segment_ids = [cls_token_segment_id] + segment_ids\n",
211 | "\n",
212 | " input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
213 | "\n",
214 | " # The mask has 1 for real tokens and 0 for padding tokens. Only real\n",
215 | " # tokens are attended to.\n",
216 | " input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)\n",
217 | "\n",
218 | " # Zero-pad up to the sequence length.\n",
219 | " padding_length = max_seq_length - len(input_ids)\n",
220 | " if pad_on_left:\n",
221 | " input_ids = ([pad_token] * padding_length) + input_ids\n",
222 | " input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask\n",
223 | " segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids\n",
224 | " else:\n",
225 | " input_ids = input_ids + ([pad_token] * padding_length)\n",
226 | " input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)\n",
227 | " segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)\n",
228 | "\n",
229 | " assert len(input_ids) == max_seq_length\n",
230 | " assert len(input_mask) == max_seq_length\n",
231 | " assert len(segment_ids) == max_seq_length\n",
232 | "\n",
233 | " if output_mode == \"classification\":\n",
234 | " label_id = label_map[example.label]\n",
235 | " elif output_mode == \"regression\":\n",
236 | " label_id = float(example.label)\n",
237 | " else:\n",
238 | " raise KeyError(output_mode)\n",
239 | "\n",
240 | " return InputFeatures(input_ids=input_ids,\n",
241 | " input_mask=input_mask,\n",
242 | " segment_ids=segment_ids,\n",
243 | " label_id=label_id)\n",
244 | " \n",
245 | "\n",
246 | "def convert_examples_to_features(examples, label_list, max_seq_length,\n",
247 | " tokenizer, output_mode,\n",
248 | " cls_token_at_end=False, pad_on_left=False,\n",
249 | " cls_token='[CLS]', sep_token='[SEP]', pad_token=0,\n",
250 | " sequence_a_segment_id=0, sequence_b_segment_id=1,\n",
251 | " cls_token_segment_id=1, pad_token_segment_id=0,\n",
252 | " mask_padding_with_zero=True,\n",
253 | " process_count=cpu_count() - 2):\n",
254 | " \"\"\" Loads a data file into a list of `InputBatch`s\n",
255 | " `cls_token_at_end` define the location of the CLS token:\n",
256 | " - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]\n",
257 | " - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]\n",
258 | " `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)\n",
259 | " \"\"\"\n",
260 | "\n",
261 | " label_map = {label : i for i, label in enumerate(label_list)}\n",
262 | "\n",
263 | " examples = [(example, label_map, max_seq_length, tokenizer, output_mode, cls_token_at_end, cls_token, sep_token, cls_token_segment_id, pad_on_left, pad_token_segment_id) for example in examples]\n",
264 | "\n",
265 | " with Pool(process_count) as p:\n",
266 | " features = list(tqdm(p.imap(convert_example_to_feature, examples, chunksize=100), total=len(examples)))\n",
267 | "\n",
268 | " return features\n",
269 | "\n",
270 | "\n",
271 | "def _truncate_seq_pair(tokens_a, tokens_b, max_length):\n",
272 | " \"\"\"Truncates a sequence pair in place to the maximum length.\"\"\"\n",
273 | "\n",
274 | " # This is a simple heuristic which will always truncate the longer sequence\n",
275 | " # one token at a time. This makes more sense than truncating an equal percent\n",
276 | " # of tokens from each, since if one sequence is very short then each token\n",
277 | " # that's truncated likely contains more information than a longer sequence.\n",
278 | " while True:\n",
279 | " total_length = len(tokens_a) + len(tokens_b)\n",
280 | " if total_length <= max_length:\n",
281 | " break\n",
282 | " if len(tokens_a) > len(tokens_b):\n",
283 | " tokens_a.pop()\n",
284 | " else:\n",
285 | " tokens_b.pop()\n",
286 | "\n",
287 | "\n",
288 | "processors = {\n",
289 | " \"binary\": BinaryProcessor\n",
290 | "}\n",
291 | "\n",
292 | "output_modes = {\n",
293 | " \"binary\": \"classification\"\n",
294 | "}\n",
295 | "\n",
296 | "GLUE_TASKS_NUM_LABELS = {\n",
297 | " \"binary\": 2\n",
298 | "}\n"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": null,
304 | "metadata": {
305 | "colab": {},
306 | "colab_type": "code",
307 | "id": "iDJUEfruz-aF"
308 | },
309 | "outputs": [],
310 | "source": [
311 | "!mkdir data\n",
312 | "!wget https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz -O data/data.tgz"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "metadata": {
319 | "colab": {},
320 | "colab_type": "code",
321 | "id": "aepMB1Z_2u0V"
322 | },
323 | "outputs": [],
324 | "source": [
325 | "!tar -xvzf data/data.tgz -C data/\n",
326 | "!mv data/yelp_review_polarity_csv/* data/\n",
327 | "!rm -r data/yelp_review_polarity_csv/\n",
328 | "!rm data/data.tgz"
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": null,
334 | "metadata": {
335 | "colab": {},
336 | "colab_type": "code",
337 | "id": "qYv4tGNs4Tg8"
338 | },
339 | "outputs": [],
340 | "source": [
341 | "import pandas as pd\n",
342 | "from tqdm import tqdm_notebook\n",
343 | "\n",
344 | "prefix = 'data/'"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "execution_count": null,
350 | "metadata": {
351 | "colab": {},
352 | "colab_type": "code",
353 | "id": "BiC1S1xq4XnN"
354 | },
355 | "outputs": [],
356 | "source": [
357 | "train_df = pd.read_csv(prefix + 'train.csv', header=None)\n",
358 | "train_df.head()"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": null,
364 | "metadata": {
365 | "colab": {},
366 | "colab_type": "code",
367 | "id": "um-kSl434Xk6"
368 | },
369 | "outputs": [],
370 | "source": [
371 | "test_df = pd.read_csv(prefix + 'test.csv', header=None)\n",
372 | "test_df.head()"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": null,
378 | "metadata": {
379 | "colab": {},
380 | "colab_type": "code",
381 | "id": "_Xe5tm-2441r"
382 | },
383 | "outputs": [],
384 | "source": [
385 | "train_df[0] = (train_df[0] == 2).astype(int)\n",
386 | "test_df[0] = (test_df[0] == 2).astype(int)"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": null,
392 | "metadata": {
393 | "colab": {},
394 | "colab_type": "code",
395 | "id": "hVsnPNzr4XjB"
396 | },
397 | "outputs": [],
398 | "source": [
399 | "train_df = pd.DataFrame({\n",
400 | " 'id':range(len(train_df)),\n",
401 | " 'label':train_df[0],\n",
402 | " 'alpha':['a']*train_df.shape[0],\n",
403 | " 'text': train_df[1].replace(r'\\n', ' ', regex=True)\n",
404 | "})\n",
405 | "\n",
406 | "train_df.head()"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": null,
412 | "metadata": {
413 | "colab": {},
414 | "colab_type": "code",
415 | "id": "IkUFgZeh4Xf2"
416 | },
417 | "outputs": [],
418 | "source": [
419 | "dev_df = pd.DataFrame({\n",
420 | " 'id':range(len(test_df)),\n",
421 | " 'label':test_df[0],\n",
422 | " 'alpha':['a']*test_df.shape[0],\n",
423 | " 'text': test_df[1].replace(r'\\n', ' ', regex=True)\n",
424 | "})\n",
425 | "\n",
426 | "dev_df.head()"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": null,
432 | "metadata": {
433 | "colab": {},
434 | "colab_type": "code",
435 | "id": "pFlvoH234XdO"
436 | },
437 | "outputs": [],
438 | "source": [
439 | "train_df.to_csv('data/train.tsv', sep='\\t', index=False, header=False, columns=train_df.columns)\n",
440 | "dev_df.to_csv('data/dev.tsv', sep='\\t', index=False, header=False, columns=dev_df.columns)"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": null,
446 | "metadata": {
447 | "colab": {},
448 | "colab_type": "code",
449 | "id": "VeXuXWylz7BD"
450 | },
451 | "outputs": [],
452 | "source": [
453 | "from __future__ import absolute_import, division, print_function\n",
454 | "\n",
455 | "import glob\n",
456 | "import logging\n",
457 | "import os\n",
458 | "import random\n",
459 | "import json\n",
460 | "\n",
461 | "import numpy as np\n",
462 | "import torch\n",
463 | "from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,\n",
464 | " TensorDataset)\n",
465 | "import random\n",
466 | "from torch.utils.data.distributed import DistributedSampler\n",
467 | "from tqdm import tqdm_notebook, trange\n",
468 | "\n",
469 | "\n",
470 | "from pytorch_transformers import (WEIGHTS_NAME, BertConfig, BertForSequenceClassification, BertTokenizer,\n",
471 | " XLMConfig, XLMForSequenceClassification, XLMTokenizer, \n",
472 | " XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer,\n",
473 | " RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)\n",
474 | "\n",
475 | "from pytorch_transformers import AdamW, WarmupLinearSchedule\n",
476 | "\n",
477 | "from utils import (convert_examples_to_features,\n",
478 | " output_modes, processors)\n",
479 | "\n",
480 | "logging.basicConfig(level=logging.INFO)\n",
481 | "logger = logging.getLogger(__name__)"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": null,
487 | "metadata": {
488 | "colab": {},
489 | "colab_type": "code",
490 | "id": "F93_pIopz7BG"
491 | },
492 | "outputs": [],
493 | "source": [
494 | "args = {\n",
495 | " 'data_dir': 'data/',\n",
496 | " 'model_type': 'roberta',\n",
497 | " 'model_name': 'roberta-base',\n",
498 | " 'task_name': 'binary',\n",
499 | " 'output_dir': 'outputs/',\n",
500 | " 'cache_dir': 'cache/',\n",
501 | " 'do_train': True,\n",
502 | " 'do_eval': True,\n",
503 | " 'fp16': False,\n",
504 | " 'fp16_opt_level': 'O1',\n",
505 | " 'max_seq_length': 128,\n",
506 | " 'output_mode': 'classification',\n",
507 | " 'train_batch_size': 8,\n",
508 | " 'eval_batch_size': 8,\n",
509 | "\n",
510 | " 'gradient_accumulation_steps': 1,\n",
511 | " 'num_train_epochs': 1,\n",
512 | " 'weight_decay': 0,\n",
513 | " 'learning_rate': 4e-5,\n",
514 | " 'adam_epsilon': 1e-8,\n",
515 | " 'warmup_steps': 0,\n",
516 | " 'max_grad_norm': 1.0,\n",
517 | "\n",
518 | " 'logging_steps': 50,\n",
519 | " 'evaluate_during_training': False,\n",
520 | " 'save_steps': 2000,\n",
521 | " 'eval_all_checkpoints': True,\n",
522 | "\n",
523 | " 'overwrite_output_dir': False,\n",
524 | " 'reprocess_input_data': False,\n",
525 | " 'notes': 'Using Yelp Reviews dataset'\n",
526 | "}\n",
527 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "execution_count": null,
533 | "metadata": {
534 | "colab": {},
535 | "colab_type": "code",
536 | "id": "atGwIw3iz7BJ"
537 | },
538 | "outputs": [],
539 | "source": [
540 | "args"
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": null,
546 | "metadata": {
547 | "colab": {},
548 | "colab_type": "code",
549 | "id": "Uzr2RwGLz7BL"
550 | },
551 | "outputs": [],
552 | "source": [
553 | "with open('args.json', 'w') as f:\n",
554 | " json.dump(args, f)"
555 | ]
556 | },
557 | {
558 | "cell_type": "code",
559 | "execution_count": null,
560 | "metadata": {
561 | "colab": {},
562 | "colab_type": "code",
563 | "id": "ymjmIyOhz7BN"
564 | },
565 | "outputs": [],
566 | "source": [
567 | "if os.path.exists(args['output_dir']) and os.listdir(args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:\n",
568 | " raise ValueError(\"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(args['output_dir']))"
569 | ]
570 | },
571 | {
572 | "cell_type": "code",
573 | "execution_count": null,
574 | "metadata": {
575 | "colab": {},
576 | "colab_type": "code",
577 | "id": "LAHYiiLMz7BP"
578 | },
579 | "outputs": [],
580 | "source": [
581 | "MODEL_CLASSES = {\n",
582 | " 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),\n",
583 | " 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),\n",
584 | " 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),\n",
585 | " 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)\n",
586 | "}\n",
587 | "\n",
588 | "config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": null,
594 | "metadata": {
595 | "colab": {},
596 | "colab_type": "code",
597 | "id": "qm5AguwFz7BR"
598 | },
599 | "outputs": [],
600 | "source": [
601 | "config = config_class.from_pretrained(args['model_name'], num_labels=2, finetuning_task=args['task_name'])\n",
602 | "tokenizer = tokenizer_class.from_pretrained(args['model_name'])"
603 | ]
604 | },
605 | {
606 | "cell_type": "code",
607 | "execution_count": null,
608 | "metadata": {
609 | "colab": {},
610 | "colab_type": "code",
611 | "id": "IGZHNvKAz7BU"
612 | },
613 | "outputs": [],
614 | "source": [
615 | "model = model_class.from_pretrained(args['model_name'])"
616 | ]
617 | },
618 | {
619 | "cell_type": "code",
620 | "execution_count": null,
621 | "metadata": {
622 | "colab": {},
623 | "colab_type": "code",
624 | "id": "xyxKpk_6z7BW"
625 | },
626 | "outputs": [],
627 | "source": [
628 | "model.to(device);"
629 | ]
630 | },
631 | {
632 | "cell_type": "code",
633 | "execution_count": null,
634 | "metadata": {
635 | "colab": {},
636 | "colab_type": "code",
637 | "id": "bsPmRyGE8GnR"
638 | },
639 | "outputs": [],
640 | "source": [
641 | "device"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": null,
647 | "metadata": {
648 | "colab": {},
649 | "colab_type": "code",
650 | "id": "Xe4P94Bfz7Ba"
651 | },
652 | "outputs": [],
653 | "source": [
654 | "task = args['task_name']\n",
655 | "\n",
656 | "processor = processors[task]()\n",
657 | "label_list = processor.get_labels()\n",
658 | "num_labels = len(label_list)"
659 | ]
660 | },
661 | {
662 | "cell_type": "code",
663 | "execution_count": null,
664 | "metadata": {
665 | "colab": {},
666 | "colab_type": "code",
667 | "id": "xqr_fwM3z7Bd"
668 | },
669 | "outputs": [],
670 | "source": [
671 | "def load_and_cache_examples(task, tokenizer, evaluate=False, undersample_scale_factor=0.01):\n",
672 | " processor = processors[task]()\n",
673 | " output_mode = args['output_mode']\n",
674 | " \n",
675 | " mode = 'dev' if evaluate else 'train'\n",
676 | " cached_features_file = os.path.join(args['data_dir'], f\"cached_{mode}_{args['model_name']}_{args['max_seq_length']}_{task}\")\n",
677 | " \n",
678 | " if os.path.exists(cached_features_file) and not args['reprocess_input_data']:\n",
679 | " logger.info(\"Loading features from cached file %s\", cached_features_file)\n",
680 | " features = torch.load(cached_features_file)\n",
681 | " \n",
682 | " else:\n",
683 | " logger.info(\"Creating features from dataset file at %s\", args['data_dir'])\n",
684 | " label_list = processor.get_labels()\n",
685 | " examples = processor.get_dev_examples(args['data_dir']) if evaluate else processor.get_train_examples(args['data_dir'])\n",
686 | " print(len(examples))\n",
687 | " examples = [example for example in examples if np.random.rand() < undersample_scale_factor]\n",
688 | " print(len(examples))\n",
689 | " \n",
690 | " features = convert_examples_to_features(examples, label_list, args['max_seq_length'], tokenizer, output_mode,\n",
691 | " cls_token_at_end=bool(args['model_type'] in ['xlnet']), # xlnet has a cls token at the end\n",
692 | " cls_token=tokenizer.cls_token,\n",
693 | " sep_token=tokenizer.sep_token,\n",
694 | " cls_token_segment_id=2 if args['model_type'] in ['xlnet'] else 0,\n",
695 | " pad_on_left=bool(args['model_type'] in ['xlnet']), # pad on the left for xlnet\n",
696 | " pad_token_segment_id=4 if args['model_type'] in ['xlnet'] else 0,\n",
697 | " process_count=2)\n",
698 | " \n",
699 | " logger.info(\"Saving features into cached file %s\", cached_features_file)\n",
700 | " torch.save(features, cached_features_file)\n",
701 | " \n",
702 | " all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n",
703 | " all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n",
704 | " all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)\n",
705 | " if output_mode == \"classification\":\n",
706 | " all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)\n",
707 | " elif output_mode == \"regression\":\n",
708 | " all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)\n",
709 | "\n",
710 | " dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)\n",
711 | " return dataset"
712 | ]
713 | },
714 | {
715 | "cell_type": "code",
716 | "execution_count": null,
717 | "metadata": {
718 | "colab": {},
719 | "colab_type": "code",
720 | "id": "oCul6vvCz7Bg"
721 | },
722 | "outputs": [],
723 | "source": [
724 | "def train(train_dataset, model, tokenizer):\n",
725 | " train_sampler = RandomSampler(train_dataset)\n",
726 | " train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args['train_batch_size'])\n",
727 | " \n",
728 | " t_total = len(train_dataloader) // args['gradient_accumulation_steps'] * args['num_train_epochs']\n",
729 | " \n",
730 | " no_decay = ['bias', 'LayerNorm.weight']\n",
731 | " optimizer_grouped_parameters = [\n",
732 | " {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args['weight_decay']},\n",
733 | " {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
734 | " ]\n",
735 | " optimizer = AdamW(optimizer_grouped_parameters, lr=args['learning_rate'], eps=args['adam_epsilon'])\n",
736 | " scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args['warmup_steps'], t_total=t_total)\n",
737 | " \n",
738 | " if args['fp16']:\n",
739 | " try:\n",
740 | " from apex import amp\n",
741 | " except ImportError:\n",
742 | " raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
743 | " model, optimizer = amp.initialize(model, optimizer, opt_level=args['fp16_opt_level'])\n",
744 | " \n",
745 | " logger.info(\"***** Running training *****\")\n",
746 | " logger.info(\" Num examples = %d\", len(train_dataset))\n",
747 | " logger.info(\" Num Epochs = %d\", args['num_train_epochs'])\n",
748 | " logger.info(\" Total train batch size = %d\", args['train_batch_size'])\n",
749 | " logger.info(\" Gradient Accumulation steps = %d\", args['gradient_accumulation_steps'])\n",
750 | " logger.info(\" Total optimization steps = %d\", t_total)\n",
751 | "\n",
752 | " global_step = 0\n",
753 | " tr_loss, logging_loss = 0.0, 0.0\n",
754 | " model.zero_grad()\n",
755 | " train_iterator = trange(int(args['num_train_epochs']), desc=\"Epoch\")\n",
756 | " \n",
757 | " for _ in train_iterator:\n",
758 | " epoch_iterator = tqdm_notebook(train_dataloader, desc=\"Iteration\")\n",
759 | " for step, batch in enumerate(epoch_iterator):\n",
760 | " model.train()\n",
761 | " batch = tuple(t.to(device) for t in batch)\n",
762 | " inputs = {'input_ids': batch[0],\n",
763 | " 'attention_mask': batch[1],\n",
764 | " 'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None, # XLM don't use segment_ids\n",
765 | " 'labels': batch[3]}\n",
766 | " outputs = model(**inputs)\n",
767 | " loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)\n",
768 | " print(\"\\r%f\" % loss, end='')\n",
769 | "\n",
770 | " if args['gradient_accumulation_steps'] > 1:\n",
771 | " loss = loss / args['gradient_accumulation_steps']\n",
772 | "\n",
773 | " if args['fp16']:\n",
774 | " with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
775 | " scaled_loss.backward()\n",
776 | " torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args['max_grad_norm'])\n",
777 | " \n",
778 | " else:\n",
779 | " loss.backward()\n",
780 | " torch.nn.utils.clip_grad_norm_(model.parameters(), args['max_grad_norm'])\n",
781 | "\n",
782 | " tr_loss += loss.item()\n",
783 | " if (step + 1) % args['gradient_accumulation_steps'] == 0:\n",
784 | " scheduler.step() # Update learning rate schedule\n",
785 | " optimizer.step()\n",
786 | " model.zero_grad()\n",
787 | " global_step += 1\n",
788 | "\n",
789 | " if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:\n",
790 | " # Log metrics\n",
791 | " if args['evaluate_during_training']: # Only evaluate when single GPU otherwise metrics may not average well\n",
792 | " results = evaluate(model, tokenizer)\n",
793 | "\n",
794 | " logging_loss = tr_loss\n",
795 | "\n",
796 | " if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:\n",
797 | " # Save model checkpoint\n",
798 | " output_dir = os.path.join(args['output_dir'], 'checkpoint-{}'.format(global_step))\n",
799 | " if not os.path.exists(output_dir):\n",
800 | " os.makedirs(output_dir)\n",
801 | " model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training\n",
802 | " model_to_save.save_pretrained(output_dir)\n",
803 | " logger.info(\"Saving model checkpoint to %s\", output_dir)\n",
804 | "\n",
805 | "\n",
806 | " return global_step, tr_loss / global_step"
807 | ]
808 | },
809 | {
810 | "cell_type": "code",
811 | "execution_count": null,
812 | "metadata": {
813 | "colab": {},
814 | "colab_type": "code",
815 | "id": "tUvkEBZUz7Bk"
816 | },
817 | "outputs": [],
818 | "source": [
819 | "from sklearn.metrics import mean_squared_error, matthews_corrcoef, confusion_matrix\n",
820 | "from scipy.stats import pearsonr\n",
821 | "\n",
822 | "def get_mismatched(labels, preds):\n",
823 | " mismatched = labels != preds\n",
824 | " examples = processor.get_dev_examples(args['data_dir'])\n",
825 | " wrong = [i for (i, v) in zip(examples, mismatched) if v]\n",
826 | " \n",
827 | " return wrong\n",
828 | "\n",
829 | "def get_eval_report(labels, preds):\n",
830 | " mcc = matthews_corrcoef(labels, preds)\n",
831 | " tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()\n",
832 | " return {\n",
833 | " \"mcc\": mcc,\n",
834 | " \"tp\": tp,\n",
835 | " \"tn\": tn,\n",
836 | " \"fp\": fp,\n",
837 | " \"fn\": fn\n",
838 | " }, get_mismatched(labels, preds)\n",
839 | "\n",
840 | "def compute_metrics(task_name, preds, labels):\n",
841 | " assert len(preds) == len(labels)\n",
842 | " return get_eval_report(labels, preds)\n",
843 | "\n",
844 | "def evaluate(model, tokenizer, prefix=\"\"):\n",
845 | " # Loop to handle MNLI double evaluation (matched, mis-matched)\n",
846 | " eval_output_dir = args['output_dir']\n",
847 | "\n",
848 | " results = {}\n",
849 | " EVAL_TASK = args['task_name']\n",
850 | "\n",
851 | " eval_dataset = load_and_cache_examples(EVAL_TASK, tokenizer, evaluate=True)\n",
852 | " if not os.path.exists(eval_output_dir):\n",
853 | " os.makedirs(eval_output_dir)\n",
854 | "\n",
855 | "\n",
856 | " eval_sampler = SequentialSampler(eval_dataset)\n",
857 | " eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args['eval_batch_size'])\n",
858 | "\n",
859 | " # Eval!\n",
860 | " logger.info(\"***** Running evaluation {} *****\".format(prefix))\n",
861 | " logger.info(\" Num examples = %d\", len(eval_dataset))\n",
862 | " logger.info(\" Batch size = %d\", args['eval_batch_size'])\n",
863 | " eval_loss = 0.0\n",
864 | " nb_eval_steps = 0\n",
865 | " preds = None\n",
866 | " out_label_ids = None\n",
867 | " for batch in tqdm_notebook(eval_dataloader, desc=\"Evaluating\"):\n",
868 | " model.eval()\n",
869 | " batch = tuple(t.to(device) for t in batch)\n",
870 | "\n",
871 | " with torch.no_grad():\n",
872 | " inputs = {'input_ids': batch[0],\n",
873 | " 'attention_mask': batch[1],\n",
874 | " 'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None, # XLM don't use segment_ids\n",
875 | " 'labels': batch[3]}\n",
876 | " outputs = model(**inputs)\n",
877 | " tmp_eval_loss, logits = outputs[:2]\n",
878 | "\n",
879 | " eval_loss += tmp_eval_loss.mean().item()\n",
880 | " nb_eval_steps += 1\n",
881 | " if preds is None:\n",
882 | " preds = logits.detach().cpu().numpy()\n",
883 | " out_label_ids = inputs['labels'].detach().cpu().numpy()\n",
884 | " else:\n",
885 | " preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)\n",
886 | " out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)\n",
887 | "\n",
888 | " eval_loss = eval_loss / nb_eval_steps\n",
889 | " if args['output_mode'] == \"classification\":\n",
890 | " preds = np.argmax(preds, axis=1)\n",
891 | " elif args['output_mode'] == \"regression\":\n",
892 | " preds = np.squeeze(preds)\n",
893 | " result, wrong = compute_metrics(EVAL_TASK, preds, out_label_ids)\n",
894 | " results.update(result)\n",
895 | "\n",
896 | " output_eval_file = os.path.join(eval_output_dir, \"eval_results.txt\")\n",
897 | " with open(output_eval_file, \"w\") as writer:\n",
898 | " logger.info(\"***** Eval results {} *****\".format(prefix))\n",
899 | " for key in sorted(result.keys()):\n",
900 | " logger.info(\" %s = %s\", key, str(result[key]))\n",
901 | " writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n",
902 | "\n",
903 | " return results, wrong"
904 | ]
905 | },
906 | {
907 | "cell_type": "code",
908 | "execution_count": null,
909 | "metadata": {
910 | "colab": {},
911 | "colab_type": "code",
912 | "id": "MlaeXY9sz7Bm"
913 | },
914 | "outputs": [],
915 | "source": [
916 | "# IMPORTANT #\n",
917 | "# Due to the 12 hour limit on Google Colab and the time it would take to convert the dataset into features, the load_and_cache_examples() function has been modified\n",
918 | "# to randomly undersample the dataset by a scale of 0.1\n",
919 | "\n",
920 | "if args['do_train']:\n",
921 | " train_dataset = load_and_cache_examples(task, tokenizer, undersample_scale_factor=0.1)\n",
922 | " global_step, tr_loss = train(train_dataset, model, tokenizer)\n",
923 | " logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)"
924 | ]
925 | },
926 | {
927 | "cell_type": "code",
928 | "execution_count": null,
929 | "metadata": {
930 | "colab": {},
931 | "colab_type": "code",
932 | "id": "1On6YjIULf7v"
933 | },
934 | "outputs": [],
935 | "source": [
936 | "if args['do_train']:\n",
937 | " if not os.path.exists(args['output_dir']):\n",
938 | " os.makedirs(args['output_dir'])\n",
939 | " logger.info(\"Saving model checkpoint to %s\", args['output_dir'])\n",
940 | " \n",
941 | " model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training\n",
942 | " model_to_save.save_pretrained(args['output_dir'])\n",
943 | " tokenizer.save_pretrained(args['output_dir'])\n",
944 | " torch.save(args, os.path.join(args['output_dir'], 'training_args.bin'))\n"
945 | ]
946 | },
947 | {
948 | "cell_type": "code",
949 | "execution_count": null,
950 | "metadata": {
951 | "colab": {},
952 | "colab_type": "code",
953 | "id": "tqiWWPA0z7Bo"
954 | },
955 | "outputs": [],
956 | "source": [
957 | "results = {}\n",
958 | "if args['do_eval']:\n",
959 | " checkpoints = [args['output_dir']]\n",
960 | " if args['eval_all_checkpoints']:\n",
961 | " checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + '/**/' + WEIGHTS_NAME, recursive=True)))\n",
962 | " logging.getLogger(\"pytorch_transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n",
963 | " logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n",
964 | " for checkpoint in checkpoints:\n",
965 | " global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else \"\"\n",
966 | " model = model_class.from_pretrained(checkpoint)\n",
967 | " model.to(device)\n",
968 | " result, wrong_preds = evaluate(model, tokenizer, prefix=global_step)\n",
969 | " result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())\n",
970 | " results.update(result)\n"
971 | ]
972 | },
973 | {
974 | "cell_type": "code",
975 | "execution_count": null,
976 | "metadata": {
977 | "colab": {},
978 | "colab_type": "code",
979 | "id": "AMb25x63z7Bq"
980 | },
981 | "outputs": [],
982 | "source": [
983 | "results"
984 | ]
985 | },
986 | {
987 | "cell_type": "code",
988 | "execution_count": null,
989 | "metadata": {
990 | "colab": {},
991 | "colab_type": "code",
992 | "id": "eyvWYNjRLHrI"
993 | },
994 | "outputs": [],
995 | "source": []
996 | }
997 | ],
998 | "metadata": {
999 | "accelerator": "GPU",
1000 | "colab": {
1001 | "collapsed_sections": [],
1002 | "include_colab_link": true,
1003 | "name": "Copy of train.ipynb",
1004 | "provenance": [],
1005 | "version": "0.3.2"
1006 | },
1007 | "kernelspec": {
1008 | "display_name": "Python 3",
1009 | "language": "python",
1010 | "name": "python3"
1011 | },
1012 | "language_info": {
1013 | "codemirror_mode": {
1014 | "name": "ipython",
1015 | "version": 3
1016 | },
1017 | "file_extension": ".py",
1018 | "mimetype": "text/x-python",
1019 | "name": "python",
1020 | "nbconvert_exporter": "python",
1021 | "pygments_lexer": "ipython3",
1022 | "version": "3.7.3"
1023 | }
1024 | },
1025 | "nbformat": 4,
1026 | "nbformat_minor": 4
1027 | }
1028 |
--------------------------------------------------------------------------------
/data_download.sh:
--------------------------------------------------------------------------------
1 | mkdir data
2 | wget https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz -O data/data.tgz
3 | tar -xvzf data/data.tgz -C data/
4 | mv data/yelp_review_polarity_csv/* data/
5 | rm -r data/yelp_review_polarity_csv/
6 | rm data/data.tgz
7 |
--------------------------------------------------------------------------------
/data_prep.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "from tqdm import tqdm_notebook\n",
11 | "\n",
12 | "prefix = 'data/'"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "train_df = pd.read_csv(prefix + 'train.csv', header=None)\n",
22 | "train_df.head()"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "test_df = pd.read_csv(prefix + 'test.csv', header=None)\n",
32 | "test_df.head()"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "train_df[0] = (train_df[0] == 2).astype(int)\n",
42 | "test_df[0] = (test_df[0] == 2).astype(int)"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "train_df = pd.DataFrame({\n",
52 | " 'id':range(len(train_df)),\n",
53 | " 'label':train_df[0],\n",
54 | " 'alpha':['a']*train_df.shape[0],\n",
55 | " 'text': train_df[1].replace(r'\\n', ' ', regex=True)\n",
56 | "})\n",
57 | "\n",
58 | "train_df.head()"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "dev_df = pd.DataFrame({\n",
68 | " 'id':range(len(test_df)),\n",
69 | " 'label':test_df[0],\n",
70 | " 'alpha':['a']*test_df.shape[0],\n",
71 | " 'text': test_df[1].replace(r'\\n', ' ', regex=True)\n",
72 | "})\n",
73 | "\n",
74 | "dev_df.head()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "train_df.to_csv('data/train.tsv', sep='\\t', index=False, header=False, columns=train_df.columns)\n",
84 | "dev_df.to_csv('data/dev.tsv', sep='\\t', index=False, header=False, columns=dev_df.columns)"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": []
93 | }
94 | ],
95 | "metadata": {
96 | "kernelspec": {
97 | "display_name": "Python 3",
98 | "language": "python",
99 | "name": "python3"
100 | },
101 | "language_info": {
102 | "codemirror_mode": {
103 | "name": "ipython",
104 | "version": 3
105 | },
106 | "file_extension": ".py",
107 | "mimetype": "text/x-python",
108 | "name": "python",
109 | "nbconvert_exporter": "python",
110 | "pygments_lexer": "ipython3",
111 | "version": "3.7.3"
112 | }
113 | },
114 | "nbformat": 4,
115 | "nbformat_minor": 4
116 | }
117 |
--------------------------------------------------------------------------------
/run_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from __future__ import absolute_import, division, print_function\n",
10 | "\n",
11 | "import glob\n",
12 | "import logging\n",
13 | "import os\n",
14 | "import random\n",
15 | "import json\n",
16 | "import math\n",
17 | "\n",
18 | "import numpy as np\n",
19 | "import torch\n",
20 | "from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,\n",
21 | " TensorDataset)\n",
22 | "import random\n",
23 | "from torch.utils.data.distributed import DistributedSampler\n",
24 | "from tqdm import tqdm_notebook, trange\n",
25 | "from tensorboardX import SummaryWriter\n",
26 | "\n",
27 | "\n",
28 | "from pytorch_transformers import (WEIGHTS_NAME, BertConfig, BertForSequenceClassification, BertTokenizer,\n",
29 | " XLMConfig, XLMForSequenceClassification, XLMTokenizer, \n",
30 | " XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer,\n",
31 | " RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)\n",
32 | "\n",
33 | "from pytorch_transformers import AdamW, WarmupLinearSchedule\n",
34 | "\n",
35 | "from utils import (convert_examples_to_features,\n",
36 | " output_modes, processors)\n",
37 | "\n",
38 | "logging.basicConfig(level=logging.INFO)\n",
39 | "logger = logging.getLogger(__name__)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "args = {\n",
49 | " 'data_dir': 'data/',\n",
50 | " 'model_type': 'xlnet',\n",
51 | " 'model_name': 'xlnet-base-cased',\n",
52 | " 'task_name': 'binary',\n",
53 | " 'output_dir': 'outputs/',\n",
54 | " 'cache_dir': 'cache/',\n",
55 | " 'do_train': True,\n",
56 | " 'do_eval': True,\n",
57 | " 'fp16': True,\n",
58 | " 'fp16_opt_level': 'O1',\n",
59 | " 'max_seq_length': 128,\n",
60 | " 'output_mode': 'classification',\n",
61 | " 'train_batch_size': 8,\n",
62 | " 'eval_batch_size': 8,\n",
63 | "\n",
64 | " 'gradient_accumulation_steps': 1,\n",
65 | " 'num_train_epochs': 1,\n",
66 | " 'weight_decay': 0,\n",
67 | " 'learning_rate': 4e-5,\n",
68 | " 'adam_epsilon': 1e-8,\n",
69 | " 'warmup_ratio': 0.06,\n",
70 | " 'warmup_steps': 0,\n",
71 | " 'max_grad_norm': 1.0,\n",
72 | "\n",
73 | " 'logging_steps': 50,\n",
74 | " 'evaluate_during_training': False,\n",
75 | " 'save_steps': 2000,\n",
76 | " 'eval_all_checkpoints': True,\n",
77 | "\n",
78 | " 'overwrite_output_dir': False,\n",
79 | " 'reprocess_input_data': True,\n",
80 | " 'notes': 'Using Yelp Reviews dataset'\n",
81 | "}\n",
82 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {},
89 | "outputs": [],
90 | "source": [
91 | "args"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "with open('args.json', 'w') as f:\n",
101 | " json.dump(args, f)"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "if os.path.exists(args['output_dir']) and os.listdir(args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:\n",
111 | " raise ValueError(\"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(args['output_dir']))"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "MODEL_CLASSES = {\n",
121 | " 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),\n",
122 | " 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),\n",
123 | " 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),\n",
124 | " 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)\n",
125 | "}\n",
126 | "\n",
127 | "config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "config = config_class.from_pretrained(args['model_name'], num_labels=2, finetuning_task=args['task_name'])\n",
137 | "tokenizer = tokenizer_class.from_pretrained(args['model_name'])"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "model = model_class.from_pretrained(args['model_name'])"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "model.to(device);"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "task = args['task_name']\n",
165 | "\n",
166 | "if task in processors.keys() and task in output_modes.keys():\n",
167 | " processor = processors[task]()\n",
168 | " label_list = processor.get_labels()\n",
169 | " num_labels = len(label_list)\n",
170 | "else:\n",
171 | " raise KeyError(f'{task} not found in processors or in output_modes. Please check utils.py.')"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "def load_and_cache_examples(task, tokenizer, evaluate=False):\n",
181 | " processor = processors[task]()\n",
182 | " output_mode = args['output_mode']\n",
183 | " \n",
184 | " mode = 'dev' if evaluate else 'train'\n",
185 | " cached_features_file = os.path.join(args['data_dir'], f\"cached_{mode}_{args['model_name']}_{args['max_seq_length']}_{task}\")\n",
186 | " \n",
187 | " if os.path.exists(cached_features_file) and not args['reprocess_input_data']:\n",
188 | " logger.info(\"Loading features from cached file %s\", cached_features_file)\n",
189 | " features = torch.load(cached_features_file)\n",
190 | " \n",
191 | " else:\n",
192 | " logger.info(\"Creating features from dataset file at %s\", args['data_dir'])\n",
193 | " label_list = processor.get_labels()\n",
194 | " examples = processor.get_dev_examples(args['data_dir']) if evaluate else processor.get_train_examples(args['data_dir'])\n",
195 | " \n",
196 | " if __name__ == \"__main__\":\n",
197 | " features = convert_examples_to_features(examples, label_list, args['max_seq_length'], tokenizer, output_mode,\n",
198 | " cls_token_at_end=bool(args['model_type'] in ['xlnet']), # xlnet has a cls token at the end\n",
199 | " cls_token=tokenizer.cls_token,\n",
200 | " cls_token_segment_id=2 if args['model_type'] in ['xlnet'] else 0,\n",
201 | " sep_token=tokenizer.sep_token,\n",
202 | " sep_token_extra=bool(args['model_type'] in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805\n",
203 | " pad_on_left=bool(args['model_type'] in ['xlnet']), # pad on the left for xlnet\n",
204 | " pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],\n",
205 | " pad_token_segment_id=4 if args['model_type'] in ['xlnet'] else 0)\n",
206 | " \n",
207 | " logger.info(\"Saving features into cached file %s\", cached_features_file)\n",
208 | " torch.save(features, cached_features_file)\n",
209 | " \n",
210 | " all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n",
211 | " all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n",
212 | " all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)\n",
213 | " if output_mode == \"classification\":\n",
214 | " all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)\n",
215 | " elif output_mode == \"regression\":\n",
216 | " all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)\n",
217 | "\n",
218 | " dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)\n",
219 | " return dataset"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "def train(train_dataset, model, tokenizer):\n",
229 | " tb_writer = SummaryWriter()\n",
230 | " \n",
231 | " train_sampler = RandomSampler(train_dataset)\n",
232 | " train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args['train_batch_size'])\n",
233 | " \n",
234 | " t_total = len(train_dataloader) // args['gradient_accumulation_steps'] * args['num_train_epochs']\n",
235 | " \n",
236 | " no_decay = ['bias', 'LayerNorm.weight']\n",
237 | " optimizer_grouped_parameters = [\n",
238 | " {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args['weight_decay']},\n",
239 | " {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
240 | " ]\n",
241 | " \n",
242 | " warmup_steps = math.ceil(t_total * args['warmup_ratio'])\n",
243 | " args['warmup_steps'] = warmup_steps if args['warmup_steps'] == 0 else args['warmup_steps']\n",
244 | " \n",
245 | " optimizer = AdamW(optimizer_grouped_parameters, lr=args['learning_rate'], eps=args['adam_epsilon'])\n",
246 | " scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args['warmup_steps'], t_total=t_total)\n",
247 | " \n",
248 | " if args['fp16']:\n",
249 | " try:\n",
250 | " from apex import amp\n",
251 | " except ImportError:\n",
252 | " raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
253 | " model, optimizer = amp.initialize(model, optimizer, opt_level=args['fp16_opt_level'])\n",
254 | " \n",
255 | " logger.info(\"***** Running training *****\")\n",
256 | " logger.info(\" Num examples = %d\", len(train_dataset))\n",
257 | " logger.info(\" Num Epochs = %d\", args['num_train_epochs'])\n",
258 | " logger.info(\" Total train batch size = %d\", args['train_batch_size'])\n",
259 | " logger.info(\" Gradient Accumulation steps = %d\", args['gradient_accumulation_steps'])\n",
260 | " logger.info(\" Total optimization steps = %d\", t_total)\n",
261 | "\n",
262 | " global_step = 0\n",
263 | " tr_loss, logging_loss = 0.0, 0.0\n",
264 | " model.zero_grad()\n",
265 | " train_iterator = trange(int(args['num_train_epochs']), desc=\"Epoch\")\n",
266 | " \n",
267 | " for _ in train_iterator:\n",
268 | " epoch_iterator = tqdm_notebook(train_dataloader, desc=\"Iteration\")\n",
269 | " for step, batch in enumerate(epoch_iterator):\n",
270 | " model.train()\n",
271 | " batch = tuple(t.to(device) for t in batch)\n",
272 | " inputs = {'input_ids': batch[0],\n",
273 | " 'attention_mask': batch[1],\n",
274 | " 'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None, # XLM don't use segment_ids\n",
275 | " 'labels': batch[3]}\n",
276 | " outputs = model(**inputs)\n",
277 | " loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)\n",
278 | " print(\"\\r%f\" % loss, end='')\n",
279 | "\n",
280 | " if args['gradient_accumulation_steps'] > 1:\n",
281 | " loss = loss / args['gradient_accumulation_steps']\n",
282 | "\n",
283 | " if args['fp16']:\n",
284 | " with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
285 | " scaled_loss.backward()\n",
286 | " torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args['max_grad_norm'])\n",
287 | " \n",
288 | " else:\n",
289 | " loss.backward()\n",
290 | " torch.nn.utils.clip_grad_norm_(model.parameters(), args['max_grad_norm'])\n",
291 | "\n",
292 | " tr_loss += loss.item()\n",
293 | " if (step + 1) % args['gradient_accumulation_steps'] == 0:\n",
294 | " optimizer.step()\n",
295 | " scheduler.step() # Update learning rate schedule\n",
296 | " model.zero_grad()\n",
297 | " global_step += 1\n",
298 | "\n",
299 | " if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:\n",
300 | " # Log metrics\n",
301 | " if args['evaluate_during_training']: # Only evaluate when single GPU otherwise metrics may not average well\n",
302 | " results, _ = evaluate(model, tokenizer)\n",
303 | " for key, value in results.items():\n",
304 | " tb_writer.add_scalar('eval_{}'.format(key), value, global_step)\n",
305 | " tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)\n",
306 | " tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args['logging_steps'], global_step)\n",
307 | " logging_loss = tr_loss\n",
308 | "\n",
309 | " if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:\n",
310 | " # Save model checkpoint\n",
311 | " output_dir = os.path.join(args['output_dir'], 'checkpoint-{}'.format(global_step))\n",
312 | " if not os.path.exists(output_dir):\n",
313 | " os.makedirs(output_dir)\n",
314 | " model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training\n",
315 | " model_to_save.save_pretrained(output_dir)\n",
316 | " logger.info(\"Saving model checkpoint to %s\", output_dir)\n",
317 | "\n",
318 | "\n",
319 | " return global_step, tr_loss / global_step"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": null,
325 | "metadata": {},
326 | "outputs": [],
327 | "source": [
328 | "from sklearn.metrics import mean_squared_error, matthews_corrcoef, confusion_matrix\n",
329 | "from scipy.stats import pearsonr\n",
330 | "\n",
331 | "def get_mismatched(labels, preds):\n",
332 | " mismatched = labels != preds\n",
333 | " examples = processor.get_dev_examples(args['data_dir'])\n",
334 | " wrong = [i for (i, v) in zip(examples, mismatched) if v]\n",
335 | " \n",
336 | " return wrong\n",
337 | "\n",
338 | "def get_eval_report(labels, preds):\n",
339 | " mcc = matthews_corrcoef(labels, preds)\n",
340 | " tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()\n",
341 | " return {\n",
342 | " \"mcc\": mcc,\n",
343 | " \"tp\": tp,\n",
344 | " \"tn\": tn,\n",
345 | " \"fp\": fp,\n",
346 | " \"fn\": fn\n",
347 | " }, get_mismatched(labels, preds)\n",
348 | "\n",
349 | "def compute_metrics(task_name, preds, labels):\n",
350 | " assert len(preds) == len(labels)\n",
351 | " return get_eval_report(labels, preds)\n",
352 | "\n",
353 | "def evaluate(model, tokenizer, prefix=\"\"):\n",
354 | " # Loop to handle MNLI double evaluation (matched, mis-matched)\n",
355 | " eval_output_dir = args['output_dir']\n",
356 | "\n",
357 | " results = {}\n",
358 | " EVAL_TASK = args['task_name']\n",
359 | "\n",
360 | " eval_dataset = load_and_cache_examples(EVAL_TASK, tokenizer, evaluate=True)\n",
361 | " if not os.path.exists(eval_output_dir):\n",
362 | " os.makedirs(eval_output_dir)\n",
363 | "\n",
364 | "\n",
365 | " eval_sampler = SequentialSampler(eval_dataset)\n",
366 | " eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args['eval_batch_size'])\n",
367 | "\n",
368 | " # Eval!\n",
369 | " logger.info(\"***** Running evaluation {} *****\".format(prefix))\n",
370 | " logger.info(\" Num examples = %d\", len(eval_dataset))\n",
371 | " logger.info(\" Batch size = %d\", args['eval_batch_size'])\n",
372 | " eval_loss = 0.0\n",
373 | " nb_eval_steps = 0\n",
374 | " preds = None\n",
375 | " out_label_ids = None\n",
376 | " for batch in tqdm_notebook(eval_dataloader, desc=\"Evaluating\"):\n",
377 | " model.eval()\n",
378 | " batch = tuple(t.to(device) for t in batch)\n",
379 | "\n",
380 | " with torch.no_grad():\n",
381 | " inputs = {'input_ids': batch[0],\n",
382 | " 'attention_mask': batch[1],\n",
383 | " 'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None, # XLM don't use segment_ids\n",
384 | " 'labels': batch[3]}\n",
385 | " outputs = model(**inputs)\n",
386 | " tmp_eval_loss, logits = outputs[:2]\n",
387 | "\n",
388 | " eval_loss += tmp_eval_loss.mean().item()\n",
389 | " nb_eval_steps += 1\n",
390 | " if preds is None:\n",
391 | " preds = logits.detach().cpu().numpy()\n",
392 | " out_label_ids = inputs['labels'].detach().cpu().numpy()\n",
393 | " else:\n",
394 | " preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)\n",
395 | " out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)\n",
396 | "\n",
397 | " eval_loss = eval_loss / nb_eval_steps\n",
398 | " if args['output_mode'] == \"classification\":\n",
399 | " preds = np.argmax(preds, axis=1)\n",
400 | " elif args['output_mode'] == \"regression\":\n",
401 | " preds = np.squeeze(preds)\n",
402 | " result, wrong = compute_metrics(EVAL_TASK, preds, out_label_ids)\n",
403 | " results.update(result)\n",
404 | "\n",
405 | " output_eval_file = os.path.join(eval_output_dir, \"eval_results.txt\")\n",
406 | " with open(output_eval_file, \"w\") as writer:\n",
407 | " logger.info(\"***** Eval results {} *****\".format(prefix))\n",
408 | " for key in sorted(result.keys()):\n",
409 | " logger.info(\" %s = %s\", key, str(result[key]))\n",
410 | " writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n",
411 | "\n",
412 | " return results, wrong"
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "execution_count": null,
418 | "metadata": {},
419 | "outputs": [],
420 | "source": [
421 | "if args['do_train']:\n",
422 | " train_dataset = load_and_cache_examples(task, tokenizer)\n",
423 | " global_step, tr_loss = train(train_dataset, model, tokenizer)\n",
424 | " logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)"
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": null,
430 | "metadata": {},
431 | "outputs": [],
432 | "source": [
433 | "if args['do_train']:\n",
434 | " if not os.path.exists(args['output_dir']):\n",
435 | " os.makedirs(args['output_dir'])\n",
436 | " logger.info(\"Saving model checkpoint to %s\", args['output_dir'])\n",
437 | " \n",
438 | " model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training\n",
439 | " model_to_save.save_pretrained(args['output_dir'])\n",
440 | " tokenizer.save_pretrained(args['output_dir'])\n",
441 | " torch.save(args, os.path.join(args['output_dir'], 'training_args.bin')) "
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": null,
447 | "metadata": {},
448 | "outputs": [],
449 | "source": [
450 | "results = {}\n",
451 | "if args['do_eval']:\n",
452 | " checkpoints = [args['output_dir']]\n",
453 | " if args['eval_all_checkpoints']:\n",
454 | " checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + '/**/' + WEIGHTS_NAME, recursive=True)))\n",
455 | " logging.getLogger(\"pytorch_transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n",
456 | " logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n",
457 | " for checkpoint in checkpoints:\n",
458 | " global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else \"\"\n",
459 | " model = model_class.from_pretrained(checkpoint)\n",
460 | " model.to(device)\n",
461 | " result, wrong_preds = evaluate(model, tokenizer, prefix=global_step)\n",
462 | " result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())\n",
463 | " results.update(result)\n"
464 | ]
465 | },
466 | {
467 | "cell_type": "code",
468 | "execution_count": null,
469 | "metadata": {},
470 | "outputs": [],
471 | "source": [
472 | "results"
473 | ]
474 | }
475 | ],
476 | "metadata": {
477 | "kernelspec": {
478 | "display_name": "Python 3",
479 | "language": "python",
480 | "name": "python3"
481 | },
482 | "language_info": {
483 | "codemirror_mode": {
484 | "name": "ipython",
485 | "version": 3
486 | },
487 | "file_extension": ".py",
488 | "mimetype": "text/x-python",
489 | "name": "python",
490 | "nbconvert_exporter": "python",
491 | "pygments_lexer": "ipython3",
492 | "version": "3.7.3"
493 | }
494 | },
495 | "nbformat": 4,
496 | "nbformat_minor": 4
497 | }
498 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT classification fine-tuning: utilities to work with GLUE tasks """
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import csv
21 | import logging
22 | import os
23 | import sys
24 | from io import open
25 |
26 | from scipy.stats import pearsonr, spearmanr
27 | from sklearn.metrics import matthews_corrcoef, f1_score
28 |
29 | from multiprocessing import Pool, cpu_count
30 | from tqdm import tqdm
31 |
32 | logger = logging.getLogger(__name__)
33 | csv.field_size_limit(2147483647)
34 |
35 | class InputExample(object):
36 | """A single training/test example for simple sequence classification."""
37 |
38 | def __init__(self, guid, text_a, text_b=None, label=None):
39 | """Constructs a InputExample.
40 |
41 | Args:
42 | guid: Unique id for the example.
43 | text_a: string. The untokenized text of the first sequence. For single
44 | sequence tasks, only this sequence must be specified.
45 | text_b: (Optional) string. The untokenized text of the second sequence.
46 | Only must be specified for sequence pair tasks.
47 | label: (Optional) string. The label of the example. This should be
48 | specified for train and dev examples, but not for test examples.
49 | """
50 | self.guid = guid
51 | self.text_a = text_a
52 | self.text_b = text_b
53 | self.label = label
54 |
55 |
56 | class InputFeatures(object):
57 | """A single set of features of data."""
58 |
59 | def __init__(self, input_ids, input_mask, segment_ids, label_id):
60 | self.input_ids = input_ids
61 | self.input_mask = input_mask
62 | self.segment_ids = segment_ids
63 | self.label_id = label_id
64 |
65 |
66 | class DataProcessor(object):
67 | """Base class for data converters for sequence classification data sets."""
68 |
69 | def get_train_examples(self, data_dir):
70 | """Gets a collection of `InputExample`s for the train set."""
71 | raise NotImplementedError()
72 |
73 | def get_dev_examples(self, data_dir):
74 | """Gets a collection of `InputExample`s for the dev set."""
75 | raise NotImplementedError()
76 |
77 | def get_labels(self):
78 | """Gets the list of labels for this data set."""
79 | raise NotImplementedError()
80 |
81 | @classmethod
82 | def _read_tsv(cls, input_file, quotechar=None):
83 | """Reads a tab separated value file."""
84 | with open(input_file, "r", encoding="utf-8-sig") as f:
85 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
86 | lines = []
87 | for line in reader:
88 | if sys.version_info[0] == 2:
89 | line = list(unicode(cell, 'utf-8') for cell in line)
90 | lines.append(line)
91 | return lines
92 |
93 |
94 | class BinaryProcessor(DataProcessor):
95 | """Processor for the binary data sets"""
96 |
97 | def get_train_examples(self, data_dir):
98 | """See base class."""
99 | return self._create_examples(
100 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
101 |
102 | def get_dev_examples(self, data_dir):
103 | """See base class."""
104 | return self._create_examples(
105 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
106 |
107 | def get_labels(self):
108 | """See base class."""
109 | return ["0", "1"]
110 |
111 | def _create_examples(self, lines, set_type):
112 | """Creates examples for the training and dev sets."""
113 | examples = []
114 | for (i, line) in enumerate(lines):
115 | guid = "%s-%s" % (set_type, i)
116 | text_a = line[3]
117 | label = line[1]
118 | examples.append(
119 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
120 | return examples
121 |
122 | def convert_example_to_feature(example_row, pad_token=0,
123 | sequence_a_segment_id=0, sequence_b_segment_id=1,
124 | cls_token_segment_id=1, pad_token_segment_id=0,
125 | mask_padding_with_zero=True, sep_token_extra=False):
126 | example, label_map, max_seq_length, tokenizer, output_mode, cls_token_at_end, cls_token, sep_token, cls_token_segment_id, pad_on_left, pad_token_segment_id, sep_token_extra = example_row
127 |
128 | tokens_a = tokenizer.tokenize(example.text_a)
129 |
130 | tokens_b = None
131 | if example.text_b:
132 | tokens_b = tokenizer.tokenize(example.text_b)
133 | # Modifies `tokens_a` and `tokens_b` in place so that the total
134 | # length is less than the specified length.
135 | # Account for [CLS], [SEP], [SEP] with "- 3". " -4" for RoBERTa.
136 | special_tokens_count = 4 if sep_token_extra else 3
137 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
138 | else:
139 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
140 | special_tokens_count = 3 if sep_token_extra else 2
141 | if len(tokens_a) > max_seq_length - special_tokens_count:
142 | tokens_a = tokens_a[:(max_seq_length - special_tokens_count)]
143 |
144 | # The convention in BERT is:
145 | # (a) For sequence pairs:
146 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
147 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
148 | # (b) For single sequences:
149 | # tokens: [CLS] the dog is hairy . [SEP]
150 | # type_ids: 0 0 0 0 0 0 0
151 | #
152 | # Where "type_ids" are used to indicate whether this is the first
153 | # sequence or the second sequence. The embedding vectors for `type=0` and
154 | # `type=1` were learned during pre-training and are added to the wordpiece
155 | # embedding vector (and position vector). This is not *strictly* necessary
156 | # since the [SEP] token unambiguously separates the sequences, but it makes
157 | # it easier for the model to learn the concept of sequences.
158 | #
159 | # For classification tasks, the first vector (corresponding to [CLS]) is
160 | # used as as the "sentence vector". Note that this only makes sense because
161 | # the entire model is fine-tuned.
162 | tokens = tokens_a + [sep_token]
163 | segment_ids = [sequence_a_segment_id] * len(tokens)
164 |
165 | if tokens_b:
166 | tokens += tokens_b + [sep_token]
167 | segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
168 |
169 | if cls_token_at_end:
170 | tokens = tokens + [cls_token]
171 | segment_ids = segment_ids + [cls_token_segment_id]
172 | else:
173 | tokens = [cls_token] + tokens
174 | segment_ids = [cls_token_segment_id] + segment_ids
175 |
176 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
177 |
178 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
179 | # tokens are attended to.
180 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
181 |
182 |
183 | # Zero-pad up to the sequence length.
184 | padding_length = max_seq_length - len(input_ids)
185 | if pad_on_left:
186 | input_ids = ([pad_token] * padding_length) + input_ids
187 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
188 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
189 | else:
190 | input_ids = input_ids + ([pad_token] * padding_length)
191 | input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
192 | segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
193 |
194 | assert len(input_ids) == max_seq_length
195 | assert len(input_mask) == max_seq_length
196 | assert len(segment_ids) == max_seq_length
197 |
198 | if output_mode == "classification":
199 | label_id = label_map[example.label]
200 | elif output_mode == "regression":
201 | label_id = float(example.label)
202 | else:
203 | raise KeyError(output_mode)
204 |
205 | return InputFeatures(input_ids=input_ids,
206 | input_mask=input_mask,
207 | segment_ids=segment_ids,
208 | label_id=label_id)
209 |
210 |
211 | def convert_examples_to_features(examples, label_list, max_seq_length,
212 | tokenizer, output_mode,
213 | cls_token_at_end=False, sep_token_extra=False, pad_on_left=False,
214 | cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
215 | sequence_a_segment_id=0, sequence_b_segment_id=1,
216 | cls_token_segment_id=1, pad_token_segment_id=0,
217 | mask_padding_with_zero=True,
218 | process_count=cpu_count() - 2):
219 | """ Loads a data file into a list of `InputBatch`s
220 | `cls_token_at_end` define the location of the CLS token:
221 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
222 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
223 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
224 | """
225 |
226 | label_map = {label : i for i, label in enumerate(label_list)}
227 |
228 | examples = [(example, label_map, max_seq_length, tokenizer, output_mode, cls_token_at_end, cls_token, sep_token, cls_token_segment_id, pad_on_left, pad_token_segment_id, sep_token_extra) for example in examples]
229 |
230 | with Pool(process_count) as p:
231 | features = list(tqdm(p.imap(convert_example_to_feature, examples, chunksize=500), total=len(examples)))
232 |
233 | return features
234 |
235 |
236 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
237 | """Truncates a sequence pair in place to the maximum length."""
238 |
239 | # This is a simple heuristic which will always truncate the longer sequence
240 | # one token at a time. This makes more sense than truncating an equal percent
241 | # of tokens from each, since if one sequence is very short then each token
242 | # that's truncated likely contains more information than a longer sequence.
243 | while True:
244 | total_length = len(tokens_a) + len(tokens_b)
245 | if total_length <= max_length:
246 | break
247 | if len(tokens_a) > len(tokens_b):
248 | tokens_a.pop()
249 | else:
250 | tokens_b.pop()
251 |
252 |
253 | processors = {
254 | "binary": BinaryProcessor
255 | }
256 |
257 | output_modes = {
258 | "binary": "classification"
259 | }
260 |
--------------------------------------------------------------------------------