├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── coco.png ├── coco_lm_pytorch ├── __init__.py └── coco_lm_pytorch.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## COCO LM Pretraining (wip) 4 | 5 | Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra. 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install coco-lm-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | An example using the `x-transformers` library 16 | 17 | ```bash 18 | $ pip install x-transformers 19 | ``` 20 | Then 21 | 22 | ```python 23 | import torch 24 | from coco_lm_pytorch import COCO 25 | 26 | # (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator 27 | 28 | from x_transformers import TransformerWrapper, Encoder 29 | 30 | generator = TransformerWrapper( 31 | num_tokens = 20000, 32 | emb_dim = 128, 33 | max_seq_len = 1024, 34 | attn_layers = Encoder( 35 | dim = 256, # smaller hidden dimension 36 | heads = 4, # less heads 37 | ff_mult = 2, # smaller feedforward dimension 38 | depth = 1 39 | ) 40 | ) 41 | 42 | discriminator = TransformerWrapper( 43 | num_tokens = 20000, 44 | emb_dim = 128, 45 | max_seq_len = 1024, 46 | attn_layers = Encoder( 47 | dim = 1024, 48 | heads = 16, 49 | ff_mult = 4, 50 | depth = 12 51 | ) 52 | ) 53 | 54 | # (2) weight tie the token and positional embeddings of generator and discriminator 55 | 56 | generator.token_emb = discriminator.token_emb 57 | generator.pos_emb = discriminator.pos_emb 58 | 59 | # weight tie any other embeddings if available, token type embeddings, etc. 60 | 61 | # (3) instantiate COCO 62 | 63 | trainer = COCO( 64 | generator, 65 | discriminator, 66 | discr_dim = 1024, # the embedding dimension of the discriminator 67 | discr_layer = 'norm', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced 68 | cls_token_id = 1, # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning 69 | mask_token_id = 2, # the token id reserved for masking 70 | pad_token_id = 0, # the token id for padding 71 | mask_prob = 0.15, # masking probability for masked language modeling 72 | mask_ignore_token_ids = [], # ids of tokens to ignore for mask modeling ex. (cls, sep) 73 | cl_weight = 1., # weight for the contrastive learning loss 74 | disc_weight = 1., # weight for the corrective learning loss 75 | gen_weight = 1. # weight for the MLM loss 76 | ) 77 | 78 | # (4) train 79 | 80 | data = torch.randint(0, 20000, (1, 1024)) 81 | 82 | loss = trainer(data) 83 | loss.backward() 84 | 85 | # after much training, the discriminator should have improved 86 | 87 | torch.save(discriminator, f'./pretrained-model.pt') 88 | ``` 89 | 90 | ## Citations 91 | 92 | ```bibtex 93 | @misc{meng2021cocolm, 94 | title = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining}, 95 | author = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song}, 96 | year = {2021}, 97 | eprint = {2102.08473}, 98 | archivePrefix = {arXiv}, 99 | primaryClass = {cs.CL} 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /coco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/coco-lm-pytorch/516c1783b9f8ec6c27bab80d3aa8521d217a18cb/coco.png -------------------------------------------------------------------------------- /coco_lm_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from coco_lm_pytorch.coco_lm_pytorch import COCO 2 | -------------------------------------------------------------------------------- /coco_lm_pytorch/coco_lm_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import reduce 3 | 4 | import torch 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | 8 | # helpers 9 | 10 | def log(t, eps=1e-9): 11 | return torch.log(t + eps) 12 | 13 | def norm(t): 14 | return F.normalize(t, p = 2, dim = -1) 15 | 16 | def gumbel_noise(t): 17 | noise = torch.zeros_like(t).uniform_(0, 1) 18 | return -log(-log(noise)) 19 | 20 | def gumbel_sample(t, temperature = 1.): 21 | return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1) 22 | 23 | def prob_mask_like(t, prob): 24 | return torch.zeros_like(t).float().uniform_(0, 1) < prob 25 | 26 | def mask_with_tokens(t, token_ids): 27 | init_no_mask = torch.full_like(t, False, dtype=torch.bool) 28 | mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) 29 | return mask 30 | 31 | def get_mask_subset_with_prob(mask, prob): 32 | batch, seq_len, device = *mask.shape, mask.device 33 | max_masked = math.ceil(prob * seq_len) 34 | 35 | num_tokens = mask.sum(dim=-1, keepdim=True) 36 | mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil()) 37 | mask_excess = mask_excess[:, :max_masked] 38 | 39 | rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) 40 | _, sampled_indices = rand.topk(max_masked, dim=-1) 41 | sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) 42 | 43 | new_mask = torch.zeros((batch, seq_len + 1), device=device) 44 | new_mask.scatter_(-1, sampled_indices, 1) 45 | return new_mask[:, 1:].bool() 46 | 47 | # hidden layer extractor class, for magically adding adapter to language model to be pretrained 48 | 49 | class HiddenLayerExtractor(nn.Module): 50 | def __init__(self, net, layer = -2): 51 | super().__init__() 52 | self.net = net 53 | self.layer = layer 54 | 55 | self.hidden = None 56 | self.hook_registered = False 57 | 58 | def _find_layer(self): 59 | if type(self.layer) == str: 60 | modules = dict([*self.net.named_modules()]) 61 | return modules.get(self.layer, None) 62 | elif type(self.layer) == int: 63 | children = [*self.net.children()] 64 | return children[self.layer] 65 | return None 66 | 67 | def _hook(self, _, __, output): 68 | self.hidden = output 69 | 70 | def _register_hook(self): 71 | layer = self._find_layer() 72 | assert layer is not None, f'hidden layer ({self.layer}) not found' 73 | handle = layer.register_forward_hook(self._hook) 74 | self.hook_registered = True 75 | 76 | def forward(self, x): 77 | if self.layer == -1: 78 | return self.net(x) 79 | 80 | if not self.hook_registered: 81 | self._register_hook() 82 | 83 | _ = self.net(x) 84 | hidden = self.hidden 85 | self.hidden = None 86 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output' 87 | return hidden 88 | 89 | # main electra class 90 | 91 | class COCO(nn.Module): 92 | def __init__( 93 | self, 94 | generator, 95 | discriminator, 96 | *, 97 | discr_dim, 98 | num_tokens = None, 99 | discr_layer = -1, 100 | mask_prob = 0.15, 101 | replace_prob = 0.85, 102 | random_token_prob = 0., 103 | pad_token_id = 0, 104 | cls_token_id = 1, 105 | mask_token_id = 2, 106 | mask_ignore_token_ids = [], 107 | disc_weight = 50., 108 | gen_weight = 1., 109 | cl_weight = 1., 110 | temperature = 1., 111 | crop_percentage = 0.5 112 | ): 113 | super().__init__() 114 | 115 | self.generator = generator 116 | self.discriminator = discriminator 117 | 118 | self.discriminator = HiddenLayerExtractor(discriminator, layer = discr_layer) 119 | self.to_correction_logits = nn.Linear(discr_dim, 1) 120 | 121 | # mlm related probabilities 122 | self.mask_prob = mask_prob 123 | self.replace_prob = replace_prob 124 | 125 | self.num_tokens = num_tokens 126 | self.random_token_prob = random_token_prob 127 | 128 | # token ids 129 | self.cls_token_id = cls_token_id 130 | self.pad_token_id = pad_token_id 131 | self.mask_token_id = mask_token_id 132 | self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id, cls_token_id]) 133 | 134 | # sampling temperature 135 | self.temperature = temperature 136 | 137 | # loss weights 138 | self.disc_weight = disc_weight 139 | self.gen_weight = gen_weight 140 | self.cl_weight = cl_weight 141 | 142 | self.cl_temperature = nn.Parameter(torch.tensor(1.)) 143 | 144 | self.crop_percentage = crop_percentage 145 | 146 | def forward(self, input, **kwargs): 147 | b, t, device = *input.shape, input.device 148 | assert b > 1, 'batch size need to be bigger than 1 for contrastive learning' 149 | 150 | cls_tokens = torch.empty(b, 1, dtype = torch.long).fill_(self.cls_token_id) 151 | 152 | input = torch.cat((cls_tokens, input), dim = 1) 153 | input = input[:, :-1] 154 | 155 | replace_prob = prob_mask_like(input, self.replace_prob) 156 | 157 | # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep]) 158 | # also do not include these special tokens in the tokens chosen at random 159 | no_mask = mask_with_tokens(input, self.mask_ignore_token_ids) 160 | mask = get_mask_subset_with_prob(~no_mask, self.mask_prob) 161 | 162 | # get random cropped input for contrastive learning 163 | random_crop = get_mask_subset_with_prob(~no_mask, self.crop_percentage) 164 | crop_length = int(t * self.crop_percentage) 165 | cropped_input = input.masked_select(random_crop).reshape(b, crop_length) 166 | cropped_input = torch.cat((cls_tokens, cropped_input), dim = 1) 167 | cropped_input = F.pad(cropped_input, (0, t - crop_length - 1), value = self.pad_token_id) 168 | 169 | # get mask indices 170 | mask_indices = torch.nonzero(mask, as_tuple=True) 171 | 172 | # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob) 173 | masked_input = input.clone().detach() 174 | 175 | # if random token probability > 0 for mlm 176 | if self.random_token_prob > 0: 177 | assert self.num_tokens is not None, 'Number of tokens (num_tokens) must be passed to Electra for randomizing tokens during masked language modeling' 178 | 179 | random_token_prob = prob_mask_like(input, self.random_token_prob) 180 | random_tokens = torch.randint(0, self.num_tokens, input.shape, device=input.device) 181 | random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids) 182 | random_token_prob &= ~random_no_mask 183 | random_indices = torch.nonzero(random_token_prob, as_tuple=True) 184 | masked_input[random_indices] = random_tokens[random_indices] 185 | 186 | # [mask] input 187 | masked_input = masked_input.masked_fill(mask * replace_prob, self.mask_token_id) 188 | 189 | # set inverse of mask to padding tokens for labels 190 | gen_labels = input.masked_fill(~mask, self.pad_token_id) 191 | 192 | # get generator output and get mlm loss 193 | logits = self.generator(masked_input, **kwargs) 194 | 195 | mlm_loss = F.cross_entropy( 196 | logits.transpose(1, 2), 197 | gen_labels, 198 | ignore_index = self.pad_token_id 199 | ) 200 | 201 | # use mask from before to select logits that need sampling 202 | sample_logits = logits[mask_indices] 203 | 204 | # sample 205 | sampled = gumbel_sample(sample_logits, temperature = self.temperature) 206 | 207 | # scatter the sampled values back to the input 208 | disc_input = input.clone() 209 | disc_input[mask_indices] = sampled.detach() 210 | 211 | # generate discriminator labels, with replaced as True and original as False 212 | disc_labels = (input != disc_input).float().detach() 213 | 214 | # get discriminator predictions of replaced / original 215 | non_padded_indices = torch.nonzero(input != self.pad_token_id, as_tuple=True) 216 | 217 | # get discriminator output and binary cross entropy loss 218 | disc_embeddings_correction = self.discriminator(disc_input, **kwargs) 219 | 220 | correction_logits = self.to_correction_logits(disc_embeddings_correction) 221 | disc_logits = correction_logits.reshape_as(disc_labels) 222 | 223 | disc_loss = F.binary_cross_entropy_with_logits( 224 | disc_logits[non_padded_indices], 225 | disc_labels[non_padded_indices] 226 | ) 227 | 228 | # contrastive loss 229 | disc_embeddings_cropped = self.discriminator(cropped_input, **kwargs) 230 | 231 | cls_tokens_corrected, cls_tokens_cropped = disc_embeddings_correction[:, 0], disc_embeddings_cropped[:, 0] 232 | cls_tokens_corrected, cls_tokens_cropped = map(norm, (cls_tokens_corrected, cls_tokens_cropped)) 233 | 234 | cl_temperature = self.cl_temperature.exp() 235 | 236 | sim = einsum('i d, j d -> i j', cls_tokens_corrected, cls_tokens_cropped) * cl_temperature 237 | labels = torch.arange(b, device = device) 238 | 239 | cl_loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) * 0.5 240 | 241 | # weight all losses 242 | weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss 243 | return weighted_loss 244 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'coco-lm-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.2', 7 | license='MIT', 8 | description = 'COCO - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/coco-lm-pytorch', 12 | keywords = [ 13 | 'transformers', 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'pretraining' 17 | ], 18 | install_requires=[ 19 | 'torch>=1.6.0', 20 | 'einops', 21 | 'x-transformers' 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 4 - Beta', 25 | 'Intended Audience :: Developers', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Programming Language :: Python :: 3.7', 29 | ], 30 | ) --------------------------------------------------------------------------------