├── .github └── workflows │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── electra.png ├── electra_pytorch ├── __init__.py └── electra_pytorch.py ├── examples └── glue │ ├── download.py │ ├── metrics.py │ ├── processors.py │ ├── run.py │ └── utils.py ├── pretraining └── openwebtext │ ├── arg.py │ ├── dataset.py │ ├── preprocess.py │ ├── pretrain.py │ ├── small_discriminator.json │ ├── small_generator.json │ └── tokenization.py ├── setup.py └── tests └── test_electra_pytorch.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install setuptools wheel twine pytest torch 30 | python setup.py install 31 | - name: Test 32 | run: | 33 | python setup.py test 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Data 132 | data 133 | output -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Electra - Pytorch 4 | 5 | A simple working wrapper for fast pretraining of language models as detailed in this paper. It speeds up training (in comparison to normal masked language modeling) by a factor of 4x, and eventually reaches better performance if trained for even longer. Special thanks to Erik Nijkamp for taking the time to replicate the results for GLUE. 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install electra-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | The following example uses `reformer-pytorch`, which is available to be pip installed. 16 | 17 | ```python 18 | import torch 19 | from torch import nn 20 | from reformer_pytorch import ReformerLM 21 | 22 | from electra_pytorch import Electra 23 | 24 | # (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator 25 | 26 | generator = ReformerLM( 27 | num_tokens = 20000, 28 | emb_dim = 128, 29 | dim = 256, # smaller hidden dimension 30 | heads = 4, # less heads 31 | ff_mult = 2, # smaller feed forward intermediate dimension 32 | dim_head = 64, 33 | depth = 12, 34 | max_seq_len = 1024 35 | ) 36 | 37 | discriminator = ReformerLM( 38 | num_tokens = 20000, 39 | emb_dim = 128, 40 | dim = 1024, 41 | dim_head = 64, 42 | heads = 16, 43 | depth = 12, 44 | ff_mult = 4, 45 | max_seq_len = 1024 46 | ) 47 | 48 | # (2) weight tie the token and positional embeddings of generator and discriminator 49 | 50 | generator.token_emb = discriminator.token_emb 51 | generator.pos_emb = discriminator.pos_emb 52 | # weight tie any other embeddings if available, token type embeddings, etc. 53 | 54 | # (3) instantiate electra 55 | 56 | trainer = Electra( 57 | generator, 58 | discriminator, 59 | discr_dim = 1024, # the embedding dimension of the discriminator 60 | discr_layer = 'reformer', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced 61 | mask_token_id = 2, # the token id reserved for masking 62 | pad_token_id = 0, # the token id for padding 63 | mask_prob = 0.15, # masking probability for masked language modeling 64 | mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep) 65 | ) 66 | 67 | # (4) train 68 | 69 | data = torch.randint(0, 20000, (1, 1024)) 70 | 71 | results = trainer(data) 72 | results.loss.backward() 73 | 74 | # after much training, the discriminator should have improved 75 | 76 | torch.save(discriminator, f'./pretrained-model.pt') 77 | ``` 78 | 79 | If you would rather not have the framework auto-magically intercept the hidden output of the discriminator, you can pass in the discriminator (with the extra linear [dim x 1]) by yourself with the following. 80 | 81 | ```python 82 | import torch 83 | from torch import nn 84 | from reformer_pytorch import ReformerLM 85 | 86 | from electra_pytorch import Electra 87 | 88 | # (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator 89 | 90 | generator = ReformerLM( 91 | num_tokens = 20000, 92 | emb_dim = 128, 93 | dim = 256, # smaller hidden dimension 94 | heads = 4, # less heads 95 | ff_mult = 2, # smaller feed forward intermediate dimension 96 | dim_head = 64, 97 | depth = 12, 98 | max_seq_len = 1024 99 | ) 100 | 101 | discriminator = ReformerLM( 102 | num_tokens = 20000, 103 | emb_dim = 128, 104 | dim = 1024, 105 | dim_head = 64, 106 | heads = 16, 107 | depth = 12, 108 | ff_mult = 4, 109 | max_seq_len = 1024, 110 | return_embeddings = True 111 | ) 112 | 113 | # (2) weight tie the token and positional embeddings of generator and discriminator 114 | 115 | generator.token_emb = discriminator.token_emb 116 | generator.pos_emb = discriminator.pos_emb 117 | # weight tie any other embeddings if available, token type embeddings, etc. 118 | 119 | # (3) instantiate electra 120 | 121 | discriminator_with_adapter = nn.Sequential(discriminator, nn.Linear(1024, 1)) 122 | 123 | trainer = Electra( 124 | generator, 125 | discriminator_with_adapter, 126 | mask_token_id = 2, # the token id reserved for masking 127 | pad_token_id = 0, # the token id for padding 128 | mask_prob = 0.15, # masking probability for masked language modeling 129 | mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep) 130 | ) 131 | 132 | # (4) train 133 | 134 | data = torch.randint(0, 20000, (1, 1024)) 135 | 136 | results = trainer(data) 137 | results.loss.backward() 138 | 139 | # after much training, the discriminator should have improved 140 | 141 | torch.save(discriminator, f'./pretrained-model.pt') 142 | ``` 143 | 144 | ## Important details for successful training 145 | 146 | The generator should be roughly a quarter to at most one half of the discriminator's size for effective training. Any greater and the generator will be too good and the adversarial game collapses. This was done by reducing the hidden dimension, feed forward hidden dimension, and number of attention heads in the paper. 147 | 148 | ## Testing 149 | 150 | ```bash 151 | $ python setup.py test 152 | ``` 153 | 154 | ## Training 155 | 156 | 1. Download the [OpenWebText](https://github.com/jcpeterson/openwebtext) dataset. 157 | 158 | ```bash 159 | $ mkdir data 160 | $ cd data 161 | $ pip3 install gdown 162 | $ gdown --id 1EA5V0oetDCOke7afsktL_JDQ-ETtNOvx 163 | $ tar -xf openwebtext.tar.xz 164 | $ wget https://storage.googleapis.com/electra-data/vocab.txt 165 | $ cd .. 166 | ``` 167 | 168 | 2. Tokenize dataset. 169 | 170 | ```bash 171 | $ python pretraining/openwebtext/preprocess.py 172 | ``` 173 | 174 | 3. Pre-train. 175 | 176 | ```bash 177 | $ python pretraining/openwebtext/pretrain.py 178 | ``` 179 | 180 | 4. Download GLUE dataset. 181 | 182 | ```bash 183 | $ python examples/glue/download.py 184 | ``` 185 | 186 | 5. Fine-tune on the MRPC sub-task of the GLUE benchmark. 187 | 188 | ```bash 189 | $ python examples/glue/run.py --model_name_or_path output/yyyy-mm-dd-hh-mm-ss/ckpt/200000 190 | ``` 191 | 192 | ## Citations 193 | 194 | ```bibtex 195 | @misc{clark2020electra, 196 | title={ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators}, 197 | author={Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning}, 198 | year={2020}, 199 | eprint={2003.10555}, 200 | archivePrefix={arXiv}, 201 | primaryClass={cs.CL} 202 | } 203 | ``` 204 | -------------------------------------------------------------------------------- /electra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/electra-pytorch/5b8bae5c3575b7529891c1b878c24688f57d7ca1/electra.png -------------------------------------------------------------------------------- /electra_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from electra_pytorch.electra_pytorch import Electra 2 | -------------------------------------------------------------------------------- /electra_pytorch/electra_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import reduce 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | # constants 10 | 11 | Results = namedtuple('Results', [ 12 | 'loss', 13 | 'mlm_loss', 14 | 'disc_loss', 15 | 'gen_acc', 16 | 'disc_acc', 17 | 'disc_labels', 18 | 'disc_predictions' 19 | ]) 20 | 21 | # helpers 22 | 23 | def log(t, eps=1e-9): 24 | return torch.log(t + eps) 25 | 26 | def gumbel_noise(t): 27 | noise = torch.zeros_like(t).uniform_(0, 1) 28 | return -log(-log(noise)) 29 | 30 | def gumbel_sample(t, temperature = 1.): 31 | return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1) 32 | 33 | def prob_mask_like(t, prob): 34 | return torch.zeros_like(t).float().uniform_(0, 1) < prob 35 | 36 | def mask_with_tokens(t, token_ids): 37 | init_no_mask = torch.full_like(t, False, dtype=torch.bool) 38 | mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) 39 | return mask 40 | 41 | def get_mask_subset_with_prob(mask, prob): 42 | batch, seq_len, device = *mask.shape, mask.device 43 | max_masked = math.ceil(prob * seq_len) 44 | 45 | num_tokens = mask.sum(dim=-1, keepdim=True) 46 | mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil()) 47 | mask_excess = mask_excess[:, :max_masked] 48 | 49 | rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) 50 | _, sampled_indices = rand.topk(max_masked, dim=-1) 51 | sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) 52 | 53 | new_mask = torch.zeros((batch, seq_len + 1), device=device) 54 | new_mask.scatter_(-1, sampled_indices, 1) 55 | return new_mask[:, 1:].bool() 56 | 57 | # hidden layer extractor class, for magically adding adapter to language model to be pretrained 58 | 59 | class HiddenLayerExtractor(nn.Module): 60 | def __init__(self, net, layer = -2): 61 | super().__init__() 62 | self.net = net 63 | self.layer = layer 64 | 65 | self.hidden = None 66 | self.hook_registered = False 67 | 68 | def _find_layer(self): 69 | if type(self.layer) == str: 70 | modules = dict([*self.net.named_modules()]) 71 | return modules.get(self.layer, None) 72 | elif type(self.layer) == int: 73 | children = [*self.net.children()] 74 | return children[self.layer] 75 | return None 76 | 77 | def _hook(self, _, __, output): 78 | self.hidden = output 79 | 80 | def _register_hook(self): 81 | layer = self._find_layer() 82 | assert layer is not None, f'hidden layer ({self.layer}) not found' 83 | handle = layer.register_forward_hook(self._hook) 84 | self.hook_registered = True 85 | 86 | def forward(self, x): 87 | if self.layer == -1: 88 | return self.net(x) 89 | 90 | if not self.hook_registered: 91 | self._register_hook() 92 | 93 | _ = self.net(x) 94 | hidden = self.hidden 95 | self.hidden = None 96 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output' 97 | return hidden 98 | 99 | # main electra class 100 | 101 | class Electra(nn.Module): 102 | def __init__( 103 | self, 104 | generator, 105 | discriminator, 106 | *, 107 | num_tokens = None, 108 | discr_dim = -1, 109 | discr_layer = -1, 110 | mask_prob = 0.15, 111 | replace_prob = 0.85, 112 | random_token_prob = 0., 113 | mask_token_id = 2, 114 | pad_token_id = 0, 115 | mask_ignore_token_ids = [], 116 | disc_weight = 50., 117 | gen_weight = 1., 118 | temperature = 1.): 119 | super().__init__() 120 | 121 | self.generator = generator 122 | self.discriminator = discriminator 123 | 124 | if discr_dim > 0: 125 | self.discriminator = nn.Sequential( 126 | HiddenLayerExtractor(discriminator, layer = discr_layer), 127 | nn.Linear(discr_dim, 1) 128 | ) 129 | 130 | # mlm related probabilities 131 | self.mask_prob = mask_prob 132 | self.replace_prob = replace_prob 133 | 134 | self.num_tokens = num_tokens 135 | self.random_token_prob = random_token_prob 136 | 137 | # token ids 138 | self.pad_token_id = pad_token_id 139 | self.mask_token_id = mask_token_id 140 | self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id]) 141 | 142 | # sampling temperature 143 | self.temperature = temperature 144 | 145 | # loss weights 146 | self.disc_weight = disc_weight 147 | self.gen_weight = gen_weight 148 | 149 | 150 | def forward(self, input, **kwargs): 151 | b, t = input.shape 152 | 153 | replace_prob = prob_mask_like(input, self.replace_prob) 154 | 155 | # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep]) 156 | # also do not include these special tokens in the tokens chosen at random 157 | no_mask = mask_with_tokens(input, self.mask_ignore_token_ids) 158 | mask = get_mask_subset_with_prob(~no_mask, self.mask_prob) 159 | 160 | # get mask indices 161 | mask_indices = torch.nonzero(mask, as_tuple=True) 162 | 163 | # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob) 164 | masked_input = input.clone().detach() 165 | 166 | # set inverse of mask to padding tokens for labels 167 | gen_labels = input.masked_fill(~mask, self.pad_token_id) 168 | 169 | # clone the mask, for potential modification if random tokens are involved 170 | # not to be mistakened for the mask above, which is for all tokens, whether not replaced nor replaced with random tokens 171 | masking_mask = mask.clone() 172 | 173 | # if random token probability > 0 for mlm 174 | if self.random_token_prob > 0: 175 | assert self.num_tokens is not None, 'Number of tokens (num_tokens) must be passed to Electra for randomizing tokens during masked language modeling' 176 | 177 | random_token_prob = prob_mask_like(input, self.random_token_prob) 178 | random_tokens = torch.randint(0, self.num_tokens, input.shape, device=input.device) 179 | random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids) 180 | random_token_prob &= ~random_no_mask 181 | masked_input = torch.where(random_token_prob, random_tokens, masked_input) 182 | 183 | # remove random token prob mask from masking mask 184 | masking_mask = masking_mask & ~random_token_prob 185 | 186 | # [mask] input 187 | masked_input = masked_input.masked_fill(masking_mask * replace_prob, self.mask_token_id) 188 | 189 | # get generator output and get mlm loss 190 | logits = self.generator(masked_input, **kwargs) 191 | 192 | mlm_loss = F.cross_entropy( 193 | logits.transpose(1, 2), 194 | gen_labels, 195 | ignore_index = self.pad_token_id 196 | ) 197 | 198 | # use mask from before to select logits that need sampling 199 | sample_logits = logits[mask_indices] 200 | 201 | # sample 202 | sampled = gumbel_sample(sample_logits, temperature = self.temperature) 203 | 204 | # scatter the sampled values back to the input 205 | disc_input = input.clone() 206 | disc_input[mask_indices] = sampled.detach() 207 | 208 | # generate discriminator labels, with replaced as True and original as False 209 | disc_labels = (input != disc_input).float().detach() 210 | 211 | # get discriminator predictions of replaced / original 212 | non_padded_indices = torch.nonzero(input != self.pad_token_id, as_tuple=True) 213 | 214 | # get discriminator output and binary cross entropy loss 215 | disc_logits = self.discriminator(disc_input, **kwargs) 216 | disc_logits = disc_logits.reshape_as(disc_labels) 217 | 218 | disc_loss = F.binary_cross_entropy_with_logits( 219 | disc_logits[non_padded_indices], 220 | disc_labels[non_padded_indices] 221 | ) 222 | 223 | # gather metrics 224 | with torch.no_grad(): 225 | gen_predictions = torch.argmax(logits, dim=-1) 226 | disc_predictions = torch.round((torch.sign(disc_logits) + 1.0) * 0.5) 227 | gen_acc = (gen_labels[mask] == gen_predictions[mask]).float().mean() 228 | disc_acc = 0.5 * (disc_labels[mask] == disc_predictions[mask]).float().mean() + 0.5 * (disc_labels[~mask] == disc_predictions[~mask]).float().mean() 229 | 230 | # return weighted sum of losses 231 | return Results(self.gen_weight * mlm_loss + self.disc_weight * disc_loss, mlm_loss, disc_loss, gen_acc, disc_acc, disc_labels, disc_predictions) 232 | -------------------------------------------------------------------------------- /examples/glue/download.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | 3 | Note: for legal reasons, we are unable to host MRPC. 4 | You can either use the version hosted by the SentEval team, which is already tokenized, 5 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 6 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 7 | You should then rename and place specific files in a folder (see below for an example). 8 | 9 | mkdir MRPC 10 | cabextract MSRParaphraseCorpus.msi -d MRPC 11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 13 | rm MRPC/_* 14 | rm MSRParaphraseCorpus.msi 15 | 16 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 17 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 18 | ''' 19 | 20 | import os 21 | import sys 22 | import shutil 23 | import argparse 24 | import tempfile 25 | import urllib.request 26 | import zipfile 27 | 28 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 29 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 30 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 31 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 32 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 33 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 34 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 35 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 36 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 37 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 38 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 39 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 40 | 41 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 42 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 43 | 44 | def download_and_extract(task, data_dir): 45 | print("Downloading and extracting %s..." % task) 46 | data_file = "%s.zip" % task 47 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 48 | with zipfile.ZipFile(data_file) as zip_ref: 49 | zip_ref.extractall(data_dir) 50 | os.remove(data_file) 51 | print("\tCompleted!") 52 | 53 | def format_mrpc(data_dir, path_to_data): 54 | print("Processing MRPC...") 55 | mrpc_dir = os.path.join(data_dir, "MRPC") 56 | if not os.path.isdir(mrpc_dir): 57 | os.mkdir(mrpc_dir) 58 | if path_to_data: 59 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 60 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 61 | else: 62 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 63 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 64 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 65 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 66 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 67 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 68 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 69 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 70 | 71 | dev_ids = [] 72 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 73 | for row in ids_fh: 74 | dev_ids.append(row.strip().split('\t')) 75 | 76 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 77 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 78 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 79 | header = data_fh.readline() 80 | train_fh.write(header) 81 | dev_fh.write(header) 82 | for row in data_fh: 83 | label, id1, id2, s1, s2 = row.strip().split('\t') 84 | if [id1, id2] in dev_ids: 85 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 86 | else: 87 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 88 | 89 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 90 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 91 | header = data_fh.readline() 92 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 93 | for idx, row in enumerate(data_fh): 94 | label, id1, id2, s1, s2 = row.strip().split('\t') 95 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 96 | print("\tCompleted!") 97 | 98 | def download_diagnostic(data_dir): 99 | print("Downloading and extracting diagnostic...") 100 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 101 | os.mkdir(os.path.join(data_dir, "diagnostic")) 102 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 103 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 104 | print("\tCompleted!") 105 | return 106 | 107 | def get_tasks(task_names): 108 | task_names = task_names.split(',') 109 | if "all" in task_names: 110 | tasks = TASKS 111 | else: 112 | tasks = [] 113 | for task_name in task_names: 114 | assert task_name in TASKS, "Task %s not found!" % task_name 115 | tasks.append(task_name) 116 | return tasks 117 | 118 | def main(arguments): 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='./data/glue_data') 121 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 122 | type=str, default='all') 123 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 124 | type=str, default='') 125 | args = parser.parse_args(arguments) 126 | 127 | if not os.path.exists(args.data_dir): 128 | os.makedirs(args.data_dir) 129 | tasks = get_tasks(args.tasks) 130 | 131 | for task in tasks: 132 | if task == 'MRPC': 133 | format_mrpc(args.data_dir, args.path_to_mrpc) 134 | elif task == 'diagnostic': 135 | download_diagnostic(args.data_dir) 136 | else: 137 | download_and_extract(task, args.data_dir) 138 | 139 | 140 | if __name__ == '__main__': 141 | sys.exit(main(sys.argv[1:])) -------------------------------------------------------------------------------- /examples/glue/metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | try: 18 | from scipy.stats import pearsonr, spearmanr 19 | from sklearn.metrics import matthews_corrcoef, f1_score 20 | 21 | _has_sklearn = True 22 | except (AttributeError, ImportError): 23 | _has_sklearn = False 24 | 25 | 26 | def is_sklearn_available(): 27 | return _has_sklearn 28 | 29 | 30 | if _has_sklearn: 31 | 32 | def simple_accuracy(preds, labels): 33 | return (preds == labels).mean() 34 | 35 | def acc_and_f1(preds, labels): 36 | acc = simple_accuracy(preds, labels) 37 | f1 = f1_score(y_true=labels, y_pred=preds) 38 | return { 39 | "acc": acc, 40 | "f1": f1, 41 | "acc_and_f1": (acc + f1) / 2, 42 | } 43 | 44 | def pearson_and_spearman(preds, labels): 45 | pearson_corr = pearsonr(preds, labels)[0] 46 | spearman_corr = spearmanr(preds, labels)[0] 47 | return { 48 | "pearson": pearson_corr, 49 | "spearmanr": spearman_corr, 50 | "corr": (pearson_corr + spearman_corr) / 2, 51 | } 52 | 53 | def glue_compute_metrics(task_name, preds, labels): 54 | assert len(preds) == len(labels) 55 | if task_name == "cola": 56 | return {"mcc": matthews_corrcoef(labels, preds)} 57 | elif task_name == "sst-2": 58 | return {"acc": simple_accuracy(preds, labels)} 59 | elif task_name == "mrpc": 60 | return acc_and_f1(preds, labels) 61 | elif task_name == "sts-b": 62 | return pearson_and_spearman(preds, labels) 63 | elif task_name == "qqp": 64 | return acc_and_f1(preds, labels) 65 | elif task_name == "mnli": 66 | return {"acc": simple_accuracy(preds, labels)} 67 | elif task_name == "mnli-mm": 68 | return {"acc": simple_accuracy(preds, labels)} 69 | elif task_name == "qnli": 70 | return {"acc": simple_accuracy(preds, labels)} 71 | elif task_name == "rte": 72 | return {"acc": simple_accuracy(preds, labels)} 73 | elif task_name == "wnli": 74 | return {"acc": simple_accuracy(preds, labels)} 75 | elif task_name == "hans": 76 | return {"acc": simple_accuracy(preds, labels)} 77 | else: 78 | raise KeyError(task_name) 79 | 80 | def xnli_compute_metrics(task_name, preds, labels): 81 | assert len(preds) == len(labels) 82 | if task_name == "xnli": 83 | return {"acc": simple_accuracy(preds, labels)} 84 | else: 85 | raise KeyError(task_name) -------------------------------------------------------------------------------- /examples/glue/processors.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ GLUE processors and helpers """ 17 | 18 | import logging 19 | import os 20 | 21 | # from ...file_utils import is_tf_available 22 | from utils import DataProcessor, InputExample, InputFeatures 23 | 24 | is_tf_available = lambda: False 25 | 26 | if is_tf_available(): 27 | import tensorflow as tf 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | def glue_convert_examples_to_features( 33 | examples, 34 | tokenizer, 35 | max_length=512, 36 | task=None, 37 | label_list=None, 38 | output_mode=None, 39 | pad_on_left=False, 40 | pad_token=0, 41 | pad_token_segment_id=0, 42 | mask_padding_with_zero=True, 43 | ): 44 | """ 45 | Loads a data file into a list of ``InputFeatures`` 46 | 47 | Args: 48 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples. 49 | tokenizer: Instance of a tokenizer that will tokenize the examples 50 | max_length: Maximum example length 51 | task: GLUE task 52 | label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method 53 | output_mode: String indicating the output mode. Either ``regression`` or ``classification`` 54 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) 55 | pad_token: Padding token 56 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4) 57 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values 58 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for 59 | actual values) 60 | 61 | Returns: 62 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset`` 63 | containing the task-specific features. If the input is a list of ``InputExamples``, will return 64 | a list of task-specific ``InputFeatures`` which can be fed to the model. 65 | 66 | """ 67 | is_tf_dataset = False 68 | if is_tf_available() and isinstance(examples, tf.data.Dataset): 69 | is_tf_dataset = True 70 | 71 | if task is not None: 72 | processor = glue_processors[task]() 73 | if label_list is None: 74 | label_list = processor.get_labels() 75 | logger.info("Using label list %s for task %s" % (label_list, task)) 76 | if output_mode is None: 77 | output_mode = glue_output_modes[task] 78 | logger.info("Using output mode %s for task %s" % (output_mode, task)) 79 | 80 | label_map = {label: i for i, label in enumerate(label_list)} 81 | 82 | features = [] 83 | for (ex_index, example) in enumerate(examples): 84 | len_examples = 0 85 | if is_tf_dataset: 86 | example = processor.get_example_from_tensor_dict(example) 87 | example = processor.tfds_map(example) 88 | len_examples = tf.data.experimental.cardinality(examples) 89 | else: 90 | len_examples = len(examples) 91 | if ex_index % 10000 == 0: 92 | logger.info("Writing example %d/%d" % (ex_index, len_examples)) 93 | 94 | inputs = tokenizer.encode_plus( 95 | example.text_a, example.text_b, add_special_tokens=True, max_length=max_length, return_token_type_ids=True, 96 | ) 97 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] 98 | 99 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 100 | # tokens are attended to. 101 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 102 | 103 | # Zero-pad up to the sequence length. 104 | padding_length = max_length - len(input_ids) 105 | if pad_on_left: 106 | input_ids = ([pad_token] * padding_length) + input_ids 107 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 108 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 109 | else: 110 | input_ids = input_ids + ([pad_token] * padding_length) 111 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 112 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 113 | 114 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) 115 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format( 116 | len(attention_mask), max_length 117 | ) 118 | assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format( 119 | len(token_type_ids), max_length 120 | ) 121 | 122 | if output_mode == "classification": 123 | label = label_map[example.label] 124 | elif output_mode == "regression": 125 | label = float(example.label) 126 | else: 127 | raise KeyError(output_mode) 128 | 129 | if ex_index < 5: 130 | logger.info("*** Example ***") 131 | logger.info("guid: %s" % (example.guid)) 132 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 133 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 134 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) 135 | logger.info("label: %s (id = %d)" % (example.label, label)) 136 | 137 | features.append( 138 | InputFeatures( 139 | input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label 140 | ) 141 | ) 142 | 143 | if is_tf_available() and is_tf_dataset: 144 | 145 | def gen(): 146 | for ex in features: 147 | yield ( 148 | { 149 | "input_ids": ex.input_ids, 150 | "attention_mask": ex.attention_mask, 151 | "token_type_ids": ex.token_type_ids, 152 | }, 153 | ex.label, 154 | ) 155 | 156 | return tf.data.Dataset.from_generator( 157 | gen, 158 | ({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64), 159 | ( 160 | { 161 | "input_ids": tf.TensorShape([None]), 162 | "attention_mask": tf.TensorShape([None]), 163 | "token_type_ids": tf.TensorShape([None]), 164 | }, 165 | tf.TensorShape([]), 166 | ), 167 | ) 168 | 169 | return features 170 | 171 | 172 | class MrpcProcessor(DataProcessor): 173 | """Processor for the MRPC data set (GLUE version).""" 174 | 175 | def get_example_from_tensor_dict(self, tensor_dict): 176 | """See base class.""" 177 | return InputExample( 178 | tensor_dict["idx"].numpy(), 179 | tensor_dict["sentence1"].numpy().decode("utf-8"), 180 | tensor_dict["sentence2"].numpy().decode("utf-8"), 181 | str(tensor_dict["label"].numpy()), 182 | ) 183 | 184 | def get_train_examples(self, data_dir): 185 | """See base class.""" 186 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 187 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 188 | 189 | def get_dev_examples(self, data_dir): 190 | """See base class.""" 191 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 192 | 193 | def get_labels(self): 194 | """See base class.""" 195 | return ["0", "1"] 196 | 197 | def _create_examples(self, lines, set_type): 198 | """Creates examples for the training and dev sets.""" 199 | examples = [] 200 | for (i, line) in enumerate(lines): 201 | if i == 0: 202 | continue 203 | guid = "%s-%s" % (set_type, i) 204 | text_a = line[3] 205 | text_b = line[4] 206 | label = line[0] 207 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 208 | return examples 209 | 210 | 211 | class MnliProcessor(DataProcessor): 212 | """Processor for the MultiNLI data set (GLUE version).""" 213 | 214 | def get_example_from_tensor_dict(self, tensor_dict): 215 | """See base class.""" 216 | return InputExample( 217 | tensor_dict["idx"].numpy(), 218 | tensor_dict["premise"].numpy().decode("utf-8"), 219 | tensor_dict["hypothesis"].numpy().decode("utf-8"), 220 | str(tensor_dict["label"].numpy()), 221 | ) 222 | 223 | def get_train_examples(self, data_dir): 224 | """See base class.""" 225 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 226 | 227 | def get_dev_examples(self, data_dir): 228 | """See base class.""" 229 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") 230 | 231 | def get_labels(self): 232 | """See base class.""" 233 | return ["contradiction", "entailment", "neutral"] 234 | 235 | def _create_examples(self, lines, set_type): 236 | """Creates examples for the training and dev sets.""" 237 | examples = [] 238 | for (i, line) in enumerate(lines): 239 | if i == 0: 240 | continue 241 | guid = "%s-%s" % (set_type, line[0]) 242 | text_a = line[8] 243 | text_b = line[9] 244 | label = line[-1] 245 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 246 | return examples 247 | 248 | 249 | class MnliMismatchedProcessor(MnliProcessor): 250 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 251 | 252 | def get_dev_examples(self, data_dir): 253 | """See base class.""" 254 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched") 255 | 256 | 257 | class ColaProcessor(DataProcessor): 258 | """Processor for the CoLA data set (GLUE version).""" 259 | 260 | def get_example_from_tensor_dict(self, tensor_dict): 261 | """See base class.""" 262 | return InputExample( 263 | tensor_dict["idx"].numpy(), 264 | tensor_dict["sentence"].numpy().decode("utf-8"), 265 | None, 266 | str(tensor_dict["label"].numpy()), 267 | ) 268 | 269 | def get_train_examples(self, data_dir): 270 | """See base class.""" 271 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 272 | 273 | def get_dev_examples(self, data_dir): 274 | """See base class.""" 275 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 276 | 277 | def get_labels(self): 278 | """See base class.""" 279 | return ["0", "1"] 280 | 281 | def _create_examples(self, lines, set_type): 282 | """Creates examples for the training and dev sets.""" 283 | examples = [] 284 | for (i, line) in enumerate(lines): 285 | guid = "%s-%s" % (set_type, i) 286 | text_a = line[3] 287 | label = line[1] 288 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 289 | return examples 290 | 291 | 292 | class Sst2Processor(DataProcessor): 293 | """Processor for the SST-2 data set (GLUE version).""" 294 | 295 | def get_example_from_tensor_dict(self, tensor_dict): 296 | """See base class.""" 297 | return InputExample( 298 | tensor_dict["idx"].numpy(), 299 | tensor_dict["sentence"].numpy().decode("utf-8"), 300 | None, 301 | str(tensor_dict["label"].numpy()), 302 | ) 303 | 304 | def get_train_examples(self, data_dir): 305 | """See base class.""" 306 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 307 | 308 | def get_dev_examples(self, data_dir): 309 | """See base class.""" 310 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 311 | 312 | def get_labels(self): 313 | """See base class.""" 314 | return ["0", "1"] 315 | 316 | def _create_examples(self, lines, set_type): 317 | """Creates examples for the training and dev sets.""" 318 | examples = [] 319 | for (i, line) in enumerate(lines): 320 | if i == 0: 321 | continue 322 | guid = "%s-%s" % (set_type, i) 323 | text_a = line[0] 324 | label = line[1] 325 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 326 | return examples 327 | 328 | 329 | class StsbProcessor(DataProcessor): 330 | """Processor for the STS-B data set (GLUE version).""" 331 | 332 | def get_example_from_tensor_dict(self, tensor_dict): 333 | """See base class.""" 334 | return InputExample( 335 | tensor_dict["idx"].numpy(), 336 | tensor_dict["sentence1"].numpy().decode("utf-8"), 337 | tensor_dict["sentence2"].numpy().decode("utf-8"), 338 | str(tensor_dict["label"].numpy()), 339 | ) 340 | 341 | def get_train_examples(self, data_dir): 342 | """See base class.""" 343 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 344 | 345 | def get_dev_examples(self, data_dir): 346 | """See base class.""" 347 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 348 | 349 | def get_labels(self): 350 | """See base class.""" 351 | return [None] 352 | 353 | def _create_examples(self, lines, set_type): 354 | """Creates examples for the training and dev sets.""" 355 | examples = [] 356 | for (i, line) in enumerate(lines): 357 | if i == 0: 358 | continue 359 | guid = "%s-%s" % (set_type, line[0]) 360 | text_a = line[7] 361 | text_b = line[8] 362 | label = line[-1] 363 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 364 | return examples 365 | 366 | 367 | class QqpProcessor(DataProcessor): 368 | """Processor for the QQP data set (GLUE version).""" 369 | 370 | def get_example_from_tensor_dict(self, tensor_dict): 371 | """See base class.""" 372 | return InputExample( 373 | tensor_dict["idx"].numpy(), 374 | tensor_dict["question1"].numpy().decode("utf-8"), 375 | tensor_dict["question2"].numpy().decode("utf-8"), 376 | str(tensor_dict["label"].numpy()), 377 | ) 378 | 379 | def get_train_examples(self, data_dir): 380 | """See base class.""" 381 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 382 | 383 | def get_dev_examples(self, data_dir): 384 | """See base class.""" 385 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 386 | 387 | def get_labels(self): 388 | """See base class.""" 389 | return ["0", "1"] 390 | 391 | def _create_examples(self, lines, set_type): 392 | """Creates examples for the training and dev sets.""" 393 | examples = [] 394 | for (i, line) in enumerate(lines): 395 | if i == 0: 396 | continue 397 | guid = "%s-%s" % (set_type, line[0]) 398 | try: 399 | text_a = line[3] 400 | text_b = line[4] 401 | label = line[5] 402 | except IndexError: 403 | continue 404 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 405 | return examples 406 | 407 | 408 | class QnliProcessor(DataProcessor): 409 | """Processor for the QNLI data set (GLUE version).""" 410 | 411 | def get_example_from_tensor_dict(self, tensor_dict): 412 | """See base class.""" 413 | return InputExample( 414 | tensor_dict["idx"].numpy(), 415 | tensor_dict["question"].numpy().decode("utf-8"), 416 | tensor_dict["sentence"].numpy().decode("utf-8"), 417 | str(tensor_dict["label"].numpy()), 418 | ) 419 | 420 | def get_train_examples(self, data_dir): 421 | """See base class.""" 422 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 423 | 424 | def get_dev_examples(self, data_dir): 425 | """See base class.""" 426 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched") 427 | 428 | def get_labels(self): 429 | """See base class.""" 430 | return ["entailment", "not_entailment"] 431 | 432 | def _create_examples(self, lines, set_type): 433 | """Creates examples for the training and dev sets.""" 434 | examples = [] 435 | for (i, line) in enumerate(lines): 436 | if i == 0: 437 | continue 438 | guid = "%s-%s" % (set_type, line[0]) 439 | text_a = line[1] 440 | text_b = line[2] 441 | label = line[-1] 442 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 443 | return examples 444 | 445 | 446 | class RteProcessor(DataProcessor): 447 | """Processor for the RTE data set (GLUE version).""" 448 | 449 | def get_example_from_tensor_dict(self, tensor_dict): 450 | """See base class.""" 451 | return InputExample( 452 | tensor_dict["idx"].numpy(), 453 | tensor_dict["sentence1"].numpy().decode("utf-8"), 454 | tensor_dict["sentence2"].numpy().decode("utf-8"), 455 | str(tensor_dict["label"].numpy()), 456 | ) 457 | 458 | def get_train_examples(self, data_dir): 459 | """See base class.""" 460 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 461 | 462 | def get_dev_examples(self, data_dir): 463 | """See base class.""" 464 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 465 | 466 | def get_labels(self): 467 | """See base class.""" 468 | return ["entailment", "not_entailment"] 469 | 470 | def _create_examples(self, lines, set_type): 471 | """Creates examples for the training and dev sets.""" 472 | examples = [] 473 | for (i, line) in enumerate(lines): 474 | if i == 0: 475 | continue 476 | guid = "%s-%s" % (set_type, line[0]) 477 | text_a = line[1] 478 | text_b = line[2] 479 | label = line[-1] 480 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 481 | return examples 482 | 483 | 484 | class WnliProcessor(DataProcessor): 485 | """Processor for the WNLI data set (GLUE version).""" 486 | 487 | def get_example_from_tensor_dict(self, tensor_dict): 488 | """See base class.""" 489 | return InputExample( 490 | tensor_dict["idx"].numpy(), 491 | tensor_dict["sentence1"].numpy().decode("utf-8"), 492 | tensor_dict["sentence2"].numpy().decode("utf-8"), 493 | str(tensor_dict["label"].numpy()), 494 | ) 495 | 496 | def get_train_examples(self, data_dir): 497 | """See base class.""" 498 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 499 | 500 | def get_dev_examples(self, data_dir): 501 | """See base class.""" 502 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 503 | 504 | def get_labels(self): 505 | """See base class.""" 506 | return ["0", "1"] 507 | 508 | def _create_examples(self, lines, set_type): 509 | """Creates examples for the training and dev sets.""" 510 | examples = [] 511 | for (i, line) in enumerate(lines): 512 | if i == 0: 513 | continue 514 | guid = "%s-%s" % (set_type, line[0]) 515 | text_a = line[1] 516 | text_b = line[2] 517 | label = line[-1] 518 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 519 | return examples 520 | 521 | 522 | glue_tasks_num_labels = { 523 | "cola": 2, 524 | "mnli": 3, 525 | "mrpc": 2, 526 | "sst-2": 2, 527 | "sts-b": 1, 528 | "qqp": 2, 529 | "qnli": 2, 530 | "rte": 2, 531 | "wnli": 2, 532 | } 533 | 534 | glue_processors = { 535 | "cola": ColaProcessor, 536 | "mnli": MnliProcessor, 537 | "mnli-mm": MnliMismatchedProcessor, 538 | "mrpc": MrpcProcessor, 539 | "sst-2": Sst2Processor, 540 | "sts-b": StsbProcessor, 541 | "qqp": QqpProcessor, 542 | "qnli": QnliProcessor, 543 | "rte": RteProcessor, 544 | "wnli": WnliProcessor, 545 | } 546 | 547 | glue_output_modes = { 548 | "cola": "classification", 549 | "mnli": "classification", 550 | "mnli-mm": "classification", 551 | "mrpc": "classification", 552 | "sst-2": "classification", 553 | "sts-b": "regression", 554 | "qqp": "classification", 555 | "qnli": "classification", 556 | "rte": "classification", 557 | "wnli": "classification", 558 | } -------------------------------------------------------------------------------- /examples/glue/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" 17 | 18 | 19 | import argparse 20 | import glob 21 | import json 22 | import logging 23 | import os 24 | import random 25 | 26 | import numpy as np 27 | import torch 28 | import torch.nn as nn 29 | from torch.optim import AdamW 30 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 31 | from torch.utils.data.distributed import DistributedSampler 32 | from tqdm import tqdm, trange 33 | 34 | from metrics import glue_compute_metrics as compute_metrics 35 | from processors import glue_convert_examples_to_features as convert_examples_to_features 36 | from processors import glue_output_modes as output_modes 37 | from processors import glue_processors as processors 38 | from processors import glue_tasks_num_labels as task_num_labels 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | ################################################## 43 | # adapters for Google-like GLUE code 44 | 45 | class TokenizerAdapter: 46 | def __init__(self, tokenizer, pad_token, cls_token="[CLS]", sep_token="[SEP]"): 47 | self.tokenizer = tokenizer 48 | self.pad_token = pad_token 49 | self.cls_token = cls_token 50 | self.sep_token = sep_token 51 | 52 | def convert_tokens_to_ids(self, tokens): 53 | return self.tokenizer.convert_tokens_to_ids(tokens) 54 | 55 | 56 | def truncate_sequences( 57 | self, 58 | ids, 59 | pair_ids, 60 | num_tokens_to_remove, 61 | truncation_strategy, 62 | stride, 63 | ): 64 | 65 | assert len(ids) > num_tokens_to_remove 66 | window_len = min(len(ids), stride + num_tokens_to_remove) 67 | overflowing_tokens = ids[-window_len:] 68 | ids = ids[:-num_tokens_to_remove] 69 | 70 | return (ids, pair_ids, overflowing_tokens) 71 | 72 | def encode_plus(self, text, text_pair, add_special_tokens, max_length, return_token_type_ids): 73 | 74 | # Tokenization 75 | token_ids_0 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) 76 | len_ids = len(token_ids_0) 77 | if text_pair: 78 | token_ids_1 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text_pair)) 79 | len_pair_ids = len(token_ids_1) 80 | else: 81 | token_ids_1 = None 82 | len_pair_ids = 0 83 | 84 | 85 | # Truncation 86 | assert add_special_tokens 87 | num_special_tokens_to_add = (2 if not text_pair else 3) 88 | total_len = len_ids + len_pair_ids + num_special_tokens_to_add 89 | if max_length and total_len > max_length: 90 | token_ids_0, token_ids_1, overflowing_tokens = self.truncate_sequences( 91 | token_ids_0, 92 | pair_ids=token_ids_1, 93 | num_tokens_to_remove=total_len - max_length, 94 | truncation_strategy='only_first', # TODO(nijkamp): is this the correct truncation strategy for all GLUE tasks? 95 | stride=0, 96 | ) 97 | 98 | 99 | # Add special tokens 100 | cls = [self.tokenizer.vocab[self.cls_token]] 101 | sep = [self.tokenizer.vocab[self.sep_token]] 102 | 103 | if not text_pair: 104 | 105 | input_ids = cls + token_ids_0 + sep 106 | token_type_ids = len(cls + token_ids_0 + sep) * [0] 107 | 108 | else: 109 | 110 | input_ids = cls + token_ids_0 + sep + token_ids_1 + sep 111 | token_type_ids = len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 112 | 113 | assert len(input_ids) <= max_length 114 | 115 | return {"input_ids": input_ids, "token_type_ids": token_type_ids} 116 | 117 | def __len__(self): 118 | return len(self.tokenizer.vocab) 119 | 120 | def save_pretrained(self, outputdir): 121 | pass 122 | 123 | def wrap_tokenizer(tokenizer, pad_token): 124 | return TokenizerAdapter(tokenizer, pad_token) 125 | 126 | 127 | ################################################## 128 | # distilled Google-like/HF glue code 129 | 130 | def set_seed(args): 131 | random.seed(args.seed) 132 | np.random.seed(args.seed) 133 | torch.manual_seed(args.seed) 134 | if args.n_gpu > 0: 135 | torch.cuda.manual_seed_all(args.seed) 136 | 137 | 138 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 139 | """ Create a schedule with a learning rate that decreases linearly after 140 | linearly increasing during a warmup period. 141 | """ 142 | 143 | def lr_lambda(current_step): 144 | if current_step < num_warmup_steps: 145 | return float(current_step) / float(max(1, num_warmup_steps)) 146 | return max( 147 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 148 | ) 149 | 150 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 151 | 152 | 153 | def train(args, train_dataset, model, tokenizer): 154 | """ Train the model """ 155 | 156 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 157 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 158 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 159 | 160 | if args.max_steps > 0: 161 | t_total = args.max_steps 162 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 163 | else: 164 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 165 | 166 | # Prepare optimizer and schedule (linear warmup and decay) 167 | no_decay = ["bias", "LayerNorm.weight"] 168 | optimizer_grouped_parameters = [ 169 | { 170 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 171 | "weight_decay": args.weight_decay, 172 | }, 173 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 174 | ] 175 | 176 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 177 | scheduler = get_linear_schedule_with_warmup( 178 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 179 | ) 180 | 181 | # Check if saved optimizer or scheduler states exist 182 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 183 | os.path.join(args.model_name_or_path, "scheduler.pt") 184 | ): 185 | # Load in optimizer and scheduler states 186 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 187 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 188 | 189 | if args.fp16: 190 | try: 191 | from apex import amp 192 | except ImportError: 193 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 194 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 195 | 196 | # multi-gpu training (should be after apex fp16 initialization) 197 | if args.n_gpu > 1: 198 | model = torch.nn.DataParallel(model) 199 | 200 | # Distributed training (should be after apex fp16 initialization) 201 | if args.local_rank != -1: 202 | model = torch.nn.parallel.DistributedDataParallel( 203 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, 204 | ) 205 | 206 | # Train! 207 | logger.info("***** Running training *****") 208 | logger.info(" Num examples = %d", len(train_dataset)) 209 | logger.info(" Num Epochs = %d", args.num_train_epochs) 210 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 211 | logger.info( 212 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 213 | args.train_batch_size 214 | * args.gradient_accumulation_steps 215 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 216 | ) 217 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 218 | logger.info(" Total optimization steps = %d", t_total) 219 | 220 | global_step = 0 221 | epochs_trained = 0 222 | steps_trained_in_current_epoch = 0 223 | # Check if continuing training from a checkpoint 224 | if os.path.exists(args.model_name_or_path): 225 | # set global_step to global_step of last saved checkpoint from model path 226 | try: 227 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) 228 | except ValueError: 229 | global_step = 0 230 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 231 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 232 | 233 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 234 | logger.info(" Continuing training from epoch %d", epochs_trained) 235 | logger.info(" Continuing training from global step %d", global_step) 236 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 237 | 238 | tr_loss, logging_loss = 0.0, 0.0 239 | model.zero_grad() 240 | train_iterator = trange( 241 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], 242 | ) 243 | set_seed(args) # Added here for reproductibility 244 | for _ in train_iterator: 245 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 246 | for step, batch in enumerate(epoch_iterator): 247 | 248 | # Skip past any already trained steps if resuming training 249 | if steps_trained_in_current_epoch > 0: 250 | steps_trained_in_current_epoch -= 1 251 | continue 252 | 253 | model.train() 254 | batch = tuple(t.to(args.device) for t in batch) 255 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} 256 | inputs["token_type_ids"] = ( 257 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None 258 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 259 | outputs = model(**inputs) 260 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 261 | 262 | if args.n_gpu > 1: 263 | loss = loss.mean() # mean() to average on multi-gpu parallel training 264 | if args.gradient_accumulation_steps > 1: 265 | loss = loss / args.gradient_accumulation_steps 266 | 267 | if args.fp16: 268 | with amp.scale_loss(loss, optimizer) as scaled_loss: 269 | scaled_loss.backward() 270 | else: 271 | loss.backward() 272 | 273 | if step % 10 == 0: 274 | print(step, loss.item()) 275 | 276 | tr_loss += loss.item() 277 | if (step + 1) % args.gradient_accumulation_steps == 0 or ( 278 | # last step in epoch but step is always smaller than gradient_accumulation_steps 279 | len(epoch_iterator) <= args.gradient_accumulation_steps 280 | and (step + 1) == len(epoch_iterator) 281 | ): 282 | if args.fp16: 283 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 284 | else: 285 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 286 | 287 | optimizer.step() 288 | scheduler.step() # Update learning rate schedule 289 | model.zero_grad() 290 | global_step += 1 291 | 292 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 293 | logs = {} 294 | if ( 295 | args.local_rank == -1 and args.evaluate_during_training 296 | ): # Only evaluate when single GPU otherwise metrics may not average well 297 | results = evaluate(args, model, tokenizer) 298 | for key, value in results.items(): 299 | eval_key = "eval_{}".format(key) 300 | logs[eval_key] = value 301 | 302 | loss_scalar = (tr_loss - logging_loss) / args.logging_steps 303 | learning_rate_scalar = scheduler.get_lr()[0] 304 | logs["learning_rate"] = learning_rate_scalar 305 | logs["loss"] = loss_scalar 306 | logging_loss = tr_loss 307 | 308 | print(json.dumps({**logs, **{"step": global_step}})) 309 | 310 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 311 | # Save model checkpoint 312 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 313 | if not os.path.exists(output_dir): 314 | os.makedirs(output_dir) 315 | model_to_save = ( 316 | model.module if hasattr(model, "module") else model 317 | ) # Take care of distributed/parallel training 318 | model_to_save.save_pretrained(output_dir) 319 | tokenizer.save_pretrained(output_dir) 320 | 321 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 322 | logger.info("Saving model checkpoint to %s", output_dir) 323 | 324 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 325 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 326 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 327 | 328 | if args.max_steps > 0 and global_step > args.max_steps: 329 | epoch_iterator.close() 330 | break 331 | if args.max_steps > 0 and global_step > args.max_steps: 332 | train_iterator.close() 333 | break 334 | 335 | return global_step, tr_loss / global_step 336 | 337 | 338 | def evaluate(args, model, tokenizer, prefix=""): 339 | # Loop to handle MNLI double evaluation (matched, mis-matched) 340 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) 341 | eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,) 342 | 343 | results = {} 344 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 345 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) 346 | 347 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 348 | os.makedirs(eval_output_dir) 349 | 350 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 351 | # Note that DistributedSampler samples randomly 352 | eval_sampler = SequentialSampler(eval_dataset) 353 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 354 | 355 | # multi-gpu eval 356 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 357 | model = torch.nn.DataParallel(model) 358 | 359 | # Eval! 360 | logger.info("***** Running evaluation {} *****".format(prefix)) 361 | logger.info(" Num examples = %d", len(eval_dataset)) 362 | logger.info(" Batch size = %d", args.eval_batch_size) 363 | eval_loss = 0.0 364 | nb_eval_steps = 0 365 | preds = None 366 | out_label_ids = None 367 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 368 | model.eval() 369 | batch = tuple(t.to(args.device) for t in batch) 370 | 371 | with torch.no_grad(): 372 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} 373 | if args.model_type != "distilbert": 374 | inputs["token_type_ids"] = ( 375 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None 376 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 377 | outputs = model(**inputs) 378 | tmp_eval_loss, logits = outputs[:2] 379 | 380 | eval_loss += tmp_eval_loss.mean().item() 381 | nb_eval_steps += 1 382 | if preds is None: 383 | preds = logits.detach().cpu().numpy() 384 | out_label_ids = inputs["labels"].detach().cpu().numpy() 385 | else: 386 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 387 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 388 | 389 | eval_loss = eval_loss / nb_eval_steps 390 | if args.output_mode == "classification": 391 | preds = np.argmax(preds, axis=1) 392 | print(preds) 393 | elif args.output_mode == "regression": 394 | preds = np.squeeze(preds) 395 | result = compute_metrics(eval_task, preds, out_label_ids) 396 | results.update(result) 397 | 398 | output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") 399 | with open(output_eval_file, "w") as writer: 400 | logger.info("***** Eval results {} *****".format(prefix)) 401 | for key in sorted(result.keys()): 402 | logger.info(" %s = %s", key, str(result[key])) 403 | writer.write("%s = %s\n" % (key, str(result[key]))) 404 | 405 | return results 406 | 407 | 408 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 409 | if args.local_rank not in [-1, 0] and not evaluate: 410 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 411 | 412 | processor = processors[task]() 413 | output_mode = output_modes[task] 414 | # Load data features from cache or dataset file 415 | cached_features_file = os.path.join( 416 | args.data_dir, 417 | "cached_{}_{}_{}_{}".format( 418 | "dev" if evaluate else "train", 419 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 420 | str(args.max_seq_length), 421 | str(task), 422 | ), 423 | ) 424 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 425 | logger.info("Loading features from cached file %s", cached_features_file) 426 | features = torch.load(cached_features_file) 427 | else: 428 | logger.info("Creating features from dataset file at %s", args.data_dir) 429 | label_list = processor.get_labels() 430 | if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]: 431 | # HACK(label indices are swapped in RoBERTa pretrained model) 432 | label_list[1], label_list[2] = label_list[2], label_list[1] 433 | examples = ( 434 | processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) 435 | ) 436 | features = convert_examples_to_features( 437 | examples, 438 | tokenizer, 439 | label_list=label_list, 440 | max_length=args.max_seq_length, 441 | output_mode=output_mode, 442 | pad_on_left=False, # pad on the left for xlnet 443 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 444 | pad_token_segment_id=0, 445 | ) 446 | if args.local_rank in [-1, 0]: 447 | logger.info("Saving features into cached file %s", cached_features_file) 448 | torch.save(features, cached_features_file) 449 | 450 | if args.local_rank == 0 and not evaluate: 451 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 452 | 453 | # Convert to Tensors and build dataset 454 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 455 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 456 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 457 | if output_mode == "classification": 458 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 459 | elif output_mode == "regression": 460 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 461 | 462 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 463 | return dataset 464 | 465 | 466 | 467 | 468 | # python run_glue.py \ 469 | # --model_name_or_path bert-base-uncased \ 470 | # --task_name $TASK_NAME \ 471 | # --do_train \ 472 | # --do_eval \ 473 | # --data_dir $GLUE_DIR/$TASK_NAME \ 474 | # --max_seq_length 128 \ 475 | # --per_gpu_train_batch_size 32 \ 476 | # --learning_rate 2e-5 \ 477 | # --num_train_epochs 3.0 \ 478 | # --output_dir /tmp/$TASK_NAME \ 479 | # --overwrite_output_dir \ 480 | # --cache_dir cache_glue_bert 481 | 482 | def main(task='MRPC', seed=42, ckpt='output/pretrain/2020-08-28-02-41-37/ckpt/60000'): 483 | parser = argparse.ArgumentParser() 484 | 485 | # Required parameters 486 | parser.add_argument( 487 | "--data_dir", 488 | default=f'data/glue_data/{task}', 489 | type=str, 490 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.", 491 | ) 492 | parser.add_argument( 493 | "--model_type", 494 | default="bert", 495 | type=str, 496 | ) 497 | parser.add_argument( 498 | "--model_name_or_path", 499 | default=ckpt, 500 | type=str, 501 | ) 502 | parser.add_argument( 503 | "--vocab_path", 504 | default='data/vocab.txt', 505 | type=str, 506 | ) 507 | parser.add_argument( 508 | "--task_name", 509 | default=task, 510 | type=str, 511 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys()), 512 | ) 513 | parser.add_argument( 514 | "--output_dir", 515 | default='output/glue', 516 | type=str, 517 | help="The output directory where the model predictions and checkpoints will be written.", 518 | ) 519 | 520 | # Other parameters 521 | parser.add_argument( 522 | "--cache_dir", 523 | default="", 524 | type=str, 525 | help="Where do you want to store the pre-trained models downloaded from s3", 526 | ) 527 | parser.add_argument( 528 | "--max_seq_length", 529 | default=128, 530 | type=int, 531 | help="The maximum total input sequence length after tokenization. Sequences longer " 532 | "than this will be truncated, sequences shorter will be padded.", 533 | ) 534 | parser.add_argument("--do_train", default=True, help="Whether to run training.") 535 | parser.add_argument("--do_eval", default=True, help="Whether to run eval on the dev set.") 536 | parser.add_argument( 537 | "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.", 538 | ) 539 | parser.add_argument( 540 | "--do_lower_case", default=True, help="Set this flag if you are using an uncased model.", 541 | ) 542 | 543 | parser.add_argument( 544 | "--per_gpu_train_batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.", 545 | ) 546 | parser.add_argument( 547 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.", 548 | ) 549 | parser.add_argument( 550 | "--gradient_accumulation_steps", 551 | type=int, 552 | default=1, 553 | help="Number of updates steps to accumulate before performing a backward/update pass.", 554 | ) 555 | parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.") 556 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 557 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 558 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 559 | parser.add_argument( 560 | "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.", 561 | ) 562 | parser.add_argument( 563 | "--max_steps", 564 | default=-1, 565 | type=int, 566 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 567 | ) 568 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 569 | 570 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") 571 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 572 | parser.add_argument( 573 | "--eval_all_checkpoints", 574 | action="store_true", 575 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", 576 | ) 577 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 578 | parser.add_argument( 579 | "--overwrite_output_dir", default=True, help="Overwrite the content of the output directory", 580 | ) 581 | parser.add_argument( 582 | "--overwrite_cache", default=True, help="Overwrite the cached training and evaluation sets", 583 | ) 584 | parser.add_argument("--seed", type=int, default=seed, help="random seed for initialization") 585 | 586 | parser.add_argument( 587 | "--fp16", 588 | action="store_true", 589 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 590 | ) 591 | parser.add_argument( 592 | "--fp16_opt_level", 593 | type=str, 594 | default="O1", 595 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 596 | "See details at https://nvidia.github.io/apex/amp.html", 597 | ) 598 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 599 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 600 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 601 | args = parser.parse_args() 602 | 603 | if ( 604 | os.path.exists(args.output_dir) 605 | and os.listdir(args.output_dir) 606 | and args.do_train 607 | and not args.overwrite_output_dir 608 | ): 609 | raise ValueError( 610 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 611 | args.output_dir 612 | ) 613 | ) 614 | 615 | # Setup distant debugging if needed 616 | if args.server_ip and args.server_port: 617 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 618 | import ptvsd 619 | 620 | print("Waiting for debugger attach") 621 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 622 | ptvsd.wait_for_attach() 623 | 624 | # Setup CUDA, GPU & distributed training 625 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 626 | args.n_gpu = 1 627 | args.device = device 628 | 629 | # Setup logging 630 | logging.basicConfig( 631 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 632 | datefmt="%m/%d/%Y %H:%M:%S", 633 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 634 | ) 635 | logger.warning( 636 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 637 | args.local_rank, 638 | device, 639 | args.n_gpu, 640 | bool(args.local_rank != -1), 641 | args.fp16, 642 | ) 643 | 644 | # Set seed 645 | set_seed(args) 646 | 647 | # Prepare GLUE task 648 | args.task_name = args.task_name.lower() 649 | if args.task_name not in processors: 650 | raise ValueError("Task not found: %s" % (args.task_name)) 651 | processor = processors[args.task_name]() 652 | args.output_mode = output_modes[args.task_name] 653 | label_list = processor.get_labels() 654 | num_labels = len(label_list) 655 | 656 | # Load pretrained model and tokenizer 657 | if args.local_rank not in [-1, 0]: 658 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 659 | 660 | from transformers import AutoConfig, AutoModelForSequenceClassification 661 | args.model_type = args.model_type.lower() 662 | config = AutoConfig.from_pretrained( 663 | args.model_name_or_path, 664 | num_labels=num_labels, 665 | finetuning_task=args.task_name, 666 | cache_dir=args.cache_dir if args.cache_dir else None, 667 | ) 668 | model = AutoModelForSequenceClassification.from_pretrained( 669 | args.model_name_or_path, 670 | from_tf=bool(".ckpt" in args.model_name_or_path), 671 | config=config, 672 | cache_dir=args.cache_dir if args.cache_dir else None, 673 | ) 674 | 675 | 676 | from pretraining.openwebtext.dataset import new_tokenizer 677 | tokenizer = wrap_tokenizer(new_tokenizer(args.vocab_path), pad_token='[PAD]') 678 | 679 | if args.local_rank == 0: 680 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 681 | 682 | model.to(args.device) 683 | 684 | logger.info("Training/evaluation parameters %s", args) 685 | 686 | # Training 687 | if args.do_train: 688 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 689 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 690 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 691 | 692 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 693 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 694 | # Create output directory if needed 695 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 696 | os.makedirs(args.output_dir) 697 | 698 | logger.info("Saving model checkpoint to %s", args.output_dir) 699 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 700 | # They can then be reloaded using `from_pretrained()` 701 | model_to_save = ( 702 | model.module if hasattr(model, "module") else model 703 | ) # Take care of distributed/parallel training 704 | model_to_save.save_pretrained(args.output_dir) 705 | tokenizer.save_pretrained(args.output_dir) 706 | 707 | # Good practice: save your training arguments together with the trained model 708 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 709 | 710 | # Load a trained model and vocabulary that you have fine-tuned 711 | model = model_to_save 712 | # TODO(nijkamp): we ignore model serialization 713 | # model = AutoModelForSequenceClassification.from_pretrained(args.output_dir) 714 | # tokenizer = AutoTokenizer.from_pretrained(args.output_dir) 715 | model.to(args.device) 716 | 717 | # Evaluation 718 | results = {} 719 | if args.do_eval and args.local_rank in [-1, 0]: 720 | # TODO(nijkamp): we ignore model serialization 721 | # tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 722 | checkpoints = [args.output_dir] 723 | if args.eval_all_checkpoints: 724 | checkpoints = list( 725 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) 726 | ) 727 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 728 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 729 | for checkpoint in checkpoints: 730 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" 731 | prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" 732 | 733 | # TODO(nijkamp): we ignore model serialization 734 | # model = AutoModelForSequenceClassification.from_pretrained(checkpoint) 735 | model.to(args.device) 736 | result = evaluate(args, model, tokenizer, prefix=prefix) 737 | result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) 738 | results.update(result) 739 | 740 | return results 741 | 742 | 743 | if __name__ == "__main__": 744 | main() -------------------------------------------------------------------------------- /examples/glue/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import copy 18 | import csv 19 | import dataclasses 20 | import json 21 | import logging 22 | from dataclasses import dataclass 23 | from typing import Optional 24 | 25 | # from ...file_utils import is_tf_available, is_torch_available 26 | 27 | is_torch_available = lambda: True 28 | is_tf_available = lambda: False 29 | 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | @dataclass(frozen=True) 35 | class InputExample: 36 | """ 37 | A single training/test example for simple sequence classification. 38 | 39 | Args: 40 | guid: Unique id for the example. 41 | text_a: string. The untokenized text of the first sequence. For single 42 | sequence tasks, only this sequence must be specified. 43 | text_b: (Optional) string. The untokenized text of the second sequence. 44 | Only must be specified for sequence pair tasks. 45 | label: (Optional) string. The label of the example. This should be 46 | specified for train and dev examples, but not for test examples. 47 | """ 48 | 49 | guid: str 50 | text_a: str 51 | text_b: Optional[str] = None 52 | label: Optional[str] = None 53 | 54 | def to_json_string(self): 55 | """Serializes this instance to a JSON string.""" 56 | return json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" 57 | 58 | 59 | class InputFeatures(object): 60 | """ 61 | A single set of features of data. 62 | 63 | Args: 64 | input_ids: Indices of input sequence tokens in the vocabulary. 65 | attention_mask: Mask to avoid performing attention on padding token indices. 66 | Mask values selected in ``[0, 1]``: 67 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 68 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 69 | label: Label corresponding to the input 70 | """ 71 | 72 | def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None): 73 | self.input_ids = input_ids 74 | self.attention_mask = attention_mask 75 | self.token_type_ids = token_type_ids 76 | self.label = label 77 | 78 | def __repr__(self): 79 | return str(self.to_json_string()) 80 | 81 | def to_dict(self): 82 | """Serializes this instance to a Python dictionary.""" 83 | output = copy.deepcopy(self.__dict__) 84 | return output 85 | 86 | def to_json_string(self): 87 | """Serializes this instance to a JSON string.""" 88 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 89 | 90 | 91 | class DataProcessor(object): 92 | """Base class for data converters for sequence classification data sets.""" 93 | 94 | def get_example_from_tensor_dict(self, tensor_dict): 95 | """Gets an example from a dict with tensorflow tensors 96 | Args: 97 | tensor_dict: Keys and values should match the corresponding Glue 98 | tensorflow_dataset examples. 99 | """ 100 | raise NotImplementedError() 101 | 102 | def get_train_examples(self, data_dir): 103 | """Gets a collection of `InputExample`s for the train set.""" 104 | raise NotImplementedError() 105 | 106 | def get_dev_examples(self, data_dir): 107 | """Gets a collection of `InputExample`s for the dev set.""" 108 | raise NotImplementedError() 109 | 110 | def get_labels(self): 111 | """Gets the list of labels for this data set.""" 112 | raise NotImplementedError() 113 | 114 | def tfds_map(self, example): 115 | """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. 116 | This method converts examples to the correct format.""" 117 | if len(self.get_labels()) > 1: 118 | example.label = self.get_labels()[int(example.label)] 119 | return example 120 | 121 | @classmethod 122 | def _read_tsv(cls, input_file, quotechar=None): 123 | """Reads a tab separated value file.""" 124 | with open(input_file, "r", encoding="utf-8-sig") as f: 125 | return list(csv.reader(f, delimiter="\t", quotechar=quotechar)) 126 | 127 | 128 | class SingleSentenceClassificationProcessor(DataProcessor): 129 | """ Generic processor for a single sentence classification data set.""" 130 | 131 | def __init__(self, labels=None, examples=None, mode="classification", verbose=False): 132 | self.labels = [] if labels is None else labels 133 | self.examples = [] if examples is None else examples 134 | self.mode = mode 135 | self.verbose = verbose 136 | 137 | def __len__(self): 138 | return len(self.examples) 139 | 140 | def __getitem__(self, idx): 141 | if isinstance(idx, slice): 142 | return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx]) 143 | return self.examples[idx] 144 | 145 | @classmethod 146 | def create_from_csv( 147 | cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs 148 | ): 149 | processor = cls(**kwargs) 150 | processor.add_examples_from_csv( 151 | file_name, 152 | split_name=split_name, 153 | column_label=column_label, 154 | column_text=column_text, 155 | column_id=column_id, 156 | skip_first_row=skip_first_row, 157 | overwrite_labels=True, 158 | overwrite_examples=True, 159 | ) 160 | return processor 161 | 162 | @classmethod 163 | def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs): 164 | processor = cls(**kwargs) 165 | processor.add_examples(texts_or_text_and_labels, labels=labels) 166 | return processor 167 | 168 | def add_examples_from_csv( 169 | self, 170 | file_name, 171 | split_name="", 172 | column_label=0, 173 | column_text=1, 174 | column_id=None, 175 | skip_first_row=False, 176 | overwrite_labels=False, 177 | overwrite_examples=False, 178 | ): 179 | lines = self._read_tsv(file_name) 180 | if skip_first_row: 181 | lines = lines[1:] 182 | texts = [] 183 | labels = [] 184 | ids = [] 185 | for (i, line) in enumerate(lines): 186 | texts.append(line[column_text]) 187 | labels.append(line[column_label]) 188 | if column_id is not None: 189 | ids.append(line[column_id]) 190 | else: 191 | guid = "%s-%s" % (split_name, i) if split_name else "%s" % i 192 | ids.append(guid) 193 | 194 | return self.add_examples( 195 | texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples 196 | ) 197 | 198 | def add_examples( 199 | self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False 200 | ): 201 | assert labels is None or len(texts_or_text_and_labels) == len(labels) 202 | assert ids is None or len(texts_or_text_and_labels) == len(ids) 203 | if ids is None: 204 | ids = [None] * len(texts_or_text_and_labels) 205 | if labels is None: 206 | labels = [None] * len(texts_or_text_and_labels) 207 | examples = [] 208 | added_labels = set() 209 | for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids): 210 | if isinstance(text_or_text_and_label, (tuple, list)) and label is None: 211 | text, label = text_or_text_and_label 212 | else: 213 | text = text_or_text_and_label 214 | added_labels.add(label) 215 | examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label)) 216 | 217 | # Update examples 218 | if overwrite_examples: 219 | self.examples = examples 220 | else: 221 | self.examples.extend(examples) 222 | 223 | # Update labels 224 | if overwrite_labels: 225 | self.labels = list(added_labels) 226 | else: 227 | self.labels = list(set(self.labels).union(added_labels)) 228 | 229 | return self.examples 230 | 231 | def get_features( 232 | self, 233 | tokenizer, 234 | max_length=None, 235 | pad_on_left=False, 236 | pad_token=0, 237 | mask_padding_with_zero=True, 238 | return_tensors=None, 239 | ): 240 | """ 241 | Convert examples in a list of ``InputFeatures`` 242 | 243 | Args: 244 | tokenizer: Instance of a tokenizer that will tokenize the examples 245 | max_length: Maximum example length 246 | task: GLUE task 247 | label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method 248 | output_mode: String indicating the output mode. Either ``regression`` or ``classification`` 249 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) 250 | pad_token: Padding token 251 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values 252 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for 253 | actual values) 254 | 255 | Returns: 256 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset`` 257 | containing the task-specific features. If the input is a list of ``InputExamples``, will return 258 | a list of task-specific ``InputFeatures`` which can be fed to the model. 259 | 260 | """ 261 | if max_length is None: 262 | max_length = tokenizer.max_len 263 | 264 | label_map = {label: i for i, label in enumerate(self.labels)} 265 | 266 | all_input_ids = [] 267 | for (ex_index, example) in enumerate(self.examples): 268 | if ex_index % 10000 == 0: 269 | logger.info("Tokenizing example %d", ex_index) 270 | 271 | input_ids = tokenizer.encode( 272 | example.text_a, add_special_tokens=True, max_length=min(max_length, tokenizer.max_len), 273 | ) 274 | all_input_ids.append(input_ids) 275 | 276 | batch_length = max(len(input_ids) for input_ids in all_input_ids) 277 | 278 | features = [] 279 | for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)): 280 | if ex_index % 10000 == 0: 281 | logger.info("Writing example %d/%d" % (ex_index, len(self.examples))) 282 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 283 | # tokens are attended to. 284 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 285 | 286 | # Zero-pad up to the sequence length. 287 | padding_length = batch_length - len(input_ids) 288 | if pad_on_left: 289 | input_ids = ([pad_token] * padding_length) + input_ids 290 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 291 | else: 292 | input_ids = input_ids + ([pad_token] * padding_length) 293 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 294 | 295 | assert len(input_ids) == batch_length, "Error with input length {} vs {}".format( 296 | len(input_ids), batch_length 297 | ) 298 | assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format( 299 | len(attention_mask), batch_length 300 | ) 301 | 302 | if self.mode == "classification": 303 | label = label_map[example.label] 304 | elif self.mode == "regression": 305 | label = float(example.label) 306 | else: 307 | raise ValueError(self.mode) 308 | 309 | if ex_index < 5 and self.verbose: 310 | logger.info("*** Example ***") 311 | logger.info("guid: %s" % (example.guid)) 312 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 313 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 314 | logger.info("label: %s (id = %d)" % (example.label, label)) 315 | 316 | features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label)) 317 | 318 | if return_tensors is None: 319 | return features 320 | elif return_tensors == "tf": 321 | if not is_tf_available(): 322 | raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported") 323 | import tensorflow as tf 324 | 325 | def gen(): 326 | for ex in features: 327 | yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label) 328 | 329 | dataset = tf.data.Dataset.from_generator( 330 | gen, 331 | ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64), 332 | ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])), 333 | ) 334 | return dataset 335 | elif return_tensors == "pt": 336 | if not is_torch_available(): 337 | raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported") 338 | import torch 339 | from torch.utils.data import TensorDataset 340 | 341 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 342 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 343 | if self.mode == "classification": 344 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 345 | elif self.mode == "regression": 346 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 347 | 348 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels) 349 | return dataset 350 | else: 351 | raise ValueError("return_tensors should be one of 'tf' or 'pt'") -------------------------------------------------------------------------------- /pretraining/openwebtext/arg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | 4 | __all__ = ('Arg', 'Int', 'Float', 'Bool', 'Str', 'Choice', 'parse_to') 5 | 6 | class Arg: 7 | def __init__(self, **kwargs): 8 | super().__init__() 9 | self.kwargs = kwargs 10 | 11 | 12 | class Int(Arg): 13 | def __init__(self, **kwargs): 14 | super().__init__(type=int, **kwargs) 15 | 16 | 17 | class Float(Arg): 18 | def __init__(self, **kwargs): 19 | super().__init__(type=float, **kwargs) 20 | 21 | 22 | class Bool(Arg): 23 | def __init__(self, **kwargs): 24 | super().__init__(type=bool, **kwargs) 25 | 26 | 27 | class Str(Arg): 28 | def __init__(self, **kwargs): 29 | super().__init__(type=str, **kwargs) 30 | 31 | 32 | class _MetaChoice(type): 33 | def __getitem__(self, item): 34 | return self(choices=list(item), type=item) 35 | 36 | 37 | class Choice(Arg, metaclass=_MetaChoice): 38 | def __init__(self, choices, **kwargs): 39 | super().__init__(choices=choices, **kwargs) 40 | 41 | 42 | def parse_to(container_class, **kwargs): 43 | def mangle_name(name): 44 | return '--' + name.replace('_', '-') 45 | 46 | parser = argparse.ArgumentParser(description=container_class.__doc__) 47 | for field in dataclasses.fields(container_class): 48 | name = field.name 49 | default = field.default 50 | value_or_class = field.type 51 | if isinstance(value_or_class, type): 52 | value = value_or_class(default=default) 53 | else: 54 | value = value_or_class 55 | value.kwargs['default'] = default 56 | parser.add_argument( 57 | mangle_name(name), **value.kwargs) 58 | 59 | arg_dict = parser.parse_args(**kwargs) 60 | return container_class(**vars(arg_dict)) -------------------------------------------------------------------------------- /pretraining/openwebtext/dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from dataclasses import dataclass 5 | from itertools import chain 6 | from functools import partial 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.utils.data 13 | 14 | from openwebtext import tokenization 15 | 16 | 17 | class ExampleBuilder: 18 | """Given a stream of input text, creates pretraining examples.""" 19 | 20 | def __init__(self, vocab, max_length): 21 | self._vocab = vocab 22 | self._current_sentences = [] 23 | self._current_length = 0 24 | self._max_length = max_length 25 | self._target_length = max_length 26 | 27 | def add_line(self, bert_tokids): 28 | """Adds a line of text to the current example being built.""" 29 | # line = line.strip().replace("\n", " ") 30 | # if (not line) and self._current_length != 0: # empty lines separate docs 31 | # return self._create_example() 32 | # bert_tokens = self._tokenizer.tokenize(line) 33 | # bert_tokids = self._tokenizer.convert_tokens_to_ids(bert_tokens) 34 | self._current_sentences.append(bert_tokids) 35 | self._current_length += len(bert_tokids) 36 | if self._current_length >= self._target_length: 37 | return self._create_example() 38 | return None 39 | 40 | def _create_example(self): 41 | """Creates a pre-training example from the current list of sentences.""" 42 | # small chance to only have one segment as in classification tasks 43 | if random.random() < 0.1: 44 | first_segment_target_length = 100000 45 | else: 46 | # -3 due to not yet having [CLS]/[SEP] tokens in the input text 47 | first_segment_target_length = (self._target_length - 3) // 2 48 | 49 | first_segment = [] 50 | second_segment = [] 51 | for sentence in self._current_sentences: 52 | # the sentence goes to the first segment if (1) the first segment is 53 | # empty, (2) the sentence doesn't put the first segment over length or 54 | # (3) 50% of the time when it does put the first segment over length 55 | if (len(first_segment) == 0 or 56 | len(first_segment) + len(sentence) < first_segment_target_length or 57 | (len(second_segment) == 0 and 58 | len(first_segment) < first_segment_target_length and 59 | random.random() < 0.5)): 60 | first_segment += sentence 61 | else: 62 | second_segment += sentence 63 | 64 | # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens 65 | first_segment = first_segment[:self._max_length - 2] 66 | second_segment = second_segment[:max(0, self._max_length - len(first_segment) - 3)] 67 | 68 | # prepare to start building the next example 69 | self._current_sentences = [] 70 | self._current_length = 0 71 | # small chance for random-length instead of max_length-length example 72 | if random.random() < 0.05: 73 | self._target_length = random.randint(5, self._max_length) 74 | else: 75 | self._target_length = self._max_length 76 | 77 | return self._make_tf_example(first_segment, second_segment) 78 | 79 | def _make_tf_example(self, first_segment, second_segment): 80 | """Converts two "segments" of text into a tf.train.Example.""" 81 | vocab = self._vocab 82 | input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]] 83 | segment_ids = [0] * len(input_ids) 84 | if second_segment: 85 | input_ids += second_segment + [vocab["[SEP]"]] 86 | segment_ids += [1] * (len(second_segment) + 1) 87 | input_mask = [1] * len(input_ids) 88 | input_ids += [0] * (self._max_length - len(input_ids)) 89 | input_mask += [0] * (self._max_length - len(input_mask)) 90 | segment_ids += [0] * (self._max_length - len(segment_ids)) 91 | 92 | def create_int_feature(tensors): 93 | return torch.tensor(tensors) 94 | 95 | tf_example = { 96 | "input_ids": create_int_feature(input_ids), 97 | "input_mask": create_int_feature(input_mask), 98 | "segment_ids": create_int_feature(segment_ids) 99 | } 100 | return tf_example 101 | 102 | 103 | class OpenWebTextDataset(torch.utils.data.IterableDataset): 104 | def __init__(self, feature_set_paths, n_tensors_per_file): 105 | self.feature_set_paths = feature_set_paths 106 | self.n_tensors_per_file = n_tensors_per_file 107 | 108 | @staticmethod 109 | def parse_file(file_index): 110 | try: 111 | features = torch.load(str(file_index)) 112 | yield from features 113 | except RuntimeError: 114 | raise RuntimeError(f'Corrupted file {file_index}') 115 | 116 | def __len__(self): 117 | return len(self.feature_set_paths) * self.n_tensors_per_file 118 | 119 | def __iter__(self): 120 | return chain.from_iterable(map(self.parse_file, self.feature_set_paths)) 121 | 122 | 123 | class ExampleBuilderDataset(torch.utils.data.IterableDataset): 124 | def __init__(self, dataset, builder): 125 | self.dataset = dataset 126 | self.builder = builder 127 | 128 | def __len__(self): 129 | return len(self.dataset) 130 | 131 | def __iter__(self): 132 | def create_example(): 133 | while True: 134 | token_ids = list(next(self.dataset).cpu().numpy()) 135 | example = self.builder.add_line(token_ids) 136 | if example: 137 | return example 138 | 139 | while True: 140 | yield create_example() 141 | 142 | 143 | 144 | def cycle(iterable): 145 | while True: 146 | for x in iterable: 147 | yield x 148 | 149 | 150 | def new_tokenizer(vocab_file, do_lower_case=True): 151 | return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) 152 | 153 | 154 | def parse_tokenizer(tokenizer, text): 155 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 156 | 157 | 158 | def create_tokenizer(vocab_file, do_lower_case=True): 159 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) 160 | return partial(parse_tokenizer, tokenizer) 161 | 162 | 163 | def load_owt(owt_dir, n_tensors_per_file): 164 | owt_dir_path = Path(owt_dir) 165 | feature_set_paths = [owt_dir_path / feature_set_path for feature_set_path in os.listdir(owt_dir_path)] 166 | np.random.shuffle(feature_set_paths) 167 | assert len(feature_set_paths) > 0 168 | return OpenWebTextDataset(feature_set_paths, n_tensors_per_file=n_tensors_per_file) 169 | 170 | 171 | def wrap_example_builder(dataset, vocab, max_length): 172 | return ExampleBuilderDataset(cycle(iter(dataset)), ExampleBuilder(vocab, max_length)) 173 | -------------------------------------------------------------------------------- /pretraining/openwebtext/preprocess.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging 3 | import math 4 | import multiprocessing 5 | import os 6 | import random 7 | import tarfile 8 | from dataclasses import dataclass 9 | from itertools import chain 10 | from functools import partial 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | 15 | import torch 16 | import torch.utils.data 17 | 18 | from pretraining.openwebtext import arg 19 | from pretraining.openwebtext import tokenization 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def parse_tokenizer(tokenizer, text): 26 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 27 | 28 | 29 | def create_tokenizer(vocab_file, do_lower_case=True): 30 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) 31 | return partial(parse_tokenizer, tokenizer) 32 | 33 | 34 | def preprocess_owt(tokenizer, src_dir, tmp_dir, trg_dir, n_dataset_building_processes, n_tensors_per_file, max_seq_length=None): 35 | # Preamble 36 | logger.info(f'Writing features to {trg_dir}.') 37 | os.makedirs(trg_dir, exist_ok=False) 38 | 39 | # Crunch files 40 | trg_dir = Path(trg_dir) 41 | src_dir = Path(src_dir) 42 | tmp_dir = Path(tmp_dir) 43 | archives = os.listdir(src_dir) 44 | n_archives_per_job = math.ceil(len(archives) / n_dataset_building_processes) 45 | job_archives = [ 46 | archives[i * n_archives_per_job : (i + 1) * n_archives_per_job] 47 | for i in range(n_dataset_building_processes) 48 | ] 49 | 50 | logger.info(f'Processing {len(archives)} archives.') 51 | assert len(archives) > 0 52 | 53 | if n_dataset_building_processes == 1: 54 | feature_set_paths = preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0) 55 | else: 56 | pool = multiprocessing.Pool(processes=n_dataset_building_processes) 57 | preprocess_owt_job_partial = partial(preprocess_owt_job, tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length) 58 | feature_sets = pool.map(preprocess_owt_job_partial, range(n_dataset_building_processes)) 59 | feature_set_paths = [file_path for feature_set in feature_sets for file_path in feature_set] 60 | 61 | return feature_set_paths 62 | 63 | 64 | def preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0): 65 | ''' 66 | OpenWebText is saved under the following format: 67 | openwebtext.zip 68 | |-> archive_xxx.zip 69 | |-> file_xxx.txt 70 | |-> file_xxz.txt 71 | ... 72 | |-> archive_xxz.zip 73 | |-> file_xxy.txt 74 | ... 75 | ... 76 | ''' 77 | 78 | # Preamble 79 | os.makedirs(tmp_dir, exist_ok=True) 80 | 81 | # Process 82 | feature_index = 0 83 | feature_set_paths = [] 84 | features = [] 85 | for archive_id, archive in enumerate(job_archives[job_id]): 86 | if os.path.isdir(src_dir / archive): 87 | logger.info(f'Ignoring rogue directory {src_dir / archive}.') 88 | continue 89 | 90 | logger.info(f'Job {job_id}: Processing {archive_id}/{len(job_archives[job_id])} {src_dir / archive}.') 91 | 92 | with tarfile.open(src_dir / archive) as t: 93 | extracted_archive = tmp_dir / f'{archive}-extracted' 94 | t.extractall(extracted_archive) 95 | 96 | for file in os.listdir(extracted_archive): 97 | file_path = extracted_archive / file 98 | 99 | with open(file_path, 'r') as f: 100 | for line in f.readlines(): 101 | line = line.strip() 102 | if len(line) > 2: 103 | encoding = tokenizer(line) 104 | features.append(torch.tensor(encoding)) 105 | 106 | while len(features) > n_tensors_per_file: 107 | feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt' 108 | torch.save(features[:n_tensors_per_file], feature_set_path) 109 | features = features[n_tensors_per_file:] 110 | feature_index += 1 111 | feature_set_paths.append(feature_set_path) 112 | 113 | # Serialize 114 | if len(features) > 0: 115 | feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt' 116 | torch.save(features, feature_set_path) 117 | feature_set_paths.append(feature_set_path) 118 | 119 | return feature_set_paths 120 | 121 | 122 | @dataclass(frozen=True) 123 | class Args: 124 | src_dir: arg.Str = 'data/openwebtext' 125 | trg_dir: arg.Str = 'data/openwebtext_features' 126 | tmp_dir: arg.Str = '/tmp/owt' 127 | vocab_file: arg.Str = 'data/vocab.txt' 128 | n_dataset_building_processes: arg.Int = 32 129 | n_tensors_per_file: arg.Int = 2048 130 | max_seq_length: arg.Int = 128 131 | 132 | 133 | def main(): 134 | args = arg.parse_to(Args) 135 | 136 | logging.basicConfig( 137 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 138 | datefmt='%m/%d/%Y %H:%M:%S', 139 | level=logging.INFO 140 | ) 141 | 142 | tokenizer = create_tokenizer(args.vocab_file) 143 | preprocess_owt(tokenizer=tokenizer, src_dir=args.src_dir, tmp_dir=args.tmp_dir, trg_dir=args.trg_dir, n_dataset_building_processes=args.n_dataset_building_processes, n_tensors_per_file=args.n_tensors_per_file, max_seq_length=args.max_seq_length) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() -------------------------------------------------------------------------------- /pretraining/openwebtext/pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dir_path = os.path.dirname(os.path.realpath(__file__)) 5 | parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir)) 6 | sys.path.insert(0, parent_dir_path) 7 | 8 | import random 9 | import logging 10 | from time import time 11 | from dataclasses import dataclass 12 | 13 | import numpy as np 14 | 15 | import torch 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.utils.data.dataloader import DataLoader 18 | 19 | from electra_pytorch import Electra 20 | 21 | from openwebtext import arg 22 | from openwebtext.dataset import load_owt, new_tokenizer, wrap_example_builder 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | ######################################################################################################## 27 | ## args 28 | 29 | @dataclass 30 | class Args: 31 | data_dir: arg.Str = 'data/openwebtext_features' 32 | data_vocab_file: arg.Str = 'data/vocab.txt' 33 | data_n_tensors_per_file: arg.Int = 2048 34 | data_max_seq_length: arg.Int = 128 35 | 36 | gpu: arg.Int = 0 37 | gpu_enabled: arg.Bool = True 38 | gpu_deterministic: arg.Bool = False 39 | gpu_mixed_precision: arg.Bool = False 40 | distributed_port: arg.Int = 8888 41 | distributed_enabled: arg.Bool = True 42 | distributed_world_size: arg.Int = 4 43 | 44 | model_generator: arg.Str = 'pretraining/openwebtext/small_generator.json' 45 | model_discriminator: arg.Str = 'pretraining/openwebtext/small_discriminator.json' 46 | model_mask_prob: arg.Float = 0.15 47 | 48 | opt_lr: arg.Float = 5e-4 49 | opt_batch_size: arg.Int = 128 // (distributed_world_size if distributed_enabled else 1) 50 | opt_warmup_steps: arg.Int = 10_000 51 | opt_num_training_steps: arg.Int = 200_000 52 | 53 | step_log: arg.Int = 10 54 | step_ckpt: arg.Int = 10_000 55 | 56 | 57 | ######################################################################################################## 58 | ## train 59 | 60 | def train(rank, args): 61 | 62 | ####################### 63 | ## distributed 64 | 65 | if args.distributed_enabled: 66 | torch.distributed.init_process_group( 67 | backend='nccl', 68 | init_method='env://', 69 | world_size=args.distributed_world_size, 70 | rank=rank) 71 | if args.gpu_enabled: 72 | device = torch.device('cuda:{}'.format(rank)) 73 | else: 74 | device = torch.device('cpu') 75 | 76 | is_master = True if not args.distributed_enabled else args.distributed_enabled and rank == 0 77 | 78 | 79 | ####################### 80 | ## preamble 81 | 82 | set_gpus(rank) 83 | set_seed(rank) 84 | set_cuda(deterministic=args.gpu_deterministic) 85 | 86 | output_dir = f'{args.output_dir}/{rank}' 87 | os.makedirs(output_dir, exist_ok=False) 88 | 89 | setup_logging(filename=f'{output_dir}/output.log', console=is_master) 90 | 91 | 92 | ####################### 93 | ## dataset 94 | 95 | tokenizer = new_tokenizer(vocab_file=args.data_vocab_file) 96 | vocab_size = len(tokenizer.vocab) 97 | ds_train = wrap_example_builder(dataset=load_owt(owt_dir=args.data_dir, n_tensors_per_file=args.data_n_tensors_per_file), vocab=tokenizer.vocab, max_length=args.data_max_seq_length) 98 | 99 | pad_token_id = tokenizer.vocab['[PAD]'] 100 | mask_token_id = tokenizer.vocab['[MASK]'] 101 | cls_token_id = tokenizer.vocab['[CLS]'] 102 | sep_token_id = tokenizer.vocab['[SEP]'] 103 | 104 | assert pad_token_id == 0 105 | assert cls_token_id == 101 106 | assert sep_token_id == 102 107 | assert mask_token_id == 103 108 | 109 | def collate_batch(examples): 110 | input_ids = torch.nn.utils.rnn.pad_sequence([example['input_ids'] for example in examples], batch_first=True, padding_value=pad_token_id) 111 | input_mask = torch.nn.utils.rnn.pad_sequence([example['input_mask'] for example in examples], batch_first=True, padding_value=pad_token_id) 112 | segment_ids = torch.nn.utils.rnn.pad_sequence([example['segment_ids'] for example in examples], batch_first=True, padding_value=pad_token_id) 113 | return input_ids, input_mask, segment_ids 114 | 115 | def cycle(iterable): 116 | while True: 117 | for x in iterable: 118 | yield x 119 | 120 | ds_train_loader = iter(cycle(DataLoader(ds_train, batch_size=args.opt_batch_size, collate_fn=collate_batch))) 121 | 122 | 123 | ####################### 124 | ## model 125 | 126 | def to_distributed_model(model): 127 | return model if not args.distributed_enabled else torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True) 128 | 129 | def tie_weights(generator, discriminator): 130 | generator.electra.embeddings.word_embeddings = discriminator.electra.embeddings.word_embeddings 131 | generator.electra.embeddings.position_embeddings = discriminator.electra.embeddings.position_embeddings 132 | generator.electra.embeddings.token_type_embeddings = discriminator.electra.embeddings.token_type_embeddings 133 | 134 | class LogitsAdapter(torch.nn.Module): 135 | def __init__(self, adaptee): 136 | super().__init__() 137 | self.adaptee = adaptee 138 | 139 | def forward(self, *args, **kwargs): 140 | return self.adaptee(*args, **kwargs)[0] 141 | 142 | from transformers import AutoConfig, ElectraForMaskedLM, ElectraForPreTraining 143 | 144 | generator = ElectraForMaskedLM(AutoConfig.from_pretrained(args.model_generator)) 145 | discriminator = ElectraForPreTraining(AutoConfig.from_pretrained(args.model_discriminator)) 146 | 147 | tie_weights(generator, discriminator) 148 | 149 | model = to_distributed_model(Electra( 150 | LogitsAdapter(generator), 151 | LogitsAdapter(discriminator), 152 | num_tokens = vocab_size, 153 | mask_token_id = mask_token_id, 154 | pad_token_id = pad_token_id, 155 | mask_prob = args.model_mask_prob, 156 | mask_ignore_token_ids = [tokenizer.vocab['[CLS]'], tokenizer.vocab['[SEP]']], 157 | random_token_prob = 0.0).to(device)) 158 | 159 | 160 | ####################### 161 | ## optimizer 162 | 163 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 164 | def lr_lambda(current_step): 165 | learning_rate = max(0.0, 1. - (float(current_step) / float(num_training_steps))) 166 | learning_rate *= min(1.0, float(current_step) / float(num_warmup_steps)) 167 | return learning_rate 168 | return LambdaLR(optimizer, lr_lambda, last_epoch) 169 | 170 | def get_params_without_weight_decay_ln(named_params, weight_decay): 171 | no_decay = ['bias', 'LayerNorm.weight'] 172 | optimizer_grouped_parameters = [ 173 | { 174 | 'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)], 175 | 'weight_decay': weight_decay, 176 | }, 177 | { 178 | 'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 179 | 'weight_decay': 0.0, 180 | }, 181 | ] 182 | return optimizer_grouped_parameters 183 | 184 | optimizer = torch.optim.AdamW(get_params_without_weight_decay_ln(model.named_parameters(), weight_decay=0.1), lr=args.opt_lr, betas=(0.9, 0.999), eps=1e-08) 185 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.opt_warmup_steps, num_training_steps=args.opt_num_training_steps) 186 | scaler = torch.cuda.amp.GradScaler(enabled=args.gpu_mixed_precision) 187 | 188 | 189 | ####################### 190 | ## train 191 | 192 | t, steps_s, eta_m = time(), 0., 0 193 | 194 | for step in range(args.opt_num_training_steps+1): 195 | input_ids, input_mask, segment_ids = next(ds_train_loader) 196 | 197 | input_ids = input_ids.to(device) 198 | input_mask = input_mask.to(device) 199 | segment_ids = segment_ids.to(device) 200 | 201 | assert input_ids.shape[1] <= args.data_max_seq_length 202 | 203 | optimizer.zero_grad() 204 | 205 | with torch.cuda.amp.autocast(enabled=args.gpu_mixed_precision): 206 | loss, loss_mlm, loss_disc, acc_gen, acc_disc, disc_labels, disc_pred = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids) 207 | 208 | scaler.scale(loss).backward() 209 | scaler.unscale_(optimizer) 210 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 211 | scaler.step(optimizer) 212 | scaler.update() 213 | scheduler.step() 214 | 215 | metrics = { 216 | 'step': (step, '{:8d}'), 217 | 'loss': (loss.item(), '{:8.5f}'), 218 | 'loss_mlm': (loss_mlm.item(), '{:8.5f}'), 219 | 'loss_disc': (loss_disc.item(), '{:8.5f}'), 220 | 'acc_gen': (acc_gen.item(), '{:5.3f}'), 221 | 'acc_disc': (acc_disc.item(), '{:5.3f}'), 222 | 'lr': (scheduler.get_last_lr()[0], '{:8.7f}'), 223 | 'steps': (steps_s, '{:4.1f}/s'), 224 | 'eta': (eta_m, '{:4d}m'), 225 | } 226 | 227 | if step % args.step_log == 0: 228 | sep = ' ' * 2 229 | logger.info(sep.join([f'{k}: {v[1].format(v[0])}' for (k, v) in metrics.items()])) 230 | 231 | if step > 0 and step % 100 == 0: 232 | t2 = time() 233 | steps_s = 100. / (t2 - t) 234 | eta_m = int(((args.opt_num_training_steps - step) / steps_s) // 60) 235 | t = t2 236 | 237 | if step % 200 == 0: 238 | logger.info(np.array2string(disc_labels[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize)) 239 | logger.info(np.array2string(disc_pred[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize)) 240 | 241 | if step > 0 and step % args.step_ckpt == 0 and is_master: 242 | discriminator.electra.save_pretrained(f'{args.output_dir}/ckpt/{step}') 243 | 244 | ######################################################################################################## 245 | ## preamble 246 | 247 | def set_gpus(gpu): 248 | torch.cuda.set_device(gpu) 249 | 250 | 251 | def set_seed(seed): 252 | os.environ['PYTHONHASHSEED'] = str(seed) 253 | random.seed(seed) 254 | np.random.seed(seed) 255 | torch.manual_seed(seed) 256 | if torch.cuda.is_available(): 257 | torch.cuda.manual_seed(seed) 258 | torch.cuda.manual_seed_all(seed) 259 | 260 | 261 | def set_cuda(deterministic=True): 262 | if torch.cuda.is_available(): 263 | torch.backends.cudnn.deterministic = deterministic 264 | torch.backends.cudnn.benchmark = not deterministic 265 | 266 | 267 | def get_exp_id(file): 268 | return os.path.splitext(os.path.basename(file))[0] 269 | 270 | 271 | def get_output_dir(exp_id): 272 | import datetime 273 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 274 | output_dir = os.path.join('output/' + exp_id, t) 275 | os.makedirs(output_dir, exist_ok=True) 276 | return output_dir 277 | 278 | 279 | def setup_logging(filename, console=True): 280 | log_format = logging.Formatter("%(asctime)s : %(message)s") 281 | logger = logging.getLogger() 282 | logger.handlers = [] 283 | file_handler = logging.FileHandler(filename) 284 | file_handler.setFormatter(log_format) 285 | logger.addHandler(file_handler) 286 | if console: 287 | console_handler = logging.StreamHandler(sys.stdout) 288 | console_handler.setFormatter(log_format) 289 | logger.addHandler(console_handler) 290 | logger.setLevel(logging.INFO) 291 | return logger 292 | 293 | 294 | def copy_source(file, output_dir): 295 | import shutil 296 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) 297 | 298 | 299 | ######################################################################################################## 300 | ## main 301 | 302 | def main(): 303 | 304 | # preamble 305 | exp_id = get_exp_id(__file__) 306 | output_dir = get_output_dir(exp_id) 307 | os.makedirs(output_dir, exist_ok=True) 308 | os.makedirs(f'{output_dir}/ckpt', exist_ok=False) 309 | copy_source(__file__, output_dir) 310 | 311 | # args 312 | args = arg.parse_to(Args) 313 | args.output_dir = output_dir 314 | args.exp_id = exp_id 315 | 316 | # distributed 317 | if args.distributed_enabled: 318 | os.environ['MASTER_ADDR'] = 'localhost' 319 | os.environ['MASTER_PORT'] = str(args.distributed_port) 320 | torch.multiprocessing.spawn(train, nprocs=args.distributed_world_size, args=(args,)) 321 | else: 322 | train(rank=args.gpu, args=args) 323 | 324 | 325 | if __name__ == '__main__': 326 | main() 327 | -------------------------------------------------------------------------------- /pretraining/openwebtext/small_discriminator.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "ElectraForPreTraining" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "embedding_size": 128, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 256, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 1024, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "electra", 15 | "num_attention_heads": 4, 16 | "num_hidden_layers": 12, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522 19 | } -------------------------------------------------------------------------------- /pretraining/openwebtext/small_generator.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "ElectraForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "embedding_size": 128, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 64, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 256, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "electra", 15 | "num_attention_heads": 1, 16 | "num_hidden_layers": 12, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522 19 | } -------------------------------------------------------------------------------- /pretraining/openwebtext/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tokenization classes, the same as used for BERT.""" 17 | 18 | 19 | import collections 20 | import unicodedata 21 | 22 | 23 | def convert_to_unicode(text): 24 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 25 | if isinstance(text, str): 26 | return text 27 | elif isinstance(text, bytes): 28 | return text.decode("utf-8", "ignore") 29 | else: 30 | raise ValueError("Unsupported string type: %s" % (type(text))) 31 | 32 | 33 | 34 | def printable_text(text): 35 | """Returns text encoded in a way suitable for print.""" 36 | 37 | # These functions want `str` for both Python2 and Python3, but in one case 38 | # it's a Unicode string and in the other it's a byte string. 39 | if isinstance(text, str): 40 | return text 41 | elif isinstance(text, bytes): 42 | return text.decode("utf-8", "ignore") 43 | else: 44 | raise ValueError("Unsupported string type: %s" % (type(text))) 45 | 46 | 47 | 48 | def load_vocab(vocab_file): 49 | """Loads a vocabulary file into a dictionary.""" 50 | vocab = collections.OrderedDict() 51 | index = 0 52 | with open(vocab_file, "r") as reader: 53 | while True: 54 | token = convert_to_unicode(reader.readline()) 55 | if not token: 56 | break 57 | token = token.strip() 58 | vocab[token] = index 59 | index += 1 60 | return vocab 61 | 62 | 63 | def convert_by_vocab(vocab, items): 64 | """Converts a sequence of [tokens|ids] using the vocab.""" 65 | output = [] 66 | for item in items: 67 | output.append(vocab[item]) 68 | return output 69 | 70 | 71 | def convert_tokens_to_ids(vocab, tokens): 72 | return convert_by_vocab(vocab, tokens) 73 | 74 | 75 | def convert_ids_to_tokens(inv_vocab, ids): 76 | return convert_by_vocab(inv_vocab, ids) 77 | 78 | 79 | def whitespace_tokenize(text): 80 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 81 | text = text.strip() 82 | if not text: 83 | return [] 84 | tokens = text.split() 85 | return tokens 86 | 87 | 88 | class FullTokenizer(object): 89 | """Runs end-to-end tokenziation.""" 90 | 91 | def __init__(self, vocab_file, do_lower_case=True): 92 | self.vocab = load_vocab(vocab_file) 93 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 94 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 95 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 96 | 97 | def tokenize(self, text): 98 | split_tokens = [] 99 | for token in self.basic_tokenizer.tokenize(text): 100 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 101 | split_tokens.append(sub_token) 102 | 103 | return split_tokens 104 | 105 | def convert_tokens_to_ids(self, tokens): 106 | return convert_by_vocab(self.vocab, tokens) 107 | 108 | def convert_ids_to_tokens(self, ids): 109 | return convert_by_vocab(self.inv_vocab, ids) 110 | 111 | 112 | class BasicTokenizer(object): 113 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 114 | 115 | def __init__(self, do_lower_case=True): 116 | """Constructs a BasicTokenizer. 117 | 118 | Args: 119 | do_lower_case: Whether to lower case the input. 120 | """ 121 | self.do_lower_case = do_lower_case 122 | 123 | def tokenize(self, text): 124 | """Tokenizes a piece of text.""" 125 | text = convert_to_unicode(text) 126 | text = self._clean_text(text) 127 | 128 | # This was added on November 1st, 2018 for the multilingual and Chinese 129 | # models. This is also applied to the English models now, but it doesn't 130 | # matter since the English models were not trained on any Chinese data 131 | # and generally don't have any Chinese data in them (there are Chinese 132 | # characters in the vocabulary because Wikipedia does have some Chinese 133 | # words in the English Wikipedia.). 134 | text = self._tokenize_chinese_chars(text) 135 | 136 | orig_tokens = whitespace_tokenize(text) 137 | split_tokens = [] 138 | for token in orig_tokens: 139 | if self.do_lower_case: 140 | token = token.lower() 141 | token = self._run_strip_accents(token) 142 | split_tokens.extend(self._run_split_on_punc(token)) 143 | 144 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 145 | return output_tokens 146 | 147 | def _run_strip_accents(self, text): 148 | """Strips accents from a piece of text.""" 149 | text = unicodedata.normalize("NFD", text) 150 | output = [] 151 | for char in text: 152 | cat = unicodedata.category(char) 153 | if cat == "Mn": 154 | continue 155 | output.append(char) 156 | return "".join(output) 157 | 158 | def _run_split_on_punc(self, text): 159 | """Splits punctuation on a piece of text.""" 160 | chars = list(text) 161 | i = 0 162 | start_new_word = True 163 | output = [] 164 | while i < len(chars): 165 | char = chars[i] 166 | if _is_punctuation(char): 167 | output.append([char]) 168 | start_new_word = True 169 | else: 170 | if start_new_word: 171 | output.append([]) 172 | start_new_word = False 173 | output[-1].append(char) 174 | i += 1 175 | 176 | return ["".join(x) for x in output] 177 | 178 | def _tokenize_chinese_chars(self, text): 179 | """Adds whitespace around any CJK character.""" 180 | output = [] 181 | for char in text: 182 | cp = ord(char) 183 | if self._is_chinese_char(cp): 184 | output.append(" ") 185 | output.append(char) 186 | output.append(" ") 187 | else: 188 | output.append(char) 189 | return "".join(output) 190 | 191 | def _is_chinese_char(self, cp): 192 | """Checks whether CP is the codepoint of a CJK character.""" 193 | # This defines a "chinese character" as anything in the CJK Unicode block: 194 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 195 | # 196 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 197 | # despite its name. The modern Korean Hangul alphabet is a different block, 198 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 199 | # space-separated words, so they are not treated specially and handled 200 | # like the all of the other languages. 201 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 202 | (cp >= 0x3400 and cp <= 0x4DBF) or # 203 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 204 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 205 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 206 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 207 | (cp >= 0xF900 and cp <= 0xFAFF) or # 208 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 209 | return True 210 | 211 | return False 212 | 213 | def _clean_text(self, text): 214 | """Performs invalid character removal and whitespace cleanup on text.""" 215 | output = [] 216 | for char in text: 217 | cp = ord(char) 218 | if cp == 0 or cp == 0xfffd or _is_control(char): 219 | continue 220 | if _is_whitespace(char): 221 | output.append(" ") 222 | else: 223 | output.append(char) 224 | return "".join(output) 225 | 226 | 227 | class WordpieceTokenizer(object): 228 | """Runs WordPiece tokenziation.""" 229 | 230 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 231 | self.vocab = vocab 232 | self.unk_token = unk_token 233 | self.max_input_chars_per_word = max_input_chars_per_word 234 | 235 | def tokenize(self, text): 236 | """Tokenizes a piece of text into its word pieces. 237 | 238 | This uses a greedy longest-match-first algorithm to perform tokenization 239 | using the given vocabulary. 240 | 241 | For example: 242 | input = "unaffable" 243 | output = ["un", "##aff", "##able"] 244 | 245 | Args: 246 | text: A single token or whitespace separated tokens. This should have 247 | already been passed through `BasicTokenizer. 248 | 249 | Returns: 250 | A list of wordpiece tokens. 251 | """ 252 | 253 | text = convert_to_unicode(text) 254 | 255 | output_tokens = [] 256 | for token in whitespace_tokenize(text): 257 | chars = list(token) 258 | if len(chars) > self.max_input_chars_per_word: 259 | output_tokens.append(self.unk_token) 260 | continue 261 | 262 | is_bad = False 263 | start = 0 264 | sub_tokens = [] 265 | while start < len(chars): 266 | end = len(chars) 267 | cur_substr = None 268 | while start < end: 269 | substr = "".join(chars[start:end]) 270 | if start > 0: 271 | substr = "##" + substr 272 | if substr in self.vocab: 273 | cur_substr = substr 274 | break 275 | end -= 1 276 | if cur_substr is None: 277 | is_bad = True 278 | break 279 | sub_tokens.append(cur_substr) 280 | start = end 281 | 282 | if is_bad: 283 | output_tokens.append(self.unk_token) 284 | else: 285 | output_tokens.extend(sub_tokens) 286 | return output_tokens 287 | 288 | 289 | def _is_whitespace(char): 290 | """Checks whether `chars` is a whitespace character.""" 291 | # \t, \n, and \r are technically contorl characters but we treat them 292 | # as whitespace since they are generally considered as such. 293 | if char == " " or char == "\t" or char == "\n" or char == "\r": 294 | return True 295 | cat = unicodedata.category(char) 296 | if cat == "Zs": 297 | return True 298 | return False 299 | 300 | 301 | def _is_control(char): 302 | """Checks whether `chars` is a control character.""" 303 | # These are technically control characters but we count them as whitespace 304 | # characters. 305 | if char == "\t" or char == "\n" or char == "\r": 306 | return False 307 | cat = unicodedata.category(char) 308 | if cat.startswith("C"): 309 | return True 310 | return False 311 | 312 | 313 | def _is_punctuation(char): 314 | """Checks whether `chars` is a punctuation character.""" 315 | cp = ord(char) 316 | # We treat all non-letter/number ASCII as punctuation. 317 | # Characters such as "^", "$", and "`" are not in the Unicode 318 | # Punctuation class but we treat them as punctuation anyways, for 319 | # consistency. 320 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 321 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 322 | return True 323 | cat = unicodedata.category(char) 324 | if cat.startswith("P"): 325 | return True 326 | return False 327 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'electra-pytorch', 5 | packages = find_packages(), 6 | version = '0.1.2', 7 | license='MIT', 8 | description = 'Electra - Pytorch', 9 | author = 'Erik Nijkamp, Phil Wang', 10 | author_email = 'erik.nijkamp@gmail.com, lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/electra-pytorch', 12 | keywords = [ 13 | 'transformers', 14 | 'artificial intelligence', 15 | 'pretraining' 16 | ], 17 | install_requires=[ 18 | 'torch>=1.6.0', 19 | 'transformers==3.0.2', 20 | 'scipy', 21 | 'sklearn' 22 | ], 23 | setup_requires=[ 24 | 'pytest-runner' 25 | ], 26 | tests_require=[ 27 | 'pytest', 28 | 'reformer-pytorch' 29 | ], 30 | classifiers=[ 31 | 'Development Status :: 4 - Beta', 32 | 'Intended Audience :: Developers', 33 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 34 | 'License :: OSI Approved :: MIT License', 35 | 'Programming Language :: Python :: 3.7', 36 | ], 37 | ) -------------------------------------------------------------------------------- /tests/test_electra_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from reformer_pytorch import ReformerLM 4 | 5 | from electra_pytorch import Electra 6 | 7 | def test_electra(): 8 | generator = ReformerLM( 9 | num_tokens = 20000, 10 | dim = 512, 11 | depth = 1, 12 | max_seq_len = 1024 13 | ) 14 | 15 | discriminator = ReformerLM( 16 | num_tokens = 20000, 17 | dim = 512, 18 | depth = 2, 19 | max_seq_len = 1024 20 | ) 21 | 22 | generator.token_emb = discriminator.token_emb 23 | generator.pos_emb = discriminator.pos_emb 24 | 25 | trainer = Electra( 26 | generator, 27 | discriminator, 28 | num_tokens = 20000, 29 | discr_dim = 512, 30 | discr_layer = 'reformer', 31 | pad_token_id = 1, 32 | mask_ignore_token_ids = [2, 3] 33 | ) 34 | 35 | data = torch.randint(0, 20000, (1, 1024)) 36 | results = trainer(data) 37 | results.loss.backward() 38 | 39 | def test_electra_without_magic(): 40 | generator = ReformerLM( 41 | num_tokens = 20000, 42 | dim = 512, 43 | depth = 1, 44 | max_seq_len = 1024 45 | ) 46 | 47 | discriminator = ReformerLM( 48 | num_tokens = 20000, 49 | dim = 512, 50 | depth = 2, 51 | max_seq_len = 1024, 52 | return_embeddings = True 53 | ) 54 | 55 | generator.token_emb = discriminator.token_emb 56 | generator.pos_emb = discriminator.pos_emb 57 | 58 | 59 | discriminator_with_adapter = nn.Sequential( 60 | discriminator, 61 | nn.Linear(512, 1), 62 | nn.Sigmoid() 63 | ) 64 | 65 | trainer = Electra( 66 | generator, 67 | discriminator_with_adapter, 68 | num_tokens = 20000, 69 | pad_token_id = 1, 70 | mask_ignore_token_ids = [2, 3] 71 | ) 72 | 73 | data = torch.randint(0, 20000, (1, 1024)) 74 | results = trainer(data) 75 | results.loss.backward() 76 | --------------------------------------------------------------------------------