├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── coco.png
├── coco_lm_pytorch
├── __init__.py
└── coco_lm_pytorch.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## COCO LM Pretraining (wip)
4 |
5 | Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install coco-lm-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | An example using the `x-transformers` library
16 |
17 | ```bash
18 | $ pip install x-transformers
19 | ```
20 | Then
21 |
22 | ```python
23 | import torch
24 | from coco_lm_pytorch import COCO
25 |
26 | # (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
27 |
28 | from x_transformers import TransformerWrapper, Encoder
29 |
30 | generator = TransformerWrapper(
31 | num_tokens = 20000,
32 | emb_dim = 128,
33 | max_seq_len = 1024,
34 | attn_layers = Encoder(
35 | dim = 256, # smaller hidden dimension
36 | heads = 4, # less heads
37 | ff_mult = 2, # smaller feedforward dimension
38 | depth = 1
39 | )
40 | )
41 |
42 | discriminator = TransformerWrapper(
43 | num_tokens = 20000,
44 | emb_dim = 128,
45 | max_seq_len = 1024,
46 | attn_layers = Encoder(
47 | dim = 1024,
48 | heads = 16,
49 | ff_mult = 4,
50 | depth = 12
51 | )
52 | )
53 |
54 | # (2) weight tie the token and positional embeddings of generator and discriminator
55 |
56 | generator.token_emb = discriminator.token_emb
57 | generator.pos_emb = discriminator.pos_emb
58 |
59 | # weight tie any other embeddings if available, token type embeddings, etc.
60 |
61 | # (3) instantiate COCO
62 |
63 | trainer = COCO(
64 | generator,
65 | discriminator,
66 | discr_dim = 1024, # the embedding dimension of the discriminator
67 | discr_layer = 'norm', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
68 | cls_token_id = 1, # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
69 | mask_token_id = 2, # the token id reserved for masking
70 | pad_token_id = 0, # the token id for padding
71 | mask_prob = 0.15, # masking probability for masked language modeling
72 | mask_ignore_token_ids = [], # ids of tokens to ignore for mask modeling ex. (cls, sep)
73 | cl_weight = 1., # weight for the contrastive learning loss
74 | disc_weight = 1., # weight for the corrective learning loss
75 | gen_weight = 1. # weight for the MLM loss
76 | )
77 |
78 | # (4) train
79 |
80 | data = torch.randint(0, 20000, (1, 1024))
81 |
82 | loss = trainer(data)
83 | loss.backward()
84 |
85 | # after much training, the discriminator should have improved
86 |
87 | torch.save(discriminator, f'./pretrained-model.pt')
88 | ```
89 |
90 | ## Citations
91 |
92 | ```bibtex
93 | @misc{meng2021cocolm,
94 | title = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining},
95 | author = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
96 | year = {2021},
97 | eprint = {2102.08473},
98 | archivePrefix = {arXiv},
99 | primaryClass = {cs.CL}
100 | }
101 | ```
102 |
--------------------------------------------------------------------------------
/coco.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/coco-lm-pytorch/516c1783b9f8ec6c27bab80d3aa8521d217a18cb/coco.png
--------------------------------------------------------------------------------
/coco_lm_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from coco_lm_pytorch.coco_lm_pytorch import COCO
2 |
--------------------------------------------------------------------------------
/coco_lm_pytorch/coco_lm_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import reduce
3 |
4 | import torch
5 | from torch import nn, einsum
6 | import torch.nn.functional as F
7 |
8 | # helpers
9 |
10 | def log(t, eps=1e-9):
11 | return torch.log(t + eps)
12 |
13 | def norm(t):
14 | return F.normalize(t, p = 2, dim = -1)
15 |
16 | def gumbel_noise(t):
17 | noise = torch.zeros_like(t).uniform_(0, 1)
18 | return -log(-log(noise))
19 |
20 | def gumbel_sample(t, temperature = 1.):
21 | return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)
22 |
23 | def prob_mask_like(t, prob):
24 | return torch.zeros_like(t).float().uniform_(0, 1) < prob
25 |
26 | def mask_with_tokens(t, token_ids):
27 | init_no_mask = torch.full_like(t, False, dtype=torch.bool)
28 | mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
29 | return mask
30 |
31 | def get_mask_subset_with_prob(mask, prob):
32 | batch, seq_len, device = *mask.shape, mask.device
33 | max_masked = math.ceil(prob * seq_len)
34 |
35 | num_tokens = mask.sum(dim=-1, keepdim=True)
36 | mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
37 | mask_excess = mask_excess[:, :max_masked]
38 |
39 | rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
40 | _, sampled_indices = rand.topk(max_masked, dim=-1)
41 | sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
42 |
43 | new_mask = torch.zeros((batch, seq_len + 1), device=device)
44 | new_mask.scatter_(-1, sampled_indices, 1)
45 | return new_mask[:, 1:].bool()
46 |
47 | # hidden layer extractor class, for magically adding adapter to language model to be pretrained
48 |
49 | class HiddenLayerExtractor(nn.Module):
50 | def __init__(self, net, layer = -2):
51 | super().__init__()
52 | self.net = net
53 | self.layer = layer
54 |
55 | self.hidden = None
56 | self.hook_registered = False
57 |
58 | def _find_layer(self):
59 | if type(self.layer) == str:
60 | modules = dict([*self.net.named_modules()])
61 | return modules.get(self.layer, None)
62 | elif type(self.layer) == int:
63 | children = [*self.net.children()]
64 | return children[self.layer]
65 | return None
66 |
67 | def _hook(self, _, __, output):
68 | self.hidden = output
69 |
70 | def _register_hook(self):
71 | layer = self._find_layer()
72 | assert layer is not None, f'hidden layer ({self.layer}) not found'
73 | handle = layer.register_forward_hook(self._hook)
74 | self.hook_registered = True
75 |
76 | def forward(self, x):
77 | if self.layer == -1:
78 | return self.net(x)
79 |
80 | if not self.hook_registered:
81 | self._register_hook()
82 |
83 | _ = self.net(x)
84 | hidden = self.hidden
85 | self.hidden = None
86 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
87 | return hidden
88 |
89 | # main electra class
90 |
91 | class COCO(nn.Module):
92 | def __init__(
93 | self,
94 | generator,
95 | discriminator,
96 | *,
97 | discr_dim,
98 | num_tokens = None,
99 | discr_layer = -1,
100 | mask_prob = 0.15,
101 | replace_prob = 0.85,
102 | random_token_prob = 0.,
103 | pad_token_id = 0,
104 | cls_token_id = 1,
105 | mask_token_id = 2,
106 | mask_ignore_token_ids = [],
107 | disc_weight = 50.,
108 | gen_weight = 1.,
109 | cl_weight = 1.,
110 | temperature = 1.,
111 | crop_percentage = 0.5
112 | ):
113 | super().__init__()
114 |
115 | self.generator = generator
116 | self.discriminator = discriminator
117 |
118 | self.discriminator = HiddenLayerExtractor(discriminator, layer = discr_layer)
119 | self.to_correction_logits = nn.Linear(discr_dim, 1)
120 |
121 | # mlm related probabilities
122 | self.mask_prob = mask_prob
123 | self.replace_prob = replace_prob
124 |
125 | self.num_tokens = num_tokens
126 | self.random_token_prob = random_token_prob
127 |
128 | # token ids
129 | self.cls_token_id = cls_token_id
130 | self.pad_token_id = pad_token_id
131 | self.mask_token_id = mask_token_id
132 | self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id, cls_token_id])
133 |
134 | # sampling temperature
135 | self.temperature = temperature
136 |
137 | # loss weights
138 | self.disc_weight = disc_weight
139 | self.gen_weight = gen_weight
140 | self.cl_weight = cl_weight
141 |
142 | self.cl_temperature = nn.Parameter(torch.tensor(1.))
143 |
144 | self.crop_percentage = crop_percentage
145 |
146 | def forward(self, input, **kwargs):
147 | b, t, device = *input.shape, input.device
148 | assert b > 1, 'batch size need to be bigger than 1 for contrastive learning'
149 |
150 | cls_tokens = torch.empty(b, 1, dtype = torch.long).fill_(self.cls_token_id)
151 |
152 | input = torch.cat((cls_tokens, input), dim = 1)
153 | input = input[:, :-1]
154 |
155 | replace_prob = prob_mask_like(input, self.replace_prob)
156 |
157 | # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
158 | # also do not include these special tokens in the tokens chosen at random
159 | no_mask = mask_with_tokens(input, self.mask_ignore_token_ids)
160 | mask = get_mask_subset_with_prob(~no_mask, self.mask_prob)
161 |
162 | # get random cropped input for contrastive learning
163 | random_crop = get_mask_subset_with_prob(~no_mask, self.crop_percentage)
164 | crop_length = int(t * self.crop_percentage)
165 | cropped_input = input.masked_select(random_crop).reshape(b, crop_length)
166 | cropped_input = torch.cat((cls_tokens, cropped_input), dim = 1)
167 | cropped_input = F.pad(cropped_input, (0, t - crop_length - 1), value = self.pad_token_id)
168 |
169 | # get mask indices
170 | mask_indices = torch.nonzero(mask, as_tuple=True)
171 |
172 | # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
173 | masked_input = input.clone().detach()
174 |
175 | # if random token probability > 0 for mlm
176 | if self.random_token_prob > 0:
177 | assert self.num_tokens is not None, 'Number of tokens (num_tokens) must be passed to Electra for randomizing tokens during masked language modeling'
178 |
179 | random_token_prob = prob_mask_like(input, self.random_token_prob)
180 | random_tokens = torch.randint(0, self.num_tokens, input.shape, device=input.device)
181 | random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids)
182 | random_token_prob &= ~random_no_mask
183 | random_indices = torch.nonzero(random_token_prob, as_tuple=True)
184 | masked_input[random_indices] = random_tokens[random_indices]
185 |
186 | # [mask] input
187 | masked_input = masked_input.masked_fill(mask * replace_prob, self.mask_token_id)
188 |
189 | # set inverse of mask to padding tokens for labels
190 | gen_labels = input.masked_fill(~mask, self.pad_token_id)
191 |
192 | # get generator output and get mlm loss
193 | logits = self.generator(masked_input, **kwargs)
194 |
195 | mlm_loss = F.cross_entropy(
196 | logits.transpose(1, 2),
197 | gen_labels,
198 | ignore_index = self.pad_token_id
199 | )
200 |
201 | # use mask from before to select logits that need sampling
202 | sample_logits = logits[mask_indices]
203 |
204 | # sample
205 | sampled = gumbel_sample(sample_logits, temperature = self.temperature)
206 |
207 | # scatter the sampled values back to the input
208 | disc_input = input.clone()
209 | disc_input[mask_indices] = sampled.detach()
210 |
211 | # generate discriminator labels, with replaced as True and original as False
212 | disc_labels = (input != disc_input).float().detach()
213 |
214 | # get discriminator predictions of replaced / original
215 | non_padded_indices = torch.nonzero(input != self.pad_token_id, as_tuple=True)
216 |
217 | # get discriminator output and binary cross entropy loss
218 | disc_embeddings_correction = self.discriminator(disc_input, **kwargs)
219 |
220 | correction_logits = self.to_correction_logits(disc_embeddings_correction)
221 | disc_logits = correction_logits.reshape_as(disc_labels)
222 |
223 | disc_loss = F.binary_cross_entropy_with_logits(
224 | disc_logits[non_padded_indices],
225 | disc_labels[non_padded_indices]
226 | )
227 |
228 | # contrastive loss
229 | disc_embeddings_cropped = self.discriminator(cropped_input, **kwargs)
230 |
231 | cls_tokens_corrected, cls_tokens_cropped = disc_embeddings_correction[:, 0], disc_embeddings_cropped[:, 0]
232 | cls_tokens_corrected, cls_tokens_cropped = map(norm, (cls_tokens_corrected, cls_tokens_cropped))
233 |
234 | cl_temperature = self.cl_temperature.exp()
235 |
236 | sim = einsum('i d, j d -> i j', cls_tokens_corrected, cls_tokens_cropped) * cl_temperature
237 | labels = torch.arange(b, device = device)
238 |
239 | cl_loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) * 0.5
240 |
241 | # weight all losses
242 | weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss
243 | return weighted_loss
244 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'coco-lm-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.2',
7 | license='MIT',
8 | description = 'COCO - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/coco-lm-pytorch',
12 | keywords = [
13 | 'transformers',
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'pretraining'
17 | ],
18 | install_requires=[
19 | 'torch>=1.6.0',
20 | 'einops',
21 | 'x-transformers'
22 | ],
23 | classifiers=[
24 | 'Development Status :: 4 - Beta',
25 | 'Intended Audience :: Developers',
26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
27 | 'License :: OSI Approved :: MIT License',
28 | 'Programming Language :: Python :: 3.7',
29 | ],
30 | )
--------------------------------------------------------------------------------