├── .github
└── workflows
│ ├── python-publish.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── electra.png
├── electra_pytorch
├── __init__.py
└── electra_pytorch.py
├── examples
└── glue
│ ├── download.py
│ ├── metrics.py
│ ├── processors.py
│ ├── run.py
│ └── utils.py
├── pretraining
└── openwebtext
│ ├── arg.py
│ ├── dataset.py
│ ├── preprocess.py
│ ├── pretrain.py
│ ├── small_discriminator.json
│ ├── small_generator.json
│ └── tokenization.py
├── setup.py
└── tests
└── test_electra_pytorch.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: [3.8]
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install setuptools wheel twine pytest torch
30 | python setup.py install
31 | - name: Test
32 | run: |
33 | python setup.py test
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Data
132 | data
133 | output
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Electra - Pytorch
4 |
5 | A simple working wrapper for fast pretraining of language models as detailed in this paper. It speeds up training (in comparison to normal masked language modeling) by a factor of 4x, and eventually reaches better performance if trained for even longer. Special thanks to Erik Nijkamp for taking the time to replicate the results for GLUE.
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install electra-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | The following example uses `reformer-pytorch`, which is available to be pip installed.
16 |
17 | ```python
18 | import torch
19 | from torch import nn
20 | from reformer_pytorch import ReformerLM
21 |
22 | from electra_pytorch import Electra
23 |
24 | # (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
25 |
26 | generator = ReformerLM(
27 | num_tokens = 20000,
28 | emb_dim = 128,
29 | dim = 256, # smaller hidden dimension
30 | heads = 4, # less heads
31 | ff_mult = 2, # smaller feed forward intermediate dimension
32 | dim_head = 64,
33 | depth = 12,
34 | max_seq_len = 1024
35 | )
36 |
37 | discriminator = ReformerLM(
38 | num_tokens = 20000,
39 | emb_dim = 128,
40 | dim = 1024,
41 | dim_head = 64,
42 | heads = 16,
43 | depth = 12,
44 | ff_mult = 4,
45 | max_seq_len = 1024
46 | )
47 |
48 | # (2) weight tie the token and positional embeddings of generator and discriminator
49 |
50 | generator.token_emb = discriminator.token_emb
51 | generator.pos_emb = discriminator.pos_emb
52 | # weight tie any other embeddings if available, token type embeddings, etc.
53 |
54 | # (3) instantiate electra
55 |
56 | trainer = Electra(
57 | generator,
58 | discriminator,
59 | discr_dim = 1024, # the embedding dimension of the discriminator
60 | discr_layer = 'reformer', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
61 | mask_token_id = 2, # the token id reserved for masking
62 | pad_token_id = 0, # the token id for padding
63 | mask_prob = 0.15, # masking probability for masked language modeling
64 | mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep)
65 | )
66 |
67 | # (4) train
68 |
69 | data = torch.randint(0, 20000, (1, 1024))
70 |
71 | results = trainer(data)
72 | results.loss.backward()
73 |
74 | # after much training, the discriminator should have improved
75 |
76 | torch.save(discriminator, f'./pretrained-model.pt')
77 | ```
78 |
79 | If you would rather not have the framework auto-magically intercept the hidden output of the discriminator, you can pass in the discriminator (with the extra linear [dim x 1]) by yourself with the following.
80 |
81 | ```python
82 | import torch
83 | from torch import nn
84 | from reformer_pytorch import ReformerLM
85 |
86 | from electra_pytorch import Electra
87 |
88 | # (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
89 |
90 | generator = ReformerLM(
91 | num_tokens = 20000,
92 | emb_dim = 128,
93 | dim = 256, # smaller hidden dimension
94 | heads = 4, # less heads
95 | ff_mult = 2, # smaller feed forward intermediate dimension
96 | dim_head = 64,
97 | depth = 12,
98 | max_seq_len = 1024
99 | )
100 |
101 | discriminator = ReformerLM(
102 | num_tokens = 20000,
103 | emb_dim = 128,
104 | dim = 1024,
105 | dim_head = 64,
106 | heads = 16,
107 | depth = 12,
108 | ff_mult = 4,
109 | max_seq_len = 1024,
110 | return_embeddings = True
111 | )
112 |
113 | # (2) weight tie the token and positional embeddings of generator and discriminator
114 |
115 | generator.token_emb = discriminator.token_emb
116 | generator.pos_emb = discriminator.pos_emb
117 | # weight tie any other embeddings if available, token type embeddings, etc.
118 |
119 | # (3) instantiate electra
120 |
121 | discriminator_with_adapter = nn.Sequential(discriminator, nn.Linear(1024, 1))
122 |
123 | trainer = Electra(
124 | generator,
125 | discriminator_with_adapter,
126 | mask_token_id = 2, # the token id reserved for masking
127 | pad_token_id = 0, # the token id for padding
128 | mask_prob = 0.15, # masking probability for masked language modeling
129 | mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep)
130 | )
131 |
132 | # (4) train
133 |
134 | data = torch.randint(0, 20000, (1, 1024))
135 |
136 | results = trainer(data)
137 | results.loss.backward()
138 |
139 | # after much training, the discriminator should have improved
140 |
141 | torch.save(discriminator, f'./pretrained-model.pt')
142 | ```
143 |
144 | ## Important details for successful training
145 |
146 | The generator should be roughly a quarter to at most one half of the discriminator's size for effective training. Any greater and the generator will be too good and the adversarial game collapses. This was done by reducing the hidden dimension, feed forward hidden dimension, and number of attention heads in the paper.
147 |
148 | ## Testing
149 |
150 | ```bash
151 | $ python setup.py test
152 | ```
153 |
154 | ## Training
155 |
156 | 1. Download the [OpenWebText](https://github.com/jcpeterson/openwebtext) dataset.
157 |
158 | ```bash
159 | $ mkdir data
160 | $ cd data
161 | $ pip3 install gdown
162 | $ gdown --id 1EA5V0oetDCOke7afsktL_JDQ-ETtNOvx
163 | $ tar -xf openwebtext.tar.xz
164 | $ wget https://storage.googleapis.com/electra-data/vocab.txt
165 | $ cd ..
166 | ```
167 |
168 | 2. Tokenize dataset.
169 |
170 | ```bash
171 | $ python pretraining/openwebtext/preprocess.py
172 | ```
173 |
174 | 3. Pre-train.
175 |
176 | ```bash
177 | $ python pretraining/openwebtext/pretrain.py
178 | ```
179 |
180 | 4. Download GLUE dataset.
181 |
182 | ```bash
183 | $ python examples/glue/download.py
184 | ```
185 |
186 | 5. Fine-tune on the MRPC sub-task of the GLUE benchmark.
187 |
188 | ```bash
189 | $ python examples/glue/run.py --model_name_or_path output/yyyy-mm-dd-hh-mm-ss/ckpt/200000
190 | ```
191 |
192 | ## Citations
193 |
194 | ```bibtex
195 | @misc{clark2020electra,
196 | title={ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators},
197 | author={Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning},
198 | year={2020},
199 | eprint={2003.10555},
200 | archivePrefix={arXiv},
201 | primaryClass={cs.CL}
202 | }
203 | ```
204 |
--------------------------------------------------------------------------------
/electra.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/electra-pytorch/5b8bae5c3575b7529891c1b878c24688f57d7ca1/electra.png
--------------------------------------------------------------------------------
/electra_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from electra_pytorch.electra_pytorch import Electra
2 |
--------------------------------------------------------------------------------
/electra_pytorch/electra_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import reduce
3 | from collections import namedtuple
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | # constants
10 |
11 | Results = namedtuple('Results', [
12 | 'loss',
13 | 'mlm_loss',
14 | 'disc_loss',
15 | 'gen_acc',
16 | 'disc_acc',
17 | 'disc_labels',
18 | 'disc_predictions'
19 | ])
20 |
21 | # helpers
22 |
23 | def log(t, eps=1e-9):
24 | return torch.log(t + eps)
25 |
26 | def gumbel_noise(t):
27 | noise = torch.zeros_like(t).uniform_(0, 1)
28 | return -log(-log(noise))
29 |
30 | def gumbel_sample(t, temperature = 1.):
31 | return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)
32 |
33 | def prob_mask_like(t, prob):
34 | return torch.zeros_like(t).float().uniform_(0, 1) < prob
35 |
36 | def mask_with_tokens(t, token_ids):
37 | init_no_mask = torch.full_like(t, False, dtype=torch.bool)
38 | mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
39 | return mask
40 |
41 | def get_mask_subset_with_prob(mask, prob):
42 | batch, seq_len, device = *mask.shape, mask.device
43 | max_masked = math.ceil(prob * seq_len)
44 |
45 | num_tokens = mask.sum(dim=-1, keepdim=True)
46 | mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
47 | mask_excess = mask_excess[:, :max_masked]
48 |
49 | rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
50 | _, sampled_indices = rand.topk(max_masked, dim=-1)
51 | sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
52 |
53 | new_mask = torch.zeros((batch, seq_len + 1), device=device)
54 | new_mask.scatter_(-1, sampled_indices, 1)
55 | return new_mask[:, 1:].bool()
56 |
57 | # hidden layer extractor class, for magically adding adapter to language model to be pretrained
58 |
59 | class HiddenLayerExtractor(nn.Module):
60 | def __init__(self, net, layer = -2):
61 | super().__init__()
62 | self.net = net
63 | self.layer = layer
64 |
65 | self.hidden = None
66 | self.hook_registered = False
67 |
68 | def _find_layer(self):
69 | if type(self.layer) == str:
70 | modules = dict([*self.net.named_modules()])
71 | return modules.get(self.layer, None)
72 | elif type(self.layer) == int:
73 | children = [*self.net.children()]
74 | return children[self.layer]
75 | return None
76 |
77 | def _hook(self, _, __, output):
78 | self.hidden = output
79 |
80 | def _register_hook(self):
81 | layer = self._find_layer()
82 | assert layer is not None, f'hidden layer ({self.layer}) not found'
83 | handle = layer.register_forward_hook(self._hook)
84 | self.hook_registered = True
85 |
86 | def forward(self, x):
87 | if self.layer == -1:
88 | return self.net(x)
89 |
90 | if not self.hook_registered:
91 | self._register_hook()
92 |
93 | _ = self.net(x)
94 | hidden = self.hidden
95 | self.hidden = None
96 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
97 | return hidden
98 |
99 | # main electra class
100 |
101 | class Electra(nn.Module):
102 | def __init__(
103 | self,
104 | generator,
105 | discriminator,
106 | *,
107 | num_tokens = None,
108 | discr_dim = -1,
109 | discr_layer = -1,
110 | mask_prob = 0.15,
111 | replace_prob = 0.85,
112 | random_token_prob = 0.,
113 | mask_token_id = 2,
114 | pad_token_id = 0,
115 | mask_ignore_token_ids = [],
116 | disc_weight = 50.,
117 | gen_weight = 1.,
118 | temperature = 1.):
119 | super().__init__()
120 |
121 | self.generator = generator
122 | self.discriminator = discriminator
123 |
124 | if discr_dim > 0:
125 | self.discriminator = nn.Sequential(
126 | HiddenLayerExtractor(discriminator, layer = discr_layer),
127 | nn.Linear(discr_dim, 1)
128 | )
129 |
130 | # mlm related probabilities
131 | self.mask_prob = mask_prob
132 | self.replace_prob = replace_prob
133 |
134 | self.num_tokens = num_tokens
135 | self.random_token_prob = random_token_prob
136 |
137 | # token ids
138 | self.pad_token_id = pad_token_id
139 | self.mask_token_id = mask_token_id
140 | self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])
141 |
142 | # sampling temperature
143 | self.temperature = temperature
144 |
145 | # loss weights
146 | self.disc_weight = disc_weight
147 | self.gen_weight = gen_weight
148 |
149 |
150 | def forward(self, input, **kwargs):
151 | b, t = input.shape
152 |
153 | replace_prob = prob_mask_like(input, self.replace_prob)
154 |
155 | # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
156 | # also do not include these special tokens in the tokens chosen at random
157 | no_mask = mask_with_tokens(input, self.mask_ignore_token_ids)
158 | mask = get_mask_subset_with_prob(~no_mask, self.mask_prob)
159 |
160 | # get mask indices
161 | mask_indices = torch.nonzero(mask, as_tuple=True)
162 |
163 | # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
164 | masked_input = input.clone().detach()
165 |
166 | # set inverse of mask to padding tokens for labels
167 | gen_labels = input.masked_fill(~mask, self.pad_token_id)
168 |
169 | # clone the mask, for potential modification if random tokens are involved
170 | # not to be mistakened for the mask above, which is for all tokens, whether not replaced nor replaced with random tokens
171 | masking_mask = mask.clone()
172 |
173 | # if random token probability > 0 for mlm
174 | if self.random_token_prob > 0:
175 | assert self.num_tokens is not None, 'Number of tokens (num_tokens) must be passed to Electra for randomizing tokens during masked language modeling'
176 |
177 | random_token_prob = prob_mask_like(input, self.random_token_prob)
178 | random_tokens = torch.randint(0, self.num_tokens, input.shape, device=input.device)
179 | random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids)
180 | random_token_prob &= ~random_no_mask
181 | masked_input = torch.where(random_token_prob, random_tokens, masked_input)
182 |
183 | # remove random token prob mask from masking mask
184 | masking_mask = masking_mask & ~random_token_prob
185 |
186 | # [mask] input
187 | masked_input = masked_input.masked_fill(masking_mask * replace_prob, self.mask_token_id)
188 |
189 | # get generator output and get mlm loss
190 | logits = self.generator(masked_input, **kwargs)
191 |
192 | mlm_loss = F.cross_entropy(
193 | logits.transpose(1, 2),
194 | gen_labels,
195 | ignore_index = self.pad_token_id
196 | )
197 |
198 | # use mask from before to select logits that need sampling
199 | sample_logits = logits[mask_indices]
200 |
201 | # sample
202 | sampled = gumbel_sample(sample_logits, temperature = self.temperature)
203 |
204 | # scatter the sampled values back to the input
205 | disc_input = input.clone()
206 | disc_input[mask_indices] = sampled.detach()
207 |
208 | # generate discriminator labels, with replaced as True and original as False
209 | disc_labels = (input != disc_input).float().detach()
210 |
211 | # get discriminator predictions of replaced / original
212 | non_padded_indices = torch.nonzero(input != self.pad_token_id, as_tuple=True)
213 |
214 | # get discriminator output and binary cross entropy loss
215 | disc_logits = self.discriminator(disc_input, **kwargs)
216 | disc_logits = disc_logits.reshape_as(disc_labels)
217 |
218 | disc_loss = F.binary_cross_entropy_with_logits(
219 | disc_logits[non_padded_indices],
220 | disc_labels[non_padded_indices]
221 | )
222 |
223 | # gather metrics
224 | with torch.no_grad():
225 | gen_predictions = torch.argmax(logits, dim=-1)
226 | disc_predictions = torch.round((torch.sign(disc_logits) + 1.0) * 0.5)
227 | gen_acc = (gen_labels[mask] == gen_predictions[mask]).float().mean()
228 | disc_acc = 0.5 * (disc_labels[mask] == disc_predictions[mask]).float().mean() + 0.5 * (disc_labels[~mask] == disc_predictions[~mask]).float().mean()
229 |
230 | # return weighted sum of losses
231 | return Results(self.gen_weight * mlm_loss + self.disc_weight * disc_loss, mlm_loss, disc_loss, gen_acc, disc_acc, disc_labels, disc_predictions)
232 |
--------------------------------------------------------------------------------
/examples/glue/download.py:
--------------------------------------------------------------------------------
1 | ''' Script for downloading all GLUE data.
2 |
3 | Note: for legal reasons, we are unable to host MRPC.
4 | You can either use the version hosted by the SentEval team, which is already tokenized,
5 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
6 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
7 | You should then rename and place specific files in a folder (see below for an example).
8 |
9 | mkdir MRPC
10 | cabextract MSRParaphraseCorpus.msi -d MRPC
11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
13 | rm MRPC/_*
14 | rm MSRParaphraseCorpus.msi
15 |
16 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
17 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
18 | '''
19 |
20 | import os
21 | import sys
22 | import shutil
23 | import argparse
24 | import tempfile
25 | import urllib.request
26 | import zipfile
27 |
28 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
29 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
30 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
31 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
32 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
33 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
34 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
35 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
36 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
37 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
38 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
39 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
40 |
41 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
42 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
43 |
44 | def download_and_extract(task, data_dir):
45 | print("Downloading and extracting %s..." % task)
46 | data_file = "%s.zip" % task
47 | urllib.request.urlretrieve(TASK2PATH[task], data_file)
48 | with zipfile.ZipFile(data_file) as zip_ref:
49 | zip_ref.extractall(data_dir)
50 | os.remove(data_file)
51 | print("\tCompleted!")
52 |
53 | def format_mrpc(data_dir, path_to_data):
54 | print("Processing MRPC...")
55 | mrpc_dir = os.path.join(data_dir, "MRPC")
56 | if not os.path.isdir(mrpc_dir):
57 | os.mkdir(mrpc_dir)
58 | if path_to_data:
59 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
60 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
61 | else:
62 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
63 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
64 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
65 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
66 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
67 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
68 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
69 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
70 |
71 | dev_ids = []
72 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
73 | for row in ids_fh:
74 | dev_ids.append(row.strip().split('\t'))
75 |
76 | with open(mrpc_train_file, encoding="utf8") as data_fh, \
77 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
78 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
79 | header = data_fh.readline()
80 | train_fh.write(header)
81 | dev_fh.write(header)
82 | for row in data_fh:
83 | label, id1, id2, s1, s2 = row.strip().split('\t')
84 | if [id1, id2] in dev_ids:
85 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
86 | else:
87 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
88 |
89 | with open(mrpc_test_file, encoding="utf8") as data_fh, \
90 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
91 | header = data_fh.readline()
92 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
93 | for idx, row in enumerate(data_fh):
94 | label, id1, id2, s1, s2 = row.strip().split('\t')
95 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
96 | print("\tCompleted!")
97 |
98 | def download_diagnostic(data_dir):
99 | print("Downloading and extracting diagnostic...")
100 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
101 | os.mkdir(os.path.join(data_dir, "diagnostic"))
102 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
103 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
104 | print("\tCompleted!")
105 | return
106 |
107 | def get_tasks(task_names):
108 | task_names = task_names.split(',')
109 | if "all" in task_names:
110 | tasks = TASKS
111 | else:
112 | tasks = []
113 | for task_name in task_names:
114 | assert task_name in TASKS, "Task %s not found!" % task_name
115 | tasks.append(task_name)
116 | return tasks
117 |
118 | def main(arguments):
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='./data/glue_data')
121 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
122 | type=str, default='all')
123 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
124 | type=str, default='')
125 | args = parser.parse_args(arguments)
126 |
127 | if not os.path.exists(args.data_dir):
128 | os.makedirs(args.data_dir)
129 | tasks = get_tasks(args.tasks)
130 |
131 | for task in tasks:
132 | if task == 'MRPC':
133 | format_mrpc(args.data_dir, args.path_to_mrpc)
134 | elif task == 'diagnostic':
135 | download_diagnostic(args.data_dir)
136 | else:
137 | download_and_extract(task, args.data_dir)
138 |
139 |
140 | if __name__ == '__main__':
141 | sys.exit(main(sys.argv[1:]))
--------------------------------------------------------------------------------
/examples/glue/metrics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | try:
18 | from scipy.stats import pearsonr, spearmanr
19 | from sklearn.metrics import matthews_corrcoef, f1_score
20 |
21 | _has_sklearn = True
22 | except (AttributeError, ImportError):
23 | _has_sklearn = False
24 |
25 |
26 | def is_sklearn_available():
27 | return _has_sklearn
28 |
29 |
30 | if _has_sklearn:
31 |
32 | def simple_accuracy(preds, labels):
33 | return (preds == labels).mean()
34 |
35 | def acc_and_f1(preds, labels):
36 | acc = simple_accuracy(preds, labels)
37 | f1 = f1_score(y_true=labels, y_pred=preds)
38 | return {
39 | "acc": acc,
40 | "f1": f1,
41 | "acc_and_f1": (acc + f1) / 2,
42 | }
43 |
44 | def pearson_and_spearman(preds, labels):
45 | pearson_corr = pearsonr(preds, labels)[0]
46 | spearman_corr = spearmanr(preds, labels)[0]
47 | return {
48 | "pearson": pearson_corr,
49 | "spearmanr": spearman_corr,
50 | "corr": (pearson_corr + spearman_corr) / 2,
51 | }
52 |
53 | def glue_compute_metrics(task_name, preds, labels):
54 | assert len(preds) == len(labels)
55 | if task_name == "cola":
56 | return {"mcc": matthews_corrcoef(labels, preds)}
57 | elif task_name == "sst-2":
58 | return {"acc": simple_accuracy(preds, labels)}
59 | elif task_name == "mrpc":
60 | return acc_and_f1(preds, labels)
61 | elif task_name == "sts-b":
62 | return pearson_and_spearman(preds, labels)
63 | elif task_name == "qqp":
64 | return acc_and_f1(preds, labels)
65 | elif task_name == "mnli":
66 | return {"acc": simple_accuracy(preds, labels)}
67 | elif task_name == "mnli-mm":
68 | return {"acc": simple_accuracy(preds, labels)}
69 | elif task_name == "qnli":
70 | return {"acc": simple_accuracy(preds, labels)}
71 | elif task_name == "rte":
72 | return {"acc": simple_accuracy(preds, labels)}
73 | elif task_name == "wnli":
74 | return {"acc": simple_accuracy(preds, labels)}
75 | elif task_name == "hans":
76 | return {"acc": simple_accuracy(preds, labels)}
77 | else:
78 | raise KeyError(task_name)
79 |
80 | def xnli_compute_metrics(task_name, preds, labels):
81 | assert len(preds) == len(labels)
82 | if task_name == "xnli":
83 | return {"acc": simple_accuracy(preds, labels)}
84 | else:
85 | raise KeyError(task_name)
--------------------------------------------------------------------------------
/examples/glue/processors.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ GLUE processors and helpers """
17 |
18 | import logging
19 | import os
20 |
21 | # from ...file_utils import is_tf_available
22 | from utils import DataProcessor, InputExample, InputFeatures
23 |
24 | is_tf_available = lambda: False
25 |
26 | if is_tf_available():
27 | import tensorflow as tf
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | def glue_convert_examples_to_features(
33 | examples,
34 | tokenizer,
35 | max_length=512,
36 | task=None,
37 | label_list=None,
38 | output_mode=None,
39 | pad_on_left=False,
40 | pad_token=0,
41 | pad_token_segment_id=0,
42 | mask_padding_with_zero=True,
43 | ):
44 | """
45 | Loads a data file into a list of ``InputFeatures``
46 |
47 | Args:
48 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
49 | tokenizer: Instance of a tokenizer that will tokenize the examples
50 | max_length: Maximum example length
51 | task: GLUE task
52 | label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
53 | output_mode: String indicating the output mode. Either ``regression`` or ``classification``
54 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
55 | pad_token: Padding token
56 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4)
57 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
58 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
59 | actual values)
60 |
61 | Returns:
62 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
63 | containing the task-specific features. If the input is a list of ``InputExamples``, will return
64 | a list of task-specific ``InputFeatures`` which can be fed to the model.
65 |
66 | """
67 | is_tf_dataset = False
68 | if is_tf_available() and isinstance(examples, tf.data.Dataset):
69 | is_tf_dataset = True
70 |
71 | if task is not None:
72 | processor = glue_processors[task]()
73 | if label_list is None:
74 | label_list = processor.get_labels()
75 | logger.info("Using label list %s for task %s" % (label_list, task))
76 | if output_mode is None:
77 | output_mode = glue_output_modes[task]
78 | logger.info("Using output mode %s for task %s" % (output_mode, task))
79 |
80 | label_map = {label: i for i, label in enumerate(label_list)}
81 |
82 | features = []
83 | for (ex_index, example) in enumerate(examples):
84 | len_examples = 0
85 | if is_tf_dataset:
86 | example = processor.get_example_from_tensor_dict(example)
87 | example = processor.tfds_map(example)
88 | len_examples = tf.data.experimental.cardinality(examples)
89 | else:
90 | len_examples = len(examples)
91 | if ex_index % 10000 == 0:
92 | logger.info("Writing example %d/%d" % (ex_index, len_examples))
93 |
94 | inputs = tokenizer.encode_plus(
95 | example.text_a, example.text_b, add_special_tokens=True, max_length=max_length, return_token_type_ids=True,
96 | )
97 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
98 |
99 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
100 | # tokens are attended to.
101 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
102 |
103 | # Zero-pad up to the sequence length.
104 | padding_length = max_length - len(input_ids)
105 | if pad_on_left:
106 | input_ids = ([pad_token] * padding_length) + input_ids
107 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
108 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
109 | else:
110 | input_ids = input_ids + ([pad_token] * padding_length)
111 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
112 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
113 |
114 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
115 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
116 | len(attention_mask), max_length
117 | )
118 | assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
119 | len(token_type_ids), max_length
120 | )
121 |
122 | if output_mode == "classification":
123 | label = label_map[example.label]
124 | elif output_mode == "regression":
125 | label = float(example.label)
126 | else:
127 | raise KeyError(output_mode)
128 |
129 | if ex_index < 5:
130 | logger.info("*** Example ***")
131 | logger.info("guid: %s" % (example.guid))
132 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
133 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
134 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
135 | logger.info("label: %s (id = %d)" % (example.label, label))
136 |
137 | features.append(
138 | InputFeatures(
139 | input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
140 | )
141 | )
142 |
143 | if is_tf_available() and is_tf_dataset:
144 |
145 | def gen():
146 | for ex in features:
147 | yield (
148 | {
149 | "input_ids": ex.input_ids,
150 | "attention_mask": ex.attention_mask,
151 | "token_type_ids": ex.token_type_ids,
152 | },
153 | ex.label,
154 | )
155 |
156 | return tf.data.Dataset.from_generator(
157 | gen,
158 | ({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
159 | (
160 | {
161 | "input_ids": tf.TensorShape([None]),
162 | "attention_mask": tf.TensorShape([None]),
163 | "token_type_ids": tf.TensorShape([None]),
164 | },
165 | tf.TensorShape([]),
166 | ),
167 | )
168 |
169 | return features
170 |
171 |
172 | class MrpcProcessor(DataProcessor):
173 | """Processor for the MRPC data set (GLUE version)."""
174 |
175 | def get_example_from_tensor_dict(self, tensor_dict):
176 | """See base class."""
177 | return InputExample(
178 | tensor_dict["idx"].numpy(),
179 | tensor_dict["sentence1"].numpy().decode("utf-8"),
180 | tensor_dict["sentence2"].numpy().decode("utf-8"),
181 | str(tensor_dict["label"].numpy()),
182 | )
183 |
184 | def get_train_examples(self, data_dir):
185 | """See base class."""
186 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
187 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
188 |
189 | def get_dev_examples(self, data_dir):
190 | """See base class."""
191 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
192 |
193 | def get_labels(self):
194 | """See base class."""
195 | return ["0", "1"]
196 |
197 | def _create_examples(self, lines, set_type):
198 | """Creates examples for the training and dev sets."""
199 | examples = []
200 | for (i, line) in enumerate(lines):
201 | if i == 0:
202 | continue
203 | guid = "%s-%s" % (set_type, i)
204 | text_a = line[3]
205 | text_b = line[4]
206 | label = line[0]
207 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
208 | return examples
209 |
210 |
211 | class MnliProcessor(DataProcessor):
212 | """Processor for the MultiNLI data set (GLUE version)."""
213 |
214 | def get_example_from_tensor_dict(self, tensor_dict):
215 | """See base class."""
216 | return InputExample(
217 | tensor_dict["idx"].numpy(),
218 | tensor_dict["premise"].numpy().decode("utf-8"),
219 | tensor_dict["hypothesis"].numpy().decode("utf-8"),
220 | str(tensor_dict["label"].numpy()),
221 | )
222 |
223 | def get_train_examples(self, data_dir):
224 | """See base class."""
225 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
226 |
227 | def get_dev_examples(self, data_dir):
228 | """See base class."""
229 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
230 |
231 | def get_labels(self):
232 | """See base class."""
233 | return ["contradiction", "entailment", "neutral"]
234 |
235 | def _create_examples(self, lines, set_type):
236 | """Creates examples for the training and dev sets."""
237 | examples = []
238 | for (i, line) in enumerate(lines):
239 | if i == 0:
240 | continue
241 | guid = "%s-%s" % (set_type, line[0])
242 | text_a = line[8]
243 | text_b = line[9]
244 | label = line[-1]
245 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
246 | return examples
247 |
248 |
249 | class MnliMismatchedProcessor(MnliProcessor):
250 | """Processor for the MultiNLI Mismatched data set (GLUE version)."""
251 |
252 | def get_dev_examples(self, data_dir):
253 | """See base class."""
254 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
255 |
256 |
257 | class ColaProcessor(DataProcessor):
258 | """Processor for the CoLA data set (GLUE version)."""
259 |
260 | def get_example_from_tensor_dict(self, tensor_dict):
261 | """See base class."""
262 | return InputExample(
263 | tensor_dict["idx"].numpy(),
264 | tensor_dict["sentence"].numpy().decode("utf-8"),
265 | None,
266 | str(tensor_dict["label"].numpy()),
267 | )
268 |
269 | def get_train_examples(self, data_dir):
270 | """See base class."""
271 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
272 |
273 | def get_dev_examples(self, data_dir):
274 | """See base class."""
275 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
276 |
277 | def get_labels(self):
278 | """See base class."""
279 | return ["0", "1"]
280 |
281 | def _create_examples(self, lines, set_type):
282 | """Creates examples for the training and dev sets."""
283 | examples = []
284 | for (i, line) in enumerate(lines):
285 | guid = "%s-%s" % (set_type, i)
286 | text_a = line[3]
287 | label = line[1]
288 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
289 | return examples
290 |
291 |
292 | class Sst2Processor(DataProcessor):
293 | """Processor for the SST-2 data set (GLUE version)."""
294 |
295 | def get_example_from_tensor_dict(self, tensor_dict):
296 | """See base class."""
297 | return InputExample(
298 | tensor_dict["idx"].numpy(),
299 | tensor_dict["sentence"].numpy().decode("utf-8"),
300 | None,
301 | str(tensor_dict["label"].numpy()),
302 | )
303 |
304 | def get_train_examples(self, data_dir):
305 | """See base class."""
306 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
307 |
308 | def get_dev_examples(self, data_dir):
309 | """See base class."""
310 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
311 |
312 | def get_labels(self):
313 | """See base class."""
314 | return ["0", "1"]
315 |
316 | def _create_examples(self, lines, set_type):
317 | """Creates examples for the training and dev sets."""
318 | examples = []
319 | for (i, line) in enumerate(lines):
320 | if i == 0:
321 | continue
322 | guid = "%s-%s" % (set_type, i)
323 | text_a = line[0]
324 | label = line[1]
325 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
326 | return examples
327 |
328 |
329 | class StsbProcessor(DataProcessor):
330 | """Processor for the STS-B data set (GLUE version)."""
331 |
332 | def get_example_from_tensor_dict(self, tensor_dict):
333 | """See base class."""
334 | return InputExample(
335 | tensor_dict["idx"].numpy(),
336 | tensor_dict["sentence1"].numpy().decode("utf-8"),
337 | tensor_dict["sentence2"].numpy().decode("utf-8"),
338 | str(tensor_dict["label"].numpy()),
339 | )
340 |
341 | def get_train_examples(self, data_dir):
342 | """See base class."""
343 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
344 |
345 | def get_dev_examples(self, data_dir):
346 | """See base class."""
347 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
348 |
349 | def get_labels(self):
350 | """See base class."""
351 | return [None]
352 |
353 | def _create_examples(self, lines, set_type):
354 | """Creates examples for the training and dev sets."""
355 | examples = []
356 | for (i, line) in enumerate(lines):
357 | if i == 0:
358 | continue
359 | guid = "%s-%s" % (set_type, line[0])
360 | text_a = line[7]
361 | text_b = line[8]
362 | label = line[-1]
363 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
364 | return examples
365 |
366 |
367 | class QqpProcessor(DataProcessor):
368 | """Processor for the QQP data set (GLUE version)."""
369 |
370 | def get_example_from_tensor_dict(self, tensor_dict):
371 | """See base class."""
372 | return InputExample(
373 | tensor_dict["idx"].numpy(),
374 | tensor_dict["question1"].numpy().decode("utf-8"),
375 | tensor_dict["question2"].numpy().decode("utf-8"),
376 | str(tensor_dict["label"].numpy()),
377 | )
378 |
379 | def get_train_examples(self, data_dir):
380 | """See base class."""
381 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
382 |
383 | def get_dev_examples(self, data_dir):
384 | """See base class."""
385 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
386 |
387 | def get_labels(self):
388 | """See base class."""
389 | return ["0", "1"]
390 |
391 | def _create_examples(self, lines, set_type):
392 | """Creates examples for the training and dev sets."""
393 | examples = []
394 | for (i, line) in enumerate(lines):
395 | if i == 0:
396 | continue
397 | guid = "%s-%s" % (set_type, line[0])
398 | try:
399 | text_a = line[3]
400 | text_b = line[4]
401 | label = line[5]
402 | except IndexError:
403 | continue
404 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
405 | return examples
406 |
407 |
408 | class QnliProcessor(DataProcessor):
409 | """Processor for the QNLI data set (GLUE version)."""
410 |
411 | def get_example_from_tensor_dict(self, tensor_dict):
412 | """See base class."""
413 | return InputExample(
414 | tensor_dict["idx"].numpy(),
415 | tensor_dict["question"].numpy().decode("utf-8"),
416 | tensor_dict["sentence"].numpy().decode("utf-8"),
417 | str(tensor_dict["label"].numpy()),
418 | )
419 |
420 | def get_train_examples(self, data_dir):
421 | """See base class."""
422 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
423 |
424 | def get_dev_examples(self, data_dir):
425 | """See base class."""
426 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
427 |
428 | def get_labels(self):
429 | """See base class."""
430 | return ["entailment", "not_entailment"]
431 |
432 | def _create_examples(self, lines, set_type):
433 | """Creates examples for the training and dev sets."""
434 | examples = []
435 | for (i, line) in enumerate(lines):
436 | if i == 0:
437 | continue
438 | guid = "%s-%s" % (set_type, line[0])
439 | text_a = line[1]
440 | text_b = line[2]
441 | label = line[-1]
442 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
443 | return examples
444 |
445 |
446 | class RteProcessor(DataProcessor):
447 | """Processor for the RTE data set (GLUE version)."""
448 |
449 | def get_example_from_tensor_dict(self, tensor_dict):
450 | """See base class."""
451 | return InputExample(
452 | tensor_dict["idx"].numpy(),
453 | tensor_dict["sentence1"].numpy().decode("utf-8"),
454 | tensor_dict["sentence2"].numpy().decode("utf-8"),
455 | str(tensor_dict["label"].numpy()),
456 | )
457 |
458 | def get_train_examples(self, data_dir):
459 | """See base class."""
460 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
461 |
462 | def get_dev_examples(self, data_dir):
463 | """See base class."""
464 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
465 |
466 | def get_labels(self):
467 | """See base class."""
468 | return ["entailment", "not_entailment"]
469 |
470 | def _create_examples(self, lines, set_type):
471 | """Creates examples for the training and dev sets."""
472 | examples = []
473 | for (i, line) in enumerate(lines):
474 | if i == 0:
475 | continue
476 | guid = "%s-%s" % (set_type, line[0])
477 | text_a = line[1]
478 | text_b = line[2]
479 | label = line[-1]
480 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
481 | return examples
482 |
483 |
484 | class WnliProcessor(DataProcessor):
485 | """Processor for the WNLI data set (GLUE version)."""
486 |
487 | def get_example_from_tensor_dict(self, tensor_dict):
488 | """See base class."""
489 | return InputExample(
490 | tensor_dict["idx"].numpy(),
491 | tensor_dict["sentence1"].numpy().decode("utf-8"),
492 | tensor_dict["sentence2"].numpy().decode("utf-8"),
493 | str(tensor_dict["label"].numpy()),
494 | )
495 |
496 | def get_train_examples(self, data_dir):
497 | """See base class."""
498 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
499 |
500 | def get_dev_examples(self, data_dir):
501 | """See base class."""
502 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
503 |
504 | def get_labels(self):
505 | """See base class."""
506 | return ["0", "1"]
507 |
508 | def _create_examples(self, lines, set_type):
509 | """Creates examples for the training and dev sets."""
510 | examples = []
511 | for (i, line) in enumerate(lines):
512 | if i == 0:
513 | continue
514 | guid = "%s-%s" % (set_type, line[0])
515 | text_a = line[1]
516 | text_b = line[2]
517 | label = line[-1]
518 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
519 | return examples
520 |
521 |
522 | glue_tasks_num_labels = {
523 | "cola": 2,
524 | "mnli": 3,
525 | "mrpc": 2,
526 | "sst-2": 2,
527 | "sts-b": 1,
528 | "qqp": 2,
529 | "qnli": 2,
530 | "rte": 2,
531 | "wnli": 2,
532 | }
533 |
534 | glue_processors = {
535 | "cola": ColaProcessor,
536 | "mnli": MnliProcessor,
537 | "mnli-mm": MnliMismatchedProcessor,
538 | "mrpc": MrpcProcessor,
539 | "sst-2": Sst2Processor,
540 | "sts-b": StsbProcessor,
541 | "qqp": QqpProcessor,
542 | "qnli": QnliProcessor,
543 | "rte": RteProcessor,
544 | "wnli": WnliProcessor,
545 | }
546 |
547 | glue_output_modes = {
548 | "cola": "classification",
549 | "mnli": "classification",
550 | "mnli-mm": "classification",
551 | "mrpc": "classification",
552 | "sst-2": "classification",
553 | "sts-b": "regression",
554 | "qqp": "classification",
555 | "qnli": "classification",
556 | "rte": "classification",
557 | "wnli": "classification",
558 | }
--------------------------------------------------------------------------------
/examples/glue/run.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa)."""
17 |
18 |
19 | import argparse
20 | import glob
21 | import json
22 | import logging
23 | import os
24 | import random
25 |
26 | import numpy as np
27 | import torch
28 | import torch.nn as nn
29 | from torch.optim import AdamW
30 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
31 | from torch.utils.data.distributed import DistributedSampler
32 | from tqdm import tqdm, trange
33 |
34 | from metrics import glue_compute_metrics as compute_metrics
35 | from processors import glue_convert_examples_to_features as convert_examples_to_features
36 | from processors import glue_output_modes as output_modes
37 | from processors import glue_processors as processors
38 | from processors import glue_tasks_num_labels as task_num_labels
39 |
40 | logger = logging.getLogger(__name__)
41 |
42 | ##################################################
43 | # adapters for Google-like GLUE code
44 |
45 | class TokenizerAdapter:
46 | def __init__(self, tokenizer, pad_token, cls_token="[CLS]", sep_token="[SEP]"):
47 | self.tokenizer = tokenizer
48 | self.pad_token = pad_token
49 | self.cls_token = cls_token
50 | self.sep_token = sep_token
51 |
52 | def convert_tokens_to_ids(self, tokens):
53 | return self.tokenizer.convert_tokens_to_ids(tokens)
54 |
55 |
56 | def truncate_sequences(
57 | self,
58 | ids,
59 | pair_ids,
60 | num_tokens_to_remove,
61 | truncation_strategy,
62 | stride,
63 | ):
64 |
65 | assert len(ids) > num_tokens_to_remove
66 | window_len = min(len(ids), stride + num_tokens_to_remove)
67 | overflowing_tokens = ids[-window_len:]
68 | ids = ids[:-num_tokens_to_remove]
69 |
70 | return (ids, pair_ids, overflowing_tokens)
71 |
72 | def encode_plus(self, text, text_pair, add_special_tokens, max_length, return_token_type_ids):
73 |
74 | # Tokenization
75 | token_ids_0 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
76 | len_ids = len(token_ids_0)
77 | if text_pair:
78 | token_ids_1 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text_pair))
79 | len_pair_ids = len(token_ids_1)
80 | else:
81 | token_ids_1 = None
82 | len_pair_ids = 0
83 |
84 |
85 | # Truncation
86 | assert add_special_tokens
87 | num_special_tokens_to_add = (2 if not text_pair else 3)
88 | total_len = len_ids + len_pair_ids + num_special_tokens_to_add
89 | if max_length and total_len > max_length:
90 | token_ids_0, token_ids_1, overflowing_tokens = self.truncate_sequences(
91 | token_ids_0,
92 | pair_ids=token_ids_1,
93 | num_tokens_to_remove=total_len - max_length,
94 | truncation_strategy='only_first', # TODO(nijkamp): is this the correct truncation strategy for all GLUE tasks?
95 | stride=0,
96 | )
97 |
98 |
99 | # Add special tokens
100 | cls = [self.tokenizer.vocab[self.cls_token]]
101 | sep = [self.tokenizer.vocab[self.sep_token]]
102 |
103 | if not text_pair:
104 |
105 | input_ids = cls + token_ids_0 + sep
106 | token_type_ids = len(cls + token_ids_0 + sep) * [0]
107 |
108 | else:
109 |
110 | input_ids = cls + token_ids_0 + sep + token_ids_1 + sep
111 | token_type_ids = len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
112 |
113 | assert len(input_ids) <= max_length
114 |
115 | return {"input_ids": input_ids, "token_type_ids": token_type_ids}
116 |
117 | def __len__(self):
118 | return len(self.tokenizer.vocab)
119 |
120 | def save_pretrained(self, outputdir):
121 | pass
122 |
123 | def wrap_tokenizer(tokenizer, pad_token):
124 | return TokenizerAdapter(tokenizer, pad_token)
125 |
126 |
127 | ##################################################
128 | # distilled Google-like/HF glue code
129 |
130 | def set_seed(args):
131 | random.seed(args.seed)
132 | np.random.seed(args.seed)
133 | torch.manual_seed(args.seed)
134 | if args.n_gpu > 0:
135 | torch.cuda.manual_seed_all(args.seed)
136 |
137 |
138 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
139 | """ Create a schedule with a learning rate that decreases linearly after
140 | linearly increasing during a warmup period.
141 | """
142 |
143 | def lr_lambda(current_step):
144 | if current_step < num_warmup_steps:
145 | return float(current_step) / float(max(1, num_warmup_steps))
146 | return max(
147 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
148 | )
149 |
150 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
151 |
152 |
153 | def train(args, train_dataset, model, tokenizer):
154 | """ Train the model """
155 |
156 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
157 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
158 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
159 |
160 | if args.max_steps > 0:
161 | t_total = args.max_steps
162 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
163 | else:
164 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
165 |
166 | # Prepare optimizer and schedule (linear warmup and decay)
167 | no_decay = ["bias", "LayerNorm.weight"]
168 | optimizer_grouped_parameters = [
169 | {
170 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
171 | "weight_decay": args.weight_decay,
172 | },
173 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
174 | ]
175 |
176 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
177 | scheduler = get_linear_schedule_with_warmup(
178 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
179 | )
180 |
181 | # Check if saved optimizer or scheduler states exist
182 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
183 | os.path.join(args.model_name_or_path, "scheduler.pt")
184 | ):
185 | # Load in optimizer and scheduler states
186 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
187 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
188 |
189 | if args.fp16:
190 | try:
191 | from apex import amp
192 | except ImportError:
193 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
194 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
195 |
196 | # multi-gpu training (should be after apex fp16 initialization)
197 | if args.n_gpu > 1:
198 | model = torch.nn.DataParallel(model)
199 |
200 | # Distributed training (should be after apex fp16 initialization)
201 | if args.local_rank != -1:
202 | model = torch.nn.parallel.DistributedDataParallel(
203 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
204 | )
205 |
206 | # Train!
207 | logger.info("***** Running training *****")
208 | logger.info(" Num examples = %d", len(train_dataset))
209 | logger.info(" Num Epochs = %d", args.num_train_epochs)
210 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
211 | logger.info(
212 | " Total train batch size (w. parallel, distributed & accumulation) = %d",
213 | args.train_batch_size
214 | * args.gradient_accumulation_steps
215 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
216 | )
217 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
218 | logger.info(" Total optimization steps = %d", t_total)
219 |
220 | global_step = 0
221 | epochs_trained = 0
222 | steps_trained_in_current_epoch = 0
223 | # Check if continuing training from a checkpoint
224 | if os.path.exists(args.model_name_or_path):
225 | # set global_step to global_step of last saved checkpoint from model path
226 | try:
227 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
228 | except ValueError:
229 | global_step = 0
230 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
231 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
232 |
233 | logger.info(" Continuing training from checkpoint, will skip to saved global_step")
234 | logger.info(" Continuing training from epoch %d", epochs_trained)
235 | logger.info(" Continuing training from global step %d", global_step)
236 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
237 |
238 | tr_loss, logging_loss = 0.0, 0.0
239 | model.zero_grad()
240 | train_iterator = trange(
241 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
242 | )
243 | set_seed(args) # Added here for reproductibility
244 | for _ in train_iterator:
245 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
246 | for step, batch in enumerate(epoch_iterator):
247 |
248 | # Skip past any already trained steps if resuming training
249 | if steps_trained_in_current_epoch > 0:
250 | steps_trained_in_current_epoch -= 1
251 | continue
252 |
253 | model.train()
254 | batch = tuple(t.to(args.device) for t in batch)
255 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
256 | inputs["token_type_ids"] = (
257 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
258 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
259 | outputs = model(**inputs)
260 | loss = outputs[0] # model outputs are always tuple in transformers (see doc)
261 |
262 | if args.n_gpu > 1:
263 | loss = loss.mean() # mean() to average on multi-gpu parallel training
264 | if args.gradient_accumulation_steps > 1:
265 | loss = loss / args.gradient_accumulation_steps
266 |
267 | if args.fp16:
268 | with amp.scale_loss(loss, optimizer) as scaled_loss:
269 | scaled_loss.backward()
270 | else:
271 | loss.backward()
272 |
273 | if step % 10 == 0:
274 | print(step, loss.item())
275 |
276 | tr_loss += loss.item()
277 | if (step + 1) % args.gradient_accumulation_steps == 0 or (
278 | # last step in epoch but step is always smaller than gradient_accumulation_steps
279 | len(epoch_iterator) <= args.gradient_accumulation_steps
280 | and (step + 1) == len(epoch_iterator)
281 | ):
282 | if args.fp16:
283 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
284 | else:
285 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
286 |
287 | optimizer.step()
288 | scheduler.step() # Update learning rate schedule
289 | model.zero_grad()
290 | global_step += 1
291 |
292 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
293 | logs = {}
294 | if (
295 | args.local_rank == -1 and args.evaluate_during_training
296 | ): # Only evaluate when single GPU otherwise metrics may not average well
297 | results = evaluate(args, model, tokenizer)
298 | for key, value in results.items():
299 | eval_key = "eval_{}".format(key)
300 | logs[eval_key] = value
301 |
302 | loss_scalar = (tr_loss - logging_loss) / args.logging_steps
303 | learning_rate_scalar = scheduler.get_lr()[0]
304 | logs["learning_rate"] = learning_rate_scalar
305 | logs["loss"] = loss_scalar
306 | logging_loss = tr_loss
307 |
308 | print(json.dumps({**logs, **{"step": global_step}}))
309 |
310 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
311 | # Save model checkpoint
312 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
313 | if not os.path.exists(output_dir):
314 | os.makedirs(output_dir)
315 | model_to_save = (
316 | model.module if hasattr(model, "module") else model
317 | ) # Take care of distributed/parallel training
318 | model_to_save.save_pretrained(output_dir)
319 | tokenizer.save_pretrained(output_dir)
320 |
321 | torch.save(args, os.path.join(output_dir, "training_args.bin"))
322 | logger.info("Saving model checkpoint to %s", output_dir)
323 |
324 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
325 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
326 | logger.info("Saving optimizer and scheduler states to %s", output_dir)
327 |
328 | if args.max_steps > 0 and global_step > args.max_steps:
329 | epoch_iterator.close()
330 | break
331 | if args.max_steps > 0 and global_step > args.max_steps:
332 | train_iterator.close()
333 | break
334 |
335 | return global_step, tr_loss / global_step
336 |
337 |
338 | def evaluate(args, model, tokenizer, prefix=""):
339 | # Loop to handle MNLI double evaluation (matched, mis-matched)
340 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
341 | eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
342 |
343 | results = {}
344 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
345 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
346 |
347 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
348 | os.makedirs(eval_output_dir)
349 |
350 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
351 | # Note that DistributedSampler samples randomly
352 | eval_sampler = SequentialSampler(eval_dataset)
353 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
354 |
355 | # multi-gpu eval
356 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
357 | model = torch.nn.DataParallel(model)
358 |
359 | # Eval!
360 | logger.info("***** Running evaluation {} *****".format(prefix))
361 | logger.info(" Num examples = %d", len(eval_dataset))
362 | logger.info(" Batch size = %d", args.eval_batch_size)
363 | eval_loss = 0.0
364 | nb_eval_steps = 0
365 | preds = None
366 | out_label_ids = None
367 | for batch in tqdm(eval_dataloader, desc="Evaluating"):
368 | model.eval()
369 | batch = tuple(t.to(args.device) for t in batch)
370 |
371 | with torch.no_grad():
372 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
373 | if args.model_type != "distilbert":
374 | inputs["token_type_ids"] = (
375 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
376 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
377 | outputs = model(**inputs)
378 | tmp_eval_loss, logits = outputs[:2]
379 |
380 | eval_loss += tmp_eval_loss.mean().item()
381 | nb_eval_steps += 1
382 | if preds is None:
383 | preds = logits.detach().cpu().numpy()
384 | out_label_ids = inputs["labels"].detach().cpu().numpy()
385 | else:
386 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
387 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
388 |
389 | eval_loss = eval_loss / nb_eval_steps
390 | if args.output_mode == "classification":
391 | preds = np.argmax(preds, axis=1)
392 | print(preds)
393 | elif args.output_mode == "regression":
394 | preds = np.squeeze(preds)
395 | result = compute_metrics(eval_task, preds, out_label_ids)
396 | results.update(result)
397 |
398 | output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
399 | with open(output_eval_file, "w") as writer:
400 | logger.info("***** Eval results {} *****".format(prefix))
401 | for key in sorted(result.keys()):
402 | logger.info(" %s = %s", key, str(result[key]))
403 | writer.write("%s = %s\n" % (key, str(result[key])))
404 |
405 | return results
406 |
407 |
408 | def load_and_cache_examples(args, task, tokenizer, evaluate=False):
409 | if args.local_rank not in [-1, 0] and not evaluate:
410 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
411 |
412 | processor = processors[task]()
413 | output_mode = output_modes[task]
414 | # Load data features from cache or dataset file
415 | cached_features_file = os.path.join(
416 | args.data_dir,
417 | "cached_{}_{}_{}_{}".format(
418 | "dev" if evaluate else "train",
419 | list(filter(None, args.model_name_or_path.split("/"))).pop(),
420 | str(args.max_seq_length),
421 | str(task),
422 | ),
423 | )
424 | if os.path.exists(cached_features_file) and not args.overwrite_cache:
425 | logger.info("Loading features from cached file %s", cached_features_file)
426 | features = torch.load(cached_features_file)
427 | else:
428 | logger.info("Creating features from dataset file at %s", args.data_dir)
429 | label_list = processor.get_labels()
430 | if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
431 | # HACK(label indices are swapped in RoBERTa pretrained model)
432 | label_list[1], label_list[2] = label_list[2], label_list[1]
433 | examples = (
434 | processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
435 | )
436 | features = convert_examples_to_features(
437 | examples,
438 | tokenizer,
439 | label_list=label_list,
440 | max_length=args.max_seq_length,
441 | output_mode=output_mode,
442 | pad_on_left=False, # pad on the left for xlnet
443 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
444 | pad_token_segment_id=0,
445 | )
446 | if args.local_rank in [-1, 0]:
447 | logger.info("Saving features into cached file %s", cached_features_file)
448 | torch.save(features, cached_features_file)
449 |
450 | if args.local_rank == 0 and not evaluate:
451 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
452 |
453 | # Convert to Tensors and build dataset
454 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
455 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
456 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
457 | if output_mode == "classification":
458 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
459 | elif output_mode == "regression":
460 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
461 |
462 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
463 | return dataset
464 |
465 |
466 |
467 |
468 | # python run_glue.py \
469 | # --model_name_or_path bert-base-uncased \
470 | # --task_name $TASK_NAME \
471 | # --do_train \
472 | # --do_eval \
473 | # --data_dir $GLUE_DIR/$TASK_NAME \
474 | # --max_seq_length 128 \
475 | # --per_gpu_train_batch_size 32 \
476 | # --learning_rate 2e-5 \
477 | # --num_train_epochs 3.0 \
478 | # --output_dir /tmp/$TASK_NAME \
479 | # --overwrite_output_dir \
480 | # --cache_dir cache_glue_bert
481 |
482 | def main(task='MRPC', seed=42, ckpt='output/pretrain/2020-08-28-02-41-37/ckpt/60000'):
483 | parser = argparse.ArgumentParser()
484 |
485 | # Required parameters
486 | parser.add_argument(
487 | "--data_dir",
488 | default=f'data/glue_data/{task}',
489 | type=str,
490 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
491 | )
492 | parser.add_argument(
493 | "--model_type",
494 | default="bert",
495 | type=str,
496 | )
497 | parser.add_argument(
498 | "--model_name_or_path",
499 | default=ckpt,
500 | type=str,
501 | )
502 | parser.add_argument(
503 | "--vocab_path",
504 | default='data/vocab.txt',
505 | type=str,
506 | )
507 | parser.add_argument(
508 | "--task_name",
509 | default=task,
510 | type=str,
511 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
512 | )
513 | parser.add_argument(
514 | "--output_dir",
515 | default='output/glue',
516 | type=str,
517 | help="The output directory where the model predictions and checkpoints will be written.",
518 | )
519 |
520 | # Other parameters
521 | parser.add_argument(
522 | "--cache_dir",
523 | default="",
524 | type=str,
525 | help="Where do you want to store the pre-trained models downloaded from s3",
526 | )
527 | parser.add_argument(
528 | "--max_seq_length",
529 | default=128,
530 | type=int,
531 | help="The maximum total input sequence length after tokenization. Sequences longer "
532 | "than this will be truncated, sequences shorter will be padded.",
533 | )
534 | parser.add_argument("--do_train", default=True, help="Whether to run training.")
535 | parser.add_argument("--do_eval", default=True, help="Whether to run eval on the dev set.")
536 | parser.add_argument(
537 | "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
538 | )
539 | parser.add_argument(
540 | "--do_lower_case", default=True, help="Set this flag if you are using an uncased model.",
541 | )
542 |
543 | parser.add_argument(
544 | "--per_gpu_train_batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.",
545 | )
546 | parser.add_argument(
547 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
548 | )
549 | parser.add_argument(
550 | "--gradient_accumulation_steps",
551 | type=int,
552 | default=1,
553 | help="Number of updates steps to accumulate before performing a backward/update pass.",
554 | )
555 | parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.")
556 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
557 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
558 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
559 | parser.add_argument(
560 | "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
561 | )
562 | parser.add_argument(
563 | "--max_steps",
564 | default=-1,
565 | type=int,
566 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
567 | )
568 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
569 |
570 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
571 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
572 | parser.add_argument(
573 | "--eval_all_checkpoints",
574 | action="store_true",
575 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
576 | )
577 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
578 | parser.add_argument(
579 | "--overwrite_output_dir", default=True, help="Overwrite the content of the output directory",
580 | )
581 | parser.add_argument(
582 | "--overwrite_cache", default=True, help="Overwrite the cached training and evaluation sets",
583 | )
584 | parser.add_argument("--seed", type=int, default=seed, help="random seed for initialization")
585 |
586 | parser.add_argument(
587 | "--fp16",
588 | action="store_true",
589 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
590 | )
591 | parser.add_argument(
592 | "--fp16_opt_level",
593 | type=str,
594 | default="O1",
595 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
596 | "See details at https://nvidia.github.io/apex/amp.html",
597 | )
598 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
599 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
600 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
601 | args = parser.parse_args()
602 |
603 | if (
604 | os.path.exists(args.output_dir)
605 | and os.listdir(args.output_dir)
606 | and args.do_train
607 | and not args.overwrite_output_dir
608 | ):
609 | raise ValueError(
610 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
611 | args.output_dir
612 | )
613 | )
614 |
615 | # Setup distant debugging if needed
616 | if args.server_ip and args.server_port:
617 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
618 | import ptvsd
619 |
620 | print("Waiting for debugger attach")
621 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
622 | ptvsd.wait_for_attach()
623 |
624 | # Setup CUDA, GPU & distributed training
625 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
626 | args.n_gpu = 1
627 | args.device = device
628 |
629 | # Setup logging
630 | logging.basicConfig(
631 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
632 | datefmt="%m/%d/%Y %H:%M:%S",
633 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
634 | )
635 | logger.warning(
636 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
637 | args.local_rank,
638 | device,
639 | args.n_gpu,
640 | bool(args.local_rank != -1),
641 | args.fp16,
642 | )
643 |
644 | # Set seed
645 | set_seed(args)
646 |
647 | # Prepare GLUE task
648 | args.task_name = args.task_name.lower()
649 | if args.task_name not in processors:
650 | raise ValueError("Task not found: %s" % (args.task_name))
651 | processor = processors[args.task_name]()
652 | args.output_mode = output_modes[args.task_name]
653 | label_list = processor.get_labels()
654 | num_labels = len(label_list)
655 |
656 | # Load pretrained model and tokenizer
657 | if args.local_rank not in [-1, 0]:
658 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
659 |
660 | from transformers import AutoConfig, AutoModelForSequenceClassification
661 | args.model_type = args.model_type.lower()
662 | config = AutoConfig.from_pretrained(
663 | args.model_name_or_path,
664 | num_labels=num_labels,
665 | finetuning_task=args.task_name,
666 | cache_dir=args.cache_dir if args.cache_dir else None,
667 | )
668 | model = AutoModelForSequenceClassification.from_pretrained(
669 | args.model_name_or_path,
670 | from_tf=bool(".ckpt" in args.model_name_or_path),
671 | config=config,
672 | cache_dir=args.cache_dir if args.cache_dir else None,
673 | )
674 |
675 |
676 | from pretraining.openwebtext.dataset import new_tokenizer
677 | tokenizer = wrap_tokenizer(new_tokenizer(args.vocab_path), pad_token='[PAD]')
678 |
679 | if args.local_rank == 0:
680 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
681 |
682 | model.to(args.device)
683 |
684 | logger.info("Training/evaluation parameters %s", args)
685 |
686 | # Training
687 | if args.do_train:
688 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
689 | global_step, tr_loss = train(args, train_dataset, model, tokenizer)
690 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
691 |
692 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
693 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
694 | # Create output directory if needed
695 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
696 | os.makedirs(args.output_dir)
697 |
698 | logger.info("Saving model checkpoint to %s", args.output_dir)
699 | # Save a trained model, configuration and tokenizer using `save_pretrained()`.
700 | # They can then be reloaded using `from_pretrained()`
701 | model_to_save = (
702 | model.module if hasattr(model, "module") else model
703 | ) # Take care of distributed/parallel training
704 | model_to_save.save_pretrained(args.output_dir)
705 | tokenizer.save_pretrained(args.output_dir)
706 |
707 | # Good practice: save your training arguments together with the trained model
708 | torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
709 |
710 | # Load a trained model and vocabulary that you have fine-tuned
711 | model = model_to_save
712 | # TODO(nijkamp): we ignore model serialization
713 | # model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
714 | # tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
715 | model.to(args.device)
716 |
717 | # Evaluation
718 | results = {}
719 | if args.do_eval and args.local_rank in [-1, 0]:
720 | # TODO(nijkamp): we ignore model serialization
721 | # tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
722 | checkpoints = [args.output_dir]
723 | if args.eval_all_checkpoints:
724 | checkpoints = list(
725 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
726 | )
727 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
728 | logger.info("Evaluate the following checkpoints: %s", checkpoints)
729 | for checkpoint in checkpoints:
730 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
731 | prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
732 |
733 | # TODO(nijkamp): we ignore model serialization
734 | # model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
735 | model.to(args.device)
736 | result = evaluate(args, model, tokenizer, prefix=prefix)
737 | result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
738 | results.update(result)
739 |
740 | return results
741 |
742 |
743 | if __name__ == "__main__":
744 | main()
--------------------------------------------------------------------------------
/examples/glue/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import copy
18 | import csv
19 | import dataclasses
20 | import json
21 | import logging
22 | from dataclasses import dataclass
23 | from typing import Optional
24 |
25 | # from ...file_utils import is_tf_available, is_torch_available
26 |
27 | is_torch_available = lambda: True
28 | is_tf_available = lambda: False
29 |
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | @dataclass(frozen=True)
35 | class InputExample:
36 | """
37 | A single training/test example for simple sequence classification.
38 |
39 | Args:
40 | guid: Unique id for the example.
41 | text_a: string. The untokenized text of the first sequence. For single
42 | sequence tasks, only this sequence must be specified.
43 | text_b: (Optional) string. The untokenized text of the second sequence.
44 | Only must be specified for sequence pair tasks.
45 | label: (Optional) string. The label of the example. This should be
46 | specified for train and dev examples, but not for test examples.
47 | """
48 |
49 | guid: str
50 | text_a: str
51 | text_b: Optional[str] = None
52 | label: Optional[str] = None
53 |
54 | def to_json_string(self):
55 | """Serializes this instance to a JSON string."""
56 | return json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
57 |
58 |
59 | class InputFeatures(object):
60 | """
61 | A single set of features of data.
62 |
63 | Args:
64 | input_ids: Indices of input sequence tokens in the vocabulary.
65 | attention_mask: Mask to avoid performing attention on padding token indices.
66 | Mask values selected in ``[0, 1]``:
67 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
68 | token_type_ids: Segment token indices to indicate first and second portions of the inputs.
69 | label: Label corresponding to the input
70 | """
71 |
72 | def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None):
73 | self.input_ids = input_ids
74 | self.attention_mask = attention_mask
75 | self.token_type_ids = token_type_ids
76 | self.label = label
77 |
78 | def __repr__(self):
79 | return str(self.to_json_string())
80 |
81 | def to_dict(self):
82 | """Serializes this instance to a Python dictionary."""
83 | output = copy.deepcopy(self.__dict__)
84 | return output
85 |
86 | def to_json_string(self):
87 | """Serializes this instance to a JSON string."""
88 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
89 |
90 |
91 | class DataProcessor(object):
92 | """Base class for data converters for sequence classification data sets."""
93 |
94 | def get_example_from_tensor_dict(self, tensor_dict):
95 | """Gets an example from a dict with tensorflow tensors
96 | Args:
97 | tensor_dict: Keys and values should match the corresponding Glue
98 | tensorflow_dataset examples.
99 | """
100 | raise NotImplementedError()
101 |
102 | def get_train_examples(self, data_dir):
103 | """Gets a collection of `InputExample`s for the train set."""
104 | raise NotImplementedError()
105 |
106 | def get_dev_examples(self, data_dir):
107 | """Gets a collection of `InputExample`s for the dev set."""
108 | raise NotImplementedError()
109 |
110 | def get_labels(self):
111 | """Gets the list of labels for this data set."""
112 | raise NotImplementedError()
113 |
114 | def tfds_map(self, example):
115 | """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
116 | This method converts examples to the correct format."""
117 | if len(self.get_labels()) > 1:
118 | example.label = self.get_labels()[int(example.label)]
119 | return example
120 |
121 | @classmethod
122 | def _read_tsv(cls, input_file, quotechar=None):
123 | """Reads a tab separated value file."""
124 | with open(input_file, "r", encoding="utf-8-sig") as f:
125 | return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
126 |
127 |
128 | class SingleSentenceClassificationProcessor(DataProcessor):
129 | """ Generic processor for a single sentence classification data set."""
130 |
131 | def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
132 | self.labels = [] if labels is None else labels
133 | self.examples = [] if examples is None else examples
134 | self.mode = mode
135 | self.verbose = verbose
136 |
137 | def __len__(self):
138 | return len(self.examples)
139 |
140 | def __getitem__(self, idx):
141 | if isinstance(idx, slice):
142 | return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
143 | return self.examples[idx]
144 |
145 | @classmethod
146 | def create_from_csv(
147 | cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
148 | ):
149 | processor = cls(**kwargs)
150 | processor.add_examples_from_csv(
151 | file_name,
152 | split_name=split_name,
153 | column_label=column_label,
154 | column_text=column_text,
155 | column_id=column_id,
156 | skip_first_row=skip_first_row,
157 | overwrite_labels=True,
158 | overwrite_examples=True,
159 | )
160 | return processor
161 |
162 | @classmethod
163 | def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
164 | processor = cls(**kwargs)
165 | processor.add_examples(texts_or_text_and_labels, labels=labels)
166 | return processor
167 |
168 | def add_examples_from_csv(
169 | self,
170 | file_name,
171 | split_name="",
172 | column_label=0,
173 | column_text=1,
174 | column_id=None,
175 | skip_first_row=False,
176 | overwrite_labels=False,
177 | overwrite_examples=False,
178 | ):
179 | lines = self._read_tsv(file_name)
180 | if skip_first_row:
181 | lines = lines[1:]
182 | texts = []
183 | labels = []
184 | ids = []
185 | for (i, line) in enumerate(lines):
186 | texts.append(line[column_text])
187 | labels.append(line[column_label])
188 | if column_id is not None:
189 | ids.append(line[column_id])
190 | else:
191 | guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
192 | ids.append(guid)
193 |
194 | return self.add_examples(
195 | texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
196 | )
197 |
198 | def add_examples(
199 | self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
200 | ):
201 | assert labels is None or len(texts_or_text_and_labels) == len(labels)
202 | assert ids is None or len(texts_or_text_and_labels) == len(ids)
203 | if ids is None:
204 | ids = [None] * len(texts_or_text_and_labels)
205 | if labels is None:
206 | labels = [None] * len(texts_or_text_and_labels)
207 | examples = []
208 | added_labels = set()
209 | for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids):
210 | if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
211 | text, label = text_or_text_and_label
212 | else:
213 | text = text_or_text_and_label
214 | added_labels.add(label)
215 | examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
216 |
217 | # Update examples
218 | if overwrite_examples:
219 | self.examples = examples
220 | else:
221 | self.examples.extend(examples)
222 |
223 | # Update labels
224 | if overwrite_labels:
225 | self.labels = list(added_labels)
226 | else:
227 | self.labels = list(set(self.labels).union(added_labels))
228 |
229 | return self.examples
230 |
231 | def get_features(
232 | self,
233 | tokenizer,
234 | max_length=None,
235 | pad_on_left=False,
236 | pad_token=0,
237 | mask_padding_with_zero=True,
238 | return_tensors=None,
239 | ):
240 | """
241 | Convert examples in a list of ``InputFeatures``
242 |
243 | Args:
244 | tokenizer: Instance of a tokenizer that will tokenize the examples
245 | max_length: Maximum example length
246 | task: GLUE task
247 | label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
248 | output_mode: String indicating the output mode. Either ``regression`` or ``classification``
249 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
250 | pad_token: Padding token
251 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
252 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
253 | actual values)
254 |
255 | Returns:
256 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
257 | containing the task-specific features. If the input is a list of ``InputExamples``, will return
258 | a list of task-specific ``InputFeatures`` which can be fed to the model.
259 |
260 | """
261 | if max_length is None:
262 | max_length = tokenizer.max_len
263 |
264 | label_map = {label: i for i, label in enumerate(self.labels)}
265 |
266 | all_input_ids = []
267 | for (ex_index, example) in enumerate(self.examples):
268 | if ex_index % 10000 == 0:
269 | logger.info("Tokenizing example %d", ex_index)
270 |
271 | input_ids = tokenizer.encode(
272 | example.text_a, add_special_tokens=True, max_length=min(max_length, tokenizer.max_len),
273 | )
274 | all_input_ids.append(input_ids)
275 |
276 | batch_length = max(len(input_ids) for input_ids in all_input_ids)
277 |
278 | features = []
279 | for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)):
280 | if ex_index % 10000 == 0:
281 | logger.info("Writing example %d/%d" % (ex_index, len(self.examples)))
282 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
283 | # tokens are attended to.
284 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
285 |
286 | # Zero-pad up to the sequence length.
287 | padding_length = batch_length - len(input_ids)
288 | if pad_on_left:
289 | input_ids = ([pad_token] * padding_length) + input_ids
290 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
291 | else:
292 | input_ids = input_ids + ([pad_token] * padding_length)
293 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
294 |
295 | assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(
296 | len(input_ids), batch_length
297 | )
298 | assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(
299 | len(attention_mask), batch_length
300 | )
301 |
302 | if self.mode == "classification":
303 | label = label_map[example.label]
304 | elif self.mode == "regression":
305 | label = float(example.label)
306 | else:
307 | raise ValueError(self.mode)
308 |
309 | if ex_index < 5 and self.verbose:
310 | logger.info("*** Example ***")
311 | logger.info("guid: %s" % (example.guid))
312 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
313 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
314 | logger.info("label: %s (id = %d)" % (example.label, label))
315 |
316 | features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
317 |
318 | if return_tensors is None:
319 | return features
320 | elif return_tensors == "tf":
321 | if not is_tf_available():
322 | raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
323 | import tensorflow as tf
324 |
325 | def gen():
326 | for ex in features:
327 | yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
328 |
329 | dataset = tf.data.Dataset.from_generator(
330 | gen,
331 | ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
332 | ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
333 | )
334 | return dataset
335 | elif return_tensors == "pt":
336 | if not is_torch_available():
337 | raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
338 | import torch
339 | from torch.utils.data import TensorDataset
340 |
341 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
342 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
343 | if self.mode == "classification":
344 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
345 | elif self.mode == "regression":
346 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
347 |
348 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
349 | return dataset
350 | else:
351 | raise ValueError("return_tensors should be one of 'tf' or 'pt'")
--------------------------------------------------------------------------------
/pretraining/openwebtext/arg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dataclasses
3 |
4 | __all__ = ('Arg', 'Int', 'Float', 'Bool', 'Str', 'Choice', 'parse_to')
5 |
6 | class Arg:
7 | def __init__(self, **kwargs):
8 | super().__init__()
9 | self.kwargs = kwargs
10 |
11 |
12 | class Int(Arg):
13 | def __init__(self, **kwargs):
14 | super().__init__(type=int, **kwargs)
15 |
16 |
17 | class Float(Arg):
18 | def __init__(self, **kwargs):
19 | super().__init__(type=float, **kwargs)
20 |
21 |
22 | class Bool(Arg):
23 | def __init__(self, **kwargs):
24 | super().__init__(type=bool, **kwargs)
25 |
26 |
27 | class Str(Arg):
28 | def __init__(self, **kwargs):
29 | super().__init__(type=str, **kwargs)
30 |
31 |
32 | class _MetaChoice(type):
33 | def __getitem__(self, item):
34 | return self(choices=list(item), type=item)
35 |
36 |
37 | class Choice(Arg, metaclass=_MetaChoice):
38 | def __init__(self, choices, **kwargs):
39 | super().__init__(choices=choices, **kwargs)
40 |
41 |
42 | def parse_to(container_class, **kwargs):
43 | def mangle_name(name):
44 | return '--' + name.replace('_', '-')
45 |
46 | parser = argparse.ArgumentParser(description=container_class.__doc__)
47 | for field in dataclasses.fields(container_class):
48 | name = field.name
49 | default = field.default
50 | value_or_class = field.type
51 | if isinstance(value_or_class, type):
52 | value = value_or_class(default=default)
53 | else:
54 | value = value_or_class
55 | value.kwargs['default'] = default
56 | parser.add_argument(
57 | mangle_name(name), **value.kwargs)
58 |
59 | arg_dict = parser.parse_args(**kwargs)
60 | return container_class(**vars(arg_dict))
--------------------------------------------------------------------------------
/pretraining/openwebtext/dataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | from dataclasses import dataclass
5 | from itertools import chain
6 | from functools import partial
7 | from pathlib import Path
8 |
9 | import numpy as np
10 |
11 | import torch
12 | import torch.utils.data
13 |
14 | from openwebtext import tokenization
15 |
16 |
17 | class ExampleBuilder:
18 | """Given a stream of input text, creates pretraining examples."""
19 |
20 | def __init__(self, vocab, max_length):
21 | self._vocab = vocab
22 | self._current_sentences = []
23 | self._current_length = 0
24 | self._max_length = max_length
25 | self._target_length = max_length
26 |
27 | def add_line(self, bert_tokids):
28 | """Adds a line of text to the current example being built."""
29 | # line = line.strip().replace("\n", " ")
30 | # if (not line) and self._current_length != 0: # empty lines separate docs
31 | # return self._create_example()
32 | # bert_tokens = self._tokenizer.tokenize(line)
33 | # bert_tokids = self._tokenizer.convert_tokens_to_ids(bert_tokens)
34 | self._current_sentences.append(bert_tokids)
35 | self._current_length += len(bert_tokids)
36 | if self._current_length >= self._target_length:
37 | return self._create_example()
38 | return None
39 |
40 | def _create_example(self):
41 | """Creates a pre-training example from the current list of sentences."""
42 | # small chance to only have one segment as in classification tasks
43 | if random.random() < 0.1:
44 | first_segment_target_length = 100000
45 | else:
46 | # -3 due to not yet having [CLS]/[SEP] tokens in the input text
47 | first_segment_target_length = (self._target_length - 3) // 2
48 |
49 | first_segment = []
50 | second_segment = []
51 | for sentence in self._current_sentences:
52 | # the sentence goes to the first segment if (1) the first segment is
53 | # empty, (2) the sentence doesn't put the first segment over length or
54 | # (3) 50% of the time when it does put the first segment over length
55 | if (len(first_segment) == 0 or
56 | len(first_segment) + len(sentence) < first_segment_target_length or
57 | (len(second_segment) == 0 and
58 | len(first_segment) < first_segment_target_length and
59 | random.random() < 0.5)):
60 | first_segment += sentence
61 | else:
62 | second_segment += sentence
63 |
64 | # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens
65 | first_segment = first_segment[:self._max_length - 2]
66 | second_segment = second_segment[:max(0, self._max_length - len(first_segment) - 3)]
67 |
68 | # prepare to start building the next example
69 | self._current_sentences = []
70 | self._current_length = 0
71 | # small chance for random-length instead of max_length-length example
72 | if random.random() < 0.05:
73 | self._target_length = random.randint(5, self._max_length)
74 | else:
75 | self._target_length = self._max_length
76 |
77 | return self._make_tf_example(first_segment, second_segment)
78 |
79 | def _make_tf_example(self, first_segment, second_segment):
80 | """Converts two "segments" of text into a tf.train.Example."""
81 | vocab = self._vocab
82 | input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]]
83 | segment_ids = [0] * len(input_ids)
84 | if second_segment:
85 | input_ids += second_segment + [vocab["[SEP]"]]
86 | segment_ids += [1] * (len(second_segment) + 1)
87 | input_mask = [1] * len(input_ids)
88 | input_ids += [0] * (self._max_length - len(input_ids))
89 | input_mask += [0] * (self._max_length - len(input_mask))
90 | segment_ids += [0] * (self._max_length - len(segment_ids))
91 |
92 | def create_int_feature(tensors):
93 | return torch.tensor(tensors)
94 |
95 | tf_example = {
96 | "input_ids": create_int_feature(input_ids),
97 | "input_mask": create_int_feature(input_mask),
98 | "segment_ids": create_int_feature(segment_ids)
99 | }
100 | return tf_example
101 |
102 |
103 | class OpenWebTextDataset(torch.utils.data.IterableDataset):
104 | def __init__(self, feature_set_paths, n_tensors_per_file):
105 | self.feature_set_paths = feature_set_paths
106 | self.n_tensors_per_file = n_tensors_per_file
107 |
108 | @staticmethod
109 | def parse_file(file_index):
110 | try:
111 | features = torch.load(str(file_index))
112 | yield from features
113 | except RuntimeError:
114 | raise RuntimeError(f'Corrupted file {file_index}')
115 |
116 | def __len__(self):
117 | return len(self.feature_set_paths) * self.n_tensors_per_file
118 |
119 | def __iter__(self):
120 | return chain.from_iterable(map(self.parse_file, self.feature_set_paths))
121 |
122 |
123 | class ExampleBuilderDataset(torch.utils.data.IterableDataset):
124 | def __init__(self, dataset, builder):
125 | self.dataset = dataset
126 | self.builder = builder
127 |
128 | def __len__(self):
129 | return len(self.dataset)
130 |
131 | def __iter__(self):
132 | def create_example():
133 | while True:
134 | token_ids = list(next(self.dataset).cpu().numpy())
135 | example = self.builder.add_line(token_ids)
136 | if example:
137 | return example
138 |
139 | while True:
140 | yield create_example()
141 |
142 |
143 |
144 | def cycle(iterable):
145 | while True:
146 | for x in iterable:
147 | yield x
148 |
149 |
150 | def new_tokenizer(vocab_file, do_lower_case=True):
151 | return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
152 |
153 |
154 | def parse_tokenizer(tokenizer, text):
155 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
156 |
157 |
158 | def create_tokenizer(vocab_file, do_lower_case=True):
159 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
160 | return partial(parse_tokenizer, tokenizer)
161 |
162 |
163 | def load_owt(owt_dir, n_tensors_per_file):
164 | owt_dir_path = Path(owt_dir)
165 | feature_set_paths = [owt_dir_path / feature_set_path for feature_set_path in os.listdir(owt_dir_path)]
166 | np.random.shuffle(feature_set_paths)
167 | assert len(feature_set_paths) > 0
168 | return OpenWebTextDataset(feature_set_paths, n_tensors_per_file=n_tensors_per_file)
169 |
170 |
171 | def wrap_example_builder(dataset, vocab, max_length):
172 | return ExampleBuilderDataset(cycle(iter(dataset)), ExampleBuilder(vocab, max_length))
173 |
--------------------------------------------------------------------------------
/pretraining/openwebtext/preprocess.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging
3 | import math
4 | import multiprocessing
5 | import os
6 | import random
7 | import tarfile
8 | from dataclasses import dataclass
9 | from itertools import chain
10 | from functools import partial
11 | from pathlib import Path
12 |
13 | import numpy as np
14 |
15 | import torch
16 | import torch.utils.data
17 |
18 | from pretraining.openwebtext import arg
19 | from pretraining.openwebtext import tokenization
20 |
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def parse_tokenizer(tokenizer, text):
26 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
27 |
28 |
29 | def create_tokenizer(vocab_file, do_lower_case=True):
30 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
31 | return partial(parse_tokenizer, tokenizer)
32 |
33 |
34 | def preprocess_owt(tokenizer, src_dir, tmp_dir, trg_dir, n_dataset_building_processes, n_tensors_per_file, max_seq_length=None):
35 | # Preamble
36 | logger.info(f'Writing features to {trg_dir}.')
37 | os.makedirs(trg_dir, exist_ok=False)
38 |
39 | # Crunch files
40 | trg_dir = Path(trg_dir)
41 | src_dir = Path(src_dir)
42 | tmp_dir = Path(tmp_dir)
43 | archives = os.listdir(src_dir)
44 | n_archives_per_job = math.ceil(len(archives) / n_dataset_building_processes)
45 | job_archives = [
46 | archives[i * n_archives_per_job : (i + 1) * n_archives_per_job]
47 | for i in range(n_dataset_building_processes)
48 | ]
49 |
50 | logger.info(f'Processing {len(archives)} archives.')
51 | assert len(archives) > 0
52 |
53 | if n_dataset_building_processes == 1:
54 | feature_set_paths = preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0)
55 | else:
56 | pool = multiprocessing.Pool(processes=n_dataset_building_processes)
57 | preprocess_owt_job_partial = partial(preprocess_owt_job, tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length)
58 | feature_sets = pool.map(preprocess_owt_job_partial, range(n_dataset_building_processes))
59 | feature_set_paths = [file_path for feature_set in feature_sets for file_path in feature_set]
60 |
61 | return feature_set_paths
62 |
63 |
64 | def preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0):
65 | '''
66 | OpenWebText is saved under the following format:
67 | openwebtext.zip
68 | |-> archive_xxx.zip
69 | |-> file_xxx.txt
70 | |-> file_xxz.txt
71 | ...
72 | |-> archive_xxz.zip
73 | |-> file_xxy.txt
74 | ...
75 | ...
76 | '''
77 |
78 | # Preamble
79 | os.makedirs(tmp_dir, exist_ok=True)
80 |
81 | # Process
82 | feature_index = 0
83 | feature_set_paths = []
84 | features = []
85 | for archive_id, archive in enumerate(job_archives[job_id]):
86 | if os.path.isdir(src_dir / archive):
87 | logger.info(f'Ignoring rogue directory {src_dir / archive}.')
88 | continue
89 |
90 | logger.info(f'Job {job_id}: Processing {archive_id}/{len(job_archives[job_id])} {src_dir / archive}.')
91 |
92 | with tarfile.open(src_dir / archive) as t:
93 | extracted_archive = tmp_dir / f'{archive}-extracted'
94 | t.extractall(extracted_archive)
95 |
96 | for file in os.listdir(extracted_archive):
97 | file_path = extracted_archive / file
98 |
99 | with open(file_path, 'r') as f:
100 | for line in f.readlines():
101 | line = line.strip()
102 | if len(line) > 2:
103 | encoding = tokenizer(line)
104 | features.append(torch.tensor(encoding))
105 |
106 | while len(features) > n_tensors_per_file:
107 | feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt'
108 | torch.save(features[:n_tensors_per_file], feature_set_path)
109 | features = features[n_tensors_per_file:]
110 | feature_index += 1
111 | feature_set_paths.append(feature_set_path)
112 |
113 | # Serialize
114 | if len(features) > 0:
115 | feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt'
116 | torch.save(features, feature_set_path)
117 | feature_set_paths.append(feature_set_path)
118 |
119 | return feature_set_paths
120 |
121 |
122 | @dataclass(frozen=True)
123 | class Args:
124 | src_dir: arg.Str = 'data/openwebtext'
125 | trg_dir: arg.Str = 'data/openwebtext_features'
126 | tmp_dir: arg.Str = '/tmp/owt'
127 | vocab_file: arg.Str = 'data/vocab.txt'
128 | n_dataset_building_processes: arg.Int = 32
129 | n_tensors_per_file: arg.Int = 2048
130 | max_seq_length: arg.Int = 128
131 |
132 |
133 | def main():
134 | args = arg.parse_to(Args)
135 |
136 | logging.basicConfig(
137 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
138 | datefmt='%m/%d/%Y %H:%M:%S',
139 | level=logging.INFO
140 | )
141 |
142 | tokenizer = create_tokenizer(args.vocab_file)
143 | preprocess_owt(tokenizer=tokenizer, src_dir=args.src_dir, tmp_dir=args.tmp_dir, trg_dir=args.trg_dir, n_dataset_building_processes=args.n_dataset_building_processes, n_tensors_per_file=args.n_tensors_per_file, max_seq_length=args.max_seq_length)
144 |
145 |
146 | if __name__ == '__main__':
147 | main()
--------------------------------------------------------------------------------
/pretraining/openwebtext/pretrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | dir_path = os.path.dirname(os.path.realpath(__file__))
5 | parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir))
6 | sys.path.insert(0, parent_dir_path)
7 |
8 | import random
9 | import logging
10 | from time import time
11 | from dataclasses import dataclass
12 |
13 | import numpy as np
14 |
15 | import torch
16 | from torch.optim.lr_scheduler import LambdaLR
17 | from torch.utils.data.dataloader import DataLoader
18 |
19 | from electra_pytorch import Electra
20 |
21 | from openwebtext import arg
22 | from openwebtext.dataset import load_owt, new_tokenizer, wrap_example_builder
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | ########################################################################################################
27 | ## args
28 |
29 | @dataclass
30 | class Args:
31 | data_dir: arg.Str = 'data/openwebtext_features'
32 | data_vocab_file: arg.Str = 'data/vocab.txt'
33 | data_n_tensors_per_file: arg.Int = 2048
34 | data_max_seq_length: arg.Int = 128
35 |
36 | gpu: arg.Int = 0
37 | gpu_enabled: arg.Bool = True
38 | gpu_deterministic: arg.Bool = False
39 | gpu_mixed_precision: arg.Bool = False
40 | distributed_port: arg.Int = 8888
41 | distributed_enabled: arg.Bool = True
42 | distributed_world_size: arg.Int = 4
43 |
44 | model_generator: arg.Str = 'pretraining/openwebtext/small_generator.json'
45 | model_discriminator: arg.Str = 'pretraining/openwebtext/small_discriminator.json'
46 | model_mask_prob: arg.Float = 0.15
47 |
48 | opt_lr: arg.Float = 5e-4
49 | opt_batch_size: arg.Int = 128 // (distributed_world_size if distributed_enabled else 1)
50 | opt_warmup_steps: arg.Int = 10_000
51 | opt_num_training_steps: arg.Int = 200_000
52 |
53 | step_log: arg.Int = 10
54 | step_ckpt: arg.Int = 10_000
55 |
56 |
57 | ########################################################################################################
58 | ## train
59 |
60 | def train(rank, args):
61 |
62 | #######################
63 | ## distributed
64 |
65 | if args.distributed_enabled:
66 | torch.distributed.init_process_group(
67 | backend='nccl',
68 | init_method='env://',
69 | world_size=args.distributed_world_size,
70 | rank=rank)
71 | if args.gpu_enabled:
72 | device = torch.device('cuda:{}'.format(rank))
73 | else:
74 | device = torch.device('cpu')
75 |
76 | is_master = True if not args.distributed_enabled else args.distributed_enabled and rank == 0
77 |
78 |
79 | #######################
80 | ## preamble
81 |
82 | set_gpus(rank)
83 | set_seed(rank)
84 | set_cuda(deterministic=args.gpu_deterministic)
85 |
86 | output_dir = f'{args.output_dir}/{rank}'
87 | os.makedirs(output_dir, exist_ok=False)
88 |
89 | setup_logging(filename=f'{output_dir}/output.log', console=is_master)
90 |
91 |
92 | #######################
93 | ## dataset
94 |
95 | tokenizer = new_tokenizer(vocab_file=args.data_vocab_file)
96 | vocab_size = len(tokenizer.vocab)
97 | ds_train = wrap_example_builder(dataset=load_owt(owt_dir=args.data_dir, n_tensors_per_file=args.data_n_tensors_per_file), vocab=tokenizer.vocab, max_length=args.data_max_seq_length)
98 |
99 | pad_token_id = tokenizer.vocab['[PAD]']
100 | mask_token_id = tokenizer.vocab['[MASK]']
101 | cls_token_id = tokenizer.vocab['[CLS]']
102 | sep_token_id = tokenizer.vocab['[SEP]']
103 |
104 | assert pad_token_id == 0
105 | assert cls_token_id == 101
106 | assert sep_token_id == 102
107 | assert mask_token_id == 103
108 |
109 | def collate_batch(examples):
110 | input_ids = torch.nn.utils.rnn.pad_sequence([example['input_ids'] for example in examples], batch_first=True, padding_value=pad_token_id)
111 | input_mask = torch.nn.utils.rnn.pad_sequence([example['input_mask'] for example in examples], batch_first=True, padding_value=pad_token_id)
112 | segment_ids = torch.nn.utils.rnn.pad_sequence([example['segment_ids'] for example in examples], batch_first=True, padding_value=pad_token_id)
113 | return input_ids, input_mask, segment_ids
114 |
115 | def cycle(iterable):
116 | while True:
117 | for x in iterable:
118 | yield x
119 |
120 | ds_train_loader = iter(cycle(DataLoader(ds_train, batch_size=args.opt_batch_size, collate_fn=collate_batch)))
121 |
122 |
123 | #######################
124 | ## model
125 |
126 | def to_distributed_model(model):
127 | return model if not args.distributed_enabled else torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)
128 |
129 | def tie_weights(generator, discriminator):
130 | generator.electra.embeddings.word_embeddings = discriminator.electra.embeddings.word_embeddings
131 | generator.electra.embeddings.position_embeddings = discriminator.electra.embeddings.position_embeddings
132 | generator.electra.embeddings.token_type_embeddings = discriminator.electra.embeddings.token_type_embeddings
133 |
134 | class LogitsAdapter(torch.nn.Module):
135 | def __init__(self, adaptee):
136 | super().__init__()
137 | self.adaptee = adaptee
138 |
139 | def forward(self, *args, **kwargs):
140 | return self.adaptee(*args, **kwargs)[0]
141 |
142 | from transformers import AutoConfig, ElectraForMaskedLM, ElectraForPreTraining
143 |
144 | generator = ElectraForMaskedLM(AutoConfig.from_pretrained(args.model_generator))
145 | discriminator = ElectraForPreTraining(AutoConfig.from_pretrained(args.model_discriminator))
146 |
147 | tie_weights(generator, discriminator)
148 |
149 | model = to_distributed_model(Electra(
150 | LogitsAdapter(generator),
151 | LogitsAdapter(discriminator),
152 | num_tokens = vocab_size,
153 | mask_token_id = mask_token_id,
154 | pad_token_id = pad_token_id,
155 | mask_prob = args.model_mask_prob,
156 | mask_ignore_token_ids = [tokenizer.vocab['[CLS]'], tokenizer.vocab['[SEP]']],
157 | random_token_prob = 0.0).to(device))
158 |
159 |
160 | #######################
161 | ## optimizer
162 |
163 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
164 | def lr_lambda(current_step):
165 | learning_rate = max(0.0, 1. - (float(current_step) / float(num_training_steps)))
166 | learning_rate *= min(1.0, float(current_step) / float(num_warmup_steps))
167 | return learning_rate
168 | return LambdaLR(optimizer, lr_lambda, last_epoch)
169 |
170 | def get_params_without_weight_decay_ln(named_params, weight_decay):
171 | no_decay = ['bias', 'LayerNorm.weight']
172 | optimizer_grouped_parameters = [
173 | {
174 | 'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)],
175 | 'weight_decay': weight_decay,
176 | },
177 | {
178 | 'params': [p for n, p in named_params if any(nd in n for nd in no_decay)],
179 | 'weight_decay': 0.0,
180 | },
181 | ]
182 | return optimizer_grouped_parameters
183 |
184 | optimizer = torch.optim.AdamW(get_params_without_weight_decay_ln(model.named_parameters(), weight_decay=0.1), lr=args.opt_lr, betas=(0.9, 0.999), eps=1e-08)
185 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.opt_warmup_steps, num_training_steps=args.opt_num_training_steps)
186 | scaler = torch.cuda.amp.GradScaler(enabled=args.gpu_mixed_precision)
187 |
188 |
189 | #######################
190 | ## train
191 |
192 | t, steps_s, eta_m = time(), 0., 0
193 |
194 | for step in range(args.opt_num_training_steps+1):
195 | input_ids, input_mask, segment_ids = next(ds_train_loader)
196 |
197 | input_ids = input_ids.to(device)
198 | input_mask = input_mask.to(device)
199 | segment_ids = segment_ids.to(device)
200 |
201 | assert input_ids.shape[1] <= args.data_max_seq_length
202 |
203 | optimizer.zero_grad()
204 |
205 | with torch.cuda.amp.autocast(enabled=args.gpu_mixed_precision):
206 | loss, loss_mlm, loss_disc, acc_gen, acc_disc, disc_labels, disc_pred = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids)
207 |
208 | scaler.scale(loss).backward()
209 | scaler.unscale_(optimizer)
210 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
211 | scaler.step(optimizer)
212 | scaler.update()
213 | scheduler.step()
214 |
215 | metrics = {
216 | 'step': (step, '{:8d}'),
217 | 'loss': (loss.item(), '{:8.5f}'),
218 | 'loss_mlm': (loss_mlm.item(), '{:8.5f}'),
219 | 'loss_disc': (loss_disc.item(), '{:8.5f}'),
220 | 'acc_gen': (acc_gen.item(), '{:5.3f}'),
221 | 'acc_disc': (acc_disc.item(), '{:5.3f}'),
222 | 'lr': (scheduler.get_last_lr()[0], '{:8.7f}'),
223 | 'steps': (steps_s, '{:4.1f}/s'),
224 | 'eta': (eta_m, '{:4d}m'),
225 | }
226 |
227 | if step % args.step_log == 0:
228 | sep = ' ' * 2
229 | logger.info(sep.join([f'{k}: {v[1].format(v[0])}' for (k, v) in metrics.items()]))
230 |
231 | if step > 0 and step % 100 == 0:
232 | t2 = time()
233 | steps_s = 100. / (t2 - t)
234 | eta_m = int(((args.opt_num_training_steps - step) / steps_s) // 60)
235 | t = t2
236 |
237 | if step % 200 == 0:
238 | logger.info(np.array2string(disc_labels[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize))
239 | logger.info(np.array2string(disc_pred[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize))
240 |
241 | if step > 0 and step % args.step_ckpt == 0 and is_master:
242 | discriminator.electra.save_pretrained(f'{args.output_dir}/ckpt/{step}')
243 |
244 | ########################################################################################################
245 | ## preamble
246 |
247 | def set_gpus(gpu):
248 | torch.cuda.set_device(gpu)
249 |
250 |
251 | def set_seed(seed):
252 | os.environ['PYTHONHASHSEED'] = str(seed)
253 | random.seed(seed)
254 | np.random.seed(seed)
255 | torch.manual_seed(seed)
256 | if torch.cuda.is_available():
257 | torch.cuda.manual_seed(seed)
258 | torch.cuda.manual_seed_all(seed)
259 |
260 |
261 | def set_cuda(deterministic=True):
262 | if torch.cuda.is_available():
263 | torch.backends.cudnn.deterministic = deterministic
264 | torch.backends.cudnn.benchmark = not deterministic
265 |
266 |
267 | def get_exp_id(file):
268 | return os.path.splitext(os.path.basename(file))[0]
269 |
270 |
271 | def get_output_dir(exp_id):
272 | import datetime
273 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
274 | output_dir = os.path.join('output/' + exp_id, t)
275 | os.makedirs(output_dir, exist_ok=True)
276 | return output_dir
277 |
278 |
279 | def setup_logging(filename, console=True):
280 | log_format = logging.Formatter("%(asctime)s : %(message)s")
281 | logger = logging.getLogger()
282 | logger.handlers = []
283 | file_handler = logging.FileHandler(filename)
284 | file_handler.setFormatter(log_format)
285 | logger.addHandler(file_handler)
286 | if console:
287 | console_handler = logging.StreamHandler(sys.stdout)
288 | console_handler.setFormatter(log_format)
289 | logger.addHandler(console_handler)
290 | logger.setLevel(logging.INFO)
291 | return logger
292 |
293 |
294 | def copy_source(file, output_dir):
295 | import shutil
296 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
297 |
298 |
299 | ########################################################################################################
300 | ## main
301 |
302 | def main():
303 |
304 | # preamble
305 | exp_id = get_exp_id(__file__)
306 | output_dir = get_output_dir(exp_id)
307 | os.makedirs(output_dir, exist_ok=True)
308 | os.makedirs(f'{output_dir}/ckpt', exist_ok=False)
309 | copy_source(__file__, output_dir)
310 |
311 | # args
312 | args = arg.parse_to(Args)
313 | args.output_dir = output_dir
314 | args.exp_id = exp_id
315 |
316 | # distributed
317 | if args.distributed_enabled:
318 | os.environ['MASTER_ADDR'] = 'localhost'
319 | os.environ['MASTER_PORT'] = str(args.distributed_port)
320 | torch.multiprocessing.spawn(train, nprocs=args.distributed_world_size, args=(args,))
321 | else:
322 | train(rank=args.gpu, args=args)
323 |
324 |
325 | if __name__ == '__main__':
326 | main()
327 |
--------------------------------------------------------------------------------
/pretraining/openwebtext/small_discriminator.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "ElectraForPreTraining"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "embedding_size": 128,
7 | "hidden_act": "gelu",
8 | "hidden_dropout_prob": 0.1,
9 | "hidden_size": 256,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 1024,
12 | "layer_norm_eps": 1e-12,
13 | "max_position_embeddings": 512,
14 | "model_type": "electra",
15 | "num_attention_heads": 4,
16 | "num_hidden_layers": 12,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522
19 | }
--------------------------------------------------------------------------------
/pretraining/openwebtext/small_generator.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "ElectraForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "embedding_size": 128,
7 | "hidden_act": "gelu",
8 | "hidden_dropout_prob": 0.1,
9 | "hidden_size": 64,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 256,
12 | "layer_norm_eps": 1e-12,
13 | "max_position_embeddings": 512,
14 | "model_type": "electra",
15 | "num_attention_heads": 1,
16 | "num_hidden_layers": 12,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522
19 | }
--------------------------------------------------------------------------------
/pretraining/openwebtext/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tokenization classes, the same as used for BERT."""
17 |
18 |
19 | import collections
20 | import unicodedata
21 |
22 |
23 | def convert_to_unicode(text):
24 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
25 | if isinstance(text, str):
26 | return text
27 | elif isinstance(text, bytes):
28 | return text.decode("utf-8", "ignore")
29 | else:
30 | raise ValueError("Unsupported string type: %s" % (type(text)))
31 |
32 |
33 |
34 | def printable_text(text):
35 | """Returns text encoded in a way suitable for print."""
36 |
37 | # These functions want `str` for both Python2 and Python3, but in one case
38 | # it's a Unicode string and in the other it's a byte string.
39 | if isinstance(text, str):
40 | return text
41 | elif isinstance(text, bytes):
42 | return text.decode("utf-8", "ignore")
43 | else:
44 | raise ValueError("Unsupported string type: %s" % (type(text)))
45 |
46 |
47 |
48 | def load_vocab(vocab_file):
49 | """Loads a vocabulary file into a dictionary."""
50 | vocab = collections.OrderedDict()
51 | index = 0
52 | with open(vocab_file, "r") as reader:
53 | while True:
54 | token = convert_to_unicode(reader.readline())
55 | if not token:
56 | break
57 | token = token.strip()
58 | vocab[token] = index
59 | index += 1
60 | return vocab
61 |
62 |
63 | def convert_by_vocab(vocab, items):
64 | """Converts a sequence of [tokens|ids] using the vocab."""
65 | output = []
66 | for item in items:
67 | output.append(vocab[item])
68 | return output
69 |
70 |
71 | def convert_tokens_to_ids(vocab, tokens):
72 | return convert_by_vocab(vocab, tokens)
73 |
74 |
75 | def convert_ids_to_tokens(inv_vocab, ids):
76 | return convert_by_vocab(inv_vocab, ids)
77 |
78 |
79 | def whitespace_tokenize(text):
80 | """Runs basic whitespace cleaning and splitting on a piece of text."""
81 | text = text.strip()
82 | if not text:
83 | return []
84 | tokens = text.split()
85 | return tokens
86 |
87 |
88 | class FullTokenizer(object):
89 | """Runs end-to-end tokenziation."""
90 |
91 | def __init__(self, vocab_file, do_lower_case=True):
92 | self.vocab = load_vocab(vocab_file)
93 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
94 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
95 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
96 |
97 | def tokenize(self, text):
98 | split_tokens = []
99 | for token in self.basic_tokenizer.tokenize(text):
100 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
101 | split_tokens.append(sub_token)
102 |
103 | return split_tokens
104 |
105 | def convert_tokens_to_ids(self, tokens):
106 | return convert_by_vocab(self.vocab, tokens)
107 |
108 | def convert_ids_to_tokens(self, ids):
109 | return convert_by_vocab(self.inv_vocab, ids)
110 |
111 |
112 | class BasicTokenizer(object):
113 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
114 |
115 | def __init__(self, do_lower_case=True):
116 | """Constructs a BasicTokenizer.
117 |
118 | Args:
119 | do_lower_case: Whether to lower case the input.
120 | """
121 | self.do_lower_case = do_lower_case
122 |
123 | def tokenize(self, text):
124 | """Tokenizes a piece of text."""
125 | text = convert_to_unicode(text)
126 | text = self._clean_text(text)
127 |
128 | # This was added on November 1st, 2018 for the multilingual and Chinese
129 | # models. This is also applied to the English models now, but it doesn't
130 | # matter since the English models were not trained on any Chinese data
131 | # and generally don't have any Chinese data in them (there are Chinese
132 | # characters in the vocabulary because Wikipedia does have some Chinese
133 | # words in the English Wikipedia.).
134 | text = self._tokenize_chinese_chars(text)
135 |
136 | orig_tokens = whitespace_tokenize(text)
137 | split_tokens = []
138 | for token in orig_tokens:
139 | if self.do_lower_case:
140 | token = token.lower()
141 | token = self._run_strip_accents(token)
142 | split_tokens.extend(self._run_split_on_punc(token))
143 |
144 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
145 | return output_tokens
146 |
147 | def _run_strip_accents(self, text):
148 | """Strips accents from a piece of text."""
149 | text = unicodedata.normalize("NFD", text)
150 | output = []
151 | for char in text:
152 | cat = unicodedata.category(char)
153 | if cat == "Mn":
154 | continue
155 | output.append(char)
156 | return "".join(output)
157 |
158 | def _run_split_on_punc(self, text):
159 | """Splits punctuation on a piece of text."""
160 | chars = list(text)
161 | i = 0
162 | start_new_word = True
163 | output = []
164 | while i < len(chars):
165 | char = chars[i]
166 | if _is_punctuation(char):
167 | output.append([char])
168 | start_new_word = True
169 | else:
170 | if start_new_word:
171 | output.append([])
172 | start_new_word = False
173 | output[-1].append(char)
174 | i += 1
175 |
176 | return ["".join(x) for x in output]
177 |
178 | def _tokenize_chinese_chars(self, text):
179 | """Adds whitespace around any CJK character."""
180 | output = []
181 | for char in text:
182 | cp = ord(char)
183 | if self._is_chinese_char(cp):
184 | output.append(" ")
185 | output.append(char)
186 | output.append(" ")
187 | else:
188 | output.append(char)
189 | return "".join(output)
190 |
191 | def _is_chinese_char(self, cp):
192 | """Checks whether CP is the codepoint of a CJK character."""
193 | # This defines a "chinese character" as anything in the CJK Unicode block:
194 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
195 | #
196 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
197 | # despite its name. The modern Korean Hangul alphabet is a different block,
198 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
199 | # space-separated words, so they are not treated specially and handled
200 | # like the all of the other languages.
201 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
202 | (cp >= 0x3400 and cp <= 0x4DBF) or #
203 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
204 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
205 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
206 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
207 | (cp >= 0xF900 and cp <= 0xFAFF) or #
208 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
209 | return True
210 |
211 | return False
212 |
213 | def _clean_text(self, text):
214 | """Performs invalid character removal and whitespace cleanup on text."""
215 | output = []
216 | for char in text:
217 | cp = ord(char)
218 | if cp == 0 or cp == 0xfffd or _is_control(char):
219 | continue
220 | if _is_whitespace(char):
221 | output.append(" ")
222 | else:
223 | output.append(char)
224 | return "".join(output)
225 |
226 |
227 | class WordpieceTokenizer(object):
228 | """Runs WordPiece tokenziation."""
229 |
230 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
231 | self.vocab = vocab
232 | self.unk_token = unk_token
233 | self.max_input_chars_per_word = max_input_chars_per_word
234 |
235 | def tokenize(self, text):
236 | """Tokenizes a piece of text into its word pieces.
237 |
238 | This uses a greedy longest-match-first algorithm to perform tokenization
239 | using the given vocabulary.
240 |
241 | For example:
242 | input = "unaffable"
243 | output = ["un", "##aff", "##able"]
244 |
245 | Args:
246 | text: A single token or whitespace separated tokens. This should have
247 | already been passed through `BasicTokenizer.
248 |
249 | Returns:
250 | A list of wordpiece tokens.
251 | """
252 |
253 | text = convert_to_unicode(text)
254 |
255 | output_tokens = []
256 | for token in whitespace_tokenize(text):
257 | chars = list(token)
258 | if len(chars) > self.max_input_chars_per_word:
259 | output_tokens.append(self.unk_token)
260 | continue
261 |
262 | is_bad = False
263 | start = 0
264 | sub_tokens = []
265 | while start < len(chars):
266 | end = len(chars)
267 | cur_substr = None
268 | while start < end:
269 | substr = "".join(chars[start:end])
270 | if start > 0:
271 | substr = "##" + substr
272 | if substr in self.vocab:
273 | cur_substr = substr
274 | break
275 | end -= 1
276 | if cur_substr is None:
277 | is_bad = True
278 | break
279 | sub_tokens.append(cur_substr)
280 | start = end
281 |
282 | if is_bad:
283 | output_tokens.append(self.unk_token)
284 | else:
285 | output_tokens.extend(sub_tokens)
286 | return output_tokens
287 |
288 |
289 | def _is_whitespace(char):
290 | """Checks whether `chars` is a whitespace character."""
291 | # \t, \n, and \r are technically contorl characters but we treat them
292 | # as whitespace since they are generally considered as such.
293 | if char == " " or char == "\t" or char == "\n" or char == "\r":
294 | return True
295 | cat = unicodedata.category(char)
296 | if cat == "Zs":
297 | return True
298 | return False
299 |
300 |
301 | def _is_control(char):
302 | """Checks whether `chars` is a control character."""
303 | # These are technically control characters but we count them as whitespace
304 | # characters.
305 | if char == "\t" or char == "\n" or char == "\r":
306 | return False
307 | cat = unicodedata.category(char)
308 | if cat.startswith("C"):
309 | return True
310 | return False
311 |
312 |
313 | def _is_punctuation(char):
314 | """Checks whether `chars` is a punctuation character."""
315 | cp = ord(char)
316 | # We treat all non-letter/number ASCII as punctuation.
317 | # Characters such as "^", "$", and "`" are not in the Unicode
318 | # Punctuation class but we treat them as punctuation anyways, for
319 | # consistency.
320 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
321 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
322 | return True
323 | cat = unicodedata.category(char)
324 | if cat.startswith("P"):
325 | return True
326 | return False
327 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'electra-pytorch',
5 | packages = find_packages(),
6 | version = '0.1.2',
7 | license='MIT',
8 | description = 'Electra - Pytorch',
9 | author = 'Erik Nijkamp, Phil Wang',
10 | author_email = 'erik.nijkamp@gmail.com, lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/electra-pytorch',
12 | keywords = [
13 | 'transformers',
14 | 'artificial intelligence',
15 | 'pretraining'
16 | ],
17 | install_requires=[
18 | 'torch>=1.6.0',
19 | 'transformers==3.0.2',
20 | 'scipy',
21 | 'sklearn'
22 | ],
23 | setup_requires=[
24 | 'pytest-runner'
25 | ],
26 | tests_require=[
27 | 'pytest',
28 | 'reformer-pytorch'
29 | ],
30 | classifiers=[
31 | 'Development Status :: 4 - Beta',
32 | 'Intended Audience :: Developers',
33 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
34 | 'License :: OSI Approved :: MIT License',
35 | 'Programming Language :: Python :: 3.7',
36 | ],
37 | )
--------------------------------------------------------------------------------
/tests/test_electra_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from reformer_pytorch import ReformerLM
4 |
5 | from electra_pytorch import Electra
6 |
7 | def test_electra():
8 | generator = ReformerLM(
9 | num_tokens = 20000,
10 | dim = 512,
11 | depth = 1,
12 | max_seq_len = 1024
13 | )
14 |
15 | discriminator = ReformerLM(
16 | num_tokens = 20000,
17 | dim = 512,
18 | depth = 2,
19 | max_seq_len = 1024
20 | )
21 |
22 | generator.token_emb = discriminator.token_emb
23 | generator.pos_emb = discriminator.pos_emb
24 |
25 | trainer = Electra(
26 | generator,
27 | discriminator,
28 | num_tokens = 20000,
29 | discr_dim = 512,
30 | discr_layer = 'reformer',
31 | pad_token_id = 1,
32 | mask_ignore_token_ids = [2, 3]
33 | )
34 |
35 | data = torch.randint(0, 20000, (1, 1024))
36 | results = trainer(data)
37 | results.loss.backward()
38 |
39 | def test_electra_without_magic():
40 | generator = ReformerLM(
41 | num_tokens = 20000,
42 | dim = 512,
43 | depth = 1,
44 | max_seq_len = 1024
45 | )
46 |
47 | discriminator = ReformerLM(
48 | num_tokens = 20000,
49 | dim = 512,
50 | depth = 2,
51 | max_seq_len = 1024,
52 | return_embeddings = True
53 | )
54 |
55 | generator.token_emb = discriminator.token_emb
56 | generator.pos_emb = discriminator.pos_emb
57 |
58 |
59 | discriminator_with_adapter = nn.Sequential(
60 | discriminator,
61 | nn.Linear(512, 1),
62 | nn.Sigmoid()
63 | )
64 |
65 | trainer = Electra(
66 | generator,
67 | discriminator_with_adapter,
68 | num_tokens = 20000,
69 | pad_token_id = 1,
70 | mask_ignore_token_ids = [2, 3]
71 | )
72 |
73 | data = torch.randint(0, 20000, (1, 1024))
74 | results = trainer(data)
75 | results.loss.backward()
76 |
--------------------------------------------------------------------------------