├── .gitignore ├── LICENSE ├── README.md ├── README_data ├── 2019-08-21-18-01-07.png ├── 2019-08-22-14-16-49.png ├── 2019-08-22-14-16-59.png └── 2019-08-30-22-18-28.png ├── config ├── bert_base.json ├── eval.json ├── non-uda.json └── uda.json ├── data.zip ├── download.sh ├── load_data.py ├── main.py ├── models.py ├── train.py └── utils ├── __init__.py ├── checkpoint.py ├── configuration.py ├── optim.py ├── tokenization.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | BERT_Base_Uncased 2 | results_045 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UDA(Unsupervised Data Augmentation) with BERT 2 | This is re-implementation of Google's UDA [[paper]](https://arxiv.org/abs/1904.12848)[[tensorflow]](https://github.com/google-research/uda) in pytorch with Kakao Brain's Pytorchic BERT[[pytorch]](https://github.com/dhlee347/pytorchic-bert). 3 | 4 | Model  | UDA official | This repository 5 | -- | -- | -- 6 | UDA (X) | 68% |   7 | UDA (O) | 90% | 88.45% 8 | 9 | (Max sequence length = 128, Train batch size = 8) 10 | 11 | ![](README_data/2019-08-30-22-18-28.png) 12 | 13 | 14 | ## UDA 15 | > UDA(Unsupervised Data Augmentation) is a semi-supervised learning method which achieves SOTA results on a wide variety of language and vision tasks. With only 20 labeled examples, UDA outperforms the previous SOTA on IMDb trained on 25,000 labeled examples. (BERT=4.51, UDA=4.20, error rate) 16 | ![](README_data/2019-08-21-18-01-07.png) 17 | > * Unsupervised Data Augmentation for Consistency Training (2019 Google Brain, Q Xie et al.) 18 | 19 | #### - UDA with BERT 20 | UDA works as part of BERT. It means that UDA act as an assistant of BERT. So, in the picture above model **M** is BERT. 21 | 22 | #### - Loss 23 | UDA consist of supervised loss and unsupervised loss. Supervised loss is traditional Cross-entropy loss and Unsupervised loss is KL-divergence loss of original example and augmented example outputs. In this project, I used Back translation technique for augmentation.
24 | The supervised loss and unsupervised loss are added to form a total loss and then total loss is descent. To be careful is loss doesn't descent trough original example route. Only by labeled data and augmented unlabeled data Model's weights are updated. 25 | 26 | #### - TSA(Training Signal Annealing) 27 | There is a large gap between the amount of unlabeled data and that of labeled data. So, it is easy to overfit to labeled data. Therefore, TSA technique mask out the examples that predicted probability is bigger than threshold. The threshold is scheduled by log, linear or exponential function.
28 | ![](README_data/2019-08-22-14-16-49.png)
29 | ![](README_data/2019-08-22-14-16-59.png)
30 | 31 | #### - Sharpening Predictions 32 | The KL-divergence loss(ori, aug) is too small to just use. It can cause that the total loss is dominated by supervised loss. Therefore, Sharpening Prediction techniques is needed. 33 | 34 | - Confidence-based masking : Maksing out examples that the current model is not confident about. Specifically, in each minibatch, the consistency loss term is computed only on examples whose highest probability. 35 | - Softmax temperature controlling : Be used when computing the predictions on original example. Specifically, probability of original example is computed as Softmax(l(x)/τ) where l(x) denotes the logits and τ is the temperature. A lower temperature corresponds to a sharper distribution.
(UDA, 2019 Google Brain, Q Xie et al.) 36 | 37 | ## Requirements 38 | **UDA** : python > 3.6, fire, tqdm, tensorboardX, tensorflow, pytorch, pandas, numpy 39 | 40 | ## Overview 41 | 42 | - [`download.sh`](./download.sh) : Download pre-trained BERT model from Google's official BERT and IMDb data file 43 | - [`load_data.py`](./load_data.py) : Load the data of sup, unsup 44 | - [`models.py`](./models.py) : Model calsses for a general transformer (from Pytorchic BERT's code) 45 | - [`main.py`](./main.py) : Including default BERT, UDA(TSA, Sharpening) modes 46 | - [`train.py`](./train.py) : A custom training class(Trainer class) adopted from Pytorhchic BERT's code 47 | - ***utils*** 48 | - [`configuration.py`](./utils/configuration.py) : Set a configuration from json file 49 | - [`checkpoint.py`](./utils/checkpoint.py) : Functions to load a model from tensorflow's file (from Pytorchic BERT's code) 50 | - [`optim.py`](./utils.optim.py) : Optimizer (BERTAdam class) (from Pytorchic BERT's code) 51 | - [`tokenization.py`](./utils/tokenization.py) : Tokenizers adopted from the original Google BERT's code 52 | - [`utils.py`](./utils/utils.py) : A custom utility functions adopted from Pytorchic BERT's code 53 | 54 | ## Pre-works 55 | 56 | #### - Download pre-trained BERT model and unzip IMDb data 57 | First, you have to download pre-trained BERT_base from Google's BERT repository. And unzip IMDb data 58 | 59 | bash download.sh 60 | After running, you can get the pre-trained BERT_base_Uncased model at **/BERT_Base_Uncased** director and **/data** 61 | 62 | I use already pre-processed and augmented IMDb data extracted from official [UDA](https://github.com/google-research/uda). If you want to use your raw data, change need_prepro = True. 63 | 64 | ## Example usage 65 | This project are broadly divided into two parts(Fine-tuning, Evaluation).
66 | **Caution** : **Before runing code, you have to check and edit config file** 67 | 68 | 1. **Fine-tuning** 69 |
You can choose train mode(train, train_eval) on non-uda.json or uda.json (default : train_eval). 70 | - Non UDA fine-tuning 71 | 72 | python main.py \ 73 | --cfg='config/non-uda.json' \ 74 | --model_cfg='config/bert_base.json' 75 | 76 | - UDA fine-tuning 77 | 78 | python main.py \ 79 | --cfg='config/uda.json' \ 80 | --model_cfg='config/bert_base.json' 81 | 82 | 2. **Evaluation** 83 | - Basically evaluation code, dump out results file. So, you can change dump option in [main.py](./main.py) There is two mode (real_time print, make tsv file) 84 | 85 | python main.py \ 86 | --cfg='config/eval.json' \ 87 | --model_cfg='config/bert_base.json' 88 | 89 | 90 | ## Acknowledgement 91 | Thanks to references of [UDA](https://github.com/google-research/uda) and [Pytorchic BERT](https://github.com/dhlee347/pytorchic-bert), I can implement this code. 92 | 93 | ## TODO 94 | 1. It is known that further training(more pre-training by the specific corpus on already pre-trained BERT) can improve performance. But, this repository does not have pretrain code. So, pretrain code will be added. If you want to further training you can use [Pytorchic BERT](https://github.com/dhlee347/pytorchic-bert) 's pretrain.py or any BERT project. 95 | 96 | 2. Korean dataset version -------------------------------------------------------------------------------- /README_data/2019-08-21-18-01-07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/README_data/2019-08-21-18-01-07.png -------------------------------------------------------------------------------- /README_data/2019-08-22-14-16-49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/README_data/2019-08-22-14-16-49.png -------------------------------------------------------------------------------- /README_data/2019-08-22-14-16-59.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/README_data/2019-08-22-14-16-59.png -------------------------------------------------------------------------------- /README_data/2019-08-30-22-18-28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/README_data/2019-08-30-22-18-28.png -------------------------------------------------------------------------------- /config/bert_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 768, 3 | "dim_ff": 3072, 4 | "n_layers": 12, 5 | "p_drop_attn": 0.1, 6 | "n_heads": 12, 7 | "p_drop_hidden": 0.1, 8 | "max_len": 512, 9 | "n_segments": 2, 10 | "vocab_size": 30522 11 | } -------------------------------------------------------------------------------- /config/eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "max_seq_length": 128, 4 | "eval_batch_size": 16, 5 | "do_lower_case": true, 6 | "data_parallel": true, 7 | "need_prepro": false, 8 | "model_file": "results_045/save/model_steps_6250.pt", 9 | "eval_data_dir": "data/imdb_sup_test.txt", 10 | "vocab":"BERT_Base_Uncased/vocab.txt", 11 | "task": "imdb" 12 | } 13 | -------------------------------------------------------------------------------- /config/non-uda.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 42, 3 | "lr": 2e-5, 4 | "warmup": 0.1, 5 | "do_lower_case": true, 6 | "mode": "train_eval", 7 | "uda_mode": false, 8 | 9 | "total_steps": 10000, 10 | "max_seq_length": 128, 11 | "train_batch_size": 8, 12 | "eval_batch_size": 16, 13 | 14 | "data_parallel": true, 15 | "need_prepro": false, 16 | "sup_data_dir": "data/imdb_sup_train.txt", 17 | "eval_data_dir": "data/imdb_sup_test.txt", 18 | 19 | "model_file":null, 20 | "pretrain_file":"BERT_Base_Uncased/bert_model.ckpt", 21 | "vocab":"BERT_Base_Uncased/vocab.txt", 22 | "task": "imdb", 23 | 24 | "save_steps": 100, 25 | "check_steps": 250, 26 | "results_dir": "results_non", 27 | 28 | "is_position": false 29 | } 30 | -------------------------------------------------------------------------------- /config/uda.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 42, 3 | "lr": 2e-5, 4 | "warmup": 0.1, 5 | "do_lower_case": true, 6 | "mode": "train_eval", 7 | "uda_mode": true, 8 | 9 | "total_steps": 10000, 10 | "max_seq_length": 128, 11 | "train_batch_size": 8, 12 | "eval_batch_size": 16, 13 | 14 | "unsup_ratio": 3, 15 | "uda_coeff": 1, 16 | "tsa": "linear_schedule", 17 | "uda_softmax_temp": 0.85, 18 | "uda_confidence_thresh": 0.45, 19 | 20 | "data_parallel": true, 21 | "need_prepro": false, 22 | "sup_data_dir": "data/imdb_sup_train.txt", 23 | "unsup_data_dir": "data/imdb_unsup_train.txt", 24 | "eval_data_dir": "data/imdb_sup_test.txt", 25 | 26 | "model_file":null, 27 | "pretrain_file":"BERT_Base_Uncased/bert_model.ckpt", 28 | "vocab":"BERT_Base_Uncased/vocab.txt", 29 | "task": "imdb", 30 | 31 | "save_steps": 100, 32 | "check_steps": 250, 33 | "results_dir": "results", 34 | 35 | "is_position": false 36 | } 37 | -------------------------------------------------------------------------------- /data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/data.zip -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google UDA Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | #!/bin/bash 16 | 17 | # **** download pretrained models **** 18 | wget storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip 19 | unzip uncased_L-12_H-768_A-12.zip && rm uncased_L-12_H-768_A-12.zip 20 | mv uncased_L-12_H-768_A-12 BERT_Base_Uncased 21 | 22 | # **** unzip data **** 23 | unzip data.zip && rm data.zip -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 SanghunYun, Korea University. 2 | # (Strongly inspired by Dong-Hyun Lee, Kakao Brain) 3 | # 4 | # This file has been modified by SanghunYun, Korea Univeristy. 5 | # Little modification at Tokenizing, AddSpecialTokensWithTruncation, TokenIndexing 6 | # and CsvDataset, load_data has been newly written. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | 21 | import ast 22 | import csv 23 | import itertools 24 | 25 | import pandas as pd # only import when no need_to_preprocessing 26 | from tqdm import tqdm 27 | 28 | import torch 29 | from torch.utils.data import Dataset, DataLoader 30 | 31 | from utils import tokenization 32 | from utils.utils import truncate_tokens_pair 33 | 34 | 35 | class CsvDataset(Dataset): 36 | labels = None 37 | def __init__(self, file, need_prepro, pipeline, max_len, mode, d_type): 38 | Dataset.__init__(self) 39 | self.cnt = 0 40 | 41 | # need preprocessing 42 | if need_prepro: 43 | with open(file, 'r', encoding='utf-8') as f: 44 | lines = csv.reader(f, delimiter='\t', quotechar='"') 45 | 46 | # supervised dataset 47 | if d_type == 'sup': 48 | # if mode == 'eval': 49 | # sentences = [] 50 | data = [] 51 | 52 | for instance in self.get_sup(lines): 53 | # if mode == 'eval': 54 | # sentences.append([instance[1]]) 55 | for proc in pipeline: 56 | instance = proc(instance, d_type) 57 | data.append(instance) 58 | 59 | self.tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*data)] 60 | # if mode == 'eval': 61 | # self.tensors.append(sentences) 62 | 63 | # unsupervised dataset 64 | elif d_type == 'unsup': 65 | data = {'ori':[], 'aug':[]} 66 | for ori, aug in self.get_unsup(lines): 67 | for proc in pipeline: 68 | ori = proc(ori, d_type) 69 | aug = proc(aug, d_type) 70 | self.cnt += 1 71 | # if self.cnt == 10: 72 | # break 73 | data['ori'].append(ori) # drop label_id 74 | data['aug'].append(aug) # drop label_id 75 | ori_tensor = [torch.tensor(x, dtype=torch.long) for x in zip(*data['ori'])] 76 | aug_tensor = [torch.tensor(x, dtype=torch.long) for x in zip(*data['aug'])] 77 | self.tensors = ori_tensor + aug_tensor 78 | # already preprocessed 79 | else: 80 | f = open(file, 'r', encoding='utf-8') 81 | data = pd.read_csv(f, sep='\t') 82 | 83 | # supervised dataset 84 | if d_type == 'sup': 85 | # input_ids, segment_ids(input_type_ids), input_mask, input_label 86 | input_columns = ['input_ids', 'input_type_ids', 'input_mask', 'label_ids'] 87 | self.tensors = [torch.tensor(data[c].apply(lambda x: ast.literal_eval(x)), dtype=torch.long) \ 88 | for c in input_columns[:-1]] 89 | self.tensors.append(torch.tensor(data[input_columns[-1]], dtype=torch.long)) 90 | 91 | # unsupervised dataset 92 | elif d_type == 'unsup': 93 | input_columns = ['ori_input_ids', 'ori_input_type_ids', 'ori_input_mask', 94 | 'aug_input_ids', 'aug_input_type_ids', 'aug_input_mask'] 95 | self.tensors = [torch.tensor(data[c].apply(lambda x: ast.literal_eval(x)), dtype=torch.long) \ 96 | for c in input_columns] 97 | 98 | else: 99 | raise "d_type error. (d_type have to sup or unsup)" 100 | 101 | def __len__(self): 102 | return self.tensors[0].size(0) 103 | 104 | def __getitem__(self, index): 105 | return tuple(tensor[index] for tensor in self.tensors) 106 | 107 | def get_sup(self, lines): 108 | raise NotImplementedError 109 | 110 | def get_unsup(self, lines): 111 | raise NotImplementedError 112 | 113 | 114 | class Pipeline(): 115 | def __init__(self): 116 | super().__init__() 117 | 118 | def __call__(self, instance): 119 | raise NotImplementedError 120 | 121 | 122 | class Tokenizing(Pipeline): 123 | def __init__(self, preprocessor, tokenize): 124 | super().__init__() 125 | self.preprocessor = preprocessor 126 | self.tokenize = tokenize 127 | 128 | def __call__(self, instance, d_type): 129 | label, text_a, text_b = instance 130 | 131 | label = self.preprocessor(label) if label else None 132 | tokens_a = self.tokenize(self.preprocessor(text_a)) 133 | tokens_b = self.tokenize(self.preprocessor(text_b)) if text_b else [] 134 | 135 | return (label, tokens_a, tokens_b) 136 | 137 | 138 | class AddSpecialTokensWithTruncation(Pipeline): 139 | def __init__(self, max_len=512): 140 | super().__init__() 141 | self.max_len = max_len 142 | 143 | def __call__(self, instance, d_type): 144 | label, tokens_a, tokens_b = instance 145 | 146 | # -3 special tokens for [CLS] text_a [SEP] text_b [SEP] 147 | # -2 special tokens for [CLS] text_a [SEP] 148 | _max_len = self.max_len - 3 if tokens_b else self.max_len - 2 149 | truncate_tokens_pair(tokens_a, tokens_b, _max_len) 150 | 151 | # Add Special Tokens 152 | tokens_a = ['[CLS]'] + tokens_a + ['[SEP]'] 153 | tokens_b = tokens_b + ['[SEP]'] if tokens_b else [] 154 | 155 | return (label, tokens_a, tokens_b) 156 | 157 | 158 | class TokenIndexing(Pipeline): 159 | def __init__(self, indexer, labels, max_len=512): 160 | super().__init__() 161 | self.indexer = indexer # function : tokens to indexes 162 | # map from a label name to a label index 163 | self.label_map = {name: i for i, name in enumerate(labels)} 164 | self.max_len = max_len 165 | 166 | def __call__(self, instance, d_type): 167 | label, tokens_a, tokens_b = instance 168 | 169 | input_ids = self.indexer(tokens_a + tokens_b) 170 | segment_ids = [0]*len(tokens_a) + [1]*len(tokens_b) # type_ids 171 | input_mask = [1]*(len(tokens_a) + len(tokens_b)) 172 | label_id = self.label_map[label] if label else None 173 | 174 | # zero padding 175 | n_pad = self.max_len - len(input_ids) 176 | input_ids.extend([0]*n_pad) 177 | segment_ids.extend([0]*n_pad) 178 | input_mask.extend([0]*n_pad) 179 | 180 | if label_id != None: 181 | return (input_ids, segment_ids, input_mask, label_id) 182 | else: 183 | return (input_ids, segment_ids, input_mask) 184 | 185 | 186 | def dataset_class(task): 187 | table = {'imdb': IMDB} 188 | return table[task] 189 | 190 | 191 | class IMDB(CsvDataset): 192 | labels = ('0', '1') 193 | def __init__(self, file, need_prepro, pipeline=[], max_len=128, mode='train', d_type='sup'): 194 | super().__init__(file, need_prepro, pipeline, max_len, mode, d_type) 195 | 196 | def get_sup(self, lines): 197 | for line in itertools.islice(lines, 0, None): 198 | yield line[7], line[6], [] # label, text_a, None 199 | # yield None, line[6], [] 200 | 201 | def get_unsup(self, lines): 202 | for line in itertools.islice(lines, 0, None): 203 | yield (None, line[1], []), (None, line[2], []) # ko, en 204 | 205 | 206 | class load_data: 207 | def __init__(self, cfg): 208 | self.cfg = cfg 209 | 210 | self.TaskDataset = dataset_class(cfg.task) 211 | self.pipeline = None 212 | if cfg.need_prepro: 213 | tokenizer = tokenization.FullTokenizer(vocab_file=cfg.vocab, do_lower_case=cfg.do_lower_case) 214 | self.pipeline = [Tokenizing(tokenizer.convert_to_unicode, tokenizer.tokenize), 215 | AddSpecialTokensWithTruncation(cfg.max_seq_length), 216 | TokenIndexing(tokenizer.convert_tokens_to_ids, self.TaskDataset.labels, cfg.max_seq_length)] 217 | 218 | if cfg.mode == 'train': 219 | self.sup_data_dir = cfg.sup_data_dir 220 | self.sup_batch_size = cfg.train_batch_size 221 | self.shuffle = True 222 | elif cfg.mode == 'train_eval': 223 | self.sup_data_dir = cfg.sup_data_dir 224 | self.eval_data_dir= cfg.eval_data_dir 225 | self.sup_batch_size = cfg.train_batch_size 226 | self.eval_batch_size = cfg.eval_batch_size 227 | self.shuffle = True 228 | elif cfg.mode == 'eval': 229 | self.sup_data_dir = cfg.eval_data_dir 230 | self.sup_batch_size = cfg.eval_batch_size 231 | self.shuffle = False # Not shuffel when eval mode 232 | 233 | if cfg.uda_mode: # Only uda_mode 234 | self.unsup_data_dir = cfg.unsup_data_dir 235 | self.unsup_batch_size = cfg.train_batch_size * cfg.unsup_ratio 236 | 237 | def sup_data_iter(self): 238 | sup_dataset = self.TaskDataset(self.sup_data_dir, self.cfg.need_prepro, self.pipeline, self.cfg.max_seq_length, self.cfg.mode, 'sup') 239 | sup_data_iter = DataLoader(sup_dataset, batch_size=self.sup_batch_size, shuffle=self.shuffle) 240 | 241 | return sup_data_iter 242 | 243 | def unsup_data_iter(self): 244 | unsup_dataset = self.TaskDataset(self.unsup_data_dir, self.cfg.need_prepro, self.pipeline, self.cfg.max_seq_length, self.cfg.mode, 'unsup') 245 | unsup_data_iter = DataLoader(unsup_dataset, batch_size=self.unsup_batch_size, shuffle=self.shuffle) 246 | 247 | return unsup_data_iter 248 | 249 | def eval_data_iter(self): 250 | eval_dataset = self.TaskDataset(self.eval_data_dir, self.cfg.need_prepro, self.pipeline, self.cfg.max_seq_length, 'eval', 'sup') 251 | eval_data_iter = DataLoader(eval_dataset, batch_size=self.eval_batch_size, shuffle=False) 252 | 253 | return eval_data_iter 254 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 SanghunYun, Korea University. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import fire 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import models 22 | import train 23 | from load_data import load_data 24 | from utils.utils import set_seeds, get_device, _get_device, torch_device_one 25 | from utils import optim, configuration 26 | 27 | 28 | # TSA 29 | def get_tsa_thresh(schedule, global_step, num_train_steps, start, end): 30 | training_progress = torch.tensor(float(global_step) / float(num_train_steps)) 31 | if schedule == 'linear_schedule': 32 | threshold = training_progress 33 | elif schedule == 'exp_schedule': 34 | scale = 5 35 | threshold = torch.exp((training_progress - 1) * scale) 36 | elif schedule == 'log_schedule': 37 | scale = 5 38 | threshold = 1 - torch.exp((-training_progress) * scale) 39 | output = threshold * (end - start) + start 40 | return output.to(_get_device()) 41 | 42 | 43 | def main(cfg, model_cfg): 44 | # Load Configuration 45 | cfg = configuration.params.from_json(cfg) # Train or Eval cfg 46 | model_cfg = configuration.model.from_json(model_cfg) # BERT_cfg 47 | set_seeds(cfg.seed) 48 | 49 | # Load Data & Create Criterion 50 | data = load_data(cfg) 51 | if cfg.uda_mode: 52 | unsup_criterion = nn.KLDivLoss(reduction='none') 53 | data_iter = [data.sup_data_iter(), data.unsup_data_iter()] if cfg.mode=='train' \ 54 | else [data.sup_data_iter(), data.unsup_data_iter(), data.eval_data_iter()] # train_eval 55 | else: 56 | data_iter = [data.sup_data_iter()] 57 | sup_criterion = nn.CrossEntropyLoss(reduction='none') 58 | 59 | # Load Model 60 | model = models.Classifier(model_cfg, len(data.TaskDataset.labels)) 61 | 62 | # Create trainer 63 | trainer = train.Trainer(cfg, model, data_iter, optim.optim4GPU(cfg, model), get_device()) 64 | 65 | # Training 66 | def get_loss(model, sup_batch, unsup_batch, global_step): 67 | 68 | # logits -> prob(softmax) -> log_prob(log_softmax) 69 | 70 | # batch 71 | input_ids, segment_ids, input_mask, label_ids = sup_batch 72 | if unsup_batch: 73 | ori_input_ids, ori_segment_ids, ori_input_mask, \ 74 | aug_input_ids, aug_segment_ids, aug_input_mask = unsup_batch 75 | 76 | input_ids = torch.cat((input_ids, aug_input_ids), dim=0) 77 | segment_ids = torch.cat((segment_ids, aug_segment_ids), dim=0) 78 | input_mask = torch.cat((input_mask, aug_input_mask), dim=0) 79 | 80 | # logits 81 | logits = model(input_ids, segment_ids, input_mask) 82 | 83 | # sup loss 84 | sup_size = label_ids.shape[0] 85 | sup_loss = sup_criterion(logits[:sup_size], label_ids) # shape : train_batch_size 86 | if cfg.tsa: 87 | tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1) 88 | larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh # prob = exp(log_prob), prob > tsa_threshold 89 | # larger_than_threshold = torch.sum( F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids] , dim=-1) > tsa_threshold 90 | loss_mask = torch.ones_like(label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32)) 91 | sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one()) 92 | else: 93 | sup_loss = torch.mean(sup_loss) 94 | 95 | # unsup loss 96 | if unsup_batch: 97 | # ori 98 | with torch.no_grad(): 99 | ori_logits = model(ori_input_ids, ori_segment_ids, ori_input_mask) 100 | ori_prob = F.softmax(ori_logits, dim=-1) # KLdiv target 101 | # ori_log_prob = F.log_softmax(ori_logits, dim=-1) 102 | 103 | # confidence-based masking 104 | if cfg.uda_confidence_thresh != -1: 105 | unsup_loss_mask = torch.max(ori_prob, dim=-1)[0] > cfg.uda_confidence_thresh 106 | unsup_loss_mask = unsup_loss_mask.type(torch.float32) 107 | else: 108 | unsup_loss_mask = torch.ones(len(logits) - sup_size, dtype=torch.float32) 109 | unsup_loss_mask = unsup_loss_mask.to(_get_device()) 110 | 111 | # aug 112 | # softmax temperature controlling 113 | uda_softmax_temp = cfg.uda_softmax_temp if cfg.uda_softmax_temp > 0 else 1. 114 | aug_log_prob = F.log_softmax(logits[sup_size:] / uda_softmax_temp, dim=-1) 115 | 116 | # KLdiv loss 117 | """ 118 | nn.KLDivLoss (kl_div) 119 | input : log_prob (log_softmax) 120 | target : prob (softmax) 121 | https://pytorch.org/docs/stable/nn.html 122 | 123 | unsup_loss is divied by number of unsup_loss_mask 124 | it is different from the google UDA official 125 | The official unsup_loss is divided by total 126 | https://github.com/google-research/uda/blob/master/text/uda.py#L175 127 | """ 128 | unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1) 129 | unsup_loss = torch.sum(unsup_loss * unsup_loss_mask, dim=-1) / torch.max(torch.sum(unsup_loss_mask, dim=-1), torch_device_one()) 130 | final_loss = sup_loss + cfg.uda_coeff*unsup_loss 131 | 132 | return final_loss, sup_loss, unsup_loss 133 | return sup_loss, None, None 134 | 135 | # evaluation 136 | def get_acc(model, batch): 137 | # input_ids, segment_ids, input_mask, label_id, sentence = batch 138 | input_ids, segment_ids, input_mask, label_id = batch 139 | logits = model(input_ids, segment_ids, input_mask) 140 | _, label_pred = logits.max(1) 141 | 142 | result = (label_pred == label_id).float() 143 | accuracy = result.mean() 144 | # output_dump.logs(sentence, label_pred, label_id) # output dump 145 | 146 | return accuracy, result 147 | 148 | if cfg.mode == 'train': 149 | trainer.train(get_loss, None, cfg.model_file, cfg.pretrain_file) 150 | 151 | if cfg.mode == 'train_eval': 152 | trainer.train(get_loss, get_acc, cfg.model_file, cfg.pretrain_file) 153 | 154 | if cfg.mode == 'eval': 155 | results = trainer.eval(get_acc, cfg.model_file, None) 156 | total_accuracy = torch.cat(results).mean().item() 157 | print('Accuracy :' , total_accuracy) 158 | 159 | 160 | if __name__ == '__main__': 161 | fire.Fire(main) 162 | #main('config/uda.json', 'config/bert_base.json') 163 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain. 2 | # (Strongly inspired by original Google BERT code and Hugging Face's code) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """ Transformer Model Classes & Config Class """ 18 | 19 | import math 20 | import json 21 | from typing import NamedTuple 22 | 23 | import numpy as np 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | from utils.utils import split_last, merge_last 29 | 30 | 31 | class Config(NamedTuple): 32 | "Configuration for BERT model" 33 | vocab_size: int = None # Size of Vocabulary 34 | dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder 35 | n_layers: int = 12 # Numher of Hidden Layers 36 | n_heads: int = 12 # Numher of Heads in Multi-Headed Attention Layers 37 | dim_ff: int = 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net 38 | #activ_fn: str = "gelu" # Non-linear Activation Function Type in Hidden Layers 39 | p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers 40 | p_drop_attn: float = 0.1 # Probability of Dropout of Attention Layers 41 | max_len: int = 512 # Maximum Length for Positional Embeddings 42 | n_segments: int = 2 # Number of Sentence Segments 43 | 44 | @classmethod 45 | def from_json(cls, file): 46 | return cls(**json.load(open(file, "r"))) 47 | 48 | 49 | def gelu(x): 50 | "Implementation of the gelu activation function by Hugging Face" 51 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 52 | 53 | 54 | class LayerNorm(nn.Module): 55 | "A layernorm module in the TF style (epsilon inside the square root)." 56 | def __init__(self, cfg, variance_epsilon=1e-12): 57 | super().__init__() 58 | self.gamma = nn.Parameter(torch.ones(cfg.dim)) 59 | self.beta = nn.Parameter(torch.zeros(cfg.dim)) 60 | self.variance_epsilon = variance_epsilon 61 | 62 | def forward(self, x): 63 | u = x.mean(-1, keepdim=True) 64 | s = (x - u).pow(2).mean(-1, keepdim=True) 65 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 66 | return self.gamma * x + self.beta 67 | 68 | 69 | class Embeddings(nn.Module): 70 | "The embedding module from word, position and token_type embeddings." 71 | def __init__(self, cfg): 72 | super().__init__() 73 | self.tok_embed = nn.Embedding(cfg.vocab_size, cfg.dim) # token embedding 74 | self.pos_embed = nn.Embedding(cfg.max_len, cfg.dim) # position embedding 75 | self.seg_embed = nn.Embedding(cfg.n_segments, cfg.dim) # segment(token type) embedding 76 | 77 | self.norm = LayerNorm(cfg) 78 | self.drop = nn.Dropout(cfg.p_drop_hidden) 79 | 80 | def forward(self, x, seg): 81 | seq_len = x.size(1) 82 | pos = torch.arange(seq_len, dtype=torch.long, device=x.device) 83 | pos = pos.unsqueeze(0).expand_as(x) # (S,) -> (1, S) -> (B, S) 이렇게 외부에서 생성되는 값 84 | 85 | e = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg) 86 | return self.drop(self.norm(e)) 87 | 88 | 89 | class MultiHeadedSelfAttention(nn.Module): 90 | """ Multi-Headed Dot Product Attention """ 91 | def __init__(self, cfg): 92 | super().__init__() 93 | self.proj_q = nn.Linear(cfg.dim, cfg.dim) 94 | self.proj_k = nn.Linear(cfg.dim, cfg.dim) 95 | self.proj_v = nn.Linear(cfg.dim, cfg.dim) 96 | self.drop = nn.Dropout(cfg.p_drop_attn) 97 | self.scores = None # for visualization 98 | self.n_heads = cfg.n_heads 99 | 100 | def forward(self, x, mask): 101 | """ 102 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 103 | mask : (B(batch_size) x S(seq_len)) 104 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 105 | """ 106 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 107 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 108 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) 109 | for x in [q, k, v]) 110 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 111 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 112 | if mask is not None: 113 | mask = mask[:, None, None, :].float() 114 | scores -= 10000.0 * (1.0 - mask) 115 | scores = self.drop(F.softmax(scores, dim=-1)) 116 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 117 | h = (scores @ v).transpose(1, 2).contiguous() 118 | # -merge-> (B, S, D) 119 | h = merge_last(h, 2) 120 | self.scores = scores 121 | return h 122 | 123 | 124 | class PositionWiseFeedForward(nn.Module): 125 | """ FeedForward Neural Networks for each position """ 126 | def __init__(self, cfg): 127 | super().__init__() 128 | self.fc1 = nn.Linear(cfg.dim, cfg.dim_ff) 129 | self.fc2 = nn.Linear(cfg.dim_ff, cfg.dim) 130 | #self.activ = lambda x: activ_fn(cfg.activ_fn, x) 131 | 132 | def forward(self, x): 133 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 134 | return self.fc2(gelu(self.fc1(x))) 135 | 136 | 137 | class Block(nn.Module): 138 | """ Transformer Block """ 139 | def __init__(self, cfg): 140 | super().__init__() 141 | self.attn = MultiHeadedSelfAttention(cfg) 142 | self.proj = nn.Linear(cfg.dim, cfg.dim) 143 | self.norm1 = LayerNorm(cfg) 144 | self.pwff = PositionWiseFeedForward(cfg) 145 | self.norm2 = LayerNorm(cfg) 146 | self.drop = nn.Dropout(cfg.p_drop_hidden) 147 | 148 | def forward(self, x, mask): 149 | h = self.attn(x, mask) 150 | h = self.norm1(x + self.drop(self.proj(h))) 151 | h = self.norm2(h + self.drop(self.pwff(h))) 152 | return h 153 | 154 | 155 | class Transformer(nn.Module): 156 | """ Transformer with Self-Attentive Blocks""" 157 | def __init__(self, cfg): 158 | super().__init__() 159 | self.embed = Embeddings(cfg) 160 | self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) # h 번 반복 161 | 162 | def forward(self, x, seg, mask): 163 | h = self.embed(x, seg) 164 | for block in self.blocks: 165 | h = block(h, mask) 166 | return h 167 | 168 | 169 | class Classifier(nn.Module): 170 | """ Classifier with Transformer """ 171 | def __init__(self, cfg, n_labels): 172 | super().__init__() 173 | self.transformer = Transformer(cfg) 174 | self.fc = nn.Linear(cfg.dim, cfg.dim) 175 | self.activ = nn.Tanh() 176 | self.drop = nn.Dropout(cfg.p_drop_hidden) 177 | self.classifier = nn.Linear(cfg.dim, n_labels) 178 | 179 | def forward(self, input_ids, segment_ids, input_mask): 180 | h = self.transformer(input_ids, segment_ids, input_mask) 181 | # only use the first h in the sequence 182 | pooled_h = self.activ(self.fc(h[:, 0])) # 맨앞의 [CLS]만 뽑아내기 183 | logits = self.classifier(self.drop(pooled_h)) 184 | return logits 185 | 186 | class Opinion_extract(nn.Module): 187 | """ Opinion_extraction """ 188 | def __init__(self, cfg, max_len, n_labels): 189 | super().__init__() 190 | self.transformer = Transformer(cfg) 191 | self.fc = nn.Linear(cfg.dim, cfg.dim) 192 | self.activ = nn.Tanh() 193 | self.drop = nn.Dropout(cfg.p_drop_hidden) 194 | self.extract = nn.Linear(cfg.dim, n_labels) 195 | self.sigmoid = nn.Sigmoid() 196 | 197 | def forward(self, input_ids, segment_ids, input_mask): 198 | h = self.transformer(input_ids, segment_ids, input_mask) 199 | # 전체 시퀀스 길이 만큼 뽑아내기 200 | h = self.drop(self.activ(self.fc(h[:, 1:-1]))) 201 | seq_h = self.extract(h) 202 | seq_h = seq_h.squeeze() 203 | return self.sigmoid(seq_h) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 SanghunYun, Korea University. 2 | # (Strongly inspired by Dong-Hyun Lee, Kakao Brain) 3 | # 4 | # Except load and save function, the whole codes of file has been modified and added by 5 | # SanghunYun, Korea University for UDA. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import os 20 | import json 21 | from copy import deepcopy 22 | from typing import NamedTuple 23 | from tqdm import tqdm 24 | 25 | import torch 26 | import torch.nn as nn 27 | 28 | from utils import checkpoint 29 | # from utils.logger import Logger 30 | from tensorboardX import SummaryWriter 31 | from utils.utils import output_logging 32 | 33 | 34 | class Trainer(object): 35 | """Training Helper class""" 36 | def __init__(self, cfg, model, data_iter, optimizer, device): 37 | self.cfg = cfg 38 | self.model = model 39 | self.optimizer = optimizer 40 | self.device = device 41 | 42 | # data iter 43 | if len(data_iter) == 1: 44 | self.sup_iter = data_iter[0] 45 | elif len(data_iter) == 2: 46 | self.sup_iter = self.repeat_dataloader(data_iter[0]) 47 | self.unsup_iter = self.repeat_dataloader(data_iter[1]) 48 | elif len(data_iter) == 3: 49 | self.sup_iter = self.repeat_dataloader(data_iter[0]) 50 | self.unsup_iter = self.repeat_dataloader(data_iter[1]) 51 | self.eval_iter = data_iter[2] 52 | 53 | def train(self, get_loss, get_acc, model_file, pretrain_file): 54 | """ train uda""" 55 | 56 | # tensorboardX logging 57 | if self.cfg.results_dir: 58 | logger = SummaryWriter(log_dir=os.path.join(self.cfg.results_dir, 'logs')) 59 | 60 | self.model.train() 61 | self.load(model_file, pretrain_file) # between model_file and pretrain_file, only one model will be loaded 62 | model = self.model.to(self.device) 63 | if self.cfg.data_parallel: # Parallel GPU mode 64 | model = nn.DataParallel(model) 65 | 66 | global_step = 0 67 | loss_sum = 0. 68 | max_acc = [0., 0] # acc, step 69 | 70 | # Progress bar is set by unsup or sup data 71 | # uda_mode == True --> sup_iter is repeated 72 | # uda_mode == False --> sup_iter is not repeated 73 | iter_bar = tqdm(self.unsup_iter, total=self.cfg.total_steps) if self.cfg.uda_mode \ 74 | else tqdm(self.sup_iter, total=self.cfg.total_steps) 75 | for i, batch in enumerate(iter_bar): 76 | 77 | # Device assignment 78 | if self.cfg.uda_mode: 79 | sup_batch = [t.to(self.device) for t in next(self.sup_iter)] 80 | unsup_batch = [t.to(self.device) for t in batch] 81 | else: 82 | sup_batch = [t.to(self.device) for t in batch] 83 | unsup_batch = None 84 | 85 | # update 86 | self.optimizer.zero_grad() 87 | final_loss, sup_loss, unsup_loss = get_loss(model, sup_batch, unsup_batch, global_step) 88 | final_loss.backward() 89 | self.optimizer.step() 90 | 91 | # print loss 92 | global_step += 1 93 | loss_sum += final_loss.item() 94 | if self.cfg.uda_mode: 95 | iter_bar.set_description('final=%5.3f unsup=%5.3f sup=%5.3f'\ 96 | % (final_loss.item(), unsup_loss.item(), sup_loss.item())) 97 | else: 98 | iter_bar.set_description('loss=%5.3f' % (final_loss.item())) 99 | 100 | # logging 101 | if self.cfg.uda_mode: 102 | logger.add_scalars('data/scalar_group', 103 | {'final_loss': final_loss.item(), 104 | 'sup_loss': sup_loss.item(), 105 | 'unsup_loss': unsup_loss.item(), 106 | 'lr': self.optimizer.get_lr()[0] 107 | }, global_step) 108 | else: 109 | logger.add_scalars('data/scalar_group', 110 | {'sup_loss': final_loss.item()}, global_step) 111 | 112 | if global_step % self.cfg.save_steps == 0: 113 | self.save(global_step) 114 | 115 | if get_acc and global_step % self.cfg.check_steps == 0 and global_step > 4999: 116 | results = self.eval(get_acc, None, model) 117 | total_accuracy = torch.cat(results).mean().item() 118 | logger.add_scalars('data/scalar_group', {'eval_acc' : total_accuracy}, global_step) 119 | if max_acc[0] < total_accuracy: 120 | self.save(global_step) 121 | max_acc = total_accuracy, global_step 122 | print('Accuracy : %5.3f' % total_accuracy) 123 | print('Max Accuracy : %5.3f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[1], global_step), end='\n\n') 124 | 125 | if self.cfg.total_steps and self.cfg.total_steps < global_step: 126 | print('The total steps have been reached') 127 | print('Average Loss %5.3f' % (loss_sum/(i+1))) 128 | if get_acc: 129 | results = self.eval(get_acc, None, model) 130 | total_accuracy = torch.cat(results).mean().item() 131 | logger.add_scalars('data/scalar_group', {'eval_acc' : total_accuracy}, global_step) 132 | if max_acc[0] < total_accuracy: 133 | max_acc = total_accuracy, global_step 134 | print('Accuracy :', total_accuracy) 135 | print('Max Accuracy : %5.3f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[1], global_step), end='\n\n') 136 | self.save(global_step) 137 | return 138 | return global_step 139 | 140 | def eval(self, evaluate, model_file, model): 141 | """ evaluation function """ 142 | if model_file: 143 | self.model.eval() 144 | self.load(model_file, None) 145 | model = self.model.to(self.device) 146 | if self.cfg.data_parallel: 147 | model = nn.DataParallel(model) 148 | 149 | results = [] 150 | iter_bar = tqdm(self.sup_iter) if model_file \ 151 | else tqdm(deepcopy(self.eval_iter)) 152 | for batch in iter_bar: 153 | batch = [t.to(self.device) for t in batch] 154 | 155 | with torch.no_grad(): 156 | accuracy, result = evaluate(model, batch) 157 | results.append(result) 158 | 159 | iter_bar.set_description('Eval Acc=%5.3f' % accuracy) 160 | return results 161 | 162 | def load(self, model_file, pretrain_file): 163 | """ between model_file and pretrain_file, only one model will be loaded """ 164 | if model_file: 165 | print('Loading the model from', model_file) 166 | if torch.cuda.is_available(): 167 | self.model.load_state_dict(torch.load(model_file)) 168 | else: 169 | self.model.load_state_dict(torch.load(model_file, map_location='cpu')) 170 | 171 | elif pretrain_file: 172 | print('Loading the pretrained model from', pretrain_file) 173 | if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow 174 | checkpoint.load_model(self.model.transformer, pretrain_file) 175 | elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch 176 | self.model.transformer.load_state_dict( 177 | {key[12:]: value 178 | for key, value in torch.load(pretrain_file).items() 179 | if key.startswith('transformer')} 180 | ) # load only transformer parts 181 | 182 | def save(self, i): 183 | """ save model """ 184 | if not os.path.isdir(os.path.join(self.cfg.results_dir, 'save')): 185 | os.makedirs(os.path.join(self.cfg.results_dir, 'save')) 186 | torch.save(self.model.state_dict(), 187 | os.path.join(self.cfg.results_dir, 'save', 'model_steps_'+str(i)+'.pt')) 188 | 189 | def repeat_dataloader(self, iterable): 190 | """ repeat dataloader """ 191 | while True: 192 | for x in iterable: 193 | yield x 194 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanghunYun/UDA_pytorch/0ba5cf8d8a6f698e19a295119f084a17dfa7a1e3/utils/__init__.py -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | import torch 18 | 19 | def load_param(checkpoint_file, conversion_table): 20 | """ 21 | Load parameters in pytorch model from checkpoint file according to conversion_table 22 | checkpoint_file : pretrained checkpoint model file in tensorflow 23 | cnoversion_table : { pytorch tensor in a model : checkpoint variable name } 24 | """ 25 | 26 | for pyt_param, tf_param_name in conversion_table.items(): 27 | tf_param = tf.train.load_variable(checkpoint_file, tf_param_name) 28 | 29 | # for weight(kernel), we should do transpose --> pytorch, tensorflow 다름 30 | if tf_param_name.endswith('kernel'): 31 | tf_param = np.transpose(tf_param) 32 | 33 | assert pyt_param.size() == tf_param.shape, \ 34 | 'Dim Mismatch: %s vs %s ; %s' % (tuple(pyt_param.size()), tf_param.shape, tf_param_name) 35 | 36 | # assign pytorch tensor from tensorflow param 37 | pyt_param.data = torch.from_numpy(tf_param) 38 | 39 | 40 | def load_model(model, checkpoint_file): 41 | """Load the pytorch model from checkpoint file""" 42 | 43 | # Embedding layer 44 | e, p = model.embed, 'bert/embeddings/' 45 | load_param(checkpoint_file, { 46 | e.tok_embed.weight: p+'word_embeddings', 47 | e.pos_embed.weight: p+'position_embeddings', 48 | e.seg_embed.weight: p+'token_type_embeddings', 49 | e.norm.gamma: p+'LayerNorm/gamma', 50 | e.norm.beta: p+'LayerNorm/beta' 51 | }) 52 | 53 | # Transformer blocks 54 | for i in range(len(model.blocks)): 55 | b, p = model.blocks[i], "bert/encoder/layer_%d/"%i 56 | load_param(checkpoint_file, { 57 | b.attn.proj_q.weight: p+"attention/self/query/kernel", 58 | b.attn.proj_q.bias: p+"attention/self/query/bias", 59 | b.attn.proj_k.weight: p+"attention/self/key/kernel", 60 | b.attn.proj_k.bias: p+"attention/self/key/bias", 61 | b.attn.proj_v.weight: p+"attention/self/value/kernel", 62 | b.attn.proj_v.bias: p+"attention/self/value/bias", 63 | b.proj.weight: p+"attention/output/dense/kernel", 64 | b.proj.bias: p+"attention/output/dense/bias", 65 | b.pwff.fc1.weight: p+"intermediate/dense/kernel", 66 | b.pwff.fc1.bias: p+"intermediate/dense/bias", 67 | b.pwff.fc2.weight: p+"output/dense/kernel", 68 | b.pwff.fc2.bias: p+"output/dense/bias", 69 | b.norm1.gamma: p+"attention/output/LayerNorm/gamma", 70 | b.norm1.beta: p+"attention/output/LayerNorm/beta", 71 | b.norm2.gamma: p+"output/LayerNorm/gamma", 72 | b.norm2.beta: p+"output/LayerNorm/beta", 73 | }) 74 | 75 | -------------------------------------------------------------------------------- /utils/configuration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 SanghunYun, Korea University. 2 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain. 3 | # (Strongly inspired by original Google BERT code and Hugging Face's code) 4 | # 5 | # SanghunYun, Korea University refered Dong-Hyun Lee, Kakao Brain's code (class model) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | import json 21 | from typing import NamedTuple 22 | 23 | 24 | class params(NamedTuple): 25 | 26 | ############################ guide ############################# 27 | # lr(learning rate) : fine_tune(2e-5), futher-train(1.5e-4~2e-5) 28 | # mode : train, eval, test 29 | # uda_mode : True, False 30 | # total_steps : n_epochs * n_examples / 3 31 | # max_seq_length : 128, 256, 512 32 | # unsup_ratio : more than 3 33 | # uda_softmax_temp : more than 0.5 34 | # uda_confidence_temp : ?? 35 | # tsa : linear_schedule 36 | ################################################################ 37 | 38 | # train 39 | seed: int = 1421 40 | lr: int = 2e-5 # lr_scheduled = lr * factor 41 | # n_epochs: int = 3 42 | warmup: float = 0.1 # warmup steps = total_steps * warmup 43 | do_lower_case: bool = True 44 | mode: str = None # train, eval, test 45 | uda_mode: bool = False # True, False 46 | 47 | total_steps: int = 100000 # total_steps >= n_epcohs * n_examples / 3 48 | max_seq_length: int = 128 49 | train_batch_size: int = 32 50 | eval_batch_size: int = 8 51 | 52 | # unsup 53 | unsup_ratio: int = 0 # unsup_batch_size = unsup_ratio * sup_batch_size 54 | uda_coeff: int = 1 # total_loss = sup_loss + uda_coeff*unsup_loss 55 | tsa: str = 'linear_schedule' # log, linear, exp 56 | uda_softmax_temp: float = -1 # 0 ~ 1 57 | uda_confidence_thresh: float = -1 # 0 ~ 1 58 | 59 | # data 60 | data_parallel: bool = True 61 | need_prepro: bool = False # is data already preprocessed? 62 | sup_data_dir: str = None 63 | unsup_data_dir: str = None 64 | eval_data_dir: str = None 65 | n_sup: int = None 66 | n_unsup: int = None 67 | 68 | model_file: str = None # fine-tuned 69 | pretrain_file: str = None # pre-trained 70 | vocab: str = None 71 | task: str = None 72 | 73 | # results 74 | save_steps: int = 100 75 | check_steps: int = 10 76 | results_dir: str = None 77 | 78 | # appendix 79 | is_position: bool = False # appendix not used 80 | 81 | @classmethod 82 | def from_json(cls, file): 83 | return cls(**json.load(open(file, 'r'))) 84 | 85 | 86 | class pretrain(NamedTuple): 87 | seed: int = 3232 88 | batch_size: int = 32 89 | lr: int = 1.5e-4 90 | n_epochs: int = 100 91 | warmup: float = 0.1 92 | save_steps: int = 100 93 | total_steps: int = 100000 94 | results_dir : str = None 95 | 96 | # do not change 97 | uda_mode: bool = False 98 | 99 | @classmethod 100 | def from_json(cls, file): 101 | return cls(**json.load(open(file, 'r'))) 102 | 103 | 104 | 105 | class model(NamedTuple): 106 | "Configuration for BERT model" 107 | vocab_size: int = None # Size of Vocabulary 108 | dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder 109 | n_layers: int = 12 # Numher of Hidden Layers 110 | n_heads: int = 12 # Numher of Heads in Multi-Headed Attention Layers 111 | dim_ff: int = 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net 112 | #activ_fn: str = "gelu" # Non-linear Activation Function Type in Hidden Layers 113 | p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers 114 | p_drop_attn: float = 0.1 # Probability of Dropout of Attention Layers 115 | max_len: int = 512 # Maximum Length for Positional Embeddings 116 | n_segments: int = 2 # Number of Sentence Segments 117 | 118 | @classmethod 119 | def from_json(cls, file): 120 | return cls(**json.load(open(file, 'r'))) -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 2 | # and Dong-Hyun Lee, Kakao Brain. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """ a slightly modified version of Hugging Face's BERTAdam class """ 18 | 19 | import math 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.nn.utils import clip_grad_norm_ 23 | 24 | def warmup_cosine(x, warmup=0.002): 25 | if x < warmup: 26 | return x/warmup 27 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 28 | 29 | def warmup_constant(x, warmup=0.002): 30 | if x < warmup: 31 | return x/warmup 32 | return 1.0 33 | 34 | def warmup_linear(x, warmup=0.002): 35 | if x < warmup: 36 | return x/warmup 37 | return 1.0 - x 38 | 39 | SCHEDULES = { 40 | 'warmup_cosine':warmup_cosine, 41 | 'warmup_constant':warmup_constant, 42 | 'warmup_linear':warmup_linear, 43 | } 44 | 45 | class BertAdam(Optimizer): 46 | """Implements BERT version of Adam algorithm with weight decay fix. 47 | Params: 48 | lr: learning rate 49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 50 | t_total: total number of training steps for the learning 51 | rate schedule, -1 means constant learning rate. Default: -1 52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 53 | b1: Adams b1. Default: 0.9 54 | b2: Adams b2. Default: 0.999 55 | e: Adams epsilon. Default: 1e-6 56 | weight_decay_rate: Weight decay. Default: 0.01 57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 58 | """ 59 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 60 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 61 | max_grad_norm=1.0): 62 | assert lr > 0.0, "Learning rate: %f - should be > 0.0" % (lr) 63 | assert schedule in SCHEDULES, "Invalid schedule : %s" % (schedule) 64 | assert 0.0 <= warmup < 1.0 or warmup == -1.0, \ 65 | "Warmup %f - should be in 0.0 ~ 1.0 or -1 (no warm up)" % (warmup) 66 | assert 0.0 <= b1 < 1.0, "b1: %f - should be in 0.0 ~ 1.0" % (b1) 67 | assert 0.0 <= b2 < 1.0, "b2: %f - should be in 0.0 ~ 1.0" % (b2) 68 | assert e > 0.0, "epsilon: %f - should be > 0.0" % (e) 69 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 70 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 71 | max_grad_norm=max_grad_norm) 72 | super(BertAdam, self).__init__(params, defaults) 73 | 74 | def get_lr(self): 75 | """ get learning rate in training """ 76 | lr = [] 77 | for group in self.param_groups: 78 | for p in group['params']: 79 | state = self.state[p] 80 | if not state: 81 | return [0] 82 | if group['t_total'] != -1: 83 | schedule_fct = SCHEDULES[group['schedule']] 84 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 85 | else: 86 | lr_scheduled = group['lr'] 87 | lr.append(lr_scheduled) 88 | return lr 89 | 90 | def step(self, closure=None): 91 | """Performs a single optimization step. 92 | 93 | Arguments: 94 | closure (callable, optional): A closure that reevaluates the model 95 | and returns the loss. 96 | """ 97 | loss = None 98 | if closure is not None: 99 | loss = closure() 100 | 101 | for group in self.param_groups: 102 | for p in group['params']: 103 | if p.grad is None: 104 | continue 105 | grad = p.grad.data 106 | if grad.is_sparse: 107 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 108 | 109 | state = self.state[p] 110 | 111 | # State initialization 112 | if not state: 113 | state['step'] = 0 114 | # Exponential moving average of gradient values 115 | state['next_m'] = torch.zeros_like(p.data) 116 | # Exponential moving average of squared gradient values 117 | state['next_v'] = torch.zeros_like(p.data) 118 | 119 | next_m, next_v = state['next_m'], state['next_v'] 120 | beta1, beta2 = group['b1'], group['b2'] 121 | 122 | # Add grad clipping 123 | if group['max_grad_norm'] > 0: 124 | clip_grad_norm_(p, group['max_grad_norm']) 125 | 126 | # Decay the first and second moment running average coefficient 127 | # In-place operations to update the averages at the same time 128 | next_m.mul_(beta1).add_(1 - beta1, grad) 129 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 130 | update = next_m / (next_v.sqrt() + group['e']) 131 | 132 | # Just adding the square of the weights to the loss function is *not* 133 | # the correct way of using L2 regularization/weight decay with Adam, 134 | # since that will interact with the m and v parameters in strange ways. 135 | # 136 | # Instead we want to decay the weights in a manner that doesn't interact 137 | # with the m/v parameters. This is equivalent to adding the square 138 | # of the weights to the loss with plain (non-momentum) SGD. 139 | if group['weight_decay_rate'] > 0.0: 140 | update += group['weight_decay_rate'] * p.data 141 | 142 | if group['t_total'] != -1: 143 | schedule_fct = SCHEDULES[group['schedule']] 144 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 145 | else: 146 | lr_scheduled = group['lr'] 147 | 148 | update_with_lr = lr_scheduled * update 149 | p.data.add_(-update_with_lr) 150 | 151 | state['step'] += 1 152 | 153 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 154 | # No bias correction 155 | # bias_correction1 = 1 - beta1 ** state['step'] 156 | # bias_correction2 = 1 - beta2 ** state['step'] 157 | 158 | return loss 159 | 160 | 161 | 162 | def optim4GPU(cfg, model): 163 | """ optimizer for GPU training """ 164 | param_optimizer = list(model.named_parameters()) 165 | no_decay = ['bias', 'gamma', 'beta'] 166 | optimizer_grouped_parameters = [ 167 | {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01}, 168 | {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}] 169 | return BertAdam(optimizer_grouped_parameters, 170 | lr=cfg.lr, 171 | warmup=cfg.warmup, 172 | t_total=cfg.total_steps) 173 | -------------------------------------------------------------------------------- /utils/tokenization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """ Tokenization classes (It's exactly the same code as Google BERT code """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import unicodedata 24 | import re 25 | import six 26 | 27 | 28 | def convert_to_unicode(text): 29 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 30 | if six.PY3: 31 | if isinstance(text, str): 32 | return text 33 | elif isinstance(text, bytes): 34 | return text.decode("utf-8", "ignore") 35 | else: 36 | raise ValueError("Unsupported string type: %s" % (type(text))) 37 | elif six.PY2: 38 | if isinstance(text, str): 39 | return text.decode("utf-8", "ignore") 40 | elif isinstance(text, unicode): 41 | return text 42 | else: 43 | raise ValueError("Unsupported string type: %s" % (type(text))) 44 | else: 45 | raise ValueError("Not running on Python2 or Python 3?") 46 | 47 | 48 | def printable_text(text): 49 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 50 | 51 | # These functions want `str` for both Python2 and Python3, but in one case 52 | # it's a Unicode string and in the other it's a byte string. 53 | if six.PY3: 54 | if isinstance(text, str): 55 | return text 56 | elif isinstance(text, bytes): 57 | return text.decode("utf-8", "ignore") 58 | else: 59 | raise ValueError("Unsupported string type: %s" % (type(text))) 60 | elif six.PY2: 61 | if isinstance(text, str): 62 | return text 63 | 64 | elif isinstance(text, unicode): 65 | return text.encode("utf-8") 66 | else: 67 | raise ValueError("Unsupported string type: %s" % (type(text))) 68 | else: 69 | raise ValueError("Not running on Python2 or Python 3?") 70 | 71 | 72 | def load_vocab(vocab_file): 73 | """Loads a vocabulary file into a dictionary.""" 74 | vocab = collections.OrderedDict() 75 | index = 0 76 | with open(vocab_file, "r") as reader: 77 | while True: 78 | token = convert_to_unicode(reader.readline()) 79 | if not token: 80 | break 81 | token = token.strip() # sanghun 특수문자 제거 문장만 추출 82 | vocab[token] = index # index 부여 83 | index += 1 84 | return vocab 85 | 86 | 87 | def convert_tokens_to_ids(vocab, tokens): 88 | """Converts a sequence of tokens into ids using the vocab.""" 89 | ids = [] 90 | for token in tokens: 91 | ids.append(vocab[token]) 92 | return ids 93 | 94 | 95 | def whitespace_tokenize(text): 96 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 97 | text = text.strip() 98 | if not text: 99 | return [] 100 | tokens = text.split() 101 | return tokens 102 | 103 | 104 | class FullTokenizer(object): 105 | """Runs end-to-end tokenziation.""" 106 | 107 | def __init__(self, vocab_file, do_lower_case=True): 108 | self.vocab = load_vocab(vocab_file) # vocab -> idx 109 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 110 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 111 | 112 | def tokenize(self, text): 113 | split_tokens = [] 114 | for token in self.basic_tokenizer.tokenize(text): 115 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 116 | split_tokens.append(sub_token) 117 | 118 | return split_tokens 119 | 120 | def convert_tokens_to_ids(self, tokens): 121 | return convert_tokens_to_ids(self.vocab, tokens) 122 | 123 | def convert_to_unicode(self, text): 124 | return convert_to_unicode(text) 125 | 126 | 127 | 128 | class BasicTokenizer(object): 129 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 130 | 131 | def __init__(self, do_lower_case=True): 132 | """Constructs a BasicTokenizer. 133 | 134 | Args: 135 | do_lower_case: Whether to lower case the input. 136 | """ 137 | self.do_lower_case = do_lower_case 138 | 139 | def tokenize(self, text): 140 | """Tokenizes a piece of text.""" 141 | text = convert_to_unicode(text) 142 | text = self._clean_text(text) 143 | orig_tokens = whitespace_tokenize(text) 144 | split_tokens = [] 145 | for token in orig_tokens: 146 | if self.do_lower_case: 147 | token = token.lower() 148 | token = self._run_strip_accents(token) 149 | split_tokens.extend(self._run_split_on_punc(token)) 150 | 151 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 152 | return output_tokens 153 | 154 | def _run_strip_accents(self, text): 155 | """Strips accents from a piece of text.""" 156 | 157 | korean = "%s-%s%s-%s" % (chr(0xac00), chr(0xd7a3), 158 | chr(0x3131), chr(0x3163)) 159 | if re.search("[%s]+" % korean, text): 160 | return "".join( 161 | substr if re.search("^[%s]+$" % korean, substr) 162 | else self._run_strip_accents(substr) 163 | for substr in re.findall("[%s]+|[^%s]+" % (korean, korean), text) 164 | ) 165 | 166 | text = unicodedata.normalize("NFD", text) 167 | output = [] 168 | for char in text: 169 | cat = unicodedata.category(char) 170 | if cat == "Mn": 171 | continue 172 | output.append(char) 173 | return "".join(output) 174 | 175 | def _run_split_on_punc(self, text): 176 | """Splits punctuation on a piece of text.""" 177 | chars = list(text) 178 | i = 0 179 | start_new_word = True 180 | output = [] 181 | while i < len(chars): 182 | char = chars[i] 183 | if _is_punctuation(char): 184 | output.append([char]) 185 | start_new_word = True 186 | else: 187 | if start_new_word: 188 | output.append([]) 189 | start_new_word = False 190 | output[-1].append(char) 191 | i += 1 192 | 193 | return ["".join(x) for x in output] 194 | 195 | def _clean_text(self, text): 196 | """Performs invalid character removal and whitespace cleanup on text.""" 197 | output = [] 198 | for char in text: 199 | cp = ord(char) 200 | if cp == 0 or cp == 0xfffd or _is_control(char): 201 | continue 202 | if _is_whitespace(char): 203 | output.append(" ") 204 | else: 205 | output.append(char) 206 | return "".join(output) 207 | 208 | 209 | class WordpieceTokenizer(object): 210 | """Runs WordPiece tokenization.""" 211 | 212 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 213 | self.vocab = vocab 214 | self.unk_token = unk_token # UNK token : unknown 출현빈도가 낮은 단어 대체 215 | self.max_input_chars_per_word = max_input_chars_per_word 216 | 217 | def tokenize(self, text): 218 | """Tokenizes a piece of text into its word pieces. 219 | 220 | This uses a greedy longest-match-first algorithm to perform tokenization 221 | using the given vocabulary. 222 | 223 | For example: 224 | input = "unaffable" 225 | output = ["un", "##aff", "##able"] 226 | 227 | Args: 228 | text: A single token or whitespace separated tokens. This should have 229 | already been passed through `BasicTokenizer. 230 | 231 | Returns: 232 | A list of wordpiece tokens. 233 | """ 234 | 235 | text = convert_to_unicode(text) 236 | 237 | output_tokens = [] 238 | for token in whitespace_tokenize(text): 239 | chars = list(token) 240 | if len(chars) > self.max_input_chars_per_word: 241 | output_tokens.append(self.unk_token) 242 | continue 243 | 244 | is_bad = False 245 | start = 0 246 | sub_tokens = [] 247 | while start < len(chars): 248 | end = len(chars) 249 | cur_substr = None 250 | while start < end: 251 | substr = "".join(chars[start:end]) 252 | if start > 0: 253 | substr = "##" + substr 254 | if substr in self.vocab: 255 | cur_substr = substr 256 | break 257 | end -= 1 258 | if cur_substr is None: 259 | is_bad = True 260 | break 261 | sub_tokens.append(cur_substr) 262 | start = end 263 | 264 | if is_bad: 265 | output_tokens.append(self.unk_token) 266 | else: 267 | output_tokens.extend(sub_tokens) 268 | return output_tokens 269 | 270 | 271 | def _is_whitespace(char): 272 | """Checks whether `chars` is a whitespace character.""" 273 | # \t, \n, and \r are technically contorl characters but we treat them 274 | # as whitespace since they are generally considered as such. 275 | if char == " " or char == "\t" or char == "\n" or char == "\r": 276 | return True 277 | cat = unicodedata.category(char) 278 | if cat == "Zs": 279 | return True 280 | return False 281 | 282 | 283 | def _is_control(char): 284 | """Checks whether `chars` is a control character.""" 285 | # These are technically control characters but we count them as whitespace 286 | # characters. 287 | if char == "\t" or char == "\n" or char == "\r": 288 | return False 289 | cat = unicodedata.category(char) 290 | if cat.startswith("C"): 291 | return True 292 | return False 293 | 294 | 295 | def _is_punctuation(char): 296 | """Checks whether `chars` is a punctuation character.""" 297 | cp = ord(char) 298 | # We treat all non-letter/number ASCII as punctuation. 299 | # Characters such as "^", "$", and "`" are not in the Unicode 300 | # Punctuation class but we treat them as punctuation anyways, for 301 | # consistency. 302 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 303 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 304 | return True 305 | cat = unicodedata.category(char) 306 | if cat.startswith("P"): 307 | return True 308 | return False 309 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 SanghunYun, Korea University. 2 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain. 3 | # 4 | # This file has been modified by SanghunYun, Korea University 5 | # for add fucntion of _get_device and class of output_logging. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | import os 21 | import csv 22 | import random 23 | import logging 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def torch_device_one(): 30 | return torch.tensor(1.).to(_get_device()) 31 | 32 | def set_seeds(seed): 33 | "set random seeds" 34 | random.seed(seed) 35 | np.random.seed(seed) 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed_all(seed) 38 | 39 | def get_device(): 40 | "get device (CPU or GPU)" 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | n_gpu = torch.cuda.device_count() 43 | print("%s (%d GPUs)" % (device, n_gpu)) 44 | return device 45 | 46 | def _get_device(): 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | return device 49 | 50 | def split_last(x, shape): 51 | "split the last dimension to given shape" 52 | shape = list(shape) 53 | assert shape.count(-1) <= 1 54 | if -1 in shape: 55 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 56 | return x.view(*x.size()[:-1], *shape) 57 | 58 | def merge_last(x, n_dims): 59 | "merge the last n_dims to a dimension" 60 | s = x.size() 61 | assert n_dims > 1 and n_dims < len(s) 62 | return x.view(*s[:-n_dims], -1) 63 | 64 | def truncate_tokens_pair(tokens_a, tokens_b, max_len): 65 | while True: 66 | if len(tokens_a) + len(tokens_b) <= max_len: 67 | break 68 | if len(tokens_a) > len(tokens_b): 69 | tokens_a.pop() 70 | else: 71 | tokens_b.pop() 72 | 73 | def get_random_word(vocab_words): 74 | i = random.randint(0, len(vocab_words)-1) 75 | return vocab_words[i] 76 | 77 | def get_logger(name, log_path): 78 | "get logger" 79 | logger = logging.getLogger(name) 80 | fomatter = logging.Formatter( 81 | '[ %(levelname)s|%(filename)s:%(lineno)s] %(asctime)s > %(message)s') 82 | 83 | if not os.path.isfile(log_path): 84 | f = open(log_path, "w+") 85 | 86 | fileHandler = logging.FileHandler(log_path) 87 | fileHandler.setFormatter(fomatter) 88 | logger.addHandler(fileHandler) 89 | 90 | #streamHandler = logging.StreamHandler() 91 | #streamHandler.setFormatter(fomatter) 92 | #logger.addHandler(streamHandler) 93 | 94 | logger.setLevel(logging.DEBUG) 95 | return logger 96 | 97 | 98 | class output_logging(object): 99 | def __init__(self, mode, real_time=False, dump_dir=None): 100 | self.mode = mode 101 | self.real_time = real_time 102 | self.dump_dir = dump_dir if dump_dir else None 103 | 104 | if dump_dir: 105 | self.dump = open(os.path.join(dump_dir, 'logs/output.tsv'), 'w', encoding='utf-8', newline='') 106 | self.wr = csv.writer(self.dump, delimiter='\t') 107 | 108 | # header 109 | if mode == 'eval': 110 | self.wr.writerow(['Ground_truth', 'Predcit', 'sentence']) 111 | elif mode == 'test': 112 | self.wr.writerow(['Predict', 'sentence']) 113 | 114 | def __del__(self): 115 | if self.dump_dir: 116 | self.dump.close() 117 | 118 | def logs(self, sentence, pred, ground_turth=None): 119 | if self.real_time: 120 | if self.mode == 'eval': 121 | for p, g, s in zip(pred, ground_turth, sentence): 122 | print('Ground_truth | Predict') 123 | print(int(g), ' ', int(p)) 124 | print(s, end='\n\n') 125 | elif self.mode == 'test': 126 | for p, s in zip(pred, sentence): 127 | print('predict : ', int(p)) 128 | print(s, end='\n\n') 129 | 130 | if self.dump_dir: 131 | if self.mode == 'eval': 132 | for p, g, s in zip(pred, ground_turth, sentence): 133 | self.wr.writerow([int(p), int(g), s]) 134 | elif self.mode == 'test': 135 | for p, s in zip(pred, sentence): 136 | self.wr.writerow([int(p), s]) 137 | --------------------------------------------------------------------------------