├── .gitignore
├── LICENSE
├── README.md
├── README_data
├── 2019-08-21-18-01-07.png
├── 2019-08-22-14-16-49.png
├── 2019-08-22-14-16-59.png
└── 2019-08-30-22-18-28.png
├── config
├── bert_base.json
├── eval.json
├── non-uda.json
└── uda.json
├── data.zip
├── download.sh
├── load_data.py
├── main.py
├── models.py
├── train.py
└── utils
├── __init__.py
├── checkpoint.py
├── configuration.py
├── optim.py
├── tokenization.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | BERT_Base_Uncased
2 | results_045
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UDA(Unsupervised Data Augmentation) with BERT
2 | This is re-implementation of Google's UDA [[paper]](https://arxiv.org/abs/1904.12848)[[tensorflow]](https://github.com/google-research/uda) in pytorch with Kakao Brain's Pytorchic BERT[[pytorch]](https://github.com/dhlee347/pytorchic-bert).
3 |
4 | Model | UDA official | This repository
5 | -- | -- | --
6 | UDA (X) | 68% |
7 | UDA (O) | 90% | 88.45%
8 |
9 | (Max sequence length = 128, Train batch size = 8)
10 |
11 | 
12 |
13 |
14 | ## UDA
15 | > UDA(Unsupervised Data Augmentation) is a semi-supervised learning method which achieves SOTA results on a wide variety of language and vision tasks. With only 20 labeled examples, UDA outperforms the previous SOTA on IMDb trained on 25,000 labeled examples. (BERT=4.51, UDA=4.20, error rate)
16 | 
17 | > * Unsupervised Data Augmentation for Consistency Training (2019 Google Brain, Q Xie et al.)
18 |
19 | #### - UDA with BERT
20 | UDA works as part of BERT. It means that UDA act as an assistant of BERT. So, in the picture above model **M** is BERT.
21 |
22 | #### - Loss
23 | UDA consist of supervised loss and unsupervised loss. Supervised loss is traditional Cross-entropy loss and Unsupervised loss is KL-divergence loss of original example and augmented example outputs. In this project, I used Back translation technique for augmentation.
24 | The supervised loss and unsupervised loss are added to form a total loss and then total loss is descent. To be careful is loss doesn't descent trough original example route. Only by labeled data and augmented unlabeled data Model's weights are updated.
25 |
26 | #### - TSA(Training Signal Annealing)
27 | There is a large gap between the amount of unlabeled data and that of labeled data. So, it is easy to overfit to labeled data. Therefore, TSA technique mask out the examples that predicted probability is bigger than threshold. The threshold is scheduled by log, linear or exponential function.
28 | 
29 | 
30 |
31 | #### - Sharpening Predictions
32 | The KL-divergence loss(ori, aug) is too small to just use. It can cause that the total loss is dominated by supervised loss. Therefore, Sharpening Prediction techniques is needed.
33 |
34 | - Confidence-based masking : Maksing out examples that the current model is not confident about. Specifically, in each minibatch, the consistency loss term is computed only on examples whose highest probability.
35 | - Softmax temperature controlling : Be used when computing the predictions on original example. Specifically, probability of original example is computed as Softmax(l(x)/τ) where l(x) denotes the logits and τ is the temperature. A lower temperature corresponds to a sharper distribution.
(UDA, 2019 Google Brain, Q Xie et al.)
36 |
37 | ## Requirements
38 | **UDA** : python > 3.6, fire, tqdm, tensorboardX, tensorflow, pytorch, pandas, numpy
39 |
40 | ## Overview
41 |
42 | - [`download.sh`](./download.sh) : Download pre-trained BERT model from Google's official BERT and IMDb data file
43 | - [`load_data.py`](./load_data.py) : Load the data of sup, unsup
44 | - [`models.py`](./models.py) : Model calsses for a general transformer (from Pytorchic BERT's code)
45 | - [`main.py`](./main.py) : Including default BERT, UDA(TSA, Sharpening) modes
46 | - [`train.py`](./train.py) : A custom training class(Trainer class) adopted from Pytorhchic BERT's code
47 | - ***utils***
48 | - [`configuration.py`](./utils/configuration.py) : Set a configuration from json file
49 | - [`checkpoint.py`](./utils/checkpoint.py) : Functions to load a model from tensorflow's file (from Pytorchic BERT's code)
50 | - [`optim.py`](./utils.optim.py) : Optimizer (BERTAdam class) (from Pytorchic BERT's code)
51 | - [`tokenization.py`](./utils/tokenization.py) : Tokenizers adopted from the original Google BERT's code
52 | - [`utils.py`](./utils/utils.py) : A custom utility functions adopted from Pytorchic BERT's code
53 |
54 | ## Pre-works
55 |
56 | #### - Download pre-trained BERT model and unzip IMDb data
57 | First, you have to download pre-trained BERT_base from Google's BERT repository. And unzip IMDb data
58 |
59 | bash download.sh
60 | After running, you can get the pre-trained BERT_base_Uncased model at **/BERT_Base_Uncased** director and **/data**
61 |
62 | I use already pre-processed and augmented IMDb data extracted from official [UDA](https://github.com/google-research/uda). If you want to use your raw data, change need_prepro = True.
63 |
64 | ## Example usage
65 | This project are broadly divided into two parts(Fine-tuning, Evaluation).
66 | **Caution** : **Before runing code, you have to check and edit config file**
67 |
68 | 1. **Fine-tuning**
69 |
You can choose train mode(train, train_eval) on non-uda.json or uda.json (default : train_eval).
70 | - Non UDA fine-tuning
71 |
72 | python main.py \
73 | --cfg='config/non-uda.json' \
74 | --model_cfg='config/bert_base.json'
75 |
76 | - UDA fine-tuning
77 |
78 | python main.py \
79 | --cfg='config/uda.json' \
80 | --model_cfg='config/bert_base.json'
81 |
82 | 2. **Evaluation**
83 | - Basically evaluation code, dump out results file. So, you can change dump option in [main.py](./main.py) There is two mode (real_time print, make tsv file)
84 |
85 | python main.py \
86 | --cfg='config/eval.json' \
87 | --model_cfg='config/bert_base.json'
88 |
89 |
90 | ## Acknowledgement
91 | Thanks to references of [UDA](https://github.com/google-research/uda) and [Pytorchic BERT](https://github.com/dhlee347/pytorchic-bert), I can implement this code.
92 |
93 | ## TODO
94 | 1. It is known that further training(more pre-training by the specific corpus on already pre-trained BERT) can improve performance. But, this repository does not have pretrain code. So, pretrain code will be added. If you want to further training you can use [Pytorchic BERT](https://github.com/dhlee347/pytorchic-bert) 's pretrain.py or any BERT project.
95 |
96 | 2. Korean dataset version
--------------------------------------------------------------------------------
/README_data/2019-08-21-18-01-07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/README_data/2019-08-21-18-01-07.png
--------------------------------------------------------------------------------
/README_data/2019-08-22-14-16-49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/README_data/2019-08-22-14-16-49.png
--------------------------------------------------------------------------------
/README_data/2019-08-22-14-16-59.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/README_data/2019-08-22-14-16-59.png
--------------------------------------------------------------------------------
/README_data/2019-08-30-22-18-28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/README_data/2019-08-30-22-18-28.png
--------------------------------------------------------------------------------
/config/bert_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "dim": 768,
3 | "dim_ff": 3072,
4 | "n_layers": 12,
5 | "p_drop_attn": 0.1,
6 | "n_heads": 12,
7 | "p_drop_hidden": 0.1,
8 | "max_len": 512,
9 | "n_segments": 2,
10 | "vocab_size": 30522
11 | }
--------------------------------------------------------------------------------
/config/eval.json:
--------------------------------------------------------------------------------
1 | {
2 | "mode": "eval",
3 | "max_seq_length": 128,
4 | "eval_batch_size": 16,
5 | "do_lower_case": true,
6 | "data_parallel": true,
7 | "need_prepro": false,
8 | "model_file": "results_045/save/model_steps_6250.pt",
9 | "eval_data_dir": "data/imdb_sup_test.txt",
10 | "vocab":"BERT_Base_Uncased/vocab.txt",
11 | "task": "imdb"
12 | }
13 |
--------------------------------------------------------------------------------
/config/non-uda.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 42,
3 | "lr": 2e-5,
4 | "warmup": 0.1,
5 | "do_lower_case": true,
6 | "mode": "train_eval",
7 | "uda_mode": false,
8 |
9 | "total_steps": 10000,
10 | "max_seq_length": 128,
11 | "train_batch_size": 8,
12 | "eval_batch_size": 16,
13 |
14 | "data_parallel": true,
15 | "need_prepro": false,
16 | "sup_data_dir": "data/imdb_sup_train.txt",
17 | "eval_data_dir": "data/imdb_sup_test.txt",
18 |
19 | "model_file":null,
20 | "pretrain_file":"BERT_Base_Uncased/bert_model.ckpt",
21 | "vocab":"BERT_Base_Uncased/vocab.txt",
22 | "task": "imdb",
23 |
24 | "save_steps": 100,
25 | "check_steps": 250,
26 | "results_dir": "results_non",
27 |
28 | "is_position": false
29 | }
30 |
--------------------------------------------------------------------------------
/config/uda.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 42,
3 | "lr": 2e-5,
4 | "warmup": 0.1,
5 | "do_lower_case": true,
6 | "mode": "train_eval",
7 | "uda_mode": true,
8 |
9 | "total_steps": 10000,
10 | "max_seq_length": 128,
11 | "train_batch_size": 8,
12 | "eval_batch_size": 16,
13 |
14 | "unsup_ratio": 3,
15 | "uda_coeff": 1,
16 | "tsa": "linear_schedule",
17 | "uda_softmax_temp": 0.85,
18 | "uda_confidence_thresh": 0.45,
19 |
20 | "data_parallel": true,
21 | "need_prepro": false,
22 | "sup_data_dir": "data/imdb_sup_train.txt",
23 | "unsup_data_dir": "data/imdb_unsup_train.txt",
24 | "eval_data_dir": "data/imdb_sup_test.txt",
25 |
26 | "model_file":null,
27 | "pretrain_file":"BERT_Base_Uncased/bert_model.ckpt",
28 | "vocab":"BERT_Base_Uncased/vocab.txt",
29 | "task": "imdb",
30 |
31 | "save_steps": 100,
32 | "check_steps": 250,
33 | "results_dir": "results",
34 |
35 | "is_position": false
36 | }
37 |
--------------------------------------------------------------------------------
/data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/data.zip
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google UDA Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #!/bin/bash
16 |
17 | # **** download pretrained models ****
18 | wget storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
19 | unzip uncased_L-12_H-768_A-12.zip && rm uncased_L-12_H-768_A-12.zip
20 | mv uncased_L-12_H-768_A-12 BERT_Base_Uncased
21 |
22 | # **** unzip data ****
23 | unzip data.zip && rm data.zip
--------------------------------------------------------------------------------
/load_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 SanghunYun, Korea University.
2 | # (Strongly inspired by Dong-Hyun Lee, Kakao Brain)
3 | #
4 | # This file has been modified by SanghunYun, Korea Univeristy.
5 | # Little modification at Tokenizing, AddSpecialTokensWithTruncation, TokenIndexing
6 | # and CsvDataset, load_data has been newly written.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 |
21 | import ast
22 | import csv
23 | import itertools
24 |
25 | import pandas as pd # only import when no need_to_preprocessing
26 | from tqdm import tqdm
27 |
28 | import torch
29 | from torch.utils.data import Dataset, DataLoader
30 |
31 | from utils import tokenization
32 | from utils.utils import truncate_tokens_pair
33 |
34 |
35 | class CsvDataset(Dataset):
36 | labels = None
37 | def __init__(self, file, need_prepro, pipeline, max_len, mode, d_type):
38 | Dataset.__init__(self)
39 | self.cnt = 0
40 |
41 | # need preprocessing
42 | if need_prepro:
43 | with open(file, 'r', encoding='utf-8') as f:
44 | lines = csv.reader(f, delimiter='\t', quotechar='"')
45 |
46 | # supervised dataset
47 | if d_type == 'sup':
48 | # if mode == 'eval':
49 | # sentences = []
50 | data = []
51 |
52 | for instance in self.get_sup(lines):
53 | # if mode == 'eval':
54 | # sentences.append([instance[1]])
55 | for proc in pipeline:
56 | instance = proc(instance, d_type)
57 | data.append(instance)
58 |
59 | self.tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*data)]
60 | # if mode == 'eval':
61 | # self.tensors.append(sentences)
62 |
63 | # unsupervised dataset
64 | elif d_type == 'unsup':
65 | data = {'ori':[], 'aug':[]}
66 | for ori, aug in self.get_unsup(lines):
67 | for proc in pipeline:
68 | ori = proc(ori, d_type)
69 | aug = proc(aug, d_type)
70 | self.cnt += 1
71 | # if self.cnt == 10:
72 | # break
73 | data['ori'].append(ori) # drop label_id
74 | data['aug'].append(aug) # drop label_id
75 | ori_tensor = [torch.tensor(x, dtype=torch.long) for x in zip(*data['ori'])]
76 | aug_tensor = [torch.tensor(x, dtype=torch.long) for x in zip(*data['aug'])]
77 | self.tensors = ori_tensor + aug_tensor
78 | # already preprocessed
79 | else:
80 | f = open(file, 'r', encoding='utf-8')
81 | data = pd.read_csv(f, sep='\t')
82 |
83 | # supervised dataset
84 | if d_type == 'sup':
85 | # input_ids, segment_ids(input_type_ids), input_mask, input_label
86 | input_columns = ['input_ids', 'input_type_ids', 'input_mask', 'label_ids']
87 | self.tensors = [torch.tensor(data[c].apply(lambda x: ast.literal_eval(x)), dtype=torch.long) \
88 | for c in input_columns[:-1]]
89 | self.tensors.append(torch.tensor(data[input_columns[-1]], dtype=torch.long))
90 |
91 | # unsupervised dataset
92 | elif d_type == 'unsup':
93 | input_columns = ['ori_input_ids', 'ori_input_type_ids', 'ori_input_mask',
94 | 'aug_input_ids', 'aug_input_type_ids', 'aug_input_mask']
95 | self.tensors = [torch.tensor(data[c].apply(lambda x: ast.literal_eval(x)), dtype=torch.long) \
96 | for c in input_columns]
97 |
98 | else:
99 | raise "d_type error. (d_type have to sup or unsup)"
100 |
101 | def __len__(self):
102 | return self.tensors[0].size(0)
103 |
104 | def __getitem__(self, index):
105 | return tuple(tensor[index] for tensor in self.tensors)
106 |
107 | def get_sup(self, lines):
108 | raise NotImplementedError
109 |
110 | def get_unsup(self, lines):
111 | raise NotImplementedError
112 |
113 |
114 | class Pipeline():
115 | def __init__(self):
116 | super().__init__()
117 |
118 | def __call__(self, instance):
119 | raise NotImplementedError
120 |
121 |
122 | class Tokenizing(Pipeline):
123 | def __init__(self, preprocessor, tokenize):
124 | super().__init__()
125 | self.preprocessor = preprocessor
126 | self.tokenize = tokenize
127 |
128 | def __call__(self, instance, d_type):
129 | label, text_a, text_b = instance
130 |
131 | label = self.preprocessor(label) if label else None
132 | tokens_a = self.tokenize(self.preprocessor(text_a))
133 | tokens_b = self.tokenize(self.preprocessor(text_b)) if text_b else []
134 |
135 | return (label, tokens_a, tokens_b)
136 |
137 |
138 | class AddSpecialTokensWithTruncation(Pipeline):
139 | def __init__(self, max_len=512):
140 | super().__init__()
141 | self.max_len = max_len
142 |
143 | def __call__(self, instance, d_type):
144 | label, tokens_a, tokens_b = instance
145 |
146 | # -3 special tokens for [CLS] text_a [SEP] text_b [SEP]
147 | # -2 special tokens for [CLS] text_a [SEP]
148 | _max_len = self.max_len - 3 if tokens_b else self.max_len - 2
149 | truncate_tokens_pair(tokens_a, tokens_b, _max_len)
150 |
151 | # Add Special Tokens
152 | tokens_a = ['[CLS]'] + tokens_a + ['[SEP]']
153 | tokens_b = tokens_b + ['[SEP]'] if tokens_b else []
154 |
155 | return (label, tokens_a, tokens_b)
156 |
157 |
158 | class TokenIndexing(Pipeline):
159 | def __init__(self, indexer, labels, max_len=512):
160 | super().__init__()
161 | self.indexer = indexer # function : tokens to indexes
162 | # map from a label name to a label index
163 | self.label_map = {name: i for i, name in enumerate(labels)}
164 | self.max_len = max_len
165 |
166 | def __call__(self, instance, d_type):
167 | label, tokens_a, tokens_b = instance
168 |
169 | input_ids = self.indexer(tokens_a + tokens_b)
170 | segment_ids = [0]*len(tokens_a) + [1]*len(tokens_b) # type_ids
171 | input_mask = [1]*(len(tokens_a) + len(tokens_b))
172 | label_id = self.label_map[label] if label else None
173 |
174 | # zero padding
175 | n_pad = self.max_len - len(input_ids)
176 | input_ids.extend([0]*n_pad)
177 | segment_ids.extend([0]*n_pad)
178 | input_mask.extend([0]*n_pad)
179 |
180 | if label_id != None:
181 | return (input_ids, segment_ids, input_mask, label_id)
182 | else:
183 | return (input_ids, segment_ids, input_mask)
184 |
185 |
186 | def dataset_class(task):
187 | table = {'imdb': IMDB}
188 | return table[task]
189 |
190 |
191 | class IMDB(CsvDataset):
192 | labels = ('0', '1')
193 | def __init__(self, file, need_prepro, pipeline=[], max_len=128, mode='train', d_type='sup'):
194 | super().__init__(file, need_prepro, pipeline, max_len, mode, d_type)
195 |
196 | def get_sup(self, lines):
197 | for line in itertools.islice(lines, 0, None):
198 | yield line[7], line[6], [] # label, text_a, None
199 | # yield None, line[6], []
200 |
201 | def get_unsup(self, lines):
202 | for line in itertools.islice(lines, 0, None):
203 | yield (None, line[1], []), (None, line[2], []) # ko, en
204 |
205 |
206 | class load_data:
207 | def __init__(self, cfg):
208 | self.cfg = cfg
209 |
210 | self.TaskDataset = dataset_class(cfg.task)
211 | self.pipeline = None
212 | if cfg.need_prepro:
213 | tokenizer = tokenization.FullTokenizer(vocab_file=cfg.vocab, do_lower_case=cfg.do_lower_case)
214 | self.pipeline = [Tokenizing(tokenizer.convert_to_unicode, tokenizer.tokenize),
215 | AddSpecialTokensWithTruncation(cfg.max_seq_length),
216 | TokenIndexing(tokenizer.convert_tokens_to_ids, self.TaskDataset.labels, cfg.max_seq_length)]
217 |
218 | if cfg.mode == 'train':
219 | self.sup_data_dir = cfg.sup_data_dir
220 | self.sup_batch_size = cfg.train_batch_size
221 | self.shuffle = True
222 | elif cfg.mode == 'train_eval':
223 | self.sup_data_dir = cfg.sup_data_dir
224 | self.eval_data_dir= cfg.eval_data_dir
225 | self.sup_batch_size = cfg.train_batch_size
226 | self.eval_batch_size = cfg.eval_batch_size
227 | self.shuffle = True
228 | elif cfg.mode == 'eval':
229 | self.sup_data_dir = cfg.eval_data_dir
230 | self.sup_batch_size = cfg.eval_batch_size
231 | self.shuffle = False # Not shuffel when eval mode
232 |
233 | if cfg.uda_mode: # Only uda_mode
234 | self.unsup_data_dir = cfg.unsup_data_dir
235 | self.unsup_batch_size = cfg.train_batch_size * cfg.unsup_ratio
236 |
237 | def sup_data_iter(self):
238 | sup_dataset = self.TaskDataset(self.sup_data_dir, self.cfg.need_prepro, self.pipeline, self.cfg.max_seq_length, self.cfg.mode, 'sup')
239 | sup_data_iter = DataLoader(sup_dataset, batch_size=self.sup_batch_size, shuffle=self.shuffle)
240 |
241 | return sup_data_iter
242 |
243 | def unsup_data_iter(self):
244 | unsup_dataset = self.TaskDataset(self.unsup_data_dir, self.cfg.need_prepro, self.pipeline, self.cfg.max_seq_length, self.cfg.mode, 'unsup')
245 | unsup_data_iter = DataLoader(unsup_dataset, batch_size=self.unsup_batch_size, shuffle=self.shuffle)
246 |
247 | return unsup_data_iter
248 |
249 | def eval_data_iter(self):
250 | eval_dataset = self.TaskDataset(self.eval_data_dir, self.cfg.need_prepro, self.pipeline, self.cfg.max_seq_length, 'eval', 'sup')
251 | eval_data_iter = DataLoader(eval_dataset, batch_size=self.eval_batch_size, shuffle=False)
252 |
253 | return eval_data_iter
254 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 SanghunYun, Korea University.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import fire
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 | import models
22 | import train
23 | from load_data import load_data
24 | from utils.utils import set_seeds, get_device, _get_device, torch_device_one
25 | from utils import optim, configuration
26 |
27 |
28 | # TSA
29 | def get_tsa_thresh(schedule, global_step, num_train_steps, start, end):
30 | training_progress = torch.tensor(float(global_step) / float(num_train_steps))
31 | if schedule == 'linear_schedule':
32 | threshold = training_progress
33 | elif schedule == 'exp_schedule':
34 | scale = 5
35 | threshold = torch.exp((training_progress - 1) * scale)
36 | elif schedule == 'log_schedule':
37 | scale = 5
38 | threshold = 1 - torch.exp((-training_progress) * scale)
39 | output = threshold * (end - start) + start
40 | return output.to(_get_device())
41 |
42 |
43 | def main(cfg, model_cfg):
44 | # Load Configuration
45 | cfg = configuration.params.from_json(cfg) # Train or Eval cfg
46 | model_cfg = configuration.model.from_json(model_cfg) # BERT_cfg
47 | set_seeds(cfg.seed)
48 |
49 | # Load Data & Create Criterion
50 | data = load_data(cfg)
51 | if cfg.uda_mode:
52 | unsup_criterion = nn.KLDivLoss(reduction='none')
53 | data_iter = [data.sup_data_iter(), data.unsup_data_iter()] if cfg.mode=='train' \
54 | else [data.sup_data_iter(), data.unsup_data_iter(), data.eval_data_iter()] # train_eval
55 | else:
56 | data_iter = [data.sup_data_iter()]
57 | sup_criterion = nn.CrossEntropyLoss(reduction='none')
58 |
59 | # Load Model
60 | model = models.Classifier(model_cfg, len(data.TaskDataset.labels))
61 |
62 | # Create trainer
63 | trainer = train.Trainer(cfg, model, data_iter, optim.optim4GPU(cfg, model), get_device())
64 |
65 | # Training
66 | def get_loss(model, sup_batch, unsup_batch, global_step):
67 |
68 | # logits -> prob(softmax) -> log_prob(log_softmax)
69 |
70 | # batch
71 | input_ids, segment_ids, input_mask, label_ids = sup_batch
72 | if unsup_batch:
73 | ori_input_ids, ori_segment_ids, ori_input_mask, \
74 | aug_input_ids, aug_segment_ids, aug_input_mask = unsup_batch
75 |
76 | input_ids = torch.cat((input_ids, aug_input_ids), dim=0)
77 | segment_ids = torch.cat((segment_ids, aug_segment_ids), dim=0)
78 | input_mask = torch.cat((input_mask, aug_input_mask), dim=0)
79 |
80 | # logits
81 | logits = model(input_ids, segment_ids, input_mask)
82 |
83 | # sup loss
84 | sup_size = label_ids.shape[0]
85 | sup_loss = sup_criterion(logits[:sup_size], label_ids) # shape : train_batch_size
86 | if cfg.tsa:
87 | tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1)
88 | larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh # prob = exp(log_prob), prob > tsa_threshold
89 | # larger_than_threshold = torch.sum( F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids] , dim=-1) > tsa_threshold
90 | loss_mask = torch.ones_like(label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32))
91 | sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one())
92 | else:
93 | sup_loss = torch.mean(sup_loss)
94 |
95 | # unsup loss
96 | if unsup_batch:
97 | # ori
98 | with torch.no_grad():
99 | ori_logits = model(ori_input_ids, ori_segment_ids, ori_input_mask)
100 | ori_prob = F.softmax(ori_logits, dim=-1) # KLdiv target
101 | # ori_log_prob = F.log_softmax(ori_logits, dim=-1)
102 |
103 | # confidence-based masking
104 | if cfg.uda_confidence_thresh != -1:
105 | unsup_loss_mask = torch.max(ori_prob, dim=-1)[0] > cfg.uda_confidence_thresh
106 | unsup_loss_mask = unsup_loss_mask.type(torch.float32)
107 | else:
108 | unsup_loss_mask = torch.ones(len(logits) - sup_size, dtype=torch.float32)
109 | unsup_loss_mask = unsup_loss_mask.to(_get_device())
110 |
111 | # aug
112 | # softmax temperature controlling
113 | uda_softmax_temp = cfg.uda_softmax_temp if cfg.uda_softmax_temp > 0 else 1.
114 | aug_log_prob = F.log_softmax(logits[sup_size:] / uda_softmax_temp, dim=-1)
115 |
116 | # KLdiv loss
117 | """
118 | nn.KLDivLoss (kl_div)
119 | input : log_prob (log_softmax)
120 | target : prob (softmax)
121 | https://pytorch.org/docs/stable/nn.html
122 |
123 | unsup_loss is divied by number of unsup_loss_mask
124 | it is different from the google UDA official
125 | The official unsup_loss is divided by total
126 | https://github.com/google-research/uda/blob/master/text/uda.py#L175
127 | """
128 | unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1)
129 | unsup_loss = torch.sum(unsup_loss * unsup_loss_mask, dim=-1) / torch.max(torch.sum(unsup_loss_mask, dim=-1), torch_device_one())
130 | final_loss = sup_loss + cfg.uda_coeff*unsup_loss
131 |
132 | return final_loss, sup_loss, unsup_loss
133 | return sup_loss, None, None
134 |
135 | # evaluation
136 | def get_acc(model, batch):
137 | # input_ids, segment_ids, input_mask, label_id, sentence = batch
138 | input_ids, segment_ids, input_mask, label_id = batch
139 | logits = model(input_ids, segment_ids, input_mask)
140 | _, label_pred = logits.max(1)
141 |
142 | result = (label_pred == label_id).float()
143 | accuracy = result.mean()
144 | # output_dump.logs(sentence, label_pred, label_id) # output dump
145 |
146 | return accuracy, result
147 |
148 | if cfg.mode == 'train':
149 | trainer.train(get_loss, None, cfg.model_file, cfg.pretrain_file)
150 |
151 | if cfg.mode == 'train_eval':
152 | trainer.train(get_loss, get_acc, cfg.model_file, cfg.pretrain_file)
153 |
154 | if cfg.mode == 'eval':
155 | results = trainer.eval(get_acc, cfg.model_file, None)
156 | total_accuracy = torch.cat(results).mean().item()
157 | print('Accuracy :' , total_accuracy)
158 |
159 |
160 | if __name__ == '__main__':
161 | fire.Fire(main)
162 | #main('config/uda.json', 'config/bert_base.json')
163 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain.
2 | # (Strongly inspired by original Google BERT code and Hugging Face's code)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | """ Transformer Model Classes & Config Class """
18 |
19 | import math
20 | import json
21 | from typing import NamedTuple
22 |
23 | import numpy as np
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 | from utils.utils import split_last, merge_last
29 |
30 |
31 | class Config(NamedTuple):
32 | "Configuration for BERT model"
33 | vocab_size: int = None # Size of Vocabulary
34 | dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder
35 | n_layers: int = 12 # Numher of Hidden Layers
36 | n_heads: int = 12 # Numher of Heads in Multi-Headed Attention Layers
37 | dim_ff: int = 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net
38 | #activ_fn: str = "gelu" # Non-linear Activation Function Type in Hidden Layers
39 | p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers
40 | p_drop_attn: float = 0.1 # Probability of Dropout of Attention Layers
41 | max_len: int = 512 # Maximum Length for Positional Embeddings
42 | n_segments: int = 2 # Number of Sentence Segments
43 |
44 | @classmethod
45 | def from_json(cls, file):
46 | return cls(**json.load(open(file, "r")))
47 |
48 |
49 | def gelu(x):
50 | "Implementation of the gelu activation function by Hugging Face"
51 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
52 |
53 |
54 | class LayerNorm(nn.Module):
55 | "A layernorm module in the TF style (epsilon inside the square root)."
56 | def __init__(self, cfg, variance_epsilon=1e-12):
57 | super().__init__()
58 | self.gamma = nn.Parameter(torch.ones(cfg.dim))
59 | self.beta = nn.Parameter(torch.zeros(cfg.dim))
60 | self.variance_epsilon = variance_epsilon
61 |
62 | def forward(self, x):
63 | u = x.mean(-1, keepdim=True)
64 | s = (x - u).pow(2).mean(-1, keepdim=True)
65 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
66 | return self.gamma * x + self.beta
67 |
68 |
69 | class Embeddings(nn.Module):
70 | "The embedding module from word, position and token_type embeddings."
71 | def __init__(self, cfg):
72 | super().__init__()
73 | self.tok_embed = nn.Embedding(cfg.vocab_size, cfg.dim) # token embedding
74 | self.pos_embed = nn.Embedding(cfg.max_len, cfg.dim) # position embedding
75 | self.seg_embed = nn.Embedding(cfg.n_segments, cfg.dim) # segment(token type) embedding
76 |
77 | self.norm = LayerNorm(cfg)
78 | self.drop = nn.Dropout(cfg.p_drop_hidden)
79 |
80 | def forward(self, x, seg):
81 | seq_len = x.size(1)
82 | pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
83 | pos = pos.unsqueeze(0).expand_as(x) # (S,) -> (1, S) -> (B, S) 이렇게 외부에서 생성되는 값
84 |
85 | e = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
86 | return self.drop(self.norm(e))
87 |
88 |
89 | class MultiHeadedSelfAttention(nn.Module):
90 | """ Multi-Headed Dot Product Attention """
91 | def __init__(self, cfg):
92 | super().__init__()
93 | self.proj_q = nn.Linear(cfg.dim, cfg.dim)
94 | self.proj_k = nn.Linear(cfg.dim, cfg.dim)
95 | self.proj_v = nn.Linear(cfg.dim, cfg.dim)
96 | self.drop = nn.Dropout(cfg.p_drop_attn)
97 | self.scores = None # for visualization
98 | self.n_heads = cfg.n_heads
99 |
100 | def forward(self, x, mask):
101 | """
102 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
103 | mask : (B(batch_size) x S(seq_len))
104 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
105 | """
106 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
107 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
108 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2)
109 | for x in [q, k, v])
110 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
111 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
112 | if mask is not None:
113 | mask = mask[:, None, None, :].float()
114 | scores -= 10000.0 * (1.0 - mask)
115 | scores = self.drop(F.softmax(scores, dim=-1))
116 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
117 | h = (scores @ v).transpose(1, 2).contiguous()
118 | # -merge-> (B, S, D)
119 | h = merge_last(h, 2)
120 | self.scores = scores
121 | return h
122 |
123 |
124 | class PositionWiseFeedForward(nn.Module):
125 | """ FeedForward Neural Networks for each position """
126 | def __init__(self, cfg):
127 | super().__init__()
128 | self.fc1 = nn.Linear(cfg.dim, cfg.dim_ff)
129 | self.fc2 = nn.Linear(cfg.dim_ff, cfg.dim)
130 | #self.activ = lambda x: activ_fn(cfg.activ_fn, x)
131 |
132 | def forward(self, x):
133 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D)
134 | return self.fc2(gelu(self.fc1(x)))
135 |
136 |
137 | class Block(nn.Module):
138 | """ Transformer Block """
139 | def __init__(self, cfg):
140 | super().__init__()
141 | self.attn = MultiHeadedSelfAttention(cfg)
142 | self.proj = nn.Linear(cfg.dim, cfg.dim)
143 | self.norm1 = LayerNorm(cfg)
144 | self.pwff = PositionWiseFeedForward(cfg)
145 | self.norm2 = LayerNorm(cfg)
146 | self.drop = nn.Dropout(cfg.p_drop_hidden)
147 |
148 | def forward(self, x, mask):
149 | h = self.attn(x, mask)
150 | h = self.norm1(x + self.drop(self.proj(h)))
151 | h = self.norm2(h + self.drop(self.pwff(h)))
152 | return h
153 |
154 |
155 | class Transformer(nn.Module):
156 | """ Transformer with Self-Attentive Blocks"""
157 | def __init__(self, cfg):
158 | super().__init__()
159 | self.embed = Embeddings(cfg)
160 | self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) # h 번 반복
161 |
162 | def forward(self, x, seg, mask):
163 | h = self.embed(x, seg)
164 | for block in self.blocks:
165 | h = block(h, mask)
166 | return h
167 |
168 |
169 | class Classifier(nn.Module):
170 | """ Classifier with Transformer """
171 | def __init__(self, cfg, n_labels):
172 | super().__init__()
173 | self.transformer = Transformer(cfg)
174 | self.fc = nn.Linear(cfg.dim, cfg.dim)
175 | self.activ = nn.Tanh()
176 | self.drop = nn.Dropout(cfg.p_drop_hidden)
177 | self.classifier = nn.Linear(cfg.dim, n_labels)
178 |
179 | def forward(self, input_ids, segment_ids, input_mask):
180 | h = self.transformer(input_ids, segment_ids, input_mask)
181 | # only use the first h in the sequence
182 | pooled_h = self.activ(self.fc(h[:, 0])) # 맨앞의 [CLS]만 뽑아내기
183 | logits = self.classifier(self.drop(pooled_h))
184 | return logits
185 |
186 | class Opinion_extract(nn.Module):
187 | """ Opinion_extraction """
188 | def __init__(self, cfg, max_len, n_labels):
189 | super().__init__()
190 | self.transformer = Transformer(cfg)
191 | self.fc = nn.Linear(cfg.dim, cfg.dim)
192 | self.activ = nn.Tanh()
193 | self.drop = nn.Dropout(cfg.p_drop_hidden)
194 | self.extract = nn.Linear(cfg.dim, n_labels)
195 | self.sigmoid = nn.Sigmoid()
196 |
197 | def forward(self, input_ids, segment_ids, input_mask):
198 | h = self.transformer(input_ids, segment_ids, input_mask)
199 | # 전체 시퀀스 길이 만큼 뽑아내기
200 | h = self.drop(self.activ(self.fc(h[:, 1:-1])))
201 | seq_h = self.extract(h)
202 | seq_h = seq_h.squeeze()
203 | return self.sigmoid(seq_h)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 SanghunYun, Korea University.
2 | # (Strongly inspired by Dong-Hyun Lee, Kakao Brain)
3 | #
4 | # Except load and save function, the whole codes of file has been modified and added by
5 | # SanghunYun, Korea University for UDA.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | import os
20 | import json
21 | from copy import deepcopy
22 | from typing import NamedTuple
23 | from tqdm import tqdm
24 |
25 | import torch
26 | import torch.nn as nn
27 |
28 | from utils import checkpoint
29 | # from utils.logger import Logger
30 | from tensorboardX import SummaryWriter
31 | from utils.utils import output_logging
32 |
33 |
34 | class Trainer(object):
35 | """Training Helper class"""
36 | def __init__(self, cfg, model, data_iter, optimizer, device):
37 | self.cfg = cfg
38 | self.model = model
39 | self.optimizer = optimizer
40 | self.device = device
41 |
42 | # data iter
43 | if len(data_iter) == 1:
44 | self.sup_iter = data_iter[0]
45 | elif len(data_iter) == 2:
46 | self.sup_iter = self.repeat_dataloader(data_iter[0])
47 | self.unsup_iter = self.repeat_dataloader(data_iter[1])
48 | elif len(data_iter) == 3:
49 | self.sup_iter = self.repeat_dataloader(data_iter[0])
50 | self.unsup_iter = self.repeat_dataloader(data_iter[1])
51 | self.eval_iter = data_iter[2]
52 |
53 | def train(self, get_loss, get_acc, model_file, pretrain_file):
54 | """ train uda"""
55 |
56 | # tensorboardX logging
57 | if self.cfg.results_dir:
58 | logger = SummaryWriter(log_dir=os.path.join(self.cfg.results_dir, 'logs'))
59 |
60 | self.model.train()
61 | self.load(model_file, pretrain_file) # between model_file and pretrain_file, only one model will be loaded
62 | model = self.model.to(self.device)
63 | if self.cfg.data_parallel: # Parallel GPU mode
64 | model = nn.DataParallel(model)
65 |
66 | global_step = 0
67 | loss_sum = 0.
68 | max_acc = [0., 0] # acc, step
69 |
70 | # Progress bar is set by unsup or sup data
71 | # uda_mode == True --> sup_iter is repeated
72 | # uda_mode == False --> sup_iter is not repeated
73 | iter_bar = tqdm(self.unsup_iter, total=self.cfg.total_steps) if self.cfg.uda_mode \
74 | else tqdm(self.sup_iter, total=self.cfg.total_steps)
75 | for i, batch in enumerate(iter_bar):
76 |
77 | # Device assignment
78 | if self.cfg.uda_mode:
79 | sup_batch = [t.to(self.device) for t in next(self.sup_iter)]
80 | unsup_batch = [t.to(self.device) for t in batch]
81 | else:
82 | sup_batch = [t.to(self.device) for t in batch]
83 | unsup_batch = None
84 |
85 | # update
86 | self.optimizer.zero_grad()
87 | final_loss, sup_loss, unsup_loss = get_loss(model, sup_batch, unsup_batch, global_step)
88 | final_loss.backward()
89 | self.optimizer.step()
90 |
91 | # print loss
92 | global_step += 1
93 | loss_sum += final_loss.item()
94 | if self.cfg.uda_mode:
95 | iter_bar.set_description('final=%5.3f unsup=%5.3f sup=%5.3f'\
96 | % (final_loss.item(), unsup_loss.item(), sup_loss.item()))
97 | else:
98 | iter_bar.set_description('loss=%5.3f' % (final_loss.item()))
99 |
100 | # logging
101 | if self.cfg.uda_mode:
102 | logger.add_scalars('data/scalar_group',
103 | {'final_loss': final_loss.item(),
104 | 'sup_loss': sup_loss.item(),
105 | 'unsup_loss': unsup_loss.item(),
106 | 'lr': self.optimizer.get_lr()[0]
107 | }, global_step)
108 | else:
109 | logger.add_scalars('data/scalar_group',
110 | {'sup_loss': final_loss.item()}, global_step)
111 |
112 | if global_step % self.cfg.save_steps == 0:
113 | self.save(global_step)
114 |
115 | if get_acc and global_step % self.cfg.check_steps == 0 and global_step > 4999:
116 | results = self.eval(get_acc, None, model)
117 | total_accuracy = torch.cat(results).mean().item()
118 | logger.add_scalars('data/scalar_group', {'eval_acc' : total_accuracy}, global_step)
119 | if max_acc[0] < total_accuracy:
120 | self.save(global_step)
121 | max_acc = total_accuracy, global_step
122 | print('Accuracy : %5.3f' % total_accuracy)
123 | print('Max Accuracy : %5.3f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[1], global_step), end='\n\n')
124 |
125 | if self.cfg.total_steps and self.cfg.total_steps < global_step:
126 | print('The total steps have been reached')
127 | print('Average Loss %5.3f' % (loss_sum/(i+1)))
128 | if get_acc:
129 | results = self.eval(get_acc, None, model)
130 | total_accuracy = torch.cat(results).mean().item()
131 | logger.add_scalars('data/scalar_group', {'eval_acc' : total_accuracy}, global_step)
132 | if max_acc[0] < total_accuracy:
133 | max_acc = total_accuracy, global_step
134 | print('Accuracy :', total_accuracy)
135 | print('Max Accuracy : %5.3f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[1], global_step), end='\n\n')
136 | self.save(global_step)
137 | return
138 | return global_step
139 |
140 | def eval(self, evaluate, model_file, model):
141 | """ evaluation function """
142 | if model_file:
143 | self.model.eval()
144 | self.load(model_file, None)
145 | model = self.model.to(self.device)
146 | if self.cfg.data_parallel:
147 | model = nn.DataParallel(model)
148 |
149 | results = []
150 | iter_bar = tqdm(self.sup_iter) if model_file \
151 | else tqdm(deepcopy(self.eval_iter))
152 | for batch in iter_bar:
153 | batch = [t.to(self.device) for t in batch]
154 |
155 | with torch.no_grad():
156 | accuracy, result = evaluate(model, batch)
157 | results.append(result)
158 |
159 | iter_bar.set_description('Eval Acc=%5.3f' % accuracy)
160 | return results
161 |
162 | def load(self, model_file, pretrain_file):
163 | """ between model_file and pretrain_file, only one model will be loaded """
164 | if model_file:
165 | print('Loading the model from', model_file)
166 | if torch.cuda.is_available():
167 | self.model.load_state_dict(torch.load(model_file))
168 | else:
169 | self.model.load_state_dict(torch.load(model_file, map_location='cpu'))
170 |
171 | elif pretrain_file:
172 | print('Loading the pretrained model from', pretrain_file)
173 | if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
174 | checkpoint.load_model(self.model.transformer, pretrain_file)
175 | elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
176 | self.model.transformer.load_state_dict(
177 | {key[12:]: value
178 | for key, value in torch.load(pretrain_file).items()
179 | if key.startswith('transformer')}
180 | ) # load only transformer parts
181 |
182 | def save(self, i):
183 | """ save model """
184 | if not os.path.isdir(os.path.join(self.cfg.results_dir, 'save')):
185 | os.makedirs(os.path.join(self.cfg.results_dir, 'save'))
186 | torch.save(self.model.state_dict(),
187 | os.path.join(self.cfg.results_dir, 'save', 'model_steps_'+str(i)+'.pt'))
188 |
189 | def repeat_dataloader(self, iterable):
190 | """ repeat dataloader """
191 | while True:
192 | for x in iterable:
193 | yield x
194 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/utils/__init__.py
--------------------------------------------------------------------------------
/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import tensorflow as tf
17 | import torch
18 |
19 | def load_param(checkpoint_file, conversion_table):
20 | """
21 | Load parameters in pytorch model from checkpoint file according to conversion_table
22 | checkpoint_file : pretrained checkpoint model file in tensorflow
23 | cnoversion_table : { pytorch tensor in a model : checkpoint variable name }
24 | """
25 |
26 | for pyt_param, tf_param_name in conversion_table.items():
27 | tf_param = tf.train.load_variable(checkpoint_file, tf_param_name)
28 |
29 | # for weight(kernel), we should do transpose --> pytorch, tensorflow 다름
30 | if tf_param_name.endswith('kernel'):
31 | tf_param = np.transpose(tf_param)
32 |
33 | assert pyt_param.size() == tf_param.shape, \
34 | 'Dim Mismatch: %s vs %s ; %s' % (tuple(pyt_param.size()), tf_param.shape, tf_param_name)
35 |
36 | # assign pytorch tensor from tensorflow param
37 | pyt_param.data = torch.from_numpy(tf_param)
38 |
39 |
40 | def load_model(model, checkpoint_file):
41 | """Load the pytorch model from checkpoint file"""
42 |
43 | # Embedding layer
44 | e, p = model.embed, 'bert/embeddings/'
45 | load_param(checkpoint_file, {
46 | e.tok_embed.weight: p+'word_embeddings',
47 | e.pos_embed.weight: p+'position_embeddings',
48 | e.seg_embed.weight: p+'token_type_embeddings',
49 | e.norm.gamma: p+'LayerNorm/gamma',
50 | e.norm.beta: p+'LayerNorm/beta'
51 | })
52 |
53 | # Transformer blocks
54 | for i in range(len(model.blocks)):
55 | b, p = model.blocks[i], "bert/encoder/layer_%d/"%i
56 | load_param(checkpoint_file, {
57 | b.attn.proj_q.weight: p+"attention/self/query/kernel",
58 | b.attn.proj_q.bias: p+"attention/self/query/bias",
59 | b.attn.proj_k.weight: p+"attention/self/key/kernel",
60 | b.attn.proj_k.bias: p+"attention/self/key/bias",
61 | b.attn.proj_v.weight: p+"attention/self/value/kernel",
62 | b.attn.proj_v.bias: p+"attention/self/value/bias",
63 | b.proj.weight: p+"attention/output/dense/kernel",
64 | b.proj.bias: p+"attention/output/dense/bias",
65 | b.pwff.fc1.weight: p+"intermediate/dense/kernel",
66 | b.pwff.fc1.bias: p+"intermediate/dense/bias",
67 | b.pwff.fc2.weight: p+"output/dense/kernel",
68 | b.pwff.fc2.bias: p+"output/dense/bias",
69 | b.norm1.gamma: p+"attention/output/LayerNorm/gamma",
70 | b.norm1.beta: p+"attention/output/LayerNorm/beta",
71 | b.norm2.gamma: p+"output/LayerNorm/gamma",
72 | b.norm2.beta: p+"output/LayerNorm/beta",
73 | })
74 |
75 |
--------------------------------------------------------------------------------
/utils/configuration.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 SanghunYun, Korea University.
2 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain.
3 | # (Strongly inspired by original Google BERT code and Hugging Face's code)
4 | #
5 | # SanghunYun, Korea University refered Dong-Hyun Lee, Kakao Brain's code (class model)
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | import json
21 | from typing import NamedTuple
22 |
23 |
24 | class params(NamedTuple):
25 |
26 | ############################ guide #############################
27 | # lr(learning rate) : fine_tune(2e-5), futher-train(1.5e-4~2e-5)
28 | # mode : train, eval, test
29 | # uda_mode : True, False
30 | # total_steps : n_epochs * n_examples / 3
31 | # max_seq_length : 128, 256, 512
32 | # unsup_ratio : more than 3
33 | # uda_softmax_temp : more than 0.5
34 | # uda_confidence_temp : ??
35 | # tsa : linear_schedule
36 | ################################################################
37 |
38 | # train
39 | seed: int = 1421
40 | lr: int = 2e-5 # lr_scheduled = lr * factor
41 | # n_epochs: int = 3
42 | warmup: float = 0.1 # warmup steps = total_steps * warmup
43 | do_lower_case: bool = True
44 | mode: str = None # train, eval, test
45 | uda_mode: bool = False # True, False
46 |
47 | total_steps: int = 100000 # total_steps >= n_epcohs * n_examples / 3
48 | max_seq_length: int = 128
49 | train_batch_size: int = 32
50 | eval_batch_size: int = 8
51 |
52 | # unsup
53 | unsup_ratio: int = 0 # unsup_batch_size = unsup_ratio * sup_batch_size
54 | uda_coeff: int = 1 # total_loss = sup_loss + uda_coeff*unsup_loss
55 | tsa: str = 'linear_schedule' # log, linear, exp
56 | uda_softmax_temp: float = -1 # 0 ~ 1
57 | uda_confidence_thresh: float = -1 # 0 ~ 1
58 |
59 | # data
60 | data_parallel: bool = True
61 | need_prepro: bool = False # is data already preprocessed?
62 | sup_data_dir: str = None
63 | unsup_data_dir: str = None
64 | eval_data_dir: str = None
65 | n_sup: int = None
66 | n_unsup: int = None
67 |
68 | model_file: str = None # fine-tuned
69 | pretrain_file: str = None # pre-trained
70 | vocab: str = None
71 | task: str = None
72 |
73 | # results
74 | save_steps: int = 100
75 | check_steps: int = 10
76 | results_dir: str = None
77 |
78 | # appendix
79 | is_position: bool = False # appendix not used
80 |
81 | @classmethod
82 | def from_json(cls, file):
83 | return cls(**json.load(open(file, 'r')))
84 |
85 |
86 | class pretrain(NamedTuple):
87 | seed: int = 3232
88 | batch_size: int = 32
89 | lr: int = 1.5e-4
90 | n_epochs: int = 100
91 | warmup: float = 0.1
92 | save_steps: int = 100
93 | total_steps: int = 100000
94 | results_dir : str = None
95 |
96 | # do not change
97 | uda_mode: bool = False
98 |
99 | @classmethod
100 | def from_json(cls, file):
101 | return cls(**json.load(open(file, 'r')))
102 |
103 |
104 |
105 | class model(NamedTuple):
106 | "Configuration for BERT model"
107 | vocab_size: int = None # Size of Vocabulary
108 | dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder
109 | n_layers: int = 12 # Numher of Hidden Layers
110 | n_heads: int = 12 # Numher of Heads in Multi-Headed Attention Layers
111 | dim_ff: int = 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net
112 | #activ_fn: str = "gelu" # Non-linear Activation Function Type in Hidden Layers
113 | p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers
114 | p_drop_attn: float = 0.1 # Probability of Dropout of Attention Layers
115 | max_len: int = 512 # Maximum Length for Positional Embeddings
116 | n_segments: int = 2 # Number of Sentence Segments
117 |
118 | @classmethod
119 | def from_json(cls, file):
120 | return cls(**json.load(open(file, 'r')))
--------------------------------------------------------------------------------
/utils/optim.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
2 | # and Dong-Hyun Lee, Kakao Brain.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | """ a slightly modified version of Hugging Face's BERTAdam class """
18 |
19 | import math
20 | import torch
21 | from torch.optim import Optimizer
22 | from torch.nn.utils import clip_grad_norm_
23 |
24 | def warmup_cosine(x, warmup=0.002):
25 | if x < warmup:
26 | return x/warmup
27 | return 0.5 * (1.0 + torch.cos(math.pi * x))
28 |
29 | def warmup_constant(x, warmup=0.002):
30 | if x < warmup:
31 | return x/warmup
32 | return 1.0
33 |
34 | def warmup_linear(x, warmup=0.002):
35 | if x < warmup:
36 | return x/warmup
37 | return 1.0 - x
38 |
39 | SCHEDULES = {
40 | 'warmup_cosine':warmup_cosine,
41 | 'warmup_constant':warmup_constant,
42 | 'warmup_linear':warmup_linear,
43 | }
44 |
45 | class BertAdam(Optimizer):
46 | """Implements BERT version of Adam algorithm with weight decay fix.
47 | Params:
48 | lr: learning rate
49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
50 | t_total: total number of training steps for the learning
51 | rate schedule, -1 means constant learning rate. Default: -1
52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
53 | b1: Adams b1. Default: 0.9
54 | b2: Adams b2. Default: 0.999
55 | e: Adams epsilon. Default: 1e-6
56 | weight_decay_rate: Weight decay. Default: 0.01
57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
58 | """
59 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
60 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
61 | max_grad_norm=1.0):
62 | assert lr > 0.0, "Learning rate: %f - should be > 0.0" % (lr)
63 | assert schedule in SCHEDULES, "Invalid schedule : %s" % (schedule)
64 | assert 0.0 <= warmup < 1.0 or warmup == -1.0, \
65 | "Warmup %f - should be in 0.0 ~ 1.0 or -1 (no warm up)" % (warmup)
66 | assert 0.0 <= b1 < 1.0, "b1: %f - should be in 0.0 ~ 1.0" % (b1)
67 | assert 0.0 <= b2 < 1.0, "b2: %f - should be in 0.0 ~ 1.0" % (b2)
68 | assert e > 0.0, "epsilon: %f - should be > 0.0" % (e)
69 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
70 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
71 | max_grad_norm=max_grad_norm)
72 | super(BertAdam, self).__init__(params, defaults)
73 |
74 | def get_lr(self):
75 | """ get learning rate in training """
76 | lr = []
77 | for group in self.param_groups:
78 | for p in group['params']:
79 | state = self.state[p]
80 | if not state:
81 | return [0]
82 | if group['t_total'] != -1:
83 | schedule_fct = SCHEDULES[group['schedule']]
84 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
85 | else:
86 | lr_scheduled = group['lr']
87 | lr.append(lr_scheduled)
88 | return lr
89 |
90 | def step(self, closure=None):
91 | """Performs a single optimization step.
92 |
93 | Arguments:
94 | closure (callable, optional): A closure that reevaluates the model
95 | and returns the loss.
96 | """
97 | loss = None
98 | if closure is not None:
99 | loss = closure()
100 |
101 | for group in self.param_groups:
102 | for p in group['params']:
103 | if p.grad is None:
104 | continue
105 | grad = p.grad.data
106 | if grad.is_sparse:
107 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
108 |
109 | state = self.state[p]
110 |
111 | # State initialization
112 | if not state:
113 | state['step'] = 0
114 | # Exponential moving average of gradient values
115 | state['next_m'] = torch.zeros_like(p.data)
116 | # Exponential moving average of squared gradient values
117 | state['next_v'] = torch.zeros_like(p.data)
118 |
119 | next_m, next_v = state['next_m'], state['next_v']
120 | beta1, beta2 = group['b1'], group['b2']
121 |
122 | # Add grad clipping
123 | if group['max_grad_norm'] > 0:
124 | clip_grad_norm_(p, group['max_grad_norm'])
125 |
126 | # Decay the first and second moment running average coefficient
127 | # In-place operations to update the averages at the same time
128 | next_m.mul_(beta1).add_(1 - beta1, grad)
129 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
130 | update = next_m / (next_v.sqrt() + group['e'])
131 |
132 | # Just adding the square of the weights to the loss function is *not*
133 | # the correct way of using L2 regularization/weight decay with Adam,
134 | # since that will interact with the m and v parameters in strange ways.
135 | #
136 | # Instead we want to decay the weights in a manner that doesn't interact
137 | # with the m/v parameters. This is equivalent to adding the square
138 | # of the weights to the loss with plain (non-momentum) SGD.
139 | if group['weight_decay_rate'] > 0.0:
140 | update += group['weight_decay_rate'] * p.data
141 |
142 | if group['t_total'] != -1:
143 | schedule_fct = SCHEDULES[group['schedule']]
144 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
145 | else:
146 | lr_scheduled = group['lr']
147 |
148 | update_with_lr = lr_scheduled * update
149 | p.data.add_(-update_with_lr)
150 |
151 | state['step'] += 1
152 |
153 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
154 | # No bias correction
155 | # bias_correction1 = 1 - beta1 ** state['step']
156 | # bias_correction2 = 1 - beta2 ** state['step']
157 |
158 | return loss
159 |
160 |
161 |
162 | def optim4GPU(cfg, model):
163 | """ optimizer for GPU training """
164 | param_optimizer = list(model.named_parameters())
165 | no_decay = ['bias', 'gamma', 'beta']
166 | optimizer_grouped_parameters = [
167 | {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
168 | {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}]
169 | return BertAdam(optimizer_grouped_parameters,
170 | lr=cfg.lr,
171 | warmup=cfg.warmup,
172 | t_total=cfg.total_steps)
173 |
--------------------------------------------------------------------------------
/utils/tokenization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | """ Tokenization classes (It's exactly the same code as Google BERT code """
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 | import unicodedata
24 | import re
25 | import six
26 |
27 |
28 | def convert_to_unicode(text):
29 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
30 | if six.PY3:
31 | if isinstance(text, str):
32 | return text
33 | elif isinstance(text, bytes):
34 | return text.decode("utf-8", "ignore")
35 | else:
36 | raise ValueError("Unsupported string type: %s" % (type(text)))
37 | elif six.PY2:
38 | if isinstance(text, str):
39 | return text.decode("utf-8", "ignore")
40 | elif isinstance(text, unicode):
41 | return text
42 | else:
43 | raise ValueError("Unsupported string type: %s" % (type(text)))
44 | else:
45 | raise ValueError("Not running on Python2 or Python 3?")
46 |
47 |
48 | def printable_text(text):
49 | """Returns text encoded in a way suitable for print or `tf.logging`."""
50 |
51 | # These functions want `str` for both Python2 and Python3, but in one case
52 | # it's a Unicode string and in the other it's a byte string.
53 | if six.PY3:
54 | if isinstance(text, str):
55 | return text
56 | elif isinstance(text, bytes):
57 | return text.decode("utf-8", "ignore")
58 | else:
59 | raise ValueError("Unsupported string type: %s" % (type(text)))
60 | elif six.PY2:
61 | if isinstance(text, str):
62 | return text
63 |
64 | elif isinstance(text, unicode):
65 | return text.encode("utf-8")
66 | else:
67 | raise ValueError("Unsupported string type: %s" % (type(text)))
68 | else:
69 | raise ValueError("Not running on Python2 or Python 3?")
70 |
71 |
72 | def load_vocab(vocab_file):
73 | """Loads a vocabulary file into a dictionary."""
74 | vocab = collections.OrderedDict()
75 | index = 0
76 | with open(vocab_file, "r") as reader:
77 | while True:
78 | token = convert_to_unicode(reader.readline())
79 | if not token:
80 | break
81 | token = token.strip() # sanghun 특수문자 제거 문장만 추출
82 | vocab[token] = index # index 부여
83 | index += 1
84 | return vocab
85 |
86 |
87 | def convert_tokens_to_ids(vocab, tokens):
88 | """Converts a sequence of tokens into ids using the vocab."""
89 | ids = []
90 | for token in tokens:
91 | ids.append(vocab[token])
92 | return ids
93 |
94 |
95 | def whitespace_tokenize(text):
96 | """Runs basic whitespace cleaning and splitting on a peice of text."""
97 | text = text.strip()
98 | if not text:
99 | return []
100 | tokens = text.split()
101 | return tokens
102 |
103 |
104 | class FullTokenizer(object):
105 | """Runs end-to-end tokenziation."""
106 |
107 | def __init__(self, vocab_file, do_lower_case=True):
108 | self.vocab = load_vocab(vocab_file) # vocab -> idx
109 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
110 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
111 |
112 | def tokenize(self, text):
113 | split_tokens = []
114 | for token in self.basic_tokenizer.tokenize(text):
115 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
116 | split_tokens.append(sub_token)
117 |
118 | return split_tokens
119 |
120 | def convert_tokens_to_ids(self, tokens):
121 | return convert_tokens_to_ids(self.vocab, tokens)
122 |
123 | def convert_to_unicode(self, text):
124 | return convert_to_unicode(text)
125 |
126 |
127 |
128 | class BasicTokenizer(object):
129 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
130 |
131 | def __init__(self, do_lower_case=True):
132 | """Constructs a BasicTokenizer.
133 |
134 | Args:
135 | do_lower_case: Whether to lower case the input.
136 | """
137 | self.do_lower_case = do_lower_case
138 |
139 | def tokenize(self, text):
140 | """Tokenizes a piece of text."""
141 | text = convert_to_unicode(text)
142 | text = self._clean_text(text)
143 | orig_tokens = whitespace_tokenize(text)
144 | split_tokens = []
145 | for token in orig_tokens:
146 | if self.do_lower_case:
147 | token = token.lower()
148 | token = self._run_strip_accents(token)
149 | split_tokens.extend(self._run_split_on_punc(token))
150 |
151 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
152 | return output_tokens
153 |
154 | def _run_strip_accents(self, text):
155 | """Strips accents from a piece of text."""
156 |
157 | korean = "%s-%s%s-%s" % (chr(0xac00), chr(0xd7a3),
158 | chr(0x3131), chr(0x3163))
159 | if re.search("[%s]+" % korean, text):
160 | return "".join(
161 | substr if re.search("^[%s]+$" % korean, substr)
162 | else self._run_strip_accents(substr)
163 | for substr in re.findall("[%s]+|[^%s]+" % (korean, korean), text)
164 | )
165 |
166 | text = unicodedata.normalize("NFD", text)
167 | output = []
168 | for char in text:
169 | cat = unicodedata.category(char)
170 | if cat == "Mn":
171 | continue
172 | output.append(char)
173 | return "".join(output)
174 |
175 | def _run_split_on_punc(self, text):
176 | """Splits punctuation on a piece of text."""
177 | chars = list(text)
178 | i = 0
179 | start_new_word = True
180 | output = []
181 | while i < len(chars):
182 | char = chars[i]
183 | if _is_punctuation(char):
184 | output.append([char])
185 | start_new_word = True
186 | else:
187 | if start_new_word:
188 | output.append([])
189 | start_new_word = False
190 | output[-1].append(char)
191 | i += 1
192 |
193 | return ["".join(x) for x in output]
194 |
195 | def _clean_text(self, text):
196 | """Performs invalid character removal and whitespace cleanup on text."""
197 | output = []
198 | for char in text:
199 | cp = ord(char)
200 | if cp == 0 or cp == 0xfffd or _is_control(char):
201 | continue
202 | if _is_whitespace(char):
203 | output.append(" ")
204 | else:
205 | output.append(char)
206 | return "".join(output)
207 |
208 |
209 | class WordpieceTokenizer(object):
210 | """Runs WordPiece tokenization."""
211 |
212 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
213 | self.vocab = vocab
214 | self.unk_token = unk_token # UNK token : unknown 출현빈도가 낮은 단어 대체
215 | self.max_input_chars_per_word = max_input_chars_per_word
216 |
217 | def tokenize(self, text):
218 | """Tokenizes a piece of text into its word pieces.
219 |
220 | This uses a greedy longest-match-first algorithm to perform tokenization
221 | using the given vocabulary.
222 |
223 | For example:
224 | input = "unaffable"
225 | output = ["un", "##aff", "##able"]
226 |
227 | Args:
228 | text: A single token or whitespace separated tokens. This should have
229 | already been passed through `BasicTokenizer.
230 |
231 | Returns:
232 | A list of wordpiece tokens.
233 | """
234 |
235 | text = convert_to_unicode(text)
236 |
237 | output_tokens = []
238 | for token in whitespace_tokenize(text):
239 | chars = list(token)
240 | if len(chars) > self.max_input_chars_per_word:
241 | output_tokens.append(self.unk_token)
242 | continue
243 |
244 | is_bad = False
245 | start = 0
246 | sub_tokens = []
247 | while start < len(chars):
248 | end = len(chars)
249 | cur_substr = None
250 | while start < end:
251 | substr = "".join(chars[start:end])
252 | if start > 0:
253 | substr = "##" + substr
254 | if substr in self.vocab:
255 | cur_substr = substr
256 | break
257 | end -= 1
258 | if cur_substr is None:
259 | is_bad = True
260 | break
261 | sub_tokens.append(cur_substr)
262 | start = end
263 |
264 | if is_bad:
265 | output_tokens.append(self.unk_token)
266 | else:
267 | output_tokens.extend(sub_tokens)
268 | return output_tokens
269 |
270 |
271 | def _is_whitespace(char):
272 | """Checks whether `chars` is a whitespace character."""
273 | # \t, \n, and \r are technically contorl characters but we treat them
274 | # as whitespace since they are generally considered as such.
275 | if char == " " or char == "\t" or char == "\n" or char == "\r":
276 | return True
277 | cat = unicodedata.category(char)
278 | if cat == "Zs":
279 | return True
280 | return False
281 |
282 |
283 | def _is_control(char):
284 | """Checks whether `chars` is a control character."""
285 | # These are technically control characters but we count them as whitespace
286 | # characters.
287 | if char == "\t" or char == "\n" or char == "\r":
288 | return False
289 | cat = unicodedata.category(char)
290 | if cat.startswith("C"):
291 | return True
292 | return False
293 |
294 |
295 | def _is_punctuation(char):
296 | """Checks whether `chars` is a punctuation character."""
297 | cp = ord(char)
298 | # We treat all non-letter/number ASCII as punctuation.
299 | # Characters such as "^", "$", and "`" are not in the Unicode
300 | # Punctuation class but we treat them as punctuation anyways, for
301 | # consistency.
302 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
303 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
304 | return True
305 | cat = unicodedata.category(char)
306 | if cat.startswith("P"):
307 | return True
308 | return False
309 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 SanghunYun, Korea University.
2 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain.
3 | #
4 | # This file has been modified by SanghunYun, Korea University
5 | # for add fucntion of _get_device and class of output_logging.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | import os
21 | import csv
22 | import random
23 | import logging
24 |
25 | import numpy as np
26 | import torch
27 |
28 |
29 | def torch_device_one():
30 | return torch.tensor(1.).to(_get_device())
31 |
32 | def set_seeds(seed):
33 | "set random seeds"
34 | random.seed(seed)
35 | np.random.seed(seed)
36 | torch.manual_seed(seed)
37 | torch.cuda.manual_seed_all(seed)
38 |
39 | def get_device():
40 | "get device (CPU or GPU)"
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | n_gpu = torch.cuda.device_count()
43 | print("%s (%d GPUs)" % (device, n_gpu))
44 | return device
45 |
46 | def _get_device():
47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48 | return device
49 |
50 | def split_last(x, shape):
51 | "split the last dimension to given shape"
52 | shape = list(shape)
53 | assert shape.count(-1) <= 1
54 | if -1 in shape:
55 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
56 | return x.view(*x.size()[:-1], *shape)
57 |
58 | def merge_last(x, n_dims):
59 | "merge the last n_dims to a dimension"
60 | s = x.size()
61 | assert n_dims > 1 and n_dims < len(s)
62 | return x.view(*s[:-n_dims], -1)
63 |
64 | def truncate_tokens_pair(tokens_a, tokens_b, max_len):
65 | while True:
66 | if len(tokens_a) + len(tokens_b) <= max_len:
67 | break
68 | if len(tokens_a) > len(tokens_b):
69 | tokens_a.pop()
70 | else:
71 | tokens_b.pop()
72 |
73 | def get_random_word(vocab_words):
74 | i = random.randint(0, len(vocab_words)-1)
75 | return vocab_words[i]
76 |
77 | def get_logger(name, log_path):
78 | "get logger"
79 | logger = logging.getLogger(name)
80 | fomatter = logging.Formatter(
81 | '[ %(levelname)s|%(filename)s:%(lineno)s] %(asctime)s > %(message)s')
82 |
83 | if not os.path.isfile(log_path):
84 | f = open(log_path, "w+")
85 |
86 | fileHandler = logging.FileHandler(log_path)
87 | fileHandler.setFormatter(fomatter)
88 | logger.addHandler(fileHandler)
89 |
90 | #streamHandler = logging.StreamHandler()
91 | #streamHandler.setFormatter(fomatter)
92 | #logger.addHandler(streamHandler)
93 |
94 | logger.setLevel(logging.DEBUG)
95 | return logger
96 |
97 |
98 | class output_logging(object):
99 | def __init__(self, mode, real_time=False, dump_dir=None):
100 | self.mode = mode
101 | self.real_time = real_time
102 | self.dump_dir = dump_dir if dump_dir else None
103 |
104 | if dump_dir:
105 | self.dump = open(os.path.join(dump_dir, 'logs/output.tsv'), 'w', encoding='utf-8', newline='')
106 | self.wr = csv.writer(self.dump, delimiter='\t')
107 |
108 | # header
109 | if mode == 'eval':
110 | self.wr.writerow(['Ground_truth', 'Predcit', 'sentence'])
111 | elif mode == 'test':
112 | self.wr.writerow(['Predict', 'sentence'])
113 |
114 | def __del__(self):
115 | if self.dump_dir:
116 | self.dump.close()
117 |
118 | def logs(self, sentence, pred, ground_turth=None):
119 | if self.real_time:
120 | if self.mode == 'eval':
121 | for p, g, s in zip(pred, ground_turth, sentence):
122 | print('Ground_truth | Predict')
123 | print(int(g), ' ', int(p))
124 | print(s, end='\n\n')
125 | elif self.mode == 'test':
126 | for p, s in zip(pred, sentence):
127 | print('predict : ', int(p))
128 | print(s, end='\n\n')
129 |
130 | if self.dump_dir:
131 | if self.mode == 'eval':
132 | for p, g, s in zip(pred, ground_turth, sentence):
133 | self.wr.writerow([int(p), int(g), s])
134 | elif self.mode == 'test':
135 | for p, s in zip(pred, sentence):
136 | self.wr.writerow([int(p), s])
137 |
--------------------------------------------------------------------------------