├── .flake8 ├── .github └── workflows │ └── dockerhub.yml ├── .gitignore ├── README.md ├── docker └── Dockerfile.pytorch-cuda111 └── src ├── mint ├── __init__.py ├── bart.py ├── bert.py ├── common.py ├── data.py ├── examples │ ├── __init__.py │ ├── bart_completer.py │ ├── bert_completer.py │ ├── bert_searcher.py │ ├── build_search_index.py │ ├── eval_gpt_lm.py │ ├── gpt_completer.py │ ├── pretrain_bart_simple.py │ ├── pretrain_bart_wiki.py │ ├── pretrain_bert_simple.py │ ├── pretrain_bert_wiki.py │ ├── pretrain_gpt_simple.py │ ├── pretrain_t5_simple.py │ ├── t5_completer.py │ ├── tune_bart_for_cls.py │ ├── tune_bert_for_cls.py │ ├── tune_bert_for_paired_cls.py │ └── tune_gpt2_for_cls.py ├── gpt.py ├── postln.py ├── preln.py ├── t5.py └── train.py ├── setup.cfg └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501,W503,E712,E203,E722,E741,E731,C901 3 | exclude = .git,__pycache__,docs/source/conf.py,old,build,dist 4 | max-complexity = 10 5 | -------------------------------------------------------------------------------- /.github/workflows/dockerhub.yml: -------------------------------------------------------------------------------- 1 | name: Publish to DockerHub 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 13 | uses: actions/checkout@v2 14 | - name: Set up Docker Buildx 15 | uses: docker/setup-buildx-action@v1 16 | - name: Login to DockerHub 17 | uses: docker/login-action@v1 18 | with: 19 | username: ${{ secrets.DOCKER_USERNAME }} 20 | password: ${{ secrets.DOCKER_PASSWORD }} 21 | - name: Build & Push PyTorch CUDA 11.1 image 22 | uses: docker/build-push-action@v2 23 | with: 24 | file: docker/Dockerfile.pytorch-cuda111 25 | push: true 26 | tags: dpressel/mint-cuda111:latest 27 | 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | **/pip-wheel-metadata/ 29 | pip-wheel-metadata/ 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | python/README.md 110 | python/MANIFEST.in 111 | 112 | scripts/.lrs 113 | baseline.iml 114 | .idea/* 115 | tags 116 | *.zip 117 | 118 | *.pkl 119 | *.conll 120 | *.onnx 121 | *.script 122 | *.pth 123 | 124 | api-examples/pytorch_cpp/json.hpp 125 | api-examples/pytorch_cpp/libtorch* 126 | api-examples/pytorch_cpp/tag-text 127 | api-examples/pytorch_cpp/classify-text 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MinT: Minimal Transformer Library and Tutorials 2 | 3 | A minimalistic implementation of common Transformers from scratch! 4 | 5 | ## Colabs 6 | 7 | A series of tutorials on building common Transformer models from scratch. Each tutorial builds on the previous one, so they should be done in order. 8 | 9 | - [BERT from scratch](https://colab.research.google.com/drive/175hnhLkJcXH40tGGpO-1kbBrb2IIcIuT?usp=sharing) 10 | - [GPT & GPT2 from scratch](https://colab.research.google.com/drive/1svaeO-TF1UEEIq8aew4B5x-y4i79fIXv?usp=sharing) 11 | - [BART from scratch](https://colab.research.google.com/drive/12C764uTLwPMM9hUlprm_a4bUwHz91a7P?usp=sharing) 12 | - [T5 from scratch](https://colab.research.google.com/drive/1G3egJjNRrXog-8reY1Ssfoa6c92Dp4jh?usp=sharing) 13 | - [Build your own SentenceBERT](https://colab.research.google.com/drive/1P11ogAYU-EZ_Kbo7WorMM7p35qvwPuMo?usp=sharing) 14 | 15 | The code here is also factored out here as a python package for easy use outside of the tutorial. 16 | 17 | Because this is written for a tutorial to explain the modeling and training approach, we currently depend on the 18 | HuggingFace tokenizers library to implement subword tokenization. I selected it because its fast, and widely used. 19 | There are also other good, fast libraries (like BlingFire) that cover multiple subword approaches, but the library 20 | doesnt support them at this time. 21 | 22 | 23 | ## A Tiny Library for Transformers from the ground up 24 | 25 | Minimal PyTorch implementation of common Transformer architectures. Currently implements 26 | 27 | - Encoder Only 28 | - [BERT](https://aclanthology.org/N19-1423/) / [RoBERTa](https://arxiv.org/pdf/1907.11692.pdf) 29 | - Decoder Only 30 | - [GPT](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) 31 | - [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) 32 | - Encoder-Decoder 33 | - [BART](https://arxiv.org/pdf/1910.13461v1.pdf) 34 | - [T5](https://arxiv.org/pdf/1910.10683.pdf) 35 | - Dual-Encoder 36 | - [SentenceBERT](https://aclanthology.org/D19-1410.pdf) 37 | 38 | 39 | ## Pretraining 40 | 41 | There are example programs at this time showing how to pretrain from scratch (or continue pre-training on pre-trained models) 42 | 43 | ### In-memory training on a small dataset 44 | There are 2 pretraining examples, one is a toy example good for small datasets like Wikitext-2. 45 | The loader preprocesses the data and slurps the tensors into a TensorDataset. 46 | It uses the `SimpleTrainer` to train several epochs. Because the dataset is small and a Map-style dataset, it makes sense to train a whole epoch and then evaluate a whole test dataset. For large datasets, I would not recommend this approach. 47 | 48 | ### Out-of-memory training on a large dataset 49 | The second example uses an infinite IterableDataset to read multiple files (shards) and converts them to tensors on the fly. 50 | This program is a more realistic example of language modeling. 51 | 52 | ### Out-of-memory preprocessed shards on a large dataset 53 | 54 | The library also supports fully preprocessed datasets, but there is no example for that usage at this time. 55 | 56 | ### Wikipedia 57 | 58 | To pretrain on English Wikipedia with this program, you'll need an XML wikipedia dump. 59 | This is usually named `enwiki-latest-pages-articles.xml.bz2` and can be found from the [Wikipedia dump site](https://dumps.wikimedia.org/enwiki/latest/). 60 | For example, this should work for downloading: 61 | 62 | ``` 63 | wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 64 | ``` 65 | You also need to use this repository: 66 | 67 | ``` 68 | git clone https://github.com/attardi/wikiextractor 69 | cd wikiextractor 70 | git checkout 16186e290d9eb0eb3a3784c6c0635a9ed7e855c3 71 | 72 | ``` 73 | Here is how I ran it for my example: 74 | 75 | ``` 76 | python WikiExtractor.py ${INPUT}/enwiki-latest-pages-articles.xml.bz2 \ 77 | -q --json \ 78 | --processes 7 \ 79 | --output ${OUTPUT}/enwiki-extracted \ 80 | --bytes 100M \ 81 | --compress \ 82 | --links \ 83 | --discard_elements gallery,timeline,noinclude \ 84 | --min_text_length 0 \ 85 | --filter_disambig_pages 86 | ``` 87 | Regarding the command line above, only use `--compress` if you have bzip2 on your system and your Python can 88 | 89 | ```python 90 | import bz2 91 | ``` 92 | 93 | In each target generated (e.g. AA, AB, AC), we are going to rename with a prefix (e.g. AA): 94 | 95 | ``` 96 | for file in *.bz2; do mv "$file" "AA_$file"; done; 97 | ``` 98 | We can then copy these to a single directory, or split them however we would like into train and test 99 | 100 | Here is how you can train on multiple workers with DistributedDataParallel: 101 | 102 | ``` 103 | CUDA_VISIBLE_DEVICES=2,3,4,5,6,7,8,9 python -m torch.distributed.launch \ 104 | --node=1 \ 105 | --nproc_per_node=8 \ 106 | --node_rank=0 \ 107 | --master_port=$PORT \ 108 | pretrain_bert_wiki.py \ 109 | --vocab_file /data/k8s/hf-models/bert-base-uncased/vocab.txt \ 110 | --lowercase \ 111 | --train_file "/path/to/enwiki-extracted/train/" \ 112 | --valid_file "/path/to/enwiki-extracted/valid/" \ 113 | --num_train_workers 4 \ 114 | --num_valid_workers 1 --batch_size $B --num_steps $N --saves_per_cycle 1 \ 115 | --train_cycle_size 10000 \ 116 | --eval_cycle_size 500 \ 117 | --distributed 118 | 119 | ``` 120 | 121 | ## Fine-tuning 122 | 123 | The [tune_bert_for_cls](src/tfs/examples/tune_bert_for_cls.py) program is a simple example of fine-tuning 124 | our BERT implementation from scratch. 125 | 126 | ## Completer REPL 127 | 128 | The [bert_completer](src/tfs/examples/bert_completer.py) program allows you to type in masked strings and 129 | see how BERT would complete them. When it starts, you can pass `--sample` in order to get sampling from the output, 130 | otherwise it uses the most likely values. You can switch between the 2 modes at runtime using: 131 | 132 | ``` 133 | BERT>> :sample 134 | ``` 135 | or 136 | ``` 137 | BERT>> :max 138 | ``` 139 | This example uses `prompt_toolkit` which is not a core dependency, but you can install it like this: 140 | ``` 141 | pip install .[examples] 142 | ``` 143 | 144 | 145 | ## More Info Soon 146 | 147 | -------------------------------------------------------------------------------- /docker/Dockerfile.pytorch-cuda111: -------------------------------------------------------------------------------- 1 | FROM meadml/cuda11.1-cudnn8-devel-ubuntu18.04-python3.8 2 | 3 | COPY . /usr/mint 4 | WORKDIR /usr/mint 5 | 6 | RUN cd src && pip install --no-use-pep517 .[examples] 7 | 8 | # Set env variables 9 | ENV TIMING_LOG_LEVEL=DEBUG 10 | # Set terminal encodings 11 | ENV LC_ALL=C.UTF-8 12 | ENV LANG=C.UTF-8 13 | ENV PATH=/usr/local/cuda/bin:$PATH 14 | ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 15 | 16 | COPY . /usr/mint 17 | 18 | # Install pytorch 19 | 20 | 21 | RUN python3.8 -m pip --no-cache-dir install torch==1.8.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html && \ 22 | python3.8 -m pip install tensorboard 23 | 24 | 25 | -------------------------------------------------------------------------------- /src/mint/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dpressel/mint/c4f046ccd620e28ff0392b166b50ae1540e7f758/src/mint/__init__.py -------------------------------------------------------------------------------- /src/mint/bart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from typing import Optional 6 | from mint.postln import TransformerEncoderDecoder, TransformerSequenceGenerator 7 | 8 | import logging 9 | import random 10 | 11 | logger = logging.getLogger("mint") 12 | 13 | 14 | def create_dst_from_src(input_ids: torch.Tensor, decoder_start_token_id: int = 2): 15 | dst_ids = torch.ones_like(input_ids) 16 | dst_ids[:, 0] = decoder_start_token_id 17 | dst_ids[:, 1:] = input_ids[:, :-1] 18 | return dst_ids 19 | 20 | 21 | class BartLearnedPositionalEmbedding(nn.Module): 22 | """Learned positional embeddings for BART 23 | 24 | The embeddings are a combination of 2 inputs, word embeddings and positional embeddings 25 | The word embeddings is a learned vector that uses the word one-hots to convert to a dense representation. 26 | Each of these embeddings are added together in the forward 27 | """ 28 | 29 | BART_POS_OFFSET = 2 30 | 31 | def __init__( 32 | self, 33 | vocab_dim: int, 34 | hidden_dim: int = 768, 35 | padding_idx: int = 1, 36 | max_seq_len: int = 1024, 37 | ): 38 | super().__init__() 39 | self.word_embeddings = nn.Embedding(vocab_dim, hidden_dim, padding_idx) 40 | self.position_embeddings = nn.Embedding( 41 | max_seq_len + BartLearnedPositionalEmbedding.BART_POS_OFFSET, hidden_dim 42 | ) 43 | 44 | def forward( 45 | self, x: torch.Tensor, token_type: Optional[torch.Tensor] = None 46 | ) -> torch.Tensor: 47 | """Takes a tensor of shape `[B, T]` and an optional `token_type` of same shape 48 | 49 | :param x: A tensor of word one-hots, shape `[B, T]` 50 | :param token_type: Ignored for BART! 51 | :return: The sum of the positional and word embeddings 52 | """ 53 | embed = self.word_embeddings(x) 54 | 55 | position = self.position_embeddings( 56 | torch.arange(x.shape[-1], dtype=x.dtype).to(x.device) 57 | + BartLearnedPositionalEmbedding.BART_POS_OFFSET 58 | ).unsqueeze(0) 59 | 60 | return embed + position 61 | 62 | @property 63 | def weight(self): 64 | """Access word_embeddings weights 65 | 66 | :return: The word_embeddings weights 67 | """ 68 | return self.word_embeddings.weight 69 | 70 | 71 | class BartEncoderDecoder(TransformerEncoderDecoder): 72 | def __init__( 73 | self, 74 | vocab_size: int, 75 | padding_idx: int = 1, 76 | hidden_size: int = 768, 77 | num_heads: int = 12, 78 | num_encoder_layers: int = 6, 79 | num_decoder_layers: int = 6, 80 | dropout: float = 0.1, 81 | layer_norm_eps=1e-12, 82 | activation: nn.Module = nn.GELU(), 83 | feed_forward_size: Optional[int] = None, 84 | max_seq_len: int = 1024, 85 | ): 86 | super().__init__( 87 | BartLearnedPositionalEmbedding, 88 | vocab_size, 89 | padding_idx, 90 | hidden_size, 91 | num_heads, 92 | num_encoder_layers, 93 | num_decoder_layers, 94 | dropout, 95 | layer_norm_eps, 96 | activation, 97 | feed_forward_size, 98 | max_seq_len, 99 | ) 100 | 101 | 102 | class BartPooledEncoderDecoder(TransformerEncoderDecoder): 103 | EOS_TOKEN = 2 104 | 105 | def __init__( 106 | self, 107 | vocab_size: int, 108 | padding_idx: int = 1, 109 | hidden_size: int = 768, 110 | num_heads: int = 12, 111 | num_encoder_layers: int = 6, 112 | num_decoder_layers: int = 6, 113 | dropout: float = 0.1, 114 | layer_norm_eps=1e-12, 115 | activation: nn.Module = nn.GELU(), 116 | feed_forward_size: Optional[int] = None, 117 | output: Optional[nn.Module] = None, 118 | max_seq_len: int = 1024, 119 | **kwargs, 120 | ): 121 | super().__init__( 122 | BartLearnedPositionalEmbedding, 123 | vocab_size, 124 | padding_idx, 125 | hidden_size, 126 | num_heads, 127 | num_encoder_layers, 128 | num_decoder_layers, 129 | dropout, 130 | layer_norm_eps, 131 | activation, 132 | feed_forward_size, 133 | max_seq_len, 134 | ) 135 | 136 | self.output = output if output else nn.Identity() 137 | self.apply(self.init_layer_weights) 138 | 139 | def forward( 140 | self, 141 | src: torch.Tensor, 142 | dst: Optional[torch.Tensor] = None, 143 | src_mask: Optional[torch.Tensor] = None, 144 | dst_mask: Optional[torch.Tensor] = None, 145 | ) -> torch.Tensor: 146 | dst = create_dst_from_src(src) 147 | dst_enc = super().forward(src, dst, src_mask, dst_mask) 148 | 149 | eos_mask = dst.eq(BartPooledEncoderDecoder.EOS_TOKEN) 150 | eos = dst_enc[eos_mask] 151 | pooled_output = eos.view(dst_enc.shape[0], -1, dst_enc.shape[-1])[:, -1] 152 | y = self.output(pooled_output) 153 | 154 | return y 155 | 156 | 157 | class BartSequenceGenerator(TransformerSequenceGenerator): 158 | def __init__( 159 | self, 160 | vocab_size: int, 161 | padding_idx: int = 1, 162 | hidden_size: int = 768, 163 | num_heads: int = 12, 164 | num_encoder_layers: int = 6, 165 | num_decoder_layers: int = 6, 166 | dropout: float = 0.1, 167 | layer_norm_eps=1e-12, 168 | activation: nn.Module = nn.GELU(), 169 | feed_forward_size: Optional[int] = None, 170 | max_seq_len: int = 1024, 171 | **kwargs, 172 | ): 173 | super().__init__( 174 | BartLearnedPositionalEmbedding, 175 | vocab_size, 176 | padding_idx, 177 | hidden_size, 178 | num_heads, 179 | num_encoder_layers, 180 | num_decoder_layers, 181 | dropout, 182 | layer_norm_eps, 183 | activation, 184 | feed_forward_size, 185 | max_seq_len, 186 | ) 187 | self.apply(self.init_layer_weights) 188 | 189 | def create_loss(self): 190 | return nn.CrossEntropyLoss(ignore_index=1) 191 | 192 | 193 | class BartCreator: 194 | @classmethod 195 | def convert_state_dict(cls, tlm, bert_state_dict): 196 | """Convert the state dict to TFS compatible names 197 | 198 | The encoder token embeddings (AKA word_embeddings) are shared with the decoder token embeddings, and 199 | in the HF implementation, this is done via `self.shared` so all 3 items are in the original checkpoint, 200 | and we only need one of them. We have tied these together by assignment already, so loading the encoder's 201 | word embeddings updates the decoder word embeddings too 202 | 203 | Note that the positional embeddings are different for encoder and decoder, so these are not shared and both 204 | are loaded 205 | 206 | :param tlm: 207 | :param bert_state_dict: 208 | :return: 209 | """ 210 | tlm_field_names = set(k for k in tlm.state_dict().keys()) 211 | hf_field_names = bert_state_dict.keys() 212 | 213 | unused_checkpoint_fields = set(hf_field_names) 214 | remap = {} 215 | for field_name in hf_field_names: 216 | new_field_name = field_name.replace( 217 | "encoder.embed_tokens", "encoder_embeddings.word_embeddings" 218 | ) 219 | new_field_name = new_field_name.replace( 220 | "encoder.embed_positions", "encoder_embeddings.position_embeddings" 221 | ) 222 | new_field_name = new_field_name.replace( 223 | "decoder.embed_positions", "decoder_embeddings.position_embeddings" 224 | ) 225 | new_field_name = new_field_name.replace( 226 | "encoder.layernorm_embedding", "encoder_embeddings_layer_norm" 227 | ) 228 | new_field_name = new_field_name.replace( 229 | "decoder.layernorm_embedding", "decoder_embeddings_layer_norm" 230 | ) 231 | 232 | new_field_name = new_field_name.replace("self_attn", "self_attention") 233 | new_field_name = new_field_name.replace("encoder_attn", "encoder_attention") 234 | new_field_name = new_field_name.replace("k_proj", "key") 235 | new_field_name = new_field_name.replace("q_proj", "query") 236 | new_field_name = new_field_name.replace("v_proj", "value") 237 | new_field_name = new_field_name.replace("out_proj", "output") 238 | new_field_name = new_field_name.replace(".layers", "") 239 | new_field_name = new_field_name.replace( 240 | "attention.output.dense", "self_attention.output" 241 | ) 242 | new_field_name = new_field_name.replace("fc1", "ffn.0") 243 | new_field_name = new_field_name.replace("fc2", "ffn.2") 244 | new_field_name = new_field_name.replace( 245 | "final_layer_norm", "output_layer_norm" 246 | ) 247 | if new_field_name in tlm_field_names: 248 | tlm_field_names.remove(new_field_name) 249 | unused_checkpoint_fields.remove(field_name) 250 | remap[new_field_name] = bert_state_dict[field_name] 251 | 252 | tlm.load_state_dict(remap, strict=False) 253 | return tlm_field_names, unused_checkpoint_fields 254 | 255 | @classmethod 256 | def get_vocab_and_hidden_dims(cls, hf_dict: dict) -> tuple: 257 | try: 258 | embeddings_weight = hf_dict[ 259 | [k for k in hf_dict if "encoder.embed_tokens.weight" in k][0] 260 | ] 261 | except: 262 | embeddings_weight = hf_dict[ 263 | [ 264 | k 265 | for k in hf_dict 266 | if "encoder_embeddings.word_embeddings.weight" in k 267 | ][0] 268 | ] 269 | return embeddings_weight.shape 270 | 271 | @classmethod 272 | def from_pretrained(cls, checkpoint_file_or_dir: str, map_location=None, **kwargs): 273 | if os.path.isdir(checkpoint_file_or_dir): 274 | checkpoint = os.path.join(checkpoint_file_or_dir, "pytorch_model.bin") 275 | else: 276 | checkpoint = checkpoint_file_or_dir 277 | hf_dict = torch.load(checkpoint, map_location=map_location) 278 | vocab_size, hidden_size = BartCreator.get_vocab_and_hidden_dims(hf_dict) 279 | seq2seq = BartSequenceGenerator(vocab_size, **kwargs) 280 | missing, unused = BartCreator.convert_state_dict(seq2seq, hf_dict) 281 | logging.info(f"Unset params: {missing}") 282 | logging.info(f"Unused checkpoint fields: {unused}") 283 | return seq2seq 284 | 285 | @classmethod 286 | def pooled_from_pretrained( 287 | cls, checkpoint_file_or_dir: str, map_location=None, **kwargs 288 | ): 289 | if os.path.isdir(checkpoint_file_or_dir): 290 | checkpoint = os.path.join(checkpoint_file_or_dir, "pytorch_model.bin") 291 | else: 292 | checkpoint = checkpoint_file_or_dir 293 | hf_dict = torch.load(checkpoint, map_location=map_location) 294 | vocab_size, hidden_size = BartCreator.get_vocab_and_hidden_dims(hf_dict) 295 | seq2seq = BartPooledEncoderDecoder(vocab_size, **kwargs) 296 | missing, unused = BartCreator.convert_state_dict(seq2seq, hf_dict) 297 | logging.info(f"Unset params: {missing}") 298 | logging.info(f"Unused checkpoint fields: {unused}") 299 | return seq2seq 300 | 301 | 302 | def sentence_permute(inputs, labels, vocab): 303 | """A document is divided into sentences which are shuffled in random order. 304 | 305 | This is used in the final model in the paper. Our version of this is going to be an approximation where 306 | we demarcate sentences with common punctuation marks 307 | 308 | :param inputs: The inputs to the encoder 309 | :param labels: The outputs of the decoder. This starts as a copy of inputs, and each operation transforms it 310 | :param vocab: A dictionary of strings to integers 311 | :return: The transformed labels 312 | """ 313 | pad_value = vocab.get("") 314 | end_values = [vocab.get(punc) for punc in [".", "!", "?"]] 315 | mask = labels != pad_value 316 | 317 | eos_mask = np.zeros_like(mask) 318 | for eos in end_values: 319 | this_eos = inputs == eos 320 | eos_mask = eos_mask | this_eos 321 | 322 | eos_mask &= mask 323 | 324 | def _next_sentence(ids): 325 | end_positions = [0] + np.where(eos_mask)[0].tolist() + [len(ids)] 326 | for i, (begin, end) in enumerate(zip(end_positions[:-1], end_positions[1:])): 327 | yield ids[begin:end] 328 | 329 | reordered = [sentence for sentence in _next_sentence(inputs[1:-1])] 330 | random.shuffle(reordered) 331 | reordered = [inputs[:1]] + reordered + [inputs[-1:]] 332 | inputs_shuf = np.concatenate(reordered) 333 | return inputs_shuf, labels 334 | 335 | 336 | def token_mask(inputs, labels, vocab): 337 | """Following BERT, random tokens are sampled and replaced with token. 338 | 339 | This is not used in the final model in the paper. Unsuprisingly, text infilling is superior and easy in seq2seq 340 | 341 | :param inputs: The inputs to the encoder 342 | :param labels: The outputs of the decoder. This starts as a copy of inputs, and each operation transforms it 343 | :param vocab: A dictionary of strings to integers 344 | :return: The transformed labels 345 | """ 346 | pad_value = vocab.get("") 347 | mask_value = vocab.get("") 348 | vocab_size = len(vocab) 349 | masked_indices = np.random.binomial(size=len(inputs), n=1, p=0.15) 350 | # make sure if the input is padded we dont mask 351 | masked_indices = masked_indices & (labels != pad_value) 352 | # ensure at least one token is masked 353 | masked_indices[np.random.randint(1, sum(labels != pad_value))] = 1 354 | masked_indices[0] = 0 355 | masked_indices[-1] = 0 356 | # Anything not masked is 0 so no loss 357 | labels[masked_indices == 0] = 0 358 | # Of the masked items, mask 80% of them with [MASK] 359 | indices_replaced = np.random.binomial(size=len(inputs), n=1, p=0.8) 360 | indices_replaced = indices_replaced & masked_indices 361 | inputs[indices_replaced == 1] = mask_value 362 | indices_random = np.random.binomial(size=len(inputs), n=1, p=0.5) 363 | # Replace 10% of them with random words, rest preserved for auto-encoding 364 | indices_random = indices_random & masked_indices & ~indices_replaced 365 | # Dont predict [PAD] which is zero for bert and 1 for RoBERTa 366 | # We will assume here that PAD is one of the tokens near the beginning of the vocab 367 | random_words = np.random.randint( 368 | low=pad_value + 1, high=vocab_size - 1, size=len(inputs) 369 | ) 370 | inputs[indices_random == 1] = random_words[indices_random == 1] 371 | return inputs, labels 372 | 373 | 374 | def token_delete(inputs, labels, vocab): 375 | """Random tokens are deleted from the input and the model must decide which positions are missing input. 376 | 377 | In contrast to token masking, the model must decide which positions are missing inputs. This is not used in 378 | the final model produced in the paper 379 | 380 | :param inputs: The inputs to the encoder 381 | :param labels: The outputs of the decoder. This starts as a copy of inputs, and each operation transforms it 382 | :param vocab: A dictionary of strings to integers 383 | :return: The transformed labels 384 | """ 385 | pad_value = vocab.get("") 386 | deleted_indices = np.random.binomial(size=len(inputs), n=1, p=0.15) 387 | inputs_del = np.ones_like(inputs) 388 | 389 | # make sure if the input is padded we dont mask 390 | deleted_indices = deleted_indices & (labels != pad_value) 391 | unmutated = inputs[deleted_indices == False] 392 | inputs_del[: len(unmutated)] = unmutated 393 | return inputs_del, labels 394 | 395 | 396 | def document_rotate(inputs, labels, _): 397 | """A token is chosen uniformly at random, and the document is rotated so that it begins with that token. 398 | 399 | This task trains the model to identify the start of the document. This is not used in the final model produced in 400 | the paper. The function assumes that no padding is required in the encoder. This should be true during pretraining 401 | but is not necessarily correct in fine-tuning (e.g. it would not necessarily be true for NMT) 402 | 403 | :param inputs: The inputs to the encoder 404 | :param labels: The outputs of the decoder. This starts as a copy of inputs, and each operation transforms it 405 | :param vocab: A dictionary of strings to integers 406 | :return: The transformed labels 407 | """ 408 | 409 | # Leave first token on the front 410 | start_token = np.random.choice(np.arange(len(inputs) - 1)) 411 | inputs_rot = np.array( 412 | [inputs[0]] + np.roll(inputs[1:-1], -start_token).tolist() + [inputs[-1]] 413 | ) 414 | return inputs_rot, labels 415 | 416 | 417 | def text_infill(inputs, labels, vocab): 418 | """N-grams are sampled and replaced by a single token. 419 | 420 | A number of text spans are sampled, with span lengths drawn from a Poisson distribution 421 | (λ = 3). Each span is replaced with a single token. 0-length spans correspond to the insertion of 422 | tokens. Text infilling is inspired by SpanBERT (Joshi et al., 2019), but SpanBERT samples 423 | span lengths from a different (clamped geometric) distribution, and replaces each span with a sequence of 424 | tokens of exactly the same length. Text infilling teaches the model to predict how many tokens are 425 | missing from a span. 426 | 427 | :param inputs: The inputs to the encoder 428 | :param labels: The outputs of the decoder. This starts as a copy of inputs, and each operation transforms it 429 | :param vocab: A dictionary of strings to integers 430 | :return: The transformed labels 431 | """ 432 | pad_value = vocab.get("") 433 | mask_token = vocab.get("") 434 | start_value = vocab.get("") 435 | eos_value = vocab.get("") 436 | span_lengths = np.random.poisson(3, len(inputs)) 437 | masked_indices = np.random.binomial(size=len(inputs), n=1, p=0.3) 438 | # make sure if the input is padded we dont mask 439 | masked_indices = ( 440 | masked_indices 441 | & (labels != pad_value) 442 | & (labels != start_value) 443 | & (labels != eos_value) 444 | ) 445 | last = 0 446 | masked = [] 447 | 448 | for start in masked_indices.nonzero()[0]: 449 | if start <= last: 450 | continue 451 | span_end = start + span_lengths[start] 452 | if span_end >= len(inputs) - 1: 453 | break 454 | masked += inputs[last:start].tolist() + [mask_token] 455 | last = start + span_lengths[start] 456 | if last < len(inputs): 457 | masked += inputs[last:].tolist() 458 | 459 | num_masked = len(labels) - len(masked) 460 | if num_masked > 0: 461 | masked += [pad_value] * num_masked 462 | return np.array(masked), labels 463 | 464 | 465 | def noise_inputs(inputs, vocab, ops=[sentence_permute, text_infill]): 466 | """we use a combination of text infilling and masking 30% of the tokens in each doc and permuting all sentences 467 | 468 | We mask 30% of tokens in each document, and permute all sentences. The Y is an exact match of the X before 469 | sentence permutation and N-gram masking (AKA text-infilling). The first token of X should be 0 and the last 470 | non-padded token of X should be 2. Padding is 1 for BART (like RoBERTa). 471 | 472 | The label copy is going to be one element longer, so that it can be used for the Y values as labels[:, 1:] 473 | and for the teacher forcing as labels[:, :-1]. At the end of evaluation for an unperturbed sequence, the 474 | output targets would match X. 475 | 476 | :param inputs: An array of one-hot integers 477 | :param vocab: A dictionary of strings to integers 478 | :param ops: A list of noising operations to complete (sequentially) 479 | :return: the corrupted inputs and the truth labels for those inputs 480 | """ 481 | decoder_demarc = vocab.get("") 482 | labels = np.copy(inputs) 483 | for op in ops: 484 | inputs, labels = op(inputs, labels, vocab) 485 | labels = np.concatenate((np.array([decoder_demarc], dtype=labels.dtype), labels)) 486 | return inputs, labels 487 | 488 | 489 | class NoisingCollator: 490 | """For each item in a batch, noise it and return noised and denoised tensors""" 491 | 492 | def __init__(self, vocab): 493 | super().__init__() 494 | self.vocab = vocab 495 | 496 | def __call__(self, batch): 497 | """Take a batch of inputs of X, and convert it to a noised X, Y""" 498 | noised = [] 499 | denoised = [] 500 | for x in batch: 501 | x_noise, x_recon = noise_inputs(x[0].numpy(), self.vocab) 502 | noised.append(torch.from_numpy(x_noise)) 503 | denoised.append(torch.from_numpy(x_recon)) 504 | 505 | noised = torch.stack(noised) 506 | denoised = torch.stack(denoised) 507 | return noised, denoised 508 | -------------------------------------------------------------------------------- /src/mint/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional 4 | import math 5 | import logging 6 | 7 | logger = logging.getLogger("mint") 8 | 9 | 10 | class WeightTiedVocabProjection(nn.Module): 11 | """Projection layer tied to the input embeddings 12 | 13 | This is equivalent to an nn.Linear(hidden_size, vocab_size, bias=False) where the weights come from the 14 | input word embeddings. The embeddings are passed in, and we use their weights for our forward function. 15 | """ 16 | 17 | def __init__(self, from_module: nn.Module, pre_scale=1.0): 18 | """This uses another module (usually an `nn.Embedding`) to implement its forward function 19 | 20 | :param from_module: Typically an `nn.Embedding` whose weights we use to implement our linear projection 21 | """ 22 | super().__init__() 23 | self.from_module = from_module 24 | self.pre_scale = pre_scale 25 | 26 | @property 27 | def weight(self): 28 | return self.from_module.weight 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | """Project a dense hidden vector to the vocab space 32 | 33 | :param x: A dense hidden vector 34 | :return: The vocab space output 35 | """ 36 | return nn.functional.linear(x * self.pre_scale, self.weight) 37 | 38 | 39 | class MultiHeadedAttention(nn.Module): 40 | """Multi-headed attention implementation using scaled dot product 41 | 42 | Converts the input tensor to 3 low-order projections, query, key and value and performs 43 | multi-headed scaled dot-product attention on them following the Vaswani paper. The result 44 | is re-projected a single output representation 45 | 46 | """ 47 | 48 | def __init__(self, hidden_size: int, num_heads: int): 49 | """Each block has the same hidden unit size (`d_model` in the paper). Must be a multiple of num heads 50 | 51 | :param hidden_size: The number of units (both input and output) of the MHA block 52 | :param num_heads: The number of heads to split into 53 | """ 54 | super().__init__() 55 | 56 | d_k = hidden_size // num_heads 57 | self.query = nn.Linear(hidden_size, num_heads * d_k) 58 | self.key = nn.Linear(hidden_size, num_heads * d_k) 59 | self.value = nn.Linear(hidden_size, num_heads * d_k) 60 | self.output = nn.Linear(num_heads * d_k, hidden_size) 61 | self.num_heads = num_heads 62 | self.d_k = d_k 63 | self.scale = 1 / math.sqrt(d_k) 64 | 65 | def forward( 66 | self, x: torch.Tensor, mask: Optional[torch.Tensor] = None 67 | ) -> torch.Tensor: 68 | """ 69 | 70 | :param x: A `[B, T, C]` tensor where B is batch, T is time, C is hidden size 71 | :param mask: An optional mask to apply to the attention matrix 72 | :return: The attended value vector projected into the output space 73 | """ 74 | B, T, _ = x.shape 75 | query_vec = self.query(x).view(B, T, self.num_heads, -1).transpose(1, 2) 76 | key_vec = self.key(x).view(B, T, self.num_heads, -1).transpose(1, 2) 77 | value_vec = self.value(x).view(B, T, self.num_heads, -1).transpose(1, 2) 78 | 79 | # [B, H, T_q, D] x [B, H, D, T_k] = [B, H, T_q, T_k] 80 | dot_prod = (query_vec @ key_vec.transpose(-1, -2)) * self.scale 81 | 82 | if mask is not None: 83 | dot_prod = dot_prod.masked_fill(mask == False, -1e9) 84 | 85 | attn = nn.functional.softmax(dot_prod, dim=-1) 86 | pre_output = attn @ value_vec 87 | 88 | pre_output = pre_output.transpose(1, 2).contiguous() 89 | output = self.output(pre_output.view(B, T, -1)) 90 | return output 91 | 92 | 93 | class MultiHeadedEncoderDecoderAttention(nn.Module): 94 | """Multi-headed encoder-decoder attention implementation using scaled dot product 95 | 96 | Converts the input tensors to 3 low-order projections, query, key and value and performs 97 | multi-headed scaled dot-product attention on them following the Vaswani paper. The result 98 | is re-projected a single output representation 99 | 100 | """ 101 | 102 | def __init__(self, hidden_size: int, num_heads: int): 103 | """Each block has the same hidden unit size (`d_model` in the paper). Must be a multiple of num heads 104 | 105 | :param hidden_size: The number of units (both input and output) of the MHA block 106 | :param num_heads: The number of heads to split into 107 | """ 108 | super().__init__() 109 | 110 | d_k = hidden_size // num_heads 111 | self.query = nn.Linear(hidden_size, num_heads * d_k) 112 | self.key = nn.Linear(hidden_size, num_heads * d_k) 113 | self.value = nn.Linear(hidden_size, num_heads * d_k) 114 | self.output = nn.Linear(num_heads * d_k, hidden_size) 115 | self.num_heads = num_heads 116 | self.d_k = d_k 117 | self.scale = 1 / math.sqrt(d_k) 118 | 119 | def forward( 120 | self, src: torch.Tensor, dst: torch.Tensor, mask: Optional[torch.Tensor] = None 121 | ) -> torch.Tensor: 122 | """ 123 | 124 | :param src: A `[B, T_k, C]` tensor where B is batch, T_k is time, C is hidden size 125 | :param dst: A `[B, T_q, C]` tensor where B is batch, T_q is time, C is hidden size 126 | :param mask: An optional mask to apply to the src (keys) tensor 127 | :return: The attended value vector projected into the output space 128 | """ 129 | B, T_k, _ = src.shape 130 | T_q = dst.shape[1] 131 | query_vec = self.query(dst).view(B, T_q, self.num_heads, -1).transpose(1, 2) 132 | key_vec = self.key(src).view(B, T_k, self.num_heads, -1).transpose(1, 2) 133 | value_vec = self.value(src).view(B, T_k, self.num_heads, -1).transpose(1, 2) 134 | 135 | # [B, H, T_q, D] x [B, H, D, T_k] = [B, H, T_q, T_k] 136 | dot_prod = (query_vec @ key_vec.transpose(-1, -2)) * self.scale 137 | 138 | if mask is not None: 139 | dot_prod = dot_prod.masked_fill(mask == False, -1e9) 140 | 141 | attn = nn.functional.softmax(dot_prod, dim=-1) 142 | pre_output = attn @ value_vec 143 | 144 | pre_output = pre_output.transpose(1, 2).contiguous() 145 | output = self.output(pre_output.view(B, T_q, -1)) 146 | return output 147 | 148 | 149 | def create_feed_forward_layer( 150 | hidden_size: int, 151 | feed_forward_size: Optional[int] = None, 152 | activation: nn.Module = nn.GELU(), 153 | ): 154 | """Create a feed-forward layer (called FFN in the paper) 155 | 156 | This uses nn.Sequential to string together each part (the MLP and down-projection back to the output size) 157 | 158 | :param hidden_size: The transformer block size (d_model in the paper) 159 | :param feed_forward_size: The feed-forward layer size, or 4 * hidden_size. 160 | :param activation: The activation function, defaults to `nn.GELU()` 161 | :return: An n.Sequential that wraps the whole FFN transformation block 162 | """ 163 | d_ff = feed_forward_size if feed_forward_size else 4 * hidden_size 164 | return nn.Sequential( 165 | nn.Linear(hidden_size, d_ff), activation, nn.Linear(d_ff, hidden_size) 166 | ) 167 | 168 | 169 | class DefaultLayerFactory: 170 | """Implements Transformer primitives using the basic defaults we have used so far""" 171 | 172 | _instance = None 173 | 174 | @staticmethod 175 | def get_instance(): 176 | """Access the abstract factory pattern in this way 177 | 178 | It will be created on first use 179 | """ 180 | if DefaultLayerFactory._instance is None: 181 | DefaultLayerFactory() 182 | 183 | return DefaultLayerFactory._instance 184 | 185 | def __init__(self): 186 | if DefaultLayerFactory._instance is not None: 187 | raise Exception("Singleton constructor call. Expected no definition") 188 | self.encoder_multihead_attention = MultiHeadedAttention 189 | self.decoder_multihead_attention = MultiHeadedAttention 190 | self.encoder_decoder_attention = MultiHeadedEncoderDecoderAttention 191 | self.layer_norm = nn.LayerNorm 192 | self.feed_forward = create_feed_forward_layer 193 | DefaultLayerFactory._instance = self 194 | -------------------------------------------------------------------------------- /src/mint/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import IterableDataset, TensorDataset 3 | import random 4 | import logging 5 | import glob 6 | import gzip 7 | import os 8 | import json 9 | from typing import Callable, Optional, List 10 | 11 | logger = logging.getLogger("mint") 12 | 13 | try: 14 | import bz2 15 | except: 16 | logger.warning("Could not import bzip2 decompression lib") 17 | 18 | 19 | def jsonl_parser(field: str = "x") -> Callable: 20 | def get_jsonl(line) -> torch.tensor: 21 | x = json.loads(line)[field] 22 | return x if x else None 23 | 24 | return get_jsonl 25 | 26 | 27 | def gpt2_splitter(): 28 | """This is the tokenizer applied to GPT2. Its not needed now, as Tokenizers provides this logic 29 | 30 | :return: A function to tokenize a string into tokens (prior to subword splitting) 31 | """ 32 | import regex 33 | 34 | BPE_PATTERN = regex.compile( 35 | r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" 36 | ) 37 | return lambda text: [w.strip() for w in regex.findall(BPE_PATTERN, text)] 38 | 39 | 40 | class TextFile: 41 | def __init__(self, filename): 42 | self.filename = filename 43 | self.file = None 44 | 45 | def _open(self): 46 | if self.filename.endswith(".gz"): 47 | self.file = gzip.open(self.filename, mode="rt", encoding="utf-8") 48 | elif self.filename.endswith(".bz2"): 49 | self.file = bz2.open(self.filename, mode="rt", encoding="utf-8") 50 | else: 51 | self.file = open(self.filename, encoding="utf-8") 52 | 53 | def __enter__(self): 54 | self._open() 55 | return self.file 56 | 57 | def __exit__(self, exc_type, exc_value, exc_traceback): 58 | self.file.close() 59 | 60 | 61 | class RawInfiniteDataset(IterableDataset): 62 | """Infinite dataset on shards with multiple workers and preprocessing on-the-fly 63 | 64 | This approach roughly follows the algorithm from the original BERT paper. 65 | We split the files up over the worker threads and we shuffle the files. 66 | Then we use a shuffle buffer of configurable size to further shuffle each 67 | cooked tensor 68 | 69 | """ 70 | 71 | def __init__( 72 | self, 73 | pattern: str, 74 | tokenizer, 75 | training=True, 76 | prefix="[CLS]", 77 | suffix="[SEP]", 78 | seq_len: int = 512, 79 | get_data_fn: Optional[Callable] = None, 80 | shuf_buf_len: int = 100, 81 | ): 82 | super().__init__() 83 | self.get_data_fn = get_data_fn if get_data_fn else lambda x: x if x else None 84 | self.tokenizer = tokenizer 85 | self.start_token = self.tokenizer.token_to_id(prefix) 86 | self.end_token = self.tokenizer.token_to_id(suffix) 87 | self.pattern = ( 88 | pattern if not os.path.isdir(pattern) else os.path.join(pattern, "*") 89 | ) 90 | 91 | self.samples = 0 92 | self.rank = 0 93 | self.world_size = 1 94 | self.training = training 95 | self.seq_len = seq_len 96 | self.shuffle_buffer_len = shuf_buf_len 97 | 98 | if torch.distributed.is_initialized() and self.training: 99 | self.rank = torch.distributed.get_rank() 100 | self.world_size = torch.distributed.get_world_size() 101 | 102 | def _get_worker_info(self): 103 | return torch.utils.data.get_worker_info() if self.training else None 104 | 105 | def _init_read_order(self): 106 | # Each node has the same worker_info, so the unique offsets for each is 107 | # rank * num_workers + worker_id 108 | # and the total available workers is world_size * num_workers 109 | worker_info = self._get_worker_info() 110 | logger.debug("Globbing %s", self.pattern) 111 | files = sorted(list(glob.glob(self.pattern))) 112 | logger.debug("Found %d files", len(files)) 113 | if worker_info is None: 114 | num_workers_per_node = 1 115 | node_worker_id = 0 116 | else: 117 | num_workers_per_node = worker_info.num_workers 118 | node_worker_id = worker_info.id 119 | all_workers = self.world_size * num_workers_per_node 120 | offset = self.rank * num_workers_per_node + node_worker_id 121 | read_file_order = list(range(offset, len(files), all_workers)) 122 | if not read_file_order: 123 | if offset > 0: 124 | # This means the user didnt create more shards than workers 125 | logger.warning( 126 | f"There are no files to read for worker {node_worker_id}, offset {offset}!" 127 | + " This might mean that you are passing an incorrect training or validation directory" 128 | ) 129 | else: 130 | raise Exception(f"No files of pattern {self.pattern} were found!") 131 | return files, read_file_order, node_worker_id 132 | 133 | def __iter__(self): 134 | files, read_file_order, _ = self._init_read_order() 135 | # If we have multiple files per worker, possibly shuffle the file read order 136 | shuffle_buffer = [] 137 | tokens = [] 138 | while True: 139 | if self.training: 140 | random.shuffle(read_file_order) 141 | for file_idx in read_file_order: 142 | file = files[file_idx] 143 | with TextFile(file) as rf: 144 | for line in rf: 145 | line = self.get_data_fn(line.strip()) 146 | if line: 147 | line = self.tokenizer.encode(line, add_special_tokens=False) 148 | tokens += line.ids 149 | if len(tokens) >= (self.seq_len - 2): 150 | tensor = torch.tensor( 151 | [self.start_token] 152 | + tokens[: self.seq_len - 2] 153 | + [self.end_token] 154 | ) 155 | tokens = tokens[self.seq_len - 2 :] 156 | shuffle_buffer.append(tensor) 157 | if len(shuffle_buffer) == self.shuffle_buffer_len: 158 | if self.training: 159 | random.shuffle(shuffle_buffer) 160 | # Drain the shuffle buffer 161 | for element in shuffle_buffer: 162 | yield (element,) 163 | shuffle_buffer = [] 164 | 165 | 166 | class InfinitePreprocessedDataset(IterableDataset): 167 | """Infinite dataset on shards with multiple workers and preprocessing on-the-fly""" 168 | 169 | def __init__( 170 | self, pattern: str, training=True, get_data_fn=None, shuf_buf_len: int = 100 171 | ): 172 | super().__init__() 173 | self.pattern = pattern 174 | self.samples = 0 175 | self.rank = 0 176 | self.world_size = 1 177 | self.training = training 178 | self.shuf_buf_len = shuf_buf_len 179 | self.get_data_fn = get_data_fn if get_data_fn else lambda x: x if x else None 180 | if torch.distributed.is_initialized(): 181 | self.rank = torch.distributed.get_rank() 182 | self.world_size = torch.distributed.get_world_size() 183 | 184 | def _get_worker_info(self): 185 | return torch.utils.data.get_worker_info() if self.training else None 186 | 187 | def _init_read_order(self): 188 | # Each node has the same worker_info, so the unique offsets for each is 189 | # rank * num_workers + worker_id 190 | # and the total available workers is world_size * num_workers 191 | worker_info = self._get_worker_info() 192 | files = sorted(list(glob.glob(f"{self.directory}/{self.pattern}"))) 193 | 194 | if worker_info is None: 195 | num_workers_per_node = 1 196 | node_worker_id = 0 197 | else: 198 | num_workers_per_node = worker_info.num_workers 199 | node_worker_id = worker_info.id 200 | all_workers = self.world_size * num_workers_per_node 201 | offset = self.rank * num_workers_per_node + node_worker_id 202 | read_file_order = list(range(offset, len(files), all_workers)) 203 | if not read_file_order: 204 | if offset > 0: 205 | # This means the user didnt create more shards than workers 206 | logger.warning( 207 | f"There are no files to read for worker {node_worker_id}, offset {offset}!" 208 | + " This might mean that you are passing an incorrect training or validation directory" 209 | ) 210 | else: 211 | raise Exception( 212 | f"No files of pattern {self.pattern} were found in {self.directory}!" 213 | ) 214 | return files, read_file_order, node_worker_id 215 | 216 | def __iter__(self): 217 | files, read_file_order, _ = self._init_read_order() 218 | shuffle_buffer = [] 219 | while True: 220 | if self.training: 221 | random.shuffle(read_file_order) 222 | for file_idx in read_file_order: 223 | file = files[file_idx] 224 | with open(file) as rf: 225 | lines = rf.readlines() 226 | for sample in lines: 227 | sample = self.get_data(sample.strip()) 228 | if sample: 229 | shuffle_buffer.append(sample) 230 | if len(shuffle_buffer) == self.shuffle_buffer_len: 231 | if self.training: 232 | random.shuffle(shuffle_buffer) 233 | # Drain the shuffle buffer 234 | for element in shuffle_buffer: 235 | yield (element,) 236 | shuffle_buffer = [] 237 | 238 | 239 | def read_cls_dataset( 240 | file: str, 241 | tokenizer, 242 | pad_index=0, 243 | get_data_fn: Optional[Callable] = None, 244 | max_seq_len=512, 245 | label_list: Optional[List[str]] = None, 246 | ) -> TensorDataset: 247 | def read_space_delim_line(line: str): 248 | toks = line.split() 249 | label = toks[0] 250 | tokens = " ".join(toks[1:]) 251 | return label, tokens 252 | 253 | if get_data_fn is None: 254 | get_data_fn = read_space_delim_line 255 | 256 | label2index = {} if not label_list else {k: i for i, k in enumerate(label_list)} 257 | label_offset = len(label2index) 258 | x_tensor = [] 259 | y_tensor = [] 260 | with TextFile(file) as rf: 261 | for line in rf: 262 | label, example_str = get_data_fn(line.strip()) 263 | if label not in label2index: 264 | label2index[label] = label_offset 265 | label_offset += 1 266 | tokens = torch.tensor(tokenizer.encode(example_str).ids) 267 | padded = torch.full((max_seq_len,), pad_index, dtype=tokens.dtype) 268 | padded[: len(tokens)] = tokens 269 | x_tensor.append(padded) 270 | y_tensor.append(label2index[label]) 271 | x_tensor = torch.stack(x_tensor) 272 | y_tensor = torch.tensor(y_tensor, dtype=torch.long) 273 | label_list = [0] * label_offset 274 | for label, idx in label2index.items(): 275 | label_list[idx] = label 276 | return TensorDataset(x_tensor, y_tensor), label_list 277 | -------------------------------------------------------------------------------- /src/mint/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dpressel/mint/c4f046ccd620e28ff0392b166b50ae1540e7f758/src/mint/examples/__init__.py -------------------------------------------------------------------------------- /src/mint/examples/bart_completer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import os 4 | import torch 5 | from prompt_toolkit import prompt 6 | from prompt_toolkit.history import FileHistory 7 | from tokenizers import Tokenizer 8 | from bart import BartCreator 9 | 10 | logger = logging.getLogger(__file__) 11 | DECODER_START_TOKEN = 2 12 | """An example program where you can provide your BART model with a priming sequence and have it complete 13 | 14 | """ 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description="An interactive shell with BART") 19 | parser.add_argument("--model", type=str, required=True, help="Start from a model") 20 | parser.add_argument( 21 | "--tok_file", type=str, required=True, help="Path to tokenizer.json file" 22 | ) 23 | parser.add_argument( 24 | "--query", 25 | type=str, 26 | help="Optional query. If you pass this we wont use the repl", 27 | ) 28 | parser.add_argument("--history_file", type=str, default=".bart_history") 29 | parser.add_argument("--max_len", type=int, default=50) 30 | parser.add_argument("--sample", action="store_true") 31 | parser.add_argument("--temperature", default=1.0, type=float) 32 | parser.add_argument( 33 | "--device", 34 | type=str, 35 | default="cuda" if torch.cuda.is_available() else "cpu", 36 | help="Device (cuda or cpu)", 37 | ) 38 | args = parser.parse_args() 39 | logging.basicConfig(level=logging.INFO) 40 | if os.path.isdir(args.tok_file): 41 | args.tok_file = os.path.join(args.tok_file, "tokenizer.json") 42 | tokenizer = Tokenizer.from_file(args.tok_file) 43 | 44 | model = BartCreator.from_pretrained(args.model).eval() 45 | model.to(args.device) 46 | 47 | def complete(query, sampling, temperature): 48 | logger.info("Query: %s", query) 49 | tokenized_input = tokenizer.encode(query) 50 | logger.info("Input Sequence: %s", " ".join(tokenized_input.tokens)) 51 | input_ids = torch.tensor(tokenized_input.ids, device=args.device).unsqueeze(0) 52 | 53 | input_enc = model.encode(input_ids) 54 | outputs = [DECODER_START_TOKEN] 55 | with torch.no_grad(): 56 | 57 | for i in range(args.max_len): 58 | 59 | decode_ids = torch.tensor(outputs, device=args.device) 60 | # signature is encoder, decoder (up till now), encoder_mask, decoder_mask 61 | response = model.decode(input_enc, decode_ids.unsqueeze(0)).squeeze(0) 62 | response = response[len(decode_ids) - 1] 63 | if sampling: 64 | sample_dist = torch.softmax(response / temperature, -1) 65 | output = torch.multinomial(sample_dist, num_samples=1) 66 | response = output.squeeze().item() 67 | else: 68 | response = response.argmax(-1).item() 69 | 70 | outputs.append(response) 71 | outputs = tokenizer.decode(outputs[2:]) 72 | return outputs 73 | 74 | if args.query: 75 | print(complete(args.query, args.sample, args.temperature)) 76 | return 77 | 78 | prompt_name = "BART>> " 79 | history = FileHistory(args.history_file) 80 | while True: 81 | query = prompt(prompt_name, history=history) 82 | query = query.strip() 83 | if query == ":quit" or query == "quit": 84 | break 85 | if query == ":sample": 86 | args.sample = True 87 | print("Turn sampling mode on") 88 | continue 89 | if query == ":max": 90 | args.sample = False 91 | print("Turn sampling mode off") 92 | continue 93 | print(complete(query, args.sample)) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /src/mint/examples/bert_completer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import torch 4 | from prompt_toolkit import prompt 5 | from prompt_toolkit.history import FileHistory 6 | from mint.bert import BertCreator 7 | from tokenizers import BertWordPieceTokenizer 8 | 9 | logger = logging.getLogger(__file__) 10 | 11 | """An example program where you can provide your BERT model with masked tokens and have it unmask them 12 | """ 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description="An interactive shell with BERT") 17 | parser.add_argument("--model", type=str, required=True, help="Start from a model") 18 | parser.add_argument( 19 | "--vocab_file", type=str, required=True, help="Path to vocab file" 20 | ) 21 | parser.add_argument( 22 | "--query", 23 | type=str, 24 | help="Optional query. If you pass this we wont use the repl", 25 | ) 26 | parser.add_argument("--lowercase", action="store_true", help="Vocab is lower case") 27 | parser.add_argument("--history_file", type=str, default=".bert_history") 28 | parser.add_argument("--sample", action="store_true") 29 | parser.add_argument( 30 | "--device", 31 | type=str, 32 | default="cuda" if torch.cuda.is_available() else "cpu", 33 | help="Device (cuda or cpu)", 34 | ) 35 | args = parser.parse_args() 36 | logging.basicConfig(level=logging.INFO) 37 | tokenizer = BertWordPieceTokenizer(args.vocab_file, lowercase=args.lowercase) 38 | model = BertCreator.mlm_from_pretrained(args.model).eval() 39 | model.to(args.device) 40 | 41 | def complete(query, sampling): 42 | with torch.no_grad(): 43 | tokenized_input = tokenizer.encode(query) 44 | masked_offsets = [ 45 | i for i, t in enumerate(tokenized_input.tokens) if t == "[MASK]" 46 | ] 47 | tokens = tokenized_input.tokens 48 | logger.debug("Masked: %s", " ".join(tokens)) 49 | ids = torch.tensor(tokenized_input.ids, device=args.device).unsqueeze(0) 50 | response = model(ids).squeeze(0) 51 | if sampling: 52 | sample_dist = torch.softmax(response, -1) 53 | output = torch.multinomial(sample_dist, num_samples=1) 54 | response = output.squeeze().tolist() 55 | else: 56 | response = response.argmax(-1).tolist() 57 | for off in masked_offsets: 58 | tokens[off] = tokenizer.id_to_token(response[off]) 59 | return " ".join(tokens[1:-1]).replace(" ##", "") 60 | 61 | if args.query: 62 | print(complete(args.query, args.sample)) 63 | return 64 | 65 | prompt_name = "BERT>> " 66 | history = FileHistory(args.history_file) 67 | while True: 68 | query = prompt(prompt_name, history=history) 69 | query = query.strip() 70 | if query == ":quit" or query == "quit": 71 | break 72 | if query == ":sample": 73 | args.sample = True 74 | print("Turn sampling mode on") 75 | continue 76 | if query == ":max": 77 | args.sample = False 78 | print("Turn sampling mode off") 79 | continue 80 | print(complete(query, args.sample)) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /src/mint/examples/bert_searcher.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import torch 4 | from prompt_toolkit import prompt 5 | from prompt_toolkit.history import FileHistory 6 | from mint.bert import BertCreator 7 | from tokenizers import BertWordPieceTokenizer 8 | import faiss 9 | from dataclasses import dataclass 10 | import numpy as np 11 | from typing import List, Optional 12 | 13 | logger = logging.getLogger(__file__) 14 | 15 | """An example program where you can search for nearby embeddings from a search index 16 | """ 17 | 18 | 19 | @dataclass 20 | class SearchEntry: 21 | embeddings: np.ndarray 22 | text: str 23 | 24 | 25 | @dataclass 26 | class Hit: 27 | """A hit represents a single search result with a score""" 28 | 29 | sim: float 30 | id: int 31 | text: str 32 | vec: Optional[np.ndarray] = None 33 | 34 | 35 | class SearchIndex: 36 | def __init__(self, basename): 37 | """Reload an existing index""" 38 | index_filename = f"{basename}.index" 39 | vec_file = f"{basename}.npz" 40 | self.index = faiss.read_index(index_filename) 41 | arr = np.load(vec_file) 42 | self.data = SearchEntry(arr["embeds"], arr["texts"]) 43 | 44 | def search( 45 | self, embedding: np.ndarray, num_results: int = 5, return_vec=False 46 | ) -> List[Hit]: 47 | sim, I = self.index.search(embedding, num_results) 48 | sim = sim.reshape(-1) 49 | I = I.reshape(-1) 50 | hits = [] 51 | 52 | for sim, id in zip(sim, I): 53 | vec = self.data.embeddings[id] 54 | text = self.data.text[id] 55 | hit = Hit(id=id, text=text, sim=sim, vec=vec if return_vec else None) 56 | hits.append(hit) 57 | return hits 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser( 62 | description="An interactive search shell with BERT" 63 | ) 64 | parser.add_argument( 65 | "--query", 66 | type=str, 67 | help="Optional query. If you pass this we wont use the repl", 68 | ) 69 | parser.add_argument( 70 | "--k", type=int, default=3, help="How many K for nearest neighbor results?" 71 | ) 72 | parser.add_argument("--history_file", type=str, default=".embed_history") 73 | parser.add_argument( 74 | "--index", 75 | help="Index name. This is used as the base name for an NPZ and a FAISS index file", 76 | default="search", 77 | ) 78 | parser.add_argument("--model", help="A model path or checkpoint", required=True) 79 | parser.add_argument( 80 | "--hidden_size", 81 | type=int, 82 | default=768, 83 | help="Model dimension (and embedding dsz)", 84 | ) 85 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 86 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 87 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 88 | parser.add_argument( 89 | "--max_seq_len", 90 | type=int, 91 | default=512, 92 | help="Max sequence length for our embeddings", 93 | ) 94 | parser.add_argument("--batch_size", type=int, default=32, help="Batch Size") 95 | parser.add_argument("--pool_type", type=str, default="mean") 96 | parser.add_argument( 97 | "--vocab_file", type=str, help="The WordPiece model file", required=True 98 | ) 99 | parser.add_argument("--lowercase", action="store_true", help="Vocab is lower case") 100 | parser.add_argument( 101 | "--file_type", type=str, choices=["tsv", "txt", "jsonl"], default="txt" 102 | ) 103 | parser.add_argument("--column", type=str, default="0") 104 | parser.add_argument( 105 | "--has_header", 106 | action="store_true", 107 | help="The file has a header line that must be skipped (or used)", 108 | ) 109 | parser.add_argument( 110 | "--device", 111 | type=str, 112 | default="cuda" if torch.cuda.is_available() else "cpu", 113 | help="Device (cuda or cpu)", 114 | ) 115 | args = parser.parse_args() 116 | logging.basicConfig(level=logging.INFO) 117 | 118 | tokenizer = BertWordPieceTokenizer(args.vocab_file, lowercase=args.lowercase) 119 | embedder = ( 120 | BertCreator.pooled_enc_from_pretrained( 121 | args.model, use_mlp_layer=False, **vars(args) 122 | ) 123 | .eval() 124 | .to(args.device) 125 | ) 126 | logger.info("Loaded model") 127 | search_index = SearchIndex(args.index) 128 | logger.info("Initialized search index") 129 | 130 | def search(query, k): 131 | with torch.no_grad(): 132 | tokenized_input = tokenizer.encode(query) 133 | logger.info(tokenized_input.tokens) 134 | ids = torch.tensor(tokenized_input.ids, device=args.device).unsqueeze(0) 135 | embedding = embedder(ids).cpu().numpy() 136 | print(query) 137 | print("=" * 50) 138 | hits = search_index.search(embedding, k) 139 | for hit in hits: 140 | print(hit) 141 | 142 | if args.query: 143 | search(args.query, args.k) 144 | return 145 | 146 | prompt_name = "Search>> " 147 | history = FileHistory(args.history_file) 148 | while True: 149 | query = prompt(prompt_name, history=history) 150 | query = query.strip() 151 | if query == ":quit" or query == "quit": 152 | break 153 | if query.startswith(":k"): 154 | query = query.split() 155 | if len(query) == 2: 156 | args.k = int(query[-1]) 157 | print(f"Setting k={args.k}") 158 | continue 159 | search(query, args.k) 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /src/mint/examples/build_search_index.py: -------------------------------------------------------------------------------- 1 | """Create a search index for Transformer embeddings using faiss 2 | """ 3 | import argparse 4 | import torch 5 | import logging 6 | from bert import BertCreator 7 | import faiss 8 | from typing import List 9 | from tokenizers import BertWordPieceTokenizer 10 | from mint.data import TextFile 11 | import numpy as np 12 | import json 13 | 14 | 15 | def padded_batch(seqs: List[torch.Tensor]) -> torch.Tensor: 16 | max_batch_len = max([len(seq) for seq in seqs]) 17 | pad_batch = torch.zeros((len(seqs), max_batch_len), dtype=torch.long) 18 | for i, seq in enumerate(seqs): 19 | pad_batch[i, : len(seq)] = torch.tensor(seq, dtype=torch.long) 20 | return pad_batch 21 | 22 | 23 | def read_batch(file, file_type, column, has_header, batch_size, tokenizer, max_seq_len): 24 | 25 | col2index = {} 26 | if file_type == "jsonl": 27 | read_fn = lambda x: json.loads(x)[column] 28 | elif file_type == "txt": 29 | read_fn = lambda x: x 30 | elif file_type == "tsv" and not has_header: 31 | column = int(column) 32 | read_fn = lambda x: x.split("\t")[column] 33 | else: 34 | read_fn = lambda x: x.split("\t")[col2index[column]] 35 | 36 | texts = [] 37 | seqs = [] 38 | with TextFile(file) as rf: 39 | if has_header and file_type != "jsonl": 40 | header = next(rf) 41 | header = header.split("\t") 42 | col2index.update({h: i for i, h in enumerate(header)}) 43 | 44 | for line in rf: 45 | s = read_fn(line.strip()) 46 | texts.append(s) 47 | seq = tokenizer.encode(s).ids[:max_seq_len] 48 | seqs.append(seq) 49 | if len(seqs) == batch_size: 50 | batch = padded_batch(seqs) 51 | seqs = [] 52 | batch_texts = texts 53 | texts = [] 54 | yield batch, batch_texts 55 | if seqs: 56 | yield padded_batch(seqs), texts 57 | 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser( 61 | description="Build a search index using BERT (or SentenceBERT)" 62 | ) 63 | parser.add_argument("--input", help="Input file") 64 | parser.add_argument( 65 | "--index", 66 | help="Index name. This will yield an NPZ and a FAISS index file", 67 | default="search", 68 | ) 69 | parser.add_argument("--model", help="A model path or checkpoint", required=True) 70 | parser.add_argument( 71 | "--hidden_size", 72 | type=int, 73 | default=768, 74 | help="Model dimension (and embedding dsz)", 75 | ) 76 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 77 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 78 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 79 | parser.add_argument( 80 | "--max_seq_len", 81 | type=int, 82 | default=512, 83 | help="Max sequence length for our embeddings", 84 | ) 85 | parser.add_argument("--batch_size", type=int, default=32, help="Batch Size") 86 | parser.add_argument("--pool_type", type=str, default="mean") 87 | parser.add_argument( 88 | "--vocab_file", type=str, help="The WordPiece model file", required=True 89 | ) 90 | parser.add_argument("--lowercase", action="store_true", help="Vocab is lower case") 91 | parser.add_argument( 92 | "--file_type", type=str, choices=["tsv", "txt", "jsonl"], default="txt" 93 | ) 94 | parser.add_argument("--column", type=str, default="0") 95 | parser.add_argument( 96 | "--has_header", 97 | action="store_true", 98 | help="The file has a header line that must be skipped (or used)", 99 | ) 100 | parser.add_argument( 101 | "--device", 102 | type=str, 103 | default="cuda" if torch.cuda.is_available() else "cpu", 104 | help="Device (cuda or cpu)", 105 | ) 106 | args = parser.parse_args() 107 | logging.basicConfig(level=logging.INFO) 108 | tokenizer = BertWordPieceTokenizer(args.vocab_file, lowercase=args.lowercase) 109 | embedder = ( 110 | BertCreator.pooled_enc_from_pretrained( 111 | args.model, use_mlp_layer=False, **vars(args) 112 | ) 113 | .eval() 114 | .to(args.device) 115 | ) 116 | 117 | all_embeddings = [] 118 | all_texts = [] 119 | with torch.no_grad(): 120 | for (batch, texts) in read_batch( 121 | args.input, 122 | args.file_type, 123 | args.column, 124 | args.has_header, 125 | args.batch_size, 126 | tokenizer, 127 | args.max_seq_len, 128 | ): 129 | batch = batch.to(device=args.device) 130 | embedded_batch = ( 131 | embedder(batch, embedder.create_pad_mask(batch)).cpu().numpy() 132 | ) 133 | all_embeddings.append(embedded_batch) 134 | all_texts += texts 135 | all_embeddings = np.vstack(all_embeddings) 136 | np.savez(args.index + ".npz", embeds=all_embeddings, texts=all_texts) 137 | index = faiss.IndexFlatL2(args.hidden_size) 138 | index.add(all_embeddings) 139 | faiss.write_index(index, args.index + ".index") 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /src/mint/examples/eval_gpt_lm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import torch 4 | from torch.utils.data import Dataset, TensorDataset, DataLoader 5 | from mint.gpt import GPTCreator, GPT2Creator 6 | from tokenizers import Tokenizer 7 | import os 8 | from mint.train import Average 9 | from tqdm import tqdm 10 | import time 11 | import math 12 | 13 | logger = logging.getLogger(__file__) 14 | 15 | 16 | def create_single_file_dataset(tokenizer, fname: str, seq_len: int = 1024) -> Dataset: 17 | 18 | with open(fname) as rf: 19 | full_text = rf.read() 20 | 21 | num_words = len(full_text.split()) 22 | tokens = tokenizer.encode(full_text).ids 23 | num_samples = len(tokens) // seq_len 24 | trunc = num_samples * seq_len 25 | if trunc == num_samples: 26 | tokens.append(tokens[0]) 27 | x_tensors = torch.tensor(tokens[:trunc]) 28 | y_tensors = torch.tensor(tokens[1 : trunc + 1]) 29 | num_subwords = y_tensors.nelement() 30 | return ( 31 | TensorDataset(x_tensors.view(num_samples, -1), y_tensors.view(num_samples, -1)), 32 | num_words, 33 | num_subwords, 34 | ) 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser(description="GPT perplexity on test set") 39 | parser.add_argument("--model", type=str, required=True, help="Start from a model") 40 | parser.add_argument( 41 | "--vocab_file", type=str, required=True, help="Path to vocab file" 42 | ) 43 | parser.add_argument( 44 | "--merges_file", type=str, required=True, help="Path to vocab file" 45 | ) 46 | parser.add_argument("--file", type=str, help="A test file") 47 | parser.add_argument("--version", type=int, choices=[1, 2], default=2) 48 | parser.add_argument("--batch_size", type=int, default=32) 49 | parser.add_argument( 50 | "--device", 51 | type=str, 52 | default="cuda" if torch.cuda.is_available() else "cpu", 53 | help="Device (cuda or cpu)", 54 | ) 55 | args = parser.parse_args() 56 | logging.basicConfig(level=logging.INFO) 57 | tok_model = os.path.join( 58 | args.model if os.path.isdir(args.model) else os.path.dirname(args.model), 59 | "tokenizer.json", 60 | ) 61 | print(tok_model) 62 | tokenizer = Tokenizer.from_file(tok_model) 63 | Creator = GPT2Creator if args.version == 2 else GPTCreator 64 | seq_len = 1024 if args.version == 2 else 512 65 | model = Creator.lm_from_pretrained(args.model).eval() 66 | model.to(args.device) 67 | loss_function = model.create_loss().to(args.device) 68 | eval_dataset, num_words, num_subwords = create_single_file_dataset( 69 | tokenizer, args.file, seq_len=seq_len 70 | ) 71 | logger.info( 72 | "Num samples in dataset [%d], num words [%d], num subwords [%d]", 73 | len(eval_dataset), 74 | num_words, 75 | num_subwords, 76 | ) 77 | eval_data_loader = DataLoader(eval_dataset, batch_size=args.batch_size) 78 | 79 | compute_perplexity( 80 | args.device, eval_data_loader, loss_function, model, num_subwords, num_words 81 | ) 82 | 83 | 84 | def compute_perplexity( 85 | device, eval_data_loader, loss_function, model, num_subwords, num_words=None 86 | ): 87 | start = time.time() 88 | progress = tqdm(enumerate(eval_data_loader), total=len(eval_data_loader)) 89 | avg = Average("avg_loss") 90 | with torch.no_grad(): 91 | for iters, (x, y) in progress: 92 | x = x.to(device=device) 93 | y = y.to(device=device) 94 | logits = model(x) 95 | loss = loss_function(logits.reshape(-1, model.vocab_size), y.view(-1)) 96 | avg.update(loss.item()) 97 | loss = avg.avg 98 | ppl = math.exp(loss) 99 | elapsed = time.time() - start 100 | print(f"Evaluation completed [{elapsed: .2f}s]") 101 | print( 102 | f"Subword loss & perplexity: loss {loss:.4f}. perplexity (subword): {ppl: .3f}" 103 | ) 104 | 105 | if num_words is not None: 106 | word_ppl = math.exp(loss * (num_subwords - 1) / (num_words - 1)) 107 | print(f"Word level perplexity {word_ppl: .3f}") 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /src/mint/examples/gpt_completer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import os 4 | import torch 5 | from prompt_toolkit import prompt 6 | from prompt_toolkit.history import FileHistory 7 | from mint.gpt import GPTCreator, GPT2Creator 8 | from tokenizers import Tokenizer 9 | 10 | logger = logging.getLogger(__file__) 11 | 12 | """An example program where you can provide your GPT model with a priming sequence and have it complete 13 | """ 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description="An interactive shell with GPT/GPT2") 18 | parser.add_argument("--model", type=str, required=True, help="Start from a model") 19 | parser.add_argument( 20 | "--tok_file", type=str, required=True, help="Path to tokenizer.json file" 21 | ) 22 | parser.add_argument( 23 | "--query", 24 | type=str, 25 | help="Optional query. If you pass this we wont use the repl", 26 | ) 27 | parser.add_argument("--history_file", type=str, default=".gpt_history") 28 | parser.add_argument("--max_len", type=int, default=50) 29 | parser.add_argument("--sample", action="store_true") 30 | parser.add_argument("--temperature", default=1.0, type=float) 31 | parser.add_argument("--version", type=int, choices=[1, 2], default=2) 32 | parser.add_argument( 33 | "--device", 34 | type=str, 35 | default="cuda" if torch.cuda.is_available() else "cpu", 36 | help="Device (cuda or cpu)", 37 | ) 38 | args = parser.parse_args() 39 | logging.basicConfig(level=logging.INFO) 40 | if os.path.isdir(args.tok_file): 41 | args.tok_file = os.path.join(args.tok_file, "tokenizer.json") 42 | tokenizer = Tokenizer.from_file(args.tok_file) 43 | 44 | Creator = GPT2Creator if args.version == 2 else GPTCreator 45 | model = Creator.lm_from_pretrained(args.model).eval() 46 | model.to(args.device) 47 | 48 | def complete(query, sampling, temperature): 49 | logger.info("Query: %s", query) 50 | tokenized_input = tokenizer.encode(query) 51 | logger.info("Priming Sequence: %s", " ".join(tokenized_input.tokens)) 52 | inputs = tokenized_input.ids 53 | outputs = [] 54 | with torch.no_grad(): 55 | 56 | for i in range(args.max_len): 57 | 58 | ids = torch.tensor(inputs, device=args.device) 59 | response = model(ids.unsqueeze(0)).squeeze(0) 60 | response = response[len(inputs) - 1] 61 | if sampling: 62 | sample_dist = torch.softmax(response / temperature, -1) 63 | output = torch.multinomial(sample_dist, num_samples=1) 64 | response = output.squeeze().item() 65 | else: 66 | response = response.argmax(-1).item() 67 | 68 | inputs.append(response) 69 | outputs.append(response) 70 | outputs = tokenizer.decode(outputs) 71 | return outputs 72 | 73 | if args.query: 74 | print(complete(args.query, args.sample, args.temperature)) 75 | return 76 | 77 | prompt_name = f"GPT{args.version}>> " 78 | history = FileHistory(args.history_file) 79 | while True: 80 | query = prompt(prompt_name, history=history) 81 | query = query.strip() 82 | if query == ":quit" or query == "quit": 83 | break 84 | if query == ":sample": 85 | args.sample = True 86 | print("Turn sampling mode on") 87 | continue 88 | if query == ":max": 89 | args.sample = False 90 | print("Turn sampling mode off") 91 | continue 92 | print(complete(query, args.sample, args.temperature)) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /src/mint/examples/pretrain_bart_simple.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import torch 4 | from torch.utils.data import Dataset, TensorDataset 5 | from mint.bart import BartCreator, NoisingCollator, BartSequenceGenerator 6 | from mint.train import SingleDeviceSeq2SeqTrainer 7 | from tokenizers import Tokenizer 8 | import os 9 | 10 | logger = logging.getLogger(__file__) 11 | 12 | """Pre-train a BART model in PyTorch (Simple single file version) 13 | 14 | This works for a small dataset that fits in memory. We will use the SimpleTrainer's train_epochs() 15 | function to train this. 16 | 17 | """ 18 | 19 | 20 | def create_single_file_dataset( 21 | tokenizer: Tokenizer, fname: str, seq_len: int = 1024 22 | ) -> Dataset: 23 | bos_token = tokenizer.token_to_id("") 24 | eos_token = tokenizer.token_to_id("") 25 | with open(fname) as rf: 26 | tokens = [] 27 | for line in rf: 28 | line = line.strip() 29 | if line: 30 | line = tokenizer.encode(line, add_special_tokens=False) 31 | tokens += line.ids 32 | 33 | num_toks = seq_len - 2 # Ignore CLS and SEP 34 | num_samples = len(tokens) // num_toks * num_toks 35 | tensors = [ 36 | [bos_token] + tokens[i : i + num_toks] + [eos_token] 37 | for i in range(0, num_samples, num_toks) 38 | ] 39 | tensors = torch.tensor(tensors, dtype=torch.long) 40 | return TensorDataset(tensors) 41 | 42 | 43 | def try_get_global_step(checkpoint_name) -> int: 44 | """If its a checkpoint we saved the suffix will be -step-{global_step}.pth 45 | 46 | We will assume that any checkpoint we reload has the exact same parameterization as this 47 | run. If thats not the case the learning params will be different 48 | 49 | :param checkpoint_name: Either a huggingface pretrained checkpoint or one we saved here 50 | :return: Int representing the global step 51 | """ 52 | import re 53 | 54 | match = re.match("(\\S+)-step-(\\d+).pth", checkpoint_name) 55 | global_step = 0 56 | if match: 57 | global_step = int(match[2]) 58 | return global_step 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser(description="Pretrain BART (simple)") 63 | parser.add_argument("--model_checkpoint_dir", type=str) 64 | parser.add_argument( 65 | "--train_file", type=str, required=True, help="File path to use for train file" 66 | ) 67 | parser.add_argument( 68 | "--valid_file", type=str, required=True, help="File path to use for valid file" 69 | ) 70 | parser.add_argument( 71 | "--hidden_size", 72 | type=int, 73 | default=768, 74 | help="Model dimension (and embedding dsz)", 75 | ) 76 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 77 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 78 | parser.add_argument( 79 | "--num_encoder_layers", type=int, default=6, help="Number of encoder layers" 80 | ) 81 | parser.add_argument( 82 | "--num_decoder_layers", type=int, default=6, help="Number of decoder layers" 83 | ) 84 | parser.add_argument( 85 | "--num_train_workers", type=int, default=4, help="Number train workers" 86 | ) 87 | parser.add_argument( 88 | "--num_valid_workers", type=int, default=1, help="Number train workers" 89 | ) 90 | parser.add_argument("--seq_len", type=int, default=1024, help="Max input length") 91 | parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") 92 | parser.add_argument( 93 | "--tok_file", type=str, help="The path to the GPT2 tokenizer", required=True 94 | ) 95 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 96 | parser.add_argument( 97 | "--decay_type", 98 | choices=["cosine", "linear"], 99 | help="The type of learning rate decay scheduler", 100 | ) 101 | parser.add_argument( 102 | "--alpha_decay", 103 | type=float, 104 | default=0.0, 105 | help="fraction of learning rate by end of training", 106 | ) 107 | parser.add_argument("--lr", type=float, default=1.0e-4, help="Learning rate") 108 | parser.add_argument( 109 | "--clip", type=float, default=1.0, help="Clipping gradient norm" 110 | ) 111 | parser.add_argument( 112 | "--weight_decay", type=float, default=1.0e-2, help="Weight decay" 113 | ) 114 | parser.add_argument("--epochs", type=int, default=1, help="Num training epochs") 115 | parser.add_argument( 116 | "--restart_from", 117 | type=str, 118 | help="Option allows you to restart from a previous checkpoint", 119 | ) 120 | parser.add_argument( 121 | "--warmup_fract", 122 | type=float, 123 | default=0.1, 124 | help="Fraction of steps spent warming up", 125 | ) 126 | parser.add_argument( 127 | "--plateau_fract", 128 | type=float, 129 | default=0.0, 130 | help="Fraction of steps spent holding at max lr", 131 | ) 132 | parser.add_argument( 133 | "--saves_per_epoch", 134 | type=int, 135 | default=10, 136 | help="The number of checkpoints to save per epoch", 137 | ) 138 | parser.add_argument( 139 | "--device", 140 | type=str, 141 | default="cuda" if torch.cuda.is_available() else "cpu", 142 | help="Device (cuda or cpu)", 143 | ) 144 | args = parser.parse_args() 145 | logging.basicConfig(level=logging.INFO) 146 | 147 | if args.model_checkpoint_dir is None: 148 | args.model_checkpoint_dir = f"s2s-{os.getpid()}" 149 | if not os.path.exists(args.model_checkpoint_dir): 150 | os.makedirs(args.model_checkpoint_dir) 151 | 152 | if os.path.isdir(args.tok_file): 153 | args.tok_file = os.path.join(args.tok_file, "tokenizer.json") 154 | 155 | tokenizer = Tokenizer.from_file(args.tok_file) 156 | vocab = tokenizer.get_vocab() 157 | 158 | if args.restart_from: 159 | global_step = try_get_global_step(args.restart_from) 160 | 161 | model = BartCreator.from_pretrained(args.restart_from, **vars(args)) 162 | else: 163 | global_step = 0 164 | model = BartSequenceGenerator(tokenizer.get_vocab_size(), **vars(args)) 165 | print(model) 166 | trainer = SingleDeviceSeq2SeqTrainer( 167 | model, 168 | global_step=global_step, 169 | collate_function=NoisingCollator(vocab), 170 | **vars(args), 171 | ) 172 | logger.info(trainer) 173 | train_dataset = create_single_file_dataset(tokenizer, args.train_file, args.seq_len) 174 | valid_dataset = create_single_file_dataset(tokenizer, args.valid_file, args.seq_len) 175 | 176 | trainer.train_epochs( 177 | train_dataset, 178 | valid_dataset, 179 | os.path.join(args.model_checkpoint_dir, "ckpt"), 180 | args.epochs, 181 | ) 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /src/mint/examples/pretrain_bart_wiki.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import torch 4 | from torch.utils.data import Dataset 5 | from mint.bart import BartCreator, NoisingCollator, BartSequenceGenerator 6 | from mint.train import SingleDeviceSeq2SeqTrainer, DistributedSeq2SeqTrainer 7 | from mint.data import RawInfiniteDataset 8 | from tokenizers import Tokenizer 9 | import json 10 | import os 11 | 12 | logger = logging.getLogger(__file__) 13 | 14 | 15 | """Pre-train a BART model in PyTorch on all of wikipedia via https://github.com/attardi/wikiextractor 16 | 17 | We will process the data on-the-fly from the readers, and shard over multiple workers, and we will 18 | train with the SimpleTrainer's train_steps() function 19 | 20 | """ 21 | 22 | 23 | def wikipedia_parser(): 24 | from bs4 import BeautifulSoup 25 | 26 | def get_doc(line): 27 | line = json.loads(line)["text"] 28 | text = BeautifulSoup(line, features="lxml") 29 | # This is a trivial way to replace, we will just sample other surface terms and use those 30 | for link in text.find_all("a"): 31 | surface = link.get_text() 32 | link.replace_with(surface) 33 | text = text.get_text() 34 | return text 35 | 36 | return get_doc 37 | 38 | 39 | def create_sharded_dataset( 40 | tokenizer: Tokenizer, 41 | glob_path: str, 42 | is_train, 43 | seq_len: int = 1024, 44 | start_token="", 45 | end_token="", 46 | ) -> Dataset: 47 | dataset = RawInfiniteDataset( 48 | glob_path, 49 | tokenizer, 50 | is_train, 51 | prefix=start_token, 52 | suffix=end_token, 53 | seq_len=seq_len, 54 | get_data_fn=wikipedia_parser(), 55 | ) 56 | return dataset 57 | 58 | 59 | def try_get_global_step(checkpoint_name) -> int: 60 | """If its a checkpoint we saved the suffix will be -step-{global_step}.pth 61 | 62 | We will assume that any checkpoint we reload has the exact same parameterization as this 63 | run. If thats not the case the learning params will be different 64 | 65 | :param checkpoint_name: Either a huggingface pretrained checkpoint or one we saved here 66 | :return: Int representing the global step 67 | """ 68 | import re 69 | 70 | match = re.match("(\\S+)-step-(\\d+).pth", checkpoint_name) 71 | global_step = 0 72 | if match: 73 | global_step = int(match[2]) 74 | return global_step 75 | 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser(description="Pretrain BART (wiki)") 79 | parser.add_argument("--model_checkpoint_dir", type=str) 80 | parser.add_argument( 81 | "--train_file", type=str, required=True, help="File path to use for train file" 82 | ) 83 | parser.add_argument( 84 | "--valid_file", type=str, required=True, help="File path to use for valid file" 85 | ) 86 | parser.add_argument( 87 | "--hidden_size", 88 | type=int, 89 | default=768, 90 | help="Model dimension (and embedding dsz)", 91 | ) 92 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 93 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 94 | parser.add_argument( 95 | "--num_encoder_layers", type=int, default=12, help="Number of encoder layers" 96 | ) 97 | parser.add_argument( 98 | "--num_decoder_layers", type=int, default=12, help="Number of decoder layers" 99 | ) 100 | parser.add_argument( 101 | "--num_train_workers", type=int, default=4, help="Number train workers" 102 | ) 103 | parser.add_argument( 104 | "--num_valid_workers", type=int, default=1, help="Number train workers" 105 | ) 106 | parser.add_argument("--seq_len", type=int, default=1024, help="Max input length") 107 | parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") 108 | parser.add_argument( 109 | "--tok_file", type=str, help="The Tokenizer file or model dir", required=True 110 | ) 111 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 112 | parser.add_argument( 113 | "--decay_type", 114 | choices=["cosine", "linear"], 115 | default="cosine", 116 | help="The type of learning rate decay scheduler", 117 | ) 118 | parser.add_argument( 119 | "--alpha_decay", 120 | type=float, 121 | default=0.0, 122 | help="fraction of learning rate by end of training", 123 | ) 124 | parser.add_argument("--lr", type=float, default=1.0e-4, help="Learning rate") 125 | parser.add_argument( 126 | "--clip", type=float, default=1.0, help="Clipping gradient norm" 127 | ) 128 | parser.add_argument( 129 | "--weight_decay", type=float, default=1.0e-2, help="Weight decay" 130 | ) 131 | parser.add_argument( 132 | "--num_steps", type=int, default=250_000, help="Num training steps" 133 | ) 134 | parser.add_argument( 135 | "--restart_from", 136 | type=str, 137 | help="Option allows you to restart from a previous checkpoint", 138 | ) 139 | parser.add_argument( 140 | "--warmup_fract", 141 | type=float, 142 | default=0.1, 143 | help="Fraction of steps spent warming up", 144 | ) 145 | parser.add_argument( 146 | "--plateau_fract", 147 | type=float, 148 | default=0.0, 149 | help="Fraction of steps spent holding at max lr", 150 | ) 151 | parser.add_argument( 152 | "--saves_per_cycle", 153 | type=int, 154 | default=1, 155 | help="The number of checkpoints to save per epoch", 156 | ) 157 | parser.add_argument( 158 | "--train_cycle_size", 159 | type=int, 160 | default=1000, 161 | help="The many training steps to run before eval", 162 | ) 163 | parser.add_argument( 164 | "--eval_cycle_size", 165 | type=int, 166 | default=200, 167 | help="How many steps to evaluate each time", 168 | ) 169 | parser.add_argument( 170 | "--plot_lr_plan", 171 | action="store_true", 172 | help="Shows the learning rate curve (requires matplotlib)", 173 | ) 174 | parser.add_argument( 175 | "--local_rank", 176 | type=int, 177 | default=-1, 178 | help="Local rank for distributed training (-1 means use the environment variables to find)", 179 | ) 180 | parser.add_argument( 181 | "--distributed", action="store_true", help="Are we doing distributed training?" 182 | ) 183 | 184 | parser.add_argument( 185 | "--device", 186 | type=str, 187 | default="cuda" if torch.cuda.is_available() else "cpu", 188 | help="Device (cuda or cpu)", 189 | ) 190 | args = parser.parse_args() 191 | logging.basicConfig(level=logging.INFO) 192 | 193 | if args.model_checkpoint_dir is None: 194 | args.model_checkpoint_dir = f"mlm-{os.getpid()}" 195 | if not os.path.exists(args.model_checkpoint_dir): 196 | os.makedirs(args.model_checkpoint_dir) 197 | 198 | if os.path.isdir(args.tok_file): 199 | args.tok_file = os.path.join(args.tok_file, "tokenizer.json") 200 | tokenizer = Tokenizer.from_file(args.tok_file) 201 | 202 | if args.restart_from: 203 | global_step = try_get_global_step(args.restart_from) 204 | model = BartCreator.from_pretrained(args.restart_from, **vars(args)) 205 | else: 206 | global_step = 0 207 | model = BartSequenceGenerator(tokenizer.get_vocab_size(), **vars(args)) 208 | 209 | Trainer = ( 210 | DistributedSeq2SeqTrainer if args.distributed else SingleDeviceSeq2SeqTrainer 211 | ) 212 | trainer = Trainer( 213 | model, 214 | global_step=global_step, 215 | collate_function=NoisingCollator(tokenizer.get_vocab()), 216 | **vars(args), 217 | ) 218 | logger.info(trainer) 219 | if args.plot_lr_plan: 220 | trainer.show_lr_plan(args.num_steps) 221 | 222 | train_dataset = create_sharded_dataset( 223 | tokenizer, args.train_file, True, seq_len=args.seq_len 224 | ) 225 | valid_dataset = create_sharded_dataset( 226 | tokenizer, args.valid_file, False, seq_len=args.seq_len 227 | ) 228 | 229 | trainer.train_steps( 230 | train_dataset, 231 | valid_dataset, 232 | os.path.join(args.model_checkpoint_dir, "ckpt"), 233 | args.num_steps, 234 | args.saves_per_cycle, 235 | args.train_cycle_size, 236 | args.eval_cycle_size, 237 | ) 238 | 239 | 240 | if __name__ == "__main__": 241 | main() 242 | -------------------------------------------------------------------------------- /src/mint/examples/pretrain_bert_simple.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import torch 4 | from torch.utils.data import Dataset, TensorDataset 5 | from mint.bert import BertCreator, NoisingCollator, TransformerMLM 6 | from mint.train import SingleDeviceLMTrainer 7 | from tokenizers import BertWordPieceTokenizer 8 | import os 9 | 10 | logger = logging.getLogger(__file__) 11 | 12 | """Pre-train a BERT/RoBERTa model in PyTorch (Simple single file version) 13 | 14 | This works for a small dataset that fits in memory. We will use the SimpleTrainer's train_epochs() 15 | function to train this. 16 | 17 | """ 18 | 19 | 20 | def create_single_file_dataset( 21 | tokenizer: BertWordPieceTokenizer, fname: str, seq_len: int = 512 22 | ) -> Dataset: 23 | cls_token = tokenizer.token_to_id("[CLS]") 24 | sep_token = tokenizer.token_to_id("[SEP]") 25 | with open(fname) as rf: 26 | tokens = [] 27 | for line in rf: 28 | line = line.strip() 29 | if line: 30 | line = tokenizer.encode(line, add_special_tokens=False) 31 | tokens += line.ids 32 | 33 | num_toks = seq_len - 2 # Ignore CLS and SEP 34 | num_samples = len(tokens) // num_toks * num_toks 35 | tensors = [ 36 | [cls_token] + tokens[i : i + num_toks] + [sep_token] 37 | for i in range(0, num_samples, num_toks) 38 | ] 39 | tensors = torch.tensor(tensors, dtype=torch.long) 40 | return TensorDataset(tensors) 41 | 42 | 43 | def try_get_global_step(checkpoint_name) -> int: 44 | """If its a checkpoint we saved the suffix will be -step-{global_step}.pth 45 | 46 | We will assume that any checkpoint we reload has the exact same parameterization as this 47 | run. If thats not the case the learning params will be different 48 | 49 | :param checkpoint_name: Either a huggingface pretrained checkpoint or one we saved here 50 | :return: Int representing the global step 51 | """ 52 | import re 53 | 54 | match = re.match("(\\S+)-step-(\\d+).pth", checkpoint_name) 55 | global_step = 0 56 | if match: 57 | global_step = int(match[2]) 58 | return global_step 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser(description="Pretrain BERT (simple)") 63 | parser.add_argument("--model_checkpoint_dir", type=str) 64 | parser.add_argument( 65 | "--train_file", type=str, required=True, help="File path to use for train file" 66 | ) 67 | parser.add_argument( 68 | "--valid_file", type=str, required=True, help="File path to use for valid file" 69 | ) 70 | parser.add_argument( 71 | "--hidden_size", 72 | type=int, 73 | default=768, 74 | help="Model dimension (and embedding dsz)", 75 | ) 76 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 77 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 78 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 79 | parser.add_argument( 80 | "--num_train_workers", type=int, default=4, help="Number train workers" 81 | ) 82 | parser.add_argument( 83 | "--num_valid_workers", type=int, default=1, help="Number train workers" 84 | ) 85 | parser.add_argument("--seq_len", type=int, default=512, help="Max input length") 86 | parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") 87 | parser.add_argument( 88 | "--vocab_file", type=str, help="The WordPiece model file", required=True 89 | ) 90 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 91 | parser.add_argument( 92 | "--decay_type", 93 | choices=["cosine", "linear"], 94 | help="The type of learning rate decay scheduler", 95 | ) 96 | parser.add_argument( 97 | "--alpha_decay", 98 | type=float, 99 | default=0.0, 100 | help="fraction of learning rate by end of training", 101 | ) 102 | parser.add_argument("--lr", type=float, default=1.0e-4, help="Learning rate") 103 | parser.add_argument( 104 | "--clip", type=float, default=1.0, help="Clipping gradient norm" 105 | ) 106 | parser.add_argument( 107 | "--weight_decay", type=float, default=1.0e-2, help="Weight decay" 108 | ) 109 | parser.add_argument("--epochs", type=int, default=1, help="Num training epochs") 110 | parser.add_argument( 111 | "--restart_from", 112 | type=str, 113 | help="Option allows you to restart from a previous checkpoint", 114 | ) 115 | parser.add_argument( 116 | "--warmup_fract", 117 | type=float, 118 | default=0.1, 119 | help="Fraction of steps spent warming up", 120 | ) 121 | parser.add_argument( 122 | "--plateau_fract", 123 | type=float, 124 | default=0.0, 125 | help="Fraction of steps spent holding at max lr", 126 | ) 127 | parser.add_argument( 128 | "--saves_per_epoch", 129 | type=int, 130 | default=10, 131 | help="The number of checkpoints to save per epoch", 132 | ) 133 | parser.add_argument("--lowercase", action="store_true", help="Vocab is lower case") 134 | parser.add_argument( 135 | "--device", 136 | type=str, 137 | default="cuda" if torch.cuda.is_available() else "cpu", 138 | help="Device (cuda or cpu)", 139 | ) 140 | args = parser.parse_args() 141 | logging.basicConfig(level=logging.INFO) 142 | 143 | if args.model_checkpoint_dir is None: 144 | args.model_checkpoint_dir = f"mlm-{os.getpid()}" 145 | if not os.path.exists(args.model_checkpoint_dir): 146 | os.makedirs(args.model_checkpoint_dir) 147 | 148 | tokenizer = BertWordPieceTokenizer(args.vocab_file, lowercase=args.lowercase) 149 | vocab_size = tokenizer.get_vocab_size() 150 | pad_value = tokenizer.token_to_id("[PAD]") 151 | mask_value = tokenizer.token_to_id("[MASK]") 152 | 153 | if args.restart_from: 154 | global_step = try_get_global_step(args.restart_from) 155 | 156 | model = BertCreator.mlm_from_pretrained(args.restart_from, **vars(args)) 157 | else: 158 | global_step = 0 159 | model = TransformerMLM(tokenizer.get_vocab_size(), **vars(args)) 160 | 161 | trainer = SingleDeviceLMTrainer( 162 | model, 163 | global_step=global_step, 164 | collate_function=NoisingCollator(vocab_size, mask_value, pad_value), 165 | **vars(args), 166 | ) 167 | logger.info(trainer) 168 | train_dataset = create_single_file_dataset(tokenizer, args.train_file, args.seq_len) 169 | valid_dataset = create_single_file_dataset(tokenizer, args.valid_file, args.seq_len) 170 | 171 | trainer.train_epochs( 172 | train_dataset, 173 | valid_dataset, 174 | os.path.join(args.model_checkpoint_dir, "ckpt"), 175 | args.epochs, 176 | ) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /src/mint/examples/pretrain_bert_wiki.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import torch 4 | from torch.utils.data import Dataset 5 | from mint.bert import BertCreator, NoisingCollator, TransformerMLM 6 | from mint.train import DistributedLMTrainer, SingleDeviceLMTrainer 7 | from mint.data import RawInfiniteDataset 8 | from tokenizers import BertWordPieceTokenizer 9 | import json 10 | import os 11 | 12 | logger = logging.getLogger(__file__) 13 | 14 | 15 | """Pre-train a BERT/RoBERTa model in PyTorch on all of wikipedia via https://github.com/attardi/wikiextractor 16 | 17 | We will process the data on-the-fly from the readers, and shard over multiple workers, and we will 18 | train with the SimpleTrainer's train_steps() function 19 | 20 | """ 21 | 22 | 23 | def wikipedia_parser(): 24 | from bs4 import BeautifulSoup 25 | 26 | def get_doc(line): 27 | line = json.loads(line)["text"] 28 | text = BeautifulSoup(line, features="lxml") 29 | # This is a trivial way to replace, we will just sample other surface terms and use those 30 | for link in text.find_all("a"): 31 | surface = link.get_text() 32 | link.replace_with(surface) 33 | text = text.get_text() 34 | return text 35 | 36 | return get_doc 37 | 38 | 39 | def create_sharded_dataset( 40 | tokenizer: BertWordPieceTokenizer, 41 | glob_path: str, 42 | is_train, 43 | seq_len: int = 512, 44 | start_token="[CLS]", 45 | end_token="[SEP]", 46 | ) -> Dataset: 47 | dataset = RawInfiniteDataset( 48 | glob_path, 49 | tokenizer, 50 | is_train, 51 | prefix=start_token, 52 | suffix=end_token, 53 | seq_len=seq_len, 54 | get_data_fn=wikipedia_parser(), 55 | ) 56 | return dataset 57 | 58 | 59 | def try_get_global_step(checkpoint_name) -> int: 60 | """If its a checkpoint we saved the suffix will be -step-{global_step}.pth 61 | 62 | We will assume that any checkpoint we reload has the exact same parameterization as this 63 | run. If thats not the case the learning params will be different 64 | 65 | :param checkpoint_name: Either a huggingface pretrained checkpoint or one we saved here 66 | :return: Int representing the global step 67 | """ 68 | import re 69 | 70 | match = re.match("(\\S+)-step-(\\d+).pth", checkpoint_name) 71 | global_step = 0 72 | if match: 73 | global_step = int(match[2]) 74 | return global_step 75 | 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser(description="Pretrain BERT (wiki)") 79 | parser.add_argument("--model_checkpoint_dir", type=str) 80 | parser.add_argument( 81 | "--train_file", type=str, required=True, help="File path to use for train file" 82 | ) 83 | parser.add_argument( 84 | "--valid_file", type=str, required=True, help="File path to use for valid file" 85 | ) 86 | parser.add_argument( 87 | "--hidden_size", 88 | type=int, 89 | default=768, 90 | help="Model dimension (and embedding dsz)", 91 | ) 92 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 93 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 94 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 95 | parser.add_argument( 96 | "--num_train_workers", type=int, default=4, help="Number train workers" 97 | ) 98 | parser.add_argument( 99 | "--num_valid_workers", type=int, default=1, help="Number train workers" 100 | ) 101 | parser.add_argument("--seq_len", type=int, default=512, help="Max input length") 102 | parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") 103 | parser.add_argument( 104 | "--vocab_file", type=str, help="The WordPiece model file", required=True 105 | ) 106 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 107 | parser.add_argument( 108 | "--decay_type", 109 | choices=["cosine", "linear"], 110 | default="cosine", 111 | help="The type of learning rate decay scheduler", 112 | ) 113 | parser.add_argument( 114 | "--alpha_decay", 115 | type=float, 116 | default=0.0, 117 | help="fraction of learning rate by end of training", 118 | ) 119 | parser.add_argument("--lr", type=float, default=1.0e-4, help="Learning rate") 120 | parser.add_argument( 121 | "--clip", type=float, default=1.0, help="Clipping gradient norm" 122 | ) 123 | parser.add_argument( 124 | "--weight_decay", type=float, default=1.0e-2, help="Weight decay" 125 | ) 126 | parser.add_argument( 127 | "--num_steps", type=int, default=250_000, help="Num training steps" 128 | ) 129 | parser.add_argument( 130 | "--restart_from", 131 | type=str, 132 | help="Option allows you to restart from a previous checkpoint", 133 | ) 134 | parser.add_argument( 135 | "--warmup_fract", 136 | type=float, 137 | default=0.1, 138 | help="Fraction of steps spent warming up", 139 | ) 140 | parser.add_argument( 141 | "--plateau_fract", 142 | type=float, 143 | default=0.0, 144 | help="Fraction of steps spent holding at max lr", 145 | ) 146 | parser.add_argument( 147 | "--saves_per_cycle", 148 | type=int, 149 | default=1, 150 | help="The number of checkpoints to save per epoch", 151 | ) 152 | parser.add_argument( 153 | "--train_cycle_size", 154 | type=int, 155 | default=1000, 156 | help="The many training steps to run before eval", 157 | ) 158 | parser.add_argument( 159 | "--eval_cycle_size", 160 | type=int, 161 | default=200, 162 | help="How many steps to evaluate each time", 163 | ) 164 | parser.add_argument("--lowercase", action="store_true", help="Vocab is lower case") 165 | parser.add_argument( 166 | "--plot_lr_plan", 167 | action="store_true", 168 | help="Shows the learning rate curve (requires matplotlib)", 169 | ) 170 | parser.add_argument( 171 | "--local_rank", 172 | type=int, 173 | default=-1, 174 | help="Local rank for distributed training (-1 means use the environment variables to find)", 175 | ) 176 | parser.add_argument( 177 | "--distributed", action="store_true", help="Are we doing distributed training?" 178 | ) 179 | 180 | parser.add_argument( 181 | "--device", 182 | type=str, 183 | default="cuda" if torch.cuda.is_available() else "cpu", 184 | help="Device (cuda or cpu)", 185 | ) 186 | args = parser.parse_args() 187 | logging.basicConfig(level=logging.INFO) 188 | 189 | if args.model_checkpoint_dir is None: 190 | args.model_checkpoint_dir = f"mlm-{os.getpid()}" 191 | if not os.path.exists(args.model_checkpoint_dir): 192 | os.makedirs(args.model_checkpoint_dir) 193 | 194 | tokenizer = BertWordPieceTokenizer(args.vocab_file, lowercase=args.lowercase) 195 | vocab_size = tokenizer.get_vocab_size() 196 | pad_value = tokenizer.token_to_id("[PAD]") 197 | mask_value = tokenizer.token_to_id("[MASK]") 198 | 199 | if args.restart_from: 200 | global_step = try_get_global_step(args.restart_from) 201 | model = BertCreator.mlm_from_pretrained(args.restart_from, **vars(args)) 202 | else: 203 | global_step = 0 204 | model = TransformerMLM(tokenizer.get_vocab_size(), **vars(args)) 205 | 206 | Trainer = DistributedLMTrainer if args.distributed else SingleDeviceLMTrainer 207 | trainer = Trainer( 208 | model, 209 | global_step=global_step, 210 | collate_function=NoisingCollator(vocab_size, mask_value, pad_value), 211 | **vars(args), 212 | ) 213 | logger.info(trainer) 214 | if args.plot_lr_plan: 215 | trainer.show_lr_plan(args.num_steps) 216 | 217 | train_dataset = create_sharded_dataset( 218 | tokenizer, args.train_file, True, seq_len=args.seq_len 219 | ) 220 | valid_dataset = create_sharded_dataset( 221 | tokenizer, args.valid_file, False, seq_len=args.seq_len 222 | ) 223 | 224 | trainer.train_steps( 225 | train_dataset, 226 | valid_dataset, 227 | os.path.join(args.model_checkpoint_dir, "ckpt"), 228 | args.num_steps, 229 | args.saves_per_cycle, 230 | args.train_cycle_size, 231 | args.eval_cycle_size, 232 | ) 233 | 234 | 235 | if __name__ == "__main__": 236 | main() 237 | -------------------------------------------------------------------------------- /src/mint/examples/pretrain_gpt_simple.py: -------------------------------------------------------------------------------- 1 | """Pre-train a GPT2 model in PyTorch (Simple single file version) 2 | 3 | This works for a small dataset that fits in memory. We will use the SimpleTrainer's train_epochs() 4 | function to train this. 5 | 6 | """ 7 | 8 | import logging 9 | import argparse 10 | import torch 11 | from torch.utils.data import Dataset, TensorDataset 12 | from mint.gpt import GPT2TransformerLM, GPTTransformerLM, GPTCreator, GPT2Creator 13 | from mint.train import SingleDeviceLMTrainer 14 | from tokenizers import Tokenizer 15 | import os 16 | 17 | logger = logging.getLogger(__file__) 18 | 19 | 20 | def create_single_file_dataset(tokenizer: Tokenizer, fname: str, seq_len) -> Dataset: 21 | 22 | with open(fname) as rf: 23 | full_text = rf.read() 24 | tokens = tokenizer.encode(full_text).ids 25 | num_samples = len(tokens) // seq_len 26 | trunc = num_samples * seq_len 27 | if trunc == num_samples: 28 | tokens.append(tokens[0]) 29 | x_tensors = torch.tensor(tokens[:trunc]) 30 | y_tensors = torch.tensor(tokens[1 : trunc + 1]) 31 | return TensorDataset( 32 | x_tensors.view(num_samples, -1), y_tensors.view(num_samples, -1) 33 | ) 34 | 35 | 36 | def try_get_global_step(checkpoint_name) -> int: 37 | """If its a checkpoint we saved the suffix will be -step-{global_step}.pth 38 | 39 | We will assume that any checkpoint we reload has the exact same parameterization as this 40 | run. If thats not the case the learning params will be different 41 | 42 | :param checkpoint_name: Either a huggingface pretrained checkpoint or one we saved here 43 | :return: Int representing the global step 44 | """ 45 | import re 46 | 47 | match = re.match("(\\S+)-step-(\\d+).pth", checkpoint_name) 48 | global_step = 0 49 | if match: 50 | global_step = int(match[2]) 51 | return global_step 52 | 53 | 54 | def main(): 55 | parser = argparse.ArgumentParser(description="Pretrain GPT (simple)") 56 | parser.add_argument("--model_checkpoint_dir", type=str) 57 | parser.add_argument("--version", type=int, choices=[1, 2], default=2) 58 | parser.add_argument( 59 | "--train_file", type=str, required=True, help="File path to use for train file" 60 | ) 61 | parser.add_argument( 62 | "--valid_file", type=str, required=True, help="File path to use for valid file" 63 | ) 64 | parser.add_argument( 65 | "--hidden_size", 66 | type=int, 67 | default=768, 68 | help="Model dimension (and embedding dsz)", 69 | ) 70 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 71 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 72 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 73 | parser.add_argument( 74 | "--num_train_workers", type=int, default=4, help="Number train workers" 75 | ) 76 | parser.add_argument( 77 | "--num_valid_workers", type=int, default=1, help="Number train workers" 78 | ) 79 | parser.add_argument("--seq_len", type=int, default=512, help="Max input length") 80 | parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") 81 | parser.add_argument("--tok_file", type=str, help="The vocab file", required=True) 82 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 83 | parser.add_argument( 84 | "--decay_type", 85 | choices=["cosine", "linear"], 86 | help="The type of learning rate decay scheduler", 87 | ) 88 | parser.add_argument( 89 | "--alpha_decay", 90 | type=float, 91 | default=0.0, 92 | help="fraction of learning rate by end of training", 93 | ) 94 | parser.add_argument("--lr", type=float, default=1.0e-4, help="Learning rate") 95 | parser.add_argument( 96 | "--clip", type=float, default=1.0, help="Clipping gradient norm" 97 | ) 98 | parser.add_argument( 99 | "--weight_decay", type=float, default=1.0e-2, help="Weight decay" 100 | ) 101 | parser.add_argument("--epochs", type=int, default=1, help="Num training epochs") 102 | parser.add_argument( 103 | "--restart_from", 104 | type=str, 105 | help="Option allows you to restart from a previous checkpoint", 106 | ) 107 | parser.add_argument( 108 | "--warmup_fract", 109 | type=float, 110 | default=0.1, 111 | help="Fraction of steps spent warming up", 112 | ) 113 | parser.add_argument( 114 | "--plateau_fract", 115 | type=float, 116 | default=0.0, 117 | help="Fraction of steps spent holding at max lr", 118 | ) 119 | parser.add_argument( 120 | "--saves_per_epoch", 121 | type=int, 122 | default=10, 123 | help="The number of checkpoints to save per epoch", 124 | ) 125 | parser.add_argument( 126 | "--device", 127 | type=str, 128 | default="cuda" if torch.cuda.is_available() else "cpu", 129 | help="Device (cuda or cpu)", 130 | ) 131 | args = parser.parse_args() 132 | logging.basicConfig(level=logging.INFO) 133 | 134 | if args.model_checkpoint_dir is None: 135 | args.model_checkpoint_dir = f"clm-{os.getpid()}" 136 | if not os.path.exists(args.model_checkpoint_dir): 137 | os.makedirs(args.model_checkpoint_dir) 138 | 139 | if os.path.isdir(args.tok_file): 140 | args.tok_file = os.path.join(args.tok_file, "tokenizer.json") 141 | tokenizer = Tokenizer.from_file(args.tok_file) 142 | seq_len = 1024 if args.version == 2 else 512 143 | 144 | if args.restart_from: 145 | global_step = try_get_global_step(args.restart_from) 146 | Creator = GPT2Creator if args.version == 2 else GPTCreator 147 | model = Creator.lm_from_pretrained(args.restart_from, **vars(args)) 148 | else: 149 | global_step = 0 150 | GPT = GPT2TransformerLM if args.version == 2 else GPTTransformerLM 151 | model = GPT(tokenizer.get_vocab_size(), **vars(args)) 152 | 153 | trainer = SingleDeviceLMTrainer( 154 | model, 155 | global_step=global_step, 156 | **vars(args), 157 | ) 158 | logger.info(trainer) 159 | train_dataset = create_single_file_dataset(tokenizer, args.train_file, seq_len) 160 | valid_dataset = create_single_file_dataset(tokenizer, args.valid_file, seq_len) 161 | 162 | trainer.train_epochs( 163 | train_dataset, 164 | valid_dataset, 165 | os.path.join(args.model_checkpoint_dir, "ckpt"), 166 | args.epochs, 167 | ) 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /src/mint/examples/pretrain_t5_simple.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import torch 4 | from torch.utils.data import Dataset, TensorDataset 5 | from mint.t5 import T5Creator, NoisingCollator, T5SequenceGenerator 6 | from mint.train import SingleDeviceSeq2SeqTrainer 7 | from tokenizers import Tokenizer 8 | import os 9 | 10 | logger = logging.getLogger(__file__) 11 | 12 | """Pre-train a T5 model in PyTorch (Simple single file version) 13 | 14 | This works for a small dataset that fits in memory. We will use the SimpleTrainer's train_epochs() 15 | function to train this. 16 | 17 | """ 18 | 19 | 20 | def create_single_file_dataset( 21 | tokenizer: Tokenizer, fname: str, seq_len: int = 512 22 | ) -> Dataset: 23 | eos_token = tokenizer.token_to_id("") 24 | with open(fname) as rf: 25 | tokens = [] 26 | for line in rf: 27 | line = line.strip() 28 | if line: 29 | line = tokenizer.encode(line, add_special_tokens=False) 30 | tokens += line.ids 31 | 32 | num_toks = seq_len - 1 # Ignore CLS and SEP 33 | num_samples = len(tokens) // num_toks * num_toks 34 | tensors = [ 35 | tokens[i : i + num_toks] + [eos_token] 36 | for i in range(0, num_samples, num_toks) 37 | ] 38 | tensors = torch.tensor(tensors, dtype=torch.long) 39 | return TensorDataset(tensors) 40 | 41 | 42 | def try_get_global_step(checkpoint_name) -> int: 43 | """If its a checkpoint we saved the suffix will be -step-{global_step}.pth 44 | 45 | We will assume that any checkpoint we reload has the exact same parameterization as this 46 | run. If thats not the case the learning params will be different 47 | 48 | :param checkpoint_name: Either a huggingface pretrained checkpoint or one we saved here 49 | :return: Int representing the global step 50 | """ 51 | import re 52 | 53 | match = re.match("(\\S+)-step-(\\d+).pth", checkpoint_name) 54 | global_step = 0 55 | if match: 56 | global_step = int(match[2]) 57 | return global_step 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser(description="Pretrain T5 (simple)") 62 | parser.add_argument("--model_checkpoint_dir", type=str) 63 | parser.add_argument( 64 | "--train_file", type=str, required=True, help="File path to use for train file" 65 | ) 66 | parser.add_argument( 67 | "--valid_file", type=str, required=True, help="File path to use for valid file" 68 | ) 69 | parser.add_argument( 70 | "--hidden_size", 71 | type=int, 72 | default=768, 73 | help="Model dimension (and embedding dsz)", 74 | ) 75 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 76 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 77 | parser.add_argument( 78 | "--num_encoder_layers", type=int, default=12, help="Number of encoder layers" 79 | ) 80 | parser.add_argument( 81 | "--num_decoder_layers", type=int, default=12, help="Number of decoder layers" 82 | ) 83 | parser.add_argument( 84 | "--num_train_workers", type=int, default=4, help="Number train workers" 85 | ) 86 | parser.add_argument( 87 | "--num_valid_workers", type=int, default=1, help="Number train workers" 88 | ) 89 | parser.add_argument("--seq_len", type=int, default=512, help="Max input length") 90 | parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") 91 | parser.add_argument( 92 | "--tok_file", type=str, help="The path to the GPT2 tokenizer", required=True 93 | ) 94 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 95 | parser.add_argument( 96 | "--decay_type", 97 | choices=["cosine", "linear"], 98 | help="The type of learning rate decay scheduler", 99 | ) 100 | parser.add_argument( 101 | "--alpha_decay", 102 | type=float, 103 | default=0.0, 104 | help="fraction of learning rate by end of training", 105 | ) 106 | parser.add_argument("--lr", type=float, default=1.0e-4, help="Learning rate") 107 | parser.add_argument( 108 | "--clip", type=float, default=1.0, help="Clipping gradient norm" 109 | ) 110 | parser.add_argument( 111 | "--weight_decay", type=float, default=1.0e-2, help="Weight decay" 112 | ) 113 | parser.add_argument("--epochs", type=int, default=1, help="Num training epochs") 114 | parser.add_argument( 115 | "--restart_from", 116 | type=str, 117 | help="Option allows you to restart from a previous checkpoint", 118 | ) 119 | parser.add_argument( 120 | "--warmup_fract", 121 | type=float, 122 | default=0.1, 123 | help="Fraction of steps spent warming up", 124 | ) 125 | parser.add_argument( 126 | "--plateau_fract", 127 | type=float, 128 | default=0.0, 129 | help="Fraction of steps spent holding at max lr", 130 | ) 131 | parser.add_argument( 132 | "--saves_per_epoch", 133 | type=int, 134 | default=10, 135 | help="The number of checkpoints to save per epoch", 136 | ) 137 | parser.add_argument( 138 | "--device", 139 | type=str, 140 | default="cuda" if torch.cuda.is_available() else "cpu", 141 | help="Device (cuda or cpu)", 142 | ) 143 | args = parser.parse_args() 144 | logging.basicConfig(level=logging.INFO) 145 | 146 | if args.model_checkpoint_dir is None: 147 | args.model_checkpoint_dir = f"s2s-{os.getpid()}" 148 | if not os.path.exists(args.model_checkpoint_dir): 149 | os.makedirs(args.model_checkpoint_dir) 150 | 151 | if os.path.isdir(args.tok_file): 152 | args.tok_file = os.path.join(args.tok_file, "tokenizer.json") 153 | 154 | tokenizer = Tokenizer.from_file(args.tok_file) 155 | vocab = tokenizer.get_vocab() 156 | 157 | if args.restart_from: 158 | global_step = try_get_global_step(args.restart_from) 159 | 160 | model = T5Creator.from_pretrained(args.restart_from, **vars(args)) 161 | else: 162 | global_step = 0 163 | model = T5SequenceGenerator(tokenizer.get_vocab_size(), **vars(args)) 164 | print(model) 165 | trainer = SingleDeviceSeq2SeqTrainer( 166 | model, 167 | global_step=global_step, 168 | collate_function=NoisingCollator(vocab), 169 | **vars(args), 170 | ) 171 | logger.info(trainer) 172 | train_dataset = create_single_file_dataset(tokenizer, args.train_file, args.seq_len) 173 | valid_dataset = create_single_file_dataset(tokenizer, args.valid_file, args.seq_len) 174 | 175 | trainer.train_epochs( 176 | train_dataset, 177 | valid_dataset, 178 | os.path.join(args.model_checkpoint_dir, "ckpt"), 179 | args.epochs, 180 | ) 181 | 182 | 183 | if __name__ == "__main__": 184 | main() 185 | -------------------------------------------------------------------------------- /src/mint/examples/t5_completer.py: -------------------------------------------------------------------------------- 1 | """An example program where you can provide your T5 model with a priming sequence and have it complete 2 | """ 3 | import logging 4 | import argparse 5 | import os 6 | import torch 7 | from prompt_toolkit import prompt 8 | from prompt_toolkit.history import FileHistory 9 | from tokenizers import Tokenizer 10 | from t5 import T5Creator 11 | 12 | logger = logging.getLogger(__file__) 13 | DECODER_START_TOKEN = 0 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description="An interactive shell with T5") 18 | parser.add_argument("--model", type=str, required=True, help="Start from a model") 19 | parser.add_argument( 20 | "--tok_file", type=str, required=True, help="Path to tokenizer.json file" 21 | ) 22 | parser.add_argument( 23 | "--query", 24 | type=str, 25 | help="Optional query. If you pass this we wont use the repl", 26 | ) 27 | parser.add_argument("--history_file", type=str, default=".t5_history") 28 | parser.add_argument("--max_len", type=int, default=50) 29 | parser.add_argument("--num_encoder_layers", default=12, type=int) 30 | parser.add_argument("--num_decoder_layers", default=12, type=int) 31 | parser.add_argument("--num_heads", default=12, type=int) 32 | parser.add_argument("--sample", action="store_true") 33 | parser.add_argument("--temperature", default=1.0, type=float) 34 | parser.add_argument( 35 | "--device", 36 | type=str, 37 | default="cuda" if torch.cuda.is_available() else "cpu", 38 | help="Device (cuda or cpu)", 39 | ) 40 | args = parser.parse_args() 41 | logging.basicConfig(level=logging.INFO) 42 | if os.path.isdir(args.tok_file): 43 | args.tok_file = os.path.join(args.tok_file, "tokenizer.json") 44 | tokenizer = Tokenizer.from_file(args.tok_file) 45 | 46 | model = T5Creator.from_pretrained(args.model, **vars(args)).eval() 47 | model.to(args.device) 48 | 49 | EOS_ID = tokenizer.get_vocab().get("") 50 | 51 | def complete(query, sampling, temperature, decode_as_text=True): 52 | logger.info("Query: %s", query) 53 | tokenized_input = tokenizer.encode(query) 54 | logger.info("Input Sequence: %s", " ".join(tokenized_input.tokens)) 55 | input_ids = torch.tensor(tokenized_input.ids, device=args.device).unsqueeze(0) 56 | input_enc = model.encode(input_ids) 57 | outputs = [DECODER_START_TOKEN] 58 | with torch.no_grad(): 59 | 60 | for i in range(args.max_len): 61 | 62 | decode_ids = torch.tensor(outputs, device=args.device) 63 | # signature is encoder, decoder (up till now), encoder_mask, decoder_mask 64 | response = model.decode(input_enc, decode_ids.unsqueeze(0)).squeeze(0) 65 | response = response[len(decode_ids) - 1] 66 | if sampling: 67 | sample_dist = torch.softmax(response / temperature, -1) 68 | output = torch.multinomial(sample_dist, num_samples=1) 69 | response = output.squeeze().item() 70 | else: 71 | response = response.argmax(-1).item() 72 | 73 | if response == EOS_ID: 74 | break 75 | outputs.append(response) 76 | outputs = tokenizer.decode(outputs[1:]) if decode_as_text else outputs 77 | return outputs 78 | 79 | if args.query: 80 | print(complete(args.query, args.sample, args.temperature)) 81 | return 82 | 83 | prompt_name = "T5>> " 84 | history = FileHistory(args.history_file) 85 | while True: 86 | query = prompt(prompt_name, history=history) 87 | query = query.strip() 88 | if query == ":quit" or query == "quit": 89 | break 90 | if query == ":sample": 91 | args.sample = True 92 | print("Turn sampling mode on") 93 | continue 94 | if query == ":max": 95 | args.sample = False 96 | print("Turn sampling mode off") 97 | continue 98 | print(complete(query, args.sample)) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /src/mint/examples/tune_bart_for_cls.py: -------------------------------------------------------------------------------- 1 | from mint.bart import BartCreator 2 | from mint.train import Average 3 | from mint.data import read_cls_dataset 4 | from tokenizers import Tokenizer 5 | import argparse 6 | import sys 7 | import torch 8 | import logging 9 | from tqdm import tqdm 10 | import os 11 | 12 | PAD_VALUE = 1 13 | 14 | logger = logging.getLogger(__file__) 15 | 16 | """Fine-tune BART as a classifier 17 | 18 | This program fine-tunes a pre-trained BART for an unstructured prediction (classification) task. 19 | The input is assumed to be a 2-column file with the label first. The delimiter between columns should 20 | be a space or tab. 21 | 22 | Early stopping is performed on the dataset in order to determine the best checkpoint. 23 | 24 | It reads the files into a TensorDataset for simplicity, and it trims the batch to the max length 25 | observed in a minibatch. 26 | 27 | If there is a `test_file` provided in the args, we will run an evaluation on our best checkpoint. 28 | 29 | """ 30 | 31 | 32 | def valid_epoch(epoch, loss_function, model, valid_loader, device, phase="valid"): 33 | model.eval() 34 | valid_loss = Average("valid_loss") 35 | progress = tqdm(enumerate(valid_loader), total=len(valid_loader)) 36 | valid_correct = 0 37 | valid_total = 0 38 | with torch.no_grad(): 39 | for i, (x, y) in progress: 40 | x = x.to(device) 41 | y = y.to(device) 42 | y_pred = model(x, model.create_pad_mask(x)) 43 | loss = loss_function(y_pred, y) 44 | valid_loss.update(loss.item()) 45 | y_pred = y_pred.argmax(dim=-1).view(-1) 46 | y = y.view(-1) 47 | valid_correct += (y == y_pred).sum() 48 | valid_total += y.shape[0] 49 | 50 | valid_acc = 100.0 * valid_correct / valid_total 51 | progress.set_description( 52 | f"{phase} epoch {epoch + 1}, step {i}: loss {valid_loss.avg:.3f}, accuracy {valid_acc:.2f}%" 53 | ) 54 | 55 | return valid_correct / valid_total 56 | 57 | 58 | def train_epoch(epoch, loss_function, model, optimizer, train_loader, device): 59 | model.train() 60 | train_loss = Average("train_loss") 61 | progress = tqdm(enumerate(train_loader), total=len(train_loader)) 62 | train_correct = 0 63 | train_total = 0 64 | for i, (x, y) in progress: 65 | optimizer.zero_grad() 66 | x = x.to(device) 67 | y = y.to(device) 68 | y_pred = model(x, model.create_pad_mask(x)) 69 | loss = loss_function(y_pred, y) 70 | train_loss.update(loss.item()) 71 | loss.backward() 72 | optimizer.step() 73 | y_pred = y_pred.argmax(dim=-1).view(-1) 74 | y = y.view(-1) 75 | train_correct += (y == y_pred).sum() 76 | train_total += y.shape[0] 77 | train_acc = 100.0 * (train_correct / train_total) 78 | progress.set_description( 79 | f"train epoch {epoch + 1}, step {i}: loss {train_loss.avg:.3f}, accuracy {train_acc:.2f}%" 80 | ) 81 | 82 | 83 | def trim_to_shortest_len(batch): 84 | max_len = max((example[0] != PAD_VALUE).sum() for example in batch) + 1 85 | y = torch.stack([example[1] for example in batch]) 86 | x = torch.stack([example[0][:max_len] for example in batch]) 87 | return x, y 88 | 89 | 90 | def main(): 91 | parser = argparse.ArgumentParser( 92 | description="fine-tune BERT for classification (single text input only)" 93 | ) 94 | parser.add_argument("--model", type=str) 95 | parser.add_argument("--train_file", type=str, required=True) 96 | parser.add_argument("--valid_file", type=str, required=True) 97 | parser.add_argument("--test_file", type=str) 98 | parser.add_argument( 99 | "--hidden_size", 100 | type=int, 101 | default=768, 102 | help="Model dimension (and embedding dsz)", 103 | ) 104 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 105 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 106 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 107 | parser.add_argument( 108 | "--num_train_workers", type=int, default=4, help="Number train workers" 109 | ) 110 | parser.add_argument( 111 | "--num_valid_workers", type=int, default=1, help="Number train workers" 112 | ) 113 | parser.add_argument( 114 | "--max_seq_len", type=int, default=1024, help="Max input length" 115 | ) 116 | parser.add_argument("--batch_size", type=int, default=20, help="Batch Size") 117 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 118 | parser.add_argument( 119 | "--tok_file", type=str, required=True, help="Path to tokenizer.json file" 120 | ) 121 | parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate") 122 | parser.add_argument("--ckpt_base", type=str, default="ckpt-") 123 | parser.add_argument("--num_epochs", type=int, default=5) 124 | parser.add_argument( 125 | "--device", 126 | type=str, 127 | default="cuda" if torch.cuda.is_available() else "cpu", 128 | help="Device (cuda or cpu)", 129 | ) 130 | args = parser.parse_args() 131 | logging.basicConfig(level=logging.INFO) 132 | if os.path.isdir(args.tok_file): 133 | args.tok_file = os.path.join(args.tok_file, "tokenizer.json") 134 | tokenizer = Tokenizer.from_file(args.tok_file) 135 | # TODO: read the pad_index in 136 | train_set, labels = read_cls_dataset( 137 | args.train_file, tokenizer, pad_index=1, max_seq_len=args.max_seq_len 138 | ) 139 | train_loader = torch.utils.data.DataLoader( 140 | train_set, 141 | batch_size=args.batch_size, 142 | shuffle=True, 143 | collate_fn=trim_to_shortest_len, 144 | ) 145 | logger.info(labels) 146 | valid_set, labels = read_cls_dataset( 147 | args.valid_file, 148 | tokenizer, 149 | pad_index=1, 150 | max_seq_len=args.max_seq_len, 151 | label_list=labels, 152 | ) 153 | valid_loader = torch.utils.data.DataLoader( 154 | valid_set, 155 | batch_size=args.batch_size, 156 | shuffle=False, 157 | collate_fn=trim_to_shortest_len, 158 | ) 159 | 160 | num_classes = len(labels) 161 | output_layer = torch.nn.Linear(args.hidden_size, num_classes) 162 | model = BartCreator.pooled_from_pretrained( 163 | args.model, output=output_layer, **vars(args) 164 | ).to(args.device) 165 | loss_function = torch.nn.CrossEntropyLoss().to(args.device) 166 | 167 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 168 | logger.info(optimizer) 169 | checkpoint_name = args.ckpt_base + ".pth" 170 | best_valid_acc = 0.0 171 | 172 | device = args.device 173 | for epoch in range(args.num_epochs): 174 | train_epoch(epoch, loss_function, model, optimizer, train_loader, device) 175 | 176 | valid_acc_fract = valid_epoch(epoch, loss_function, model, valid_loader, device) 177 | # Early stopping check 178 | if valid_acc_fract > best_valid_acc: 179 | best_valid_acc = valid_acc_fract 180 | acc = 100.0 * valid_acc_fract 181 | logger.info(f"New best validation accuracy {acc:.2f}%") 182 | torch.save(model.state_dict(), checkpoint_name) 183 | if not args.test_file: 184 | logger.info("No test file provided, exiting") 185 | sys.exit(1) 186 | 187 | test_set, final_labels = read_cls_dataset( 188 | args.test_file, 189 | tokenizer, 190 | pad_index=1, 191 | max_seq_len=args.max_seq_len, 192 | label_list=labels, 193 | ) 194 | if len(final_labels) != num_classes: 195 | raise Exception( 196 | "The test set adds new classes with no samples in the training or validation" 197 | ) 198 | test_loader = torch.utils.data.DataLoader( 199 | test_set, 200 | batch_size=args.batch_size, 201 | shuffle=False, 202 | collate_fn=trim_to_shortest_len, 203 | ) 204 | 205 | best_state = torch.load(checkpoint_name) 206 | model.load_state_dict(best_state) 207 | eval_fract = valid_epoch(0, loss_function, model, test_loader, device, phase="test") 208 | eval_acc = 100.0 * eval_fract 209 | print(f"final test accuracy {eval_acc:.2f}%") 210 | 211 | 212 | if __name__ == "__main__": 213 | main() 214 | -------------------------------------------------------------------------------- /src/mint/examples/tune_bert_for_cls.py: -------------------------------------------------------------------------------- 1 | from mint.bert import BertCreator 2 | from mint.train import Average 3 | from mint.data import read_cls_dataset 4 | from tokenizers import BertWordPieceTokenizer 5 | import argparse 6 | import sys 7 | import torch 8 | import logging 9 | from tqdm import tqdm 10 | 11 | PAD_VALUE = 0 12 | 13 | logger = logging.getLogger(__file__) 14 | 15 | """Fine-tune BERT as a classifier 16 | 17 | This program fine-tunes a pre-trained BERT for an unstructured prediction (classification) task. 18 | The input is assumed to be a 2-column file with the label first. The delimiter between columns should 19 | be a space or tab. 20 | 21 | Early stopping is performed on the dataset in order to determine the best checkpoint. 22 | 23 | It reads the files into a TensorDataset for simplicity, and it trims the batch to the max length 24 | observed in a minibatch. 25 | 26 | If there is a `test_file` provided in the args, we will run an evaluation on our best checkpoint. 27 | 28 | """ 29 | 30 | 31 | def valid_epoch(epoch, loss_function, model, valid_loader, device, phase="valid"): 32 | model.eval() 33 | valid_loss = Average("valid_loss") 34 | progress = tqdm(enumerate(valid_loader), total=len(valid_loader)) 35 | valid_correct = 0 36 | valid_total = 0 37 | with torch.no_grad(): 38 | for i, (x, y) in progress: 39 | x = x.to(device) 40 | y = y.to(device) 41 | y_pred = model(x, model.create_pad_mask(x)) 42 | loss = loss_function(y_pred, y) 43 | valid_loss.update(loss.item()) 44 | y_pred = y_pred.argmax(dim=-1).view(-1) 45 | y = y.view(-1) 46 | valid_correct += (y == y_pred).sum() 47 | valid_total += y.shape[0] 48 | 49 | valid_acc = 100.0 * valid_correct / valid_total 50 | progress.set_description( 51 | f"{phase} epoch {epoch + 1}, step {i}: loss {valid_loss.avg:.3f}, accuracy {valid_acc:.2f}%" 52 | ) 53 | 54 | return valid_correct / valid_total 55 | 56 | 57 | def train_epoch(epoch, loss_function, model, optimizer, train_loader, device): 58 | model.train() 59 | train_loss = Average("train_loss") 60 | progress = tqdm(enumerate(train_loader), total=len(train_loader)) 61 | train_correct = 0 62 | train_total = 0 63 | for i, (x, y) in progress: 64 | optimizer.zero_grad() 65 | x = x.to(device) 66 | y = y.to(device) 67 | y_pred = model(x, model.create_pad_mask(x)) 68 | loss = loss_function(y_pred, y) 69 | train_loss.update(loss.item()) 70 | loss.backward() 71 | optimizer.step() 72 | y_pred = y_pred.argmax(dim=-1).view(-1) 73 | y = y.view(-1) 74 | train_correct += (y == y_pred).sum() 75 | train_total += y.shape[0] 76 | train_acc = 100.0 * (train_correct / train_total) 77 | progress.set_description( 78 | f"train epoch {epoch + 1}, step {i}: loss {train_loss.avg:.3f}, accuracy {train_acc:.2f}%" 79 | ) 80 | 81 | 82 | def trim_to_shortest_len(batch): 83 | max_len = max((example[0] != PAD_VALUE).sum() for example in batch) 84 | y = torch.stack([example[1] for example in batch]) 85 | x = torch.stack([example[0][:max_len] for example in batch]) 86 | return x, y 87 | 88 | 89 | def main(): 90 | parser = argparse.ArgumentParser( 91 | description="fine-tune BERT for classification (single text input only)" 92 | ) 93 | parser.add_argument("--model", type=str) 94 | parser.add_argument("--train_file", type=str, required=True) 95 | parser.add_argument("--valid_file", type=str, required=True) 96 | parser.add_argument("--test_file", type=str) 97 | parser.add_argument( 98 | "--hidden_size", 99 | type=int, 100 | default=768, 101 | help="Model dimension (and embedding dsz)", 102 | ) 103 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 104 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 105 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 106 | parser.add_argument( 107 | "--num_train_workers", type=int, default=4, help="Number train workers" 108 | ) 109 | parser.add_argument( 110 | "--num_valid_workers", type=int, default=1, help="Number train workers" 111 | ) 112 | parser.add_argument("--max_seq_len", type=int, default=512, help="Max input length") 113 | parser.add_argument("--batch_size", type=int, default=20, help="Batch Size") 114 | parser.add_argument( 115 | "--vocab_file", type=str, help="The WordPiece model file", required=True 116 | ) 117 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 118 | parser.add_argument("--lowercase", action="store_true", help="Vocab is lower case") 119 | parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate") 120 | parser.add_argument("--ckpt_base", type=str, default="ckpt-") 121 | parser.add_argument("--num_epochs", type=int, default=5) 122 | parser.add_argument( 123 | "--device", 124 | type=str, 125 | default="cuda" if torch.cuda.is_available() else "cpu", 126 | help="Device (cuda or cpu)", 127 | ) 128 | args = parser.parse_args() 129 | logging.basicConfig(level=logging.INFO) 130 | tokenizer = BertWordPieceTokenizer(args.vocab_file, lowercase=args.lowercase) 131 | # TODO: read the pad_index in 132 | train_set, labels = read_cls_dataset( 133 | args.train_file, tokenizer, pad_index=0, max_seq_len=args.max_seq_len 134 | ) 135 | train_loader = torch.utils.data.DataLoader( 136 | train_set, 137 | batch_size=args.batch_size, 138 | shuffle=True, 139 | collate_fn=trim_to_shortest_len, 140 | ) 141 | logger.info(labels) 142 | valid_set, labels = read_cls_dataset( 143 | args.valid_file, 144 | tokenizer, 145 | pad_index=0, 146 | max_seq_len=args.max_seq_len, 147 | label_list=labels, 148 | ) 149 | valid_loader = torch.utils.data.DataLoader( 150 | valid_set, 151 | batch_size=args.batch_size, 152 | shuffle=False, 153 | collate_fn=trim_to_shortest_len, 154 | ) 155 | 156 | num_classes = len(labels) 157 | output_layer = torch.nn.Linear(args.hidden_size, num_classes) 158 | model = BertCreator.pooled_enc_from_pretrained( 159 | args.model, output=output_layer, **vars(args) 160 | ).to(args.device) 161 | loss_function = torch.nn.CrossEntropyLoss().to(args.device) 162 | 163 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 164 | logger.info(optimizer) 165 | checkpoint_name = args.ckpt_base + ".pth" 166 | best_valid_acc = 0.0 167 | 168 | device = args.device 169 | for epoch in range(args.num_epochs): 170 | train_epoch(epoch, loss_function, model, optimizer, train_loader, device) 171 | 172 | valid_acc_fract = valid_epoch(epoch, loss_function, model, valid_loader, device) 173 | # Early stopping check 174 | if valid_acc_fract > best_valid_acc: 175 | best_valid_acc = valid_acc_fract 176 | acc = 100.0 * valid_acc_fract 177 | logger.info(f"New best validation accuracy {acc:.2f}%") 178 | torch.save(model.state_dict(), checkpoint_name) 179 | if not args.test_file: 180 | logger.info("No test file provided, exiting") 181 | sys.exit(1) 182 | 183 | test_set, final_labels = read_cls_dataset( 184 | args.test_file, 185 | tokenizer, 186 | pad_index=0, 187 | max_seq_len=args.max_seq_len, 188 | label_list=labels, 189 | ) 190 | if len(final_labels) != num_classes: 191 | raise Exception( 192 | "The test set adds new classes with no samples in the training or validation" 193 | ) 194 | test_loader = torch.utils.data.DataLoader( 195 | test_set, 196 | batch_size=args.batch_size, 197 | shuffle=False, 198 | collate_fn=trim_to_shortest_len, 199 | ) 200 | 201 | best_state = torch.load(checkpoint_name) 202 | model.load_state_dict(best_state) 203 | eval_fract = valid_epoch(0, loss_function, model, test_loader, device, phase="test") 204 | eval_acc = 100.0 * eval_fract 205 | print(f"final test accuracy {eval_acc:.2f}%") 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /src/mint/examples/tune_bert_for_paired_cls.py: -------------------------------------------------------------------------------- 1 | """Fine-tune BERT as a cross-encoder or dual-encoder (following SentenceBERT) classifier 2 | 3 | This program fine-tunes a pre-trained BERT for an unstructured prediction (classification) 4 | task with 2 text inputs. This typically corresponds to so-called Natural Language Inference 5 | datasets. 6 | 7 | The label space should be ternary, with -1 meaning contradiction, 1 for entailment or 0 for neutral. 8 | 9 | The loss is a cross-entropy loss 10 | 11 | For dual encoding, the network is shared at the lower layers, up until a pooling operation 12 | is performed for each channel, yielding a fixed width vector for each. The model should predict for each channel's 13 | vector, are they entailment, contradiction or neutral. This will yield a model that can be used for distance 14 | queries. 15 | 16 | Early stopping is performed on the dataset in order to determine the best checkpoint. 17 | 18 | If there is a `test_file` provided in the args, we will run an evaluation on our best checkpoint. 19 | 20 | """ 21 | 22 | from mint.bert import BertCreator 23 | from mint.train import Average 24 | import json 25 | from typing import Optional, List 26 | from mint.data import TextFile, TensorDataset 27 | from tokenizers import BertWordPieceTokenizer 28 | import argparse 29 | import sys 30 | import os 31 | import torch 32 | import logging 33 | from tqdm import tqdm 34 | 35 | PAD_VALUE = 0 36 | 37 | logger = logging.getLogger(__file__) 38 | 39 | 40 | class PairedTextTrainer: 41 | def __init__(self, model, loss_function, device, lr): 42 | self.model = model 43 | self.loss_function = loss_function 44 | self.epoch = 0 45 | self.device = device 46 | self.lr = lr 47 | self.optimizer = torch.optim.Adam(model.parameters(), lr=lr) 48 | logger.info(self.optimizer) 49 | 50 | def valid_epoch(self, valid_loader, phase="valid"): 51 | self.model.eval() 52 | valid_loss = Average("valid_loss") 53 | progress = tqdm(enumerate(valid_loader), total=len(valid_loader)) 54 | valid_correct = 0 55 | valid_total = 0 56 | with torch.no_grad(): 57 | for i, batch in progress: 58 | loss, y, y_pred = self.valid_step(batch) 59 | valid_loss.update(loss.item()) 60 | y_pred = y_pred.argmax(dim=-1).view(-1) 61 | y = y.view(-1) 62 | valid_correct += (y == y_pred).sum() 63 | valid_total += y.shape[0] 64 | 65 | valid_acc = 100.0 * valid_correct / valid_total 66 | progress.set_description( 67 | f"{phase} epoch {self.epoch + 1}, step {i}: loss {valid_loss.avg:.3f}, accuracy {valid_acc:.2f}%" 68 | ) 69 | 70 | return valid_correct / valid_total 71 | 72 | def train_epoch(self, train_loader): 73 | self.model.train() 74 | train_loss = Average("train_loss") 75 | 76 | warmup_steps = 1 77 | if self.epoch == 0: 78 | warmup_steps = int(0.1 * len(train_loader)) 79 | logger.info("Warmup steps %d", warmup_steps) 80 | progress = tqdm(enumerate(train_loader), total=len(train_loader)) 81 | train_correct = 0 82 | train_total = 0 83 | 84 | for i, batch in progress: 85 | 86 | lr_factor = min(1.0, (i + 1) / warmup_steps) 87 | for p in self.optimizer.param_groups: 88 | p["lr"] = self.lr * lr_factor 89 | 90 | self.optimizer.zero_grad() 91 | loss, y, y_pred = self.train_step(batch) 92 | train_loss.update(loss.item()) 93 | loss.backward() 94 | self.optimizer.step() 95 | y_pred = y_pred.argmax(dim=-1).view(-1) 96 | y = y.view(-1) 97 | train_correct += (y == y_pred).sum() 98 | train_total += y.shape[0] 99 | train_acc = 100.0 * (train_correct / train_total) 100 | progress.set_description( 101 | f"train epoch {self.epoch + 1}, step {i}: loss {train_loss.avg:.3f}, accuracy {train_acc:.2f}%" 102 | ) 103 | 104 | 105 | class DualEncoderTrainer(PairedTextTrainer): 106 | def train_step(self, batch): 107 | (x1, x2, y) = batch 108 | x1 = x1.to(self.device) 109 | x2 = x2.to(self.device) 110 | y = y.to(self.device) 111 | y_pred = self.model( 112 | x1, x2, self.model.create_pad_mask(x1), self.model.create_pad_mask(x2) 113 | ) 114 | loss = self.loss_function(y_pred, y) 115 | return loss, y, y_pred 116 | 117 | def valid_step(self, batch): 118 | (x1, x2, y) = batch 119 | x1 = x1.to(self.device) 120 | x2 = x2.to(self.device) 121 | y = y.to(self.device) 122 | y_pred = self.model( 123 | x1, x2, self.model.create_pad_mask(x1), self.model.create_pad_mask(x2) 124 | ) 125 | loss = self.loss_function(y_pred, y) 126 | return loss, y, y_pred 127 | 128 | 129 | class CrossEncoderTrainer(PairedTextTrainer): 130 | def train_step(self, batch): 131 | (x, tt, y) = batch 132 | x = x.to(self.device) 133 | tt = tt.to(self.device) 134 | y = y.to(self.device) 135 | y_pred = self.model(x, self.model.create_pad_mask(x), token_type=tt) 136 | loss = self.loss_function(y_pred, y) 137 | return loss, y, y_pred 138 | 139 | def valid_step(self, batch): 140 | (x, tt, y) = batch 141 | x = x.to(self.device) 142 | tt = tt.to(self.device) 143 | y = y.to(self.device) 144 | y_pred = self.model(x, self.model.create_pad_mask(x), token_type=tt) 145 | loss = self.loss_function(y_pred, y) 146 | return loss, y, y_pred 147 | 148 | 149 | def read_jsonl_dual_cls_dataset( 150 | file: str, 151 | tokenizer, 152 | pad_index=0, 153 | max_seq_len=512, 154 | cols: Optional[List[str]] = None, 155 | label_list: Optional[List[str]] = None, 156 | ) -> TensorDataset: 157 | if cols is None: 158 | cols = ["label", "sentence1", "sentence2"] 159 | if label_list is None: 160 | label_list = ["contradiction", "neutral", "entailment"] 161 | 162 | label2index = {} if not label_list else {k: i for i, k in enumerate(label_list)} 163 | 164 | def read_line(l): 165 | obj = json.loads(l) 166 | label = label2index[obj[cols[0]]] 167 | tokens = [ 168 | torch.tensor(tokenizer.encode(obj[cols[1]]).ids)[:max_seq_len], 169 | torch.tensor(tokenizer.encode(obj[cols[2]]).ids)[:max_seq_len], 170 | ] 171 | padded = [torch.full((max_seq_len,), pad_index, dtype=tokens[0].dtype)] * 2 172 | padded[0][: len(tokens[0])] = tokens[0] 173 | padded[1][: len(tokens[1])] = tokens[1] 174 | return padded + [label] 175 | 176 | if ( 177 | os.path.exists(file + ".x1.th") 178 | and os.path.exists(file + ".x2.th") 179 | and os.path.exists(file + ".y.th") 180 | ): 181 | logger.info( 182 | "Found cached tensor files, reloading. If you dont want this, delete *.th from %s", 183 | os.path.dirname(file), 184 | ) 185 | x1_tensor = torch.load(file + ".x1.th") 186 | x2_tensor = torch.load(file + ".x2.th") 187 | y_tensor = torch.load(file + ".y.th") 188 | return TensorDataset(x1_tensor, x2_tensor, y_tensor), label_list 189 | 190 | x1_tensor = [] 191 | x2_tensor = [] 192 | y_tensor = [] 193 | with TextFile(file) as rf: 194 | for line in rf: 195 | padded_x1, padded_x2, label = read_line(line.strip()) 196 | 197 | x1_tensor.append(padded_x1) 198 | x2_tensor.append(padded_x2) 199 | y_tensor.append(label) 200 | x1_tensor = torch.stack(x1_tensor) 201 | x2_tensor = torch.stack(x2_tensor) 202 | y_tensor = torch.tensor(y_tensor, dtype=torch.long) 203 | logger.info("Caching tensors for %s in its parent directory", file) 204 | torch.save(x1_tensor, file + ".x1.th") 205 | torch.save(x2_tensor, file + ".x2.th") 206 | torch.save(y_tensor, file + ".y.th") 207 | return TensorDataset(x1_tensor, x2_tensor, y_tensor), label_list 208 | 209 | 210 | def read_jsonl_cross_cls_dataset( 211 | file: str, 212 | tokenizer, 213 | pad_index=0, 214 | max_seq_len=512, 215 | cols: Optional[List[str]] = None, 216 | label_list: Optional[List[str]] = None, 217 | ) -> TensorDataset: 218 | if cols is None: 219 | cols = ["label", "sentence1", "sentence2"] 220 | if label_list is None: 221 | label_list = ["contradiction", "neutral", "entailment"] 222 | 223 | label2index = {} if not label_list else {k: i for i, k in enumerate(label_list)} 224 | 225 | def read_line(l): 226 | obj = json.loads(l) 227 | label = label2index[obj[cols[0]]] 228 | x1 = torch.tensor(tokenizer.encode(obj[cols[1]]).ids[:max_seq_len]) 229 | x2 = torch.tensor(tokenizer.encode(obj[cols[2]]).ids[1:]) 230 | x1_tt = torch.zeros_like(x1) 231 | tokens = torch.cat((x1, x2), -1)[:max_seq_len] 232 | # The output is going to look like [CLS] t1 [SEP] t2 [SEP] 233 | padded = torch.full((max_seq_len,), pad_index, dtype=tokens[0].dtype) 234 | tt = torch.ones_like(padded) 235 | tt[: len(x1_tt)] = x1_tt 236 | padded[: len(tokens)] = tokens 237 | return [padded, tt] + [label] 238 | 239 | if ( 240 | os.path.exists(file + ".x.th") 241 | and os.path.exists(file + ".x.th") 242 | and os.path.exists(file + ".y.th") 243 | ): 244 | logger.info( 245 | "Found cached tensor files, reloading. If you dont want this, delete *.th from %s", 246 | os.path.dirname(file), 247 | ) 248 | x_tensor = torch.load(file + ".x.th") 249 | tt_tensor = torch.load(file + ".tt.th") 250 | y_tensor = torch.load(file + ".y.th") 251 | return TensorDataset(x_tensor, tt_tensor, y_tensor), label_list 252 | 253 | x_tensor = [] 254 | tt_tensor = [] 255 | y_tensor = [] 256 | 257 | with TextFile(file) as rf: 258 | for line in rf: 259 | padded_x, tt_x, label = read_line(line.strip()) 260 | x_tensor.append(padded_x) 261 | tt_tensor.append(tt_x) 262 | y_tensor.append(label) 263 | x_tensor = torch.stack(x_tensor) 264 | tt_tensor = torch.stack(tt_tensor) 265 | y_tensor = torch.tensor(y_tensor, dtype=torch.long) 266 | logger.info("Caching tensors for %s in its parent directory", file) 267 | torch.save(x_tensor, file + ".x.th") 268 | torch.save(tt_tensor, file + ".tt.th") 269 | torch.save(y_tensor, file + ".y.th") 270 | return TensorDataset(x_tensor, tt_tensor, y_tensor), label_list 271 | 272 | 273 | def trim_to_shortest_len_dual(batch): 274 | max_x1_len = max((example[0] != PAD_VALUE).sum() for example in batch) 275 | max_x2_len = max((example[1] != PAD_VALUE).sum() for example in batch) 276 | 277 | y = torch.stack([example[2] for example in batch]) 278 | x1 = torch.stack([example[0][:max_x1_len] for example in batch]) 279 | x2 = torch.stack([example[1][:max_x2_len] for example in batch]) 280 | return x1, x2, y 281 | 282 | 283 | def trim_to_shortest_len_cross(batch): 284 | max_x_len = max((example[0] != PAD_VALUE).sum() for example in batch) 285 | x = torch.stack([example[0][:max_x_len] for example in batch]) 286 | tt = torch.stack([example[1][:max_x_len] for example in batch]) 287 | y = torch.stack([example[2] for example in batch]) 288 | 289 | return x, tt, y 290 | 291 | 292 | def main(): 293 | parser = argparse.ArgumentParser( 294 | description="fine-tune BERT for classification (dual text input only)" 295 | ) 296 | parser.add_argument("--model", type=str) 297 | parser.add_argument("--train_file", type=str, required=True) 298 | parser.add_argument("--valid_file", type=str, required=True) 299 | parser.add_argument("--test_file", type=str) 300 | parser.add_argument( 301 | "--hidden_size", 302 | type=int, 303 | default=768, 304 | help="Model dimension (and embedding dsz)", 305 | ) 306 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 307 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 308 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 309 | parser.add_argument( 310 | "--num_train_workers", type=int, default=4, help="Number train workers" 311 | ) 312 | parser.add_argument( 313 | "--num_valid_workers", type=int, default=1, help="Number train workers" 314 | ) 315 | parser.add_argument("--max_seq_len", type=int, default=512, help="Max input length") 316 | parser.add_argument("--batch_size", type=int, default=16, help="Batch Size") 317 | parser.add_argument( 318 | "--encoder_type", 319 | choices=["cross-encoder", "dual-encoder"], 320 | default="dual-encoder", 321 | help="Train a dual-encoder or a cross-encoder", 322 | ) 323 | parser.add_argument( 324 | "--vocab_file", type=str, help="The WordPiece model file", required=True 325 | ) 326 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 327 | parser.add_argument("--lowercase", action="store_true", help="Vocab is lower case") 328 | parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate") 329 | parser.add_argument("--ckpt_base", type=str, default="ckpt") 330 | parser.add_argument("--num_epochs", type=int, default=1) 331 | parser.add_argument( 332 | "--label_names", 333 | type=str, 334 | nargs="+", 335 | default=["contradiction", "neutral", "entailment"], 336 | ) 337 | parser.add_argument( 338 | "--col_names", type=str, nargs="+", default=["label", "sentence1", "sentence2"] 339 | ) 340 | parser.add_argument( 341 | "--device", 342 | type=str, 343 | default="cuda" if torch.cuda.is_available() else "cpu", 344 | help="Device (cuda or cpu)", 345 | ) 346 | args = parser.parse_args() 347 | logging.basicConfig(level=logging.INFO) 348 | tokenizer = BertWordPieceTokenizer(args.vocab_file, lowercase=args.lowercase) 349 | 350 | if args.encoder_type == "dual-encoder": 351 | trim_batch_fn = trim_to_shortest_len_dual 352 | read_fn = read_jsonl_dual_cls_dataset 353 | else: 354 | logger.info("Using cross-encoder") 355 | trim_batch_fn = trim_to_shortest_len_cross 356 | read_fn = read_jsonl_cross_cls_dataset 357 | 358 | train_set, labels = read_fn( 359 | args.train_file, 360 | tokenizer, 361 | pad_index=0, 362 | max_seq_len=args.max_seq_len, 363 | cols=args.col_names, 364 | label_list=args.label_names, 365 | ) 366 | train_loader = torch.utils.data.DataLoader( 367 | train_set, batch_size=args.batch_size, shuffle=True, collate_fn=trim_batch_fn 368 | ) 369 | logger.info(labels) 370 | valid_set, labels = read_fn( 371 | args.valid_file, 372 | tokenizer, 373 | pad_index=0, 374 | max_seq_len=args.max_seq_len, 375 | label_list=labels, 376 | ) 377 | valid_loader = torch.utils.data.DataLoader( 378 | valid_set, batch_size=args.batch_size, shuffle=False, collate_fn=trim_batch_fn 379 | ) 380 | 381 | num_classes = len(labels) 382 | loss_function = torch.nn.CrossEntropyLoss().to(args.device) 383 | if args.encoder_type == "dual-encoder": 384 | model = BertCreator.dual_encoder_from_pretrained( 385 | args.model, num_classes=num_classes, **vars(args) 386 | ).to(args.device) 387 | trainer = DualEncoderTrainer(model, loss_function, args.device, args.lr) 388 | else: 389 | num_classes = len(labels) 390 | output_layer = torch.nn.Linear(args.hidden_size, num_classes) 391 | model = BertCreator.pooled_enc_from_pretrained( 392 | args.model, output=output_layer, **vars(args) 393 | ).to(args.device) 394 | trainer = CrossEncoderTrainer(model, loss_function, args.device, args.lr) 395 | 396 | best_valid_acc = 0.0 397 | 398 | checkpoint_name = args.ckpt_base + ".pth" 399 | for epoch in range(args.num_epochs): 400 | trainer.train_epoch(train_loader) 401 | 402 | valid_acc_fract = trainer.valid_epoch(valid_loader) 403 | # Early stopping check 404 | if valid_acc_fract > best_valid_acc: 405 | best_valid_acc = valid_acc_fract 406 | acc = 100.0 * valid_acc_fract 407 | logger.info(f"New best validation accuracy {acc:.2f}%") 408 | torch.save(model.state_dict(), checkpoint_name) 409 | if not args.test_file: 410 | logger.info("No test file provided, exiting") 411 | sys.exit(1) 412 | 413 | test_set, final_labels = read_fn( 414 | args.test_file, 415 | tokenizer, 416 | pad_index=0, 417 | max_seq_len=args.max_seq_len, 418 | label_list=labels, 419 | ) 420 | if len(final_labels) != num_classes: 421 | raise Exception( 422 | "The test set adds new classes with no samples in the training or validation" 423 | ) 424 | test_loader = torch.utils.data.DataLoader( 425 | test_set, batch_size=args.batch_size, shuffle=False, collate_fn=trim_batch_fn 426 | ) 427 | 428 | best_state = torch.load(checkpoint_name) 429 | model.load_state_dict(best_state) 430 | eval_fract = trainer.valid_epoch(test_loader, phase="test") 431 | eval_acc = 100.0 * eval_fract 432 | print(f"final test accuracy {eval_acc:.2f}%") 433 | 434 | 435 | if __name__ == "__main__": 436 | main() 437 | -------------------------------------------------------------------------------- /src/mint/examples/tune_gpt2_for_cls.py: -------------------------------------------------------------------------------- 1 | from mint.gpt import GPT2Creator 2 | from mint.train import Average 3 | from tokenizers import ByteLevelBPETokenizer 4 | from typing import Optional, Callable, List 5 | from torch.utils.data import TensorDataset 6 | from mint.data import TextFile 7 | import argparse 8 | import sys 9 | import torch 10 | import logging 11 | from tqdm import tqdm 12 | 13 | PAD_VALUE = 0 14 | 15 | logger = logging.getLogger(__file__) 16 | 17 | """Fine-tune GPT2 as a classifier 18 | 19 | This program fine-tunes a pre-trained GPT2 for an unstructured prediction (classification) task. 20 | The input is assumed to be a 2-column file with the label first. The delimiter between columns should 21 | be a space or tab. 22 | 23 | Early stopping is performed on the dataset in order to determine the best checkpoint. 24 | 25 | It reads the files into a TensorDataset for simplicity, and it trims the batch to the max length 26 | observed in a minibatch. 27 | 28 | If there is a `test_file` provided in the args, we will run an evaluation on our best checkpoint. 29 | 30 | """ 31 | 32 | 33 | def read_cls_dataset_gpt2( 34 | file: str, 35 | tokenizer, 36 | pad_index=0, 37 | get_data_fn: Optional[Callable] = None, 38 | max_seq_len=1024, 39 | label_list: Optional[List[str]] = None, 40 | ) -> TensorDataset: 41 | def read_space_delim_line(line: str): 42 | toks = line.split() 43 | label = toks[0] 44 | tokens = " ".join(toks[1:]) 45 | return label, tokens 46 | 47 | eos = tokenizer.token_to_id("<|endoftext|>") 48 | if get_data_fn is None: 49 | get_data_fn = read_space_delim_line 50 | 51 | label2index = {} if not label_list else {k: i for i, k in enumerate(label_list)} 52 | label_offset = len(label2index) 53 | x_tensor = [] 54 | y_tensor = [] 55 | with TextFile(file) as rf: 56 | for line in rf: 57 | label, example_str = get_data_fn(line.strip()) 58 | if label not in label2index: 59 | label2index[label] = label_offset 60 | label_offset += 1 61 | tokens = torch.tensor(tokenizer.encode(example_str).ids + [eos]) 62 | padded = torch.full((max_seq_len,), pad_index, dtype=tokens.dtype) 63 | padded[: len(tokens)] = tokens 64 | x_tensor.append(padded) 65 | y_tensor.append(label2index[label]) 66 | x_tensor = torch.stack(x_tensor) 67 | y_tensor = torch.tensor(y_tensor, dtype=torch.long) 68 | label_list = [0] * label_offset 69 | for label, idx in label2index.items(): 70 | label_list[idx] = label 71 | return TensorDataset(x_tensor, y_tensor), label_list 72 | 73 | 74 | def valid_epoch(epoch, loss_function, model, valid_loader, device, phase="valid"): 75 | model.eval() 76 | valid_loss = Average("valid_loss") 77 | progress = tqdm(enumerate(valid_loader), total=len(valid_loader)) 78 | valid_correct = 0 79 | valid_total = 0 80 | with torch.no_grad(): 81 | for i, (x, y) in progress: 82 | x = x.to(device) 83 | y = y.to(device) 84 | y_pred = model(x, model.create_pad_mask(x)) 85 | loss = loss_function(y_pred, y) 86 | valid_loss.update(loss.item()) 87 | y_pred = y_pred.argmax(dim=-1).view(-1) 88 | y = y.view(-1) 89 | valid_correct += (y == y_pred).sum() 90 | valid_total += y.shape[0] 91 | 92 | valid_acc = 100.0 * valid_correct / valid_total 93 | progress.set_description( 94 | f"{phase} epoch {epoch + 1}, step {i}: loss {valid_loss.avg:.3f}, accuracy {valid_acc:.2f}%" 95 | ) 96 | 97 | return valid_correct / valid_total 98 | 99 | 100 | def train_epoch(epoch, loss_function, model, optimizer, train_loader, device): 101 | model.train() 102 | train_loss = Average("train_loss") 103 | progress = tqdm(enumerate(train_loader), total=len(train_loader)) 104 | train_correct = 0 105 | train_total = 0 106 | for i, (x, y) in progress: 107 | optimizer.zero_grad() 108 | x = x.to(device) 109 | y = y.to(device) 110 | y_pred = model(x, model.create_pad_mask(x)) 111 | loss = loss_function(y_pred, y) 112 | train_loss.update(loss.item()) 113 | loss.backward() 114 | optimizer.step() 115 | y_pred = y_pred.argmax(dim=-1).view(-1) 116 | y = y.view(-1) 117 | train_correct += (y == y_pred).sum() 118 | train_total += y.shape[0] 119 | train_acc = 100.0 * (train_correct / train_total) 120 | progress.set_description( 121 | f"train epoch {epoch + 1}, step {i}: loss {train_loss.avg:.3f}, accuracy {train_acc:.2f}%" 122 | ) 123 | 124 | 125 | def trim_to_shortest_len(batch): 126 | max_len = max((example[0] != PAD_VALUE).sum() for example in batch) 127 | y = torch.stack([example[1] for example in batch]) 128 | x = torch.stack([example[0][:max_len] for example in batch]) 129 | return x, y 130 | 131 | 132 | def main(): 133 | parser = argparse.ArgumentParser( 134 | description="fine-tune GPT for classification (single text input only)" 135 | ) 136 | parser.add_argument("--model", type=str) 137 | parser.add_argument("--train_file", type=str, required=True) 138 | parser.add_argument("--valid_file", type=str, required=True) 139 | parser.add_argument("--test_file", type=str) 140 | parser.add_argument( 141 | "--vocab_file", type=str, required=True, help="Path to vocab file" 142 | ) 143 | parser.add_argument( 144 | "--merges_file", type=str, required=True, help="Path to vocab file" 145 | ) 146 | parser.add_argument( 147 | "--hidden_size", 148 | type=int, 149 | default=768, 150 | help="Model dimension (and embedding dsz)", 151 | ) 152 | parser.add_argument("--feed_forward_size", type=int, help="FFN dimension") 153 | parser.add_argument("--num_heads", type=int, default=12, help="Number of heads") 154 | parser.add_argument("--num_layers", type=int, default=12, help="Number of layers") 155 | parser.add_argument( 156 | "--num_train_workers", type=int, default=4, help="Number train workers" 157 | ) 158 | parser.add_argument( 159 | "--num_valid_workers", type=int, default=1, help="Number train workers" 160 | ) 161 | parser.add_argument( 162 | "--max_seq_len", type=int, default=1024, help="Max input length" 163 | ) 164 | parser.add_argument("--batch_size", type=int, default=20, help="Batch Size") 165 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") 166 | parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate") 167 | parser.add_argument("--ckpt_base", type=str, default="ckpt-") 168 | parser.add_argument( 169 | "--pool_type", 170 | type=str, 171 | choices=["mean", "last"], 172 | help="token to use for pooling, defaults to mean pooling", 173 | ) 174 | parser.add_argument("--num_epochs", type=int, default=5) 175 | parser.add_argument( 176 | "--device", 177 | type=str, 178 | default="cuda" if torch.cuda.is_available() else "cpu", 179 | help="Device (cuda or cpu)", 180 | ) 181 | args = parser.parse_args() 182 | logging.basicConfig(level=logging.INFO) 183 | tokenizer = ByteLevelBPETokenizer(args.vocab_file, args.merges_file) 184 | # TODO: read the pad_index in 185 | train_set, labels = read_cls_dataset_gpt2( 186 | args.train_file, tokenizer, pad_index=0, max_seq_len=args.max_seq_len 187 | ) 188 | train_loader = torch.utils.data.DataLoader( 189 | train_set, 190 | batch_size=args.batch_size, 191 | shuffle=True, 192 | collate_fn=trim_to_shortest_len, 193 | ) 194 | logger.info(labels) 195 | valid_set, labels = read_cls_dataset_gpt2( 196 | args.valid_file, 197 | tokenizer, 198 | pad_index=0, 199 | max_seq_len=args.max_seq_len, 200 | label_list=labels, 201 | ) 202 | valid_loader = torch.utils.data.DataLoader( 203 | valid_set, 204 | batch_size=args.batch_size, 205 | shuffle=False, 206 | collate_fn=trim_to_shortest_len, 207 | ) 208 | 209 | num_classes = len(labels) 210 | output_layer = torch.nn.Linear(args.hidden_size, num_classes) 211 | pool_id = ( 212 | tokenizer.token_to_id("<|endoftext|>") if args.pool_type == "last" else None 213 | ) 214 | model = GPT2Creator.pooled_enc_from_pretrained( 215 | args.model, output=output_layer, pool_id=pool_id, **vars(args) 216 | ).to(args.device) 217 | loss_function = torch.nn.CrossEntropyLoss().to(args.device) 218 | 219 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 220 | logger.info(optimizer) 221 | checkpoint_name = args.ckpt_base + ".pth" 222 | best_valid_acc = 0.0 223 | 224 | device = args.device 225 | for epoch in range(args.num_epochs): 226 | train_epoch(epoch, loss_function, model, optimizer, train_loader, device) 227 | 228 | valid_acc_fract = valid_epoch(epoch, loss_function, model, valid_loader, device) 229 | # Early stopping check 230 | if valid_acc_fract > best_valid_acc: 231 | best_valid_acc = valid_acc_fract 232 | acc = 100.0 * valid_acc_fract 233 | logger.info(f"New best validation accuracy {acc:.2f}%") 234 | torch.save(model.state_dict(), checkpoint_name) 235 | if not args.test_file: 236 | logger.info("No test file provided, exiting") 237 | sys.exit(1) 238 | 239 | test_set, final_labels = read_cls_dataset_gpt2( 240 | args.test_file, 241 | tokenizer, 242 | pad_index=0, 243 | max_seq_len=args.max_seq_len, 244 | label_list=labels, 245 | ) 246 | if len(final_labels) != num_classes: 247 | raise Exception( 248 | "The test set adds new classes with no samples in the training or validation" 249 | ) 250 | test_loader = torch.utils.data.DataLoader( 251 | test_set, 252 | batch_size=args.batch_size, 253 | shuffle=False, 254 | collate_fn=trim_to_shortest_len, 255 | ) 256 | 257 | best_state = torch.load(checkpoint_name) 258 | model.load_state_dict(best_state) 259 | eval_fract = valid_epoch(0, loss_function, model, test_loader, device, phase="test") 260 | eval_acc = 100.0 * eval_fract 261 | print(f"final test accuracy {eval_acc:.2f}%") 262 | 263 | 264 | if __name__ == "__main__": 265 | main() 266 | -------------------------------------------------------------------------------- /src/mint/postln.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional, Callable 4 | from mint.common import DefaultLayerFactory, WeightTiedVocabProjection 5 | import logging 6 | 7 | logger = logging.getLogger("mint") 8 | 9 | 10 | class TransformerEncoderLayer(nn.Module): 11 | """A single (post-layer-norm style) Transformer layer 12 | 13 | This layer implements a post-layer-norm style Transformer. The MultiHeadedAttention is applied first, with 14 | optional dropout and added to its input, followed by normalization. Then the FFN block is applied, where 15 | an MLP layer with a larger size is applied followed by an activation and down-projection back to the input size. 16 | Dropout is again applied, and again we add the output to the input of FFN, followed by a layer norm. 17 | 18 | As this is a post-layer-norm architecture, a normalization operation should be applied prior to sending the 19 | data through this layer 20 | 21 | """ 22 | 23 | def __init__( 24 | self, 25 | hidden_size: int = 768, 26 | num_heads: int = 12, 27 | dropout: float = 0.1, 28 | layer_norm_eps: float = 1e-12, 29 | activation: nn.Module = nn.GELU(), 30 | feed_forward_size: Optional[int] = None, 31 | layer_factory=None, 32 | ): 33 | """Initialize our transformer, uses bert-base defaults 34 | 35 | :param hidden_size: Size of the transformer inputs and outputs (d_model in the paper) 36 | :param num_heads: The number of heads for multi-headed attention 37 | :param dropout: A dropout to apply to each sub-blocks outputs 38 | :param layer_norm_eps: The noise applied in the layer norm calculation 39 | :param activation: The activation function to use 40 | :param feed_forward_size: The optional size of the FFN internal representation. Defaults to 4*hidden_size 41 | :param layer_factory: An optional implementation of all layers, useful for specific model implementation details 42 | """ 43 | super().__init__() 44 | 45 | self.hidden_size = hidden_size 46 | self.dropout = dropout 47 | self.d_ff = feed_forward_size 48 | self.self_attention = layer_factory.encoder_multihead_attention( 49 | hidden_size, num_heads 50 | ) 51 | self.self_attention_layer_norm = layer_factory.layer_norm( 52 | hidden_size, layer_norm_eps 53 | ) 54 | self.ffn = layer_factory.feed_forward( 55 | hidden_size, feed_forward_size, activation 56 | ) 57 | self.output_layer_norm = layer_factory.layer_norm(hidden_size, layer_norm_eps) 58 | 59 | def maybe_dropout(self, x: torch.Tensor) -> torch.Tensor: 60 | """Apply dropout operator in graph only if training 61 | 62 | TODO: this function could also test dropout to make sure its > 0, pruning an unnecessary op 63 | if training with no dropout 64 | 65 | :param x: The output of the sub-layer 66 | :return: A (maybe) dropped out version of the input 67 | """ 68 | return nn.functional.dropout(x, self.dropout) if self.training else x 69 | 70 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): 71 | """Pass an x tensor and optional mask through the transformer layer 72 | 73 | :param x: A `[B, T, C]` tensor where B is batch, T is time, and C is the num hidden units 74 | :param mask: An optional attention mask. True where the input is valid, and false where it isnt 75 | :return: The output of the block 76 | """ 77 | y = self.self_attention_layer_norm( 78 | x + self.maybe_dropout(self.self_attention(x, mask)) 79 | ) 80 | y = self.output_layer_norm(y + self.maybe_dropout(self.ffn(y))) 81 | return y 82 | 83 | 84 | class TransformerEncoder(nn.Module): 85 | """A Post-Layer Norm Transformer Encoder (with no task heads) 86 | 87 | This encoder encapsulates the entire front-end of the Transformer from one-hots up to the final 88 | encoding. For tasks like MLM and fine-tuning we will inherit this module and provide additional 89 | functionality to the forward. 90 | 91 | This set up via inheritance to keep sub-classing and configuration params being passed to a minimum 92 | The tutorial mentions other ways that you could organize this 93 | 94 | """ 95 | 96 | def __init__( 97 | self, 98 | EmbeddingClass: Callable, 99 | vocab_size: int, 100 | padding_idx: int = 0, 101 | hidden_size: int = 768, 102 | num_heads: int = 12, 103 | num_layers: int = 12, 104 | dropout: float = 0.1, 105 | layer_norm_eps=1e-12, 106 | activation: nn.Module = nn.GELU(), 107 | feed_forward_size: Optional[int] = None, 108 | max_seq_len: int = 512, 109 | do_embeddings_layer_norm=True, 110 | layer_factory=None, 111 | ): 112 | """Set up initialization for a (post-layer-norm) Transformer. Defaults to bert-base settings 113 | 114 | :param vocab_size: The size of the input vocabulary 115 | :param padding_idx: The padding index, defaults to 0 116 | :param hidden_size: The number of hidden units 117 | :param num_heads: The number of heads for multi-headed attn. Should divide evenly into hidden_size 118 | :param num_layers: The number of transformer layers (MHA+FFN) in the architecture 119 | :param dropout: The value to apply for dropout 120 | :param layer_norm_eps: The noising term for layer norm 121 | :param activation: The activation function to use throughout 122 | :param feed_forward_size: An optional value to set for the FFN MLP output size, defaults to 4*hidden_size 123 | :param layer_factory: An optional implementation of all layers, useful for specific model implementation details 124 | 125 | """ 126 | super().__init__() 127 | self.padding_idx = padding_idx 128 | 129 | if layer_factory is None: 130 | layer_factory = DefaultLayerFactory.get_instance() 131 | 132 | self.embeddings_layer_norm = ( 133 | layer_factory.layer_norm(hidden_size, layer_norm_eps) 134 | if do_embeddings_layer_norm 135 | else nn.Identity() 136 | ) 137 | self.embeddings = EmbeddingClass( 138 | vocab_size, hidden_size, padding_idx=padding_idx, max_seq_len=max_seq_len 139 | ) 140 | self.encoder = nn.ModuleList( 141 | [ 142 | TransformerEncoderLayer( 143 | hidden_size, 144 | num_heads, 145 | dropout, 146 | layer_norm_eps, 147 | activation, 148 | feed_forward_size, 149 | layer_factory, 150 | ) 151 | for _ in range(num_layers) 152 | ] 153 | ) 154 | self.LayerNormImpl = layer_factory.layer_norm 155 | 156 | @property 157 | def hidden_size(self): 158 | """Useful to see the hidden size of the arch., but we dont a member var, its going to be all over the layers 159 | :return: 160 | """ 161 | return self.embeddings.word_embeddings.weight.shape[1] 162 | 163 | @property 164 | def vocab_size(self): 165 | """Useful to see the vocab size, but we dont need to store as a member, its the first dim of word embeddings 166 | 167 | :return: 168 | """ 169 | return self.embeddings.word_embeddings.weight.shape[0] 170 | 171 | def create_pad_mask(self, x: torch.Tensor) -> torch.Tensor: 172 | """For input padded using the padding_idx, generate an attention mask for that 173 | 174 | :param x: 175 | :return: 176 | """ 177 | mask = x != self.padding_idx 178 | return mask.unsqueeze(1).unsqueeze(1).to(device=x.device) 179 | 180 | def forward( 181 | self, 182 | x: torch.Tensor, 183 | mask: Optional[torch.Tensor] = None, 184 | token_type: Optional[torch.Tensor] = None, 185 | ) -> torch.Tensor: 186 | """ 187 | 188 | :param x: A one-hot (long) tensor of shape `[B, T]` 189 | :param mask: An optional mask to take in for attention 190 | :param token_type: 191 | :return: 192 | """ 193 | y = self.embeddings(x, token_type) 194 | y = self.embeddings_layer_norm(y) 195 | for t in self.encoder: 196 | y = t(y, mask) 197 | return y 198 | 199 | def init_layer_weights(self, module): 200 | """This not directly used on initialization. If you want to use it, call `module.apply()` on it 201 | 202 | The base classes do make use of it for MLM and pooling in their constructors 203 | :param module: 204 | :return: 205 | """ 206 | if isinstance(module, (nn.Linear, nn.Embedding, self.LayerNormImpl)): 207 | module.weight.data.normal_(mean=0.0, std=0.02) 208 | if ( 209 | isinstance(module, (nn.Linear, self.LayerNormImpl)) 210 | and module.bias is not None 211 | ): 212 | module.bias.data.zero_() 213 | 214 | 215 | class TransformerDecoderLayer(nn.Module): 216 | """A single (post-layer-norm style) Transformer Decoder layer 217 | 218 | This layer implements a post-layer-norm style Transformer Decoder (in the NMT/Encoder-Decoder sense). 219 | This module contains both self-attention (used in the decoder portion, and Encoder-Decoder cross-attention) 220 | 221 | As this is a post-layer-norm architecture, a normalization operation should be applied prior to sending the 222 | data through this layer 223 | 224 | """ 225 | 226 | def __init__( 227 | self, 228 | hidden_size: int = 768, 229 | num_heads: int = 12, 230 | dropout: float = 0.1, 231 | layer_norm_eps: float = 1e-12, 232 | activation: nn.Module = nn.GELU(), 233 | feed_forward_size: Optional[int] = None, 234 | layer_factory=None, 235 | ): 236 | """Initialize our transformer, uses bert-base defaults 237 | 238 | :param hidden_size: Size of the transformer inputs and outputs (d_model in the paper) 239 | :param num_heads: The number of heads for multi-headed attention 240 | :param dropout: A dropout to apply to each sub-blocks outputs 241 | :param layer_norm_eps: The noise applied in the layer norm calculation 242 | :param activation: The activation function to use 243 | :param feed_forward_size: The optional size of the FFN internal representation. Defaults to 4*hidden_size 244 | :param layer_factory: An optional implementation of all layers, useful for specific model implementation details 245 | 246 | """ 247 | super().__init__() 248 | 249 | self.hidden_size = hidden_size 250 | self.dropout = dropout 251 | self.d_ff = feed_forward_size 252 | if layer_factory is None: 253 | layer_factory = DefaultLayerFactory.get_instance() 254 | self.self_attention = layer_factory.decoder_multihead_attention( 255 | hidden_size, num_heads 256 | ) 257 | self.self_attention_layer_norm = layer_factory.layer_norm( 258 | hidden_size, layer_norm_eps 259 | ) 260 | self.encoder_attention = layer_factory.encoder_decoder_attention( 261 | hidden_size, num_heads 262 | ) 263 | self.encoder_attention_layer_norm = layer_factory.layer_norm( 264 | hidden_size, layer_norm_eps 265 | ) 266 | self.ffn = layer_factory.feed_forward( 267 | hidden_size, feed_forward_size, activation 268 | ) 269 | self.output_layer_norm = layer_factory.layer_norm(hidden_size, layer_norm_eps) 270 | 271 | def maybe_dropout(self, x: torch.Tensor) -> torch.Tensor: 272 | """Apply dropout operator in graph only if training 273 | 274 | TODO: this function could also test dropout to make sure its > 0, pruning an unnecessary op 275 | if training with no dropout 276 | 277 | :param x: The output of the sub-layer 278 | :return: A (maybe) dropped out version of the input 279 | """ 280 | return nn.functional.dropout(x, self.dropout) if self.training else x 281 | 282 | def forward( 283 | self, 284 | src: torch.Tensor, 285 | dst: torch.Tensor, 286 | src_mask: Optional[torch.Tensor] = None, 287 | dst_mask: Optional[torch.Tensor] = None, 288 | ): 289 | """Pass an x tensor and optional mask through the transformer layer 290 | 291 | :param x: A `[B, T, C]` tensor where B is batch, T is time, and C is the num hidden units 292 | :param mask: An optional attention mask. True where the input is valid, and false where it isnt 293 | :return: The output of the block 294 | """ 295 | y = self.self_attention_layer_norm( 296 | dst + self.maybe_dropout(self.self_attention(dst, dst_mask)) 297 | ) 298 | y = self.encoder_attention_layer_norm( 299 | y + self.maybe_dropout(self.encoder_attention(src, y, src_mask)) 300 | ) 301 | y = self.output_layer_norm(y + self.maybe_dropout(self.ffn(y))) 302 | return y 303 | 304 | 305 | class TransformerEncoderDecoder(nn.Module): 306 | """A Post-Layer Norm Transformer Decoder (with no task heads) 307 | 308 | This encoder encapsulates the entire front-end of the Transformer from one-hots up to the final 309 | encoding. For tasks like MLM and fine-tuning we will inherit this module and provide additional 310 | functionality to the forward. 311 | 312 | This set up via inheritance to keep sub-classing and configuration params being passed to a minimum 313 | The tutorial mentions other ways that you could organize this 314 | 315 | """ 316 | 317 | def __init__( 318 | self, 319 | EmbeddingClass: Callable, 320 | vocab_size: int, 321 | padding_idx: int = 0, 322 | hidden_size: int = 768, 323 | num_heads: int = 12, 324 | num_encoder_layers: int = 6, 325 | num_decoder_layers: int = 6, 326 | dropout: float = 0.1, 327 | layer_norm_eps=1e-12, 328 | activation: nn.Module = nn.GELU(), 329 | feed_forward_size: Optional[int] = None, 330 | max_seq_len: int = 512, 331 | do_embeddings_layer_norm=True, 332 | layer_factory=None, 333 | ): 334 | """Set up initialization for a (post-layer-norm) Transformer. Defaults to bert-base settings 335 | 336 | :param vocab_size: The size of the input vocabulary 337 | :param padding_idx: The padding index, defaults to 0 338 | :param hidden_size: The number of hidden units 339 | :param num_heads: The number of heads for multi-headed attn. Should divide evenly into hidden_size 340 | :param num_layers: The number of transformer layers (MHA+FFN) in the architecture 341 | :param dropout: The value to apply for dropout 342 | :param layer_norm_eps: The noising term for layer norm 343 | :param activation: The activation function to use throughout 344 | :param feed_forward_size: An optional value to set for the FFN MLP output size, defaults to 4*hidden_size 345 | :param layer_factory: An optional implementation of all layers, useful for specific model implementation details 346 | """ 347 | super().__init__() 348 | self.padding_idx = padding_idx 349 | if layer_factory is None: 350 | layer_factory = DefaultLayerFactory.get_instance() 351 | self.encoder_embeddings_layer_norm = ( 352 | layer_factory.layer_norm(hidden_size, layer_norm_eps) 353 | if do_embeddings_layer_norm 354 | else nn.Identity() 355 | ) 356 | self.decoder_embeddings_layer_norm = ( 357 | layer_factory.layer_norm(hidden_size, layer_norm_eps) 358 | if do_embeddings_layer_norm 359 | else nn.Identity() 360 | ) 361 | self.encoder_embeddings = EmbeddingClass( 362 | vocab_size, hidden_size, padding_idx=padding_idx, max_seq_len=max_seq_len 363 | ) 364 | self.decoder_embeddings = EmbeddingClass( 365 | vocab_size, hidden_size, padding_idx=padding_idx, max_seq_len=max_seq_len 366 | ) 367 | 368 | self.decoder_embeddings.word_embeddings = ( 369 | self.encoder_embeddings.word_embeddings 370 | ) 371 | 372 | self.encoder = nn.ModuleList( 373 | [ 374 | TransformerEncoderLayer( 375 | hidden_size, 376 | num_heads, 377 | dropout, 378 | layer_norm_eps, 379 | activation, 380 | feed_forward_size, 381 | layer_factory, 382 | ) 383 | for _ in range(num_encoder_layers) 384 | ] 385 | ) 386 | self.decoder = nn.ModuleList( 387 | [ 388 | TransformerDecoderLayer( 389 | hidden_size, 390 | num_heads, 391 | dropout, 392 | layer_norm_eps, 393 | activation, 394 | feed_forward_size, 395 | layer_factory, 396 | ) 397 | for _ in range(num_decoder_layers) 398 | ] 399 | ) 400 | 401 | self.register_buffer( 402 | "causal_mask", 403 | torch.tril( 404 | torch.ones( 405 | ( 406 | max_seq_len, 407 | max_seq_len, 408 | ), 409 | dtype=torch.uint8, 410 | ) 411 | ) 412 | .unsqueeze(0) 413 | .unsqueeze(0), 414 | ) 415 | self.LayerNormImpl = layer_factory.layer_norm 416 | 417 | @property 418 | def hidden_size(self): 419 | """Useful to see the hidden size of the arch., but we dont a member var, its going to be all over the layers 420 | :return: 421 | """ 422 | return self.encoder_embeddings.word_embeddings.weight.shape[1] 423 | 424 | @property 425 | def vocab_size(self): 426 | """Useful to see the vocab size, but we dont need to store as a member, its the first dim of word embeddings 427 | 428 | :return: 429 | """ 430 | return self.encoder_embeddings.word_embeddings.weight.shape[0] 431 | 432 | def create_pad_mask(self, x: torch.Tensor) -> torch.Tensor: 433 | """For input padded using the padding_idx, generate an attention mask for that 434 | 435 | :param x: 436 | :return: 437 | """ 438 | mask = x != self.padding_idx 439 | return mask.unsqueeze(1).unsqueeze(1).to(device=x.device) 440 | 441 | def forward( 442 | self, 443 | src: torch.Tensor, 444 | dst: torch.Tensor, 445 | src_mask: Optional[torch.Tensor] = None, 446 | dst_mask: Optional[torch.Tensor] = None, 447 | ) -> torch.Tensor: 448 | """ 449 | 450 | :param src: A one-hot (long) tensor of shape `[B, T_k]` 451 | :param dst: A one-hot (long) tensor of shape `[B, T_q]` 452 | :param src_mask: An optional mask to take in for attention 453 | :param dst_mask: An optional mask to take in for attention 454 | :return: 455 | """ 456 | 457 | src_enc = self.encode(src, src_mask) 458 | dst_enc = self.decode(src_enc, dst, src_mask, dst_mask) 459 | return dst_enc 460 | 461 | def decode( 462 | self, 463 | src_enc, 464 | dst, 465 | src_mask: Optional[torch.Tensor] = None, 466 | dst_mask: Optional[torch.Tensor] = None, 467 | ): 468 | futures_mask = self.causal_mask[:, :, : dst.shape[1], : dst.shape[1]] 469 | if dst_mask is not None: 470 | futures_mask = dst_mask & futures_mask.to(dtype=torch.bool) 471 | dst_enc = self.decoder_embeddings(dst) 472 | dst_enc = self.decoder_embeddings_layer_norm(dst_enc) 473 | for t in self.decoder: 474 | dst_enc = t(src_enc, dst_enc, src_mask, futures_mask) 475 | return dst_enc 476 | 477 | def encode( 478 | self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None 479 | ) -> torch.Tensor: 480 | src_enc = self.encoder_embeddings(src) 481 | src_enc = self.encoder_embeddings_layer_norm(src_enc) 482 | for t in self.encoder: 483 | src_enc = t(src_enc, src_mask) 484 | return src_enc 485 | 486 | def init_layer_weights(self, module): 487 | """This not directly used on initialization. If you want to use it, call `module.apply()` on it 488 | 489 | The base classes do make use of it for MLM and pooling in their constructors 490 | :param module: 491 | :return: 492 | """ 493 | if isinstance(module, (nn.Linear, nn.Embedding, self.LayerNormImpl)): 494 | module.weight.data.normal_(mean=0.0, std=0.02) 495 | if ( 496 | isinstance(module, (nn.Linear, self.LayerNormImpl)) 497 | and module.bias is not None 498 | ): 499 | module.bias.data.zero_() 500 | 501 | 502 | class TransformerSequenceGenerator(TransformerEncoderDecoder): 503 | """An encoder-decoder that produces word output in the decoder 504 | 505 | For training, this works with teacher forcing, where both the encoder inputs and the 506 | lagged generated tokens at each timestep are provided, starting with some well-known 507 | decoder begin token. 508 | 509 | At inference time, we will do some decoding over time, and so we need to be able to 510 | call the encoder once, and the decoder N times 511 | """ 512 | 513 | def __init__( 514 | self, 515 | EmbeddingClass: Callable, 516 | vocab_size: int, 517 | padding_idx: int = 0, 518 | hidden_size: int = 768, 519 | num_heads: int = 12, 520 | num_encoder_layers: int = 6, 521 | num_decoder_layers: int = 6, 522 | dropout: float = 0.1, 523 | layer_norm_eps=1e-12, 524 | activation: nn.Module = nn.GELU(), 525 | feed_forward_size: Optional[int] = None, 526 | max_seq_len: int = 1024, 527 | do_embeddings_layer_norm=True, 528 | layer_factory=None, 529 | ): 530 | super().__init__( 531 | EmbeddingClass, 532 | vocab_size, 533 | padding_idx, 534 | hidden_size, 535 | num_heads, 536 | num_encoder_layers, 537 | num_decoder_layers, 538 | dropout, 539 | layer_norm_eps, 540 | activation, 541 | feed_forward_size, 542 | max_seq_len, 543 | do_embeddings_layer_norm, 544 | layer_factory, 545 | ) 546 | self.output_proj = WeightTiedVocabProjection( 547 | self.decoder_embeddings.word_embeddings 548 | ) 549 | self.apply(self.init_layer_weights) 550 | 551 | def decode( 552 | self, 553 | src_enc, 554 | dst, 555 | src_mask: Optional[torch.Tensor] = None, 556 | dst_mask: Optional[torch.Tensor] = None, 557 | ): 558 | dst_enc = super().decode(src_enc, dst, src_mask, dst_mask) 559 | return self.output_proj(dst_enc) 560 | -------------------------------------------------------------------------------- /src/mint/preln.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional, Callable 4 | from mint.common import DefaultLayerFactory, WeightTiedVocabProjection 5 | import math 6 | import logging 7 | 8 | logger = logging.getLogger("mint") 9 | 10 | 11 | class PreLayerNormTransformerEncoderLayer(nn.Module): 12 | """A single (pre-layer-norm style) Transformer layer 13 | 14 | This layer implements a pre-layer-norm style Transformer. Normalization is applied first, and then 15 | MultiHeadedAttention is applied with optional dropout and added to its input. 16 | 17 | Normalization is again applied and then the FFN block is applied. Then an MLP layer with a larger size 18 | is applied followed by an activation and down-projection back to the input size. Dropout is again applied, 19 | and again we add the output to the input of FFN. 20 | 21 | """ 22 | 23 | def __init__( 24 | self, 25 | hidden_size: int = 768, 26 | num_heads: int = 12, 27 | dropout: float = 0.1, 28 | layer_norm_eps: float = 1e-12, 29 | activation: nn.Module = nn.GELU(), 30 | feed_forward_size: Optional[int] = None, 31 | layer_factory=None, 32 | ): 33 | """Initialize our transformer, uses bert-base defaults 34 | 35 | :param hidden_size: Size of the transformer inputs and outputs (d_model in the paper) 36 | :param num_heads: The number of heads for multi-headed attention 37 | :param dropout: A dropout to apply to each sub-blocks outputs 38 | :param layer_norm_eps: The noise applied in the layer norm calculation 39 | :param activation: The activation function to use 40 | :param feed_forward_size: The optional size of the FFN internal representation. Defaults to 4*hidden_size 41 | :param layer_factory: An optional implementation of all layers, useful for specific model implementation details 42 | 43 | """ 44 | super().__init__() 45 | if layer_factory is None: 46 | layer_factory = DefaultLayerFactory.get_instance() 47 | self.hidden_size = hidden_size 48 | self.dropout = dropout 49 | self.d_ff = feed_forward_size 50 | self.self_attention = layer_factory.encoder_multihead_attention( 51 | hidden_size, num_heads 52 | ) 53 | self.self_attention_layer_norm = layer_factory.layer_norm( 54 | hidden_size, layer_norm_eps 55 | ) 56 | self.ffn = layer_factory.feed_forward( 57 | hidden_size, feed_forward_size, activation 58 | ) 59 | self.output_layer_norm = layer_factory.layer_norm(hidden_size, layer_norm_eps) 60 | 61 | def maybe_dropout(self, x: torch.Tensor) -> torch.Tensor: 62 | """Apply dropout operator in graph only if training 63 | 64 | TODO: this function could also test dropout to make sure its > 0, pruning an unnecessary op 65 | if training with no dropout 66 | 67 | :param x: The output of the sub-layer 68 | :return: A (maybe) dropped out version of the input 69 | """ 70 | return nn.functional.dropout(x, self.dropout) if self.training else x 71 | 72 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): 73 | """Pass an x tensor and optional mask through the transformer layer 74 | 75 | :param x: A `[B, T, C]` tensor where B is batch, T is time, and C is the num hidden units 76 | :param mask: An optional attention mask. True where the input is valid, and false where it isnt 77 | :return: The output of the block 78 | """ 79 | 80 | residual = x 81 | y = self.self_attention_layer_norm(x) 82 | y = residual + self.maybe_dropout(self.self_attention(y, mask)) 83 | residual = y 84 | y = self.output_layer_norm(y) 85 | y = residual + self.maybe_dropout(self.ffn(y)) 86 | return y 87 | 88 | 89 | class PreLayerNormTransformerEncoder(nn.Module): 90 | """A Pre-Layer Norm Transformer Encoder (with no task heads) 91 | 92 | This encoder encapsulates the entire front-end of the Transformer from one-hots up to the final 93 | encoding. For tasks like LM and fine-tuning we will inherit this module and provide additional 94 | functionality to the forward. 95 | 96 | This set up via inheritance to keep sub-classing and configuration params being passed to a minimum 97 | The tutorial mentions other ways that you could organize this 98 | 99 | """ 100 | 101 | def __init__( 102 | self, 103 | EmbeddingClass: Callable, 104 | vocab_size: int, 105 | padding_idx: int = 0, 106 | hidden_size: int = 768, 107 | num_heads: int = 12, 108 | num_layers: int = 12, 109 | dropout: float = 0.1, 110 | layer_norm_eps=1e-12, 111 | activation: nn.Module = nn.GELU(), 112 | feed_forward_size: Optional[int] = None, 113 | max_seq_len: int = 512, 114 | layer_factory=None, 115 | ): 116 | """Set up initialization for a (pre-layer-norm) Transformer 117 | 118 | :param vocab_size: The size of the input vocabulary 119 | :param padding_idx: The padding index, defaults to 0 120 | :param hidden_size: The number of hidden units 121 | :param num_heads: The number of heads for multi-headed attn. Should divide evenly into hidden_size 122 | :param num_layers: The number of transformer layers (MHA+FFN) in the architecture 123 | :param dropout: The value to apply for dropout 124 | :param layer_norm_eps: The noising term for layer norm 125 | :param activation: The activation function to use throughout 126 | :param feed_forward_size: An optional value to set for the FFN MLP output size, defaults to 4*hidden_size 127 | :param layer_factory: An optional implementation of all layers, useful for specific model implementation details 128 | 129 | """ 130 | super().__init__() 131 | if layer_factory is None: 132 | layer_factory = DefaultLayerFactory.get_instance() 133 | self.padding_idx = padding_idx 134 | self.embeddings = EmbeddingClass( 135 | vocab_size, hidden_size, padding_idx=padding_idx, max_seq_len=max_seq_len 136 | ) 137 | self.encoder = nn.ModuleList( 138 | [ 139 | PreLayerNormTransformerEncoderLayer( 140 | hidden_size, 141 | num_heads, 142 | dropout, 143 | layer_norm_eps, 144 | activation, 145 | feed_forward_size, 146 | layer_factory, 147 | ) 148 | for _ in range(num_layers) 149 | ] 150 | ) 151 | 152 | self.LayerNormImpl = layer_factory.layer_norm 153 | self.layer_norm = self.LayerNormImpl(hidden_size, layer_norm_eps) 154 | 155 | @property 156 | def hidden_size(self): 157 | """Useful to see the hidden size of the arch., but we dont a member var, its going to be all over the layers 158 | :return: 159 | """ 160 | return self.embeddings.word_embeddings.weight.shape[1] 161 | 162 | @property 163 | def vocab_size(self): 164 | """Useful to see the vocab size, but we dont need to store as a member, its the first dim of word embeddings 165 | 166 | :return: 167 | """ 168 | return self.embeddings.word_embeddings.weight.shape[0] 169 | 170 | def create_pad_mask(self, x: torch.Tensor) -> torch.Tensor: 171 | """For input padded using the padding_idx, generate an attention mask for that 172 | 173 | :param x: 174 | :return: 175 | """ 176 | mask = x != self.padding_idx 177 | return mask.unsqueeze(1).unsqueeze(1).to(device=x.device) 178 | 179 | def forward( 180 | self, 181 | x: torch.Tensor, 182 | mask: Optional[torch.Tensor] = None, 183 | token_type: Optional[torch.Tensor] = None, 184 | ) -> torch.Tensor: 185 | """ 186 | 187 | :param x: A one-hot (long) tensor of shape `[B, T]` 188 | :param mask: An optional mask to take in for attention 189 | :param token_type: 190 | :return: 191 | """ 192 | y = self.embeddings(x, token_type) 193 | for t in self.encoder: 194 | y = t(y, mask) 195 | 196 | y = self.layer_norm(y) 197 | return y 198 | 199 | def init_layer_weights(self, module): 200 | """This not directly used on initialization. If you want to use it, call `module.apply()` on it 201 | 202 | The base classes do make use of it for MLM and pooling in their constructors 203 | :param module: 204 | :return: 205 | """ 206 | if isinstance(module, (nn.Linear, nn.Embedding, self.LayerNormImpl)): 207 | module.weight.data.normal_(mean=0.0, std=0.02) 208 | if ( 209 | isinstance(module, (nn.Linear, self.LayerNormImpl)) 210 | and module.bias is not None 211 | ): 212 | module.bias.data.zero_() 213 | 214 | # TODO: GPT2 only, move this up into the LM? 215 | for name, p in module.named_parameters(): 216 | if "ffn.2.weight" in name or "output.weight" in name: 217 | p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * len(self.encoder)))) 218 | 219 | 220 | class PreLayerNormTransformerDecoderLayer(nn.Module): 221 | """A single (pre-layer-norm style) Transformer Decoder layer 222 | 223 | This layer implements a pre-layer-norm style Transformer Decoder (in the NMT/Encoder-Decoder sense). 224 | This module contains both self-attention (used in the decoder portion, and Encoder-Decoder cross-attention) 225 | 226 | As this is a post-layer-norm architecture, a normalization operation should be applied prior to sending the 227 | data through this layer 228 | 229 | """ 230 | 231 | def __init__( 232 | self, 233 | hidden_size: int = 768, 234 | num_heads: int = 12, 235 | dropout: float = 0.1, 236 | layer_norm_eps: float = 1e-12, 237 | activation: nn.Module = nn.GELU(), 238 | feed_forward_size: Optional[int] = None, 239 | layer_factory=None, 240 | ): 241 | """Initialize our transformer, uses bert-base defaults 242 | 243 | :param hidden_size: Size of the transformer inputs and outputs (d_model in the paper) 244 | :param num_heads: The number of heads for multi-headed attention 245 | :param dropout: A dropout to apply to each sub-blocks outputs 246 | :param layer_norm_eps: The noise applied in the layer norm calculation 247 | :param activation: The activation function to use 248 | :param feed_forward_size: The optional size of the FFN internal representation. Defaults to 4*hidden_size 249 | :param layer_factory: An optional implementation of all layers, useful for specific model implementation details 250 | """ 251 | super().__init__() 252 | 253 | if layer_factory is None: 254 | layer_factory = DefaultLayerFactory.get_instance() 255 | self.hidden_size = hidden_size 256 | self.dropout = dropout 257 | self.d_ff = feed_forward_size 258 | self.self_attention = layer_factory.decoder_multihead_attention( 259 | hidden_size, num_heads 260 | ) 261 | self.self_attention_layer_norm = layer_factory.layer_norm( 262 | hidden_size, layer_norm_eps 263 | ) 264 | self.encoder_attention = layer_factory.encoder_decoder_attention( 265 | hidden_size, num_heads 266 | ) 267 | self.encoder_attention_layer_norm = layer_factory.layer_norm( 268 | hidden_size, layer_norm_eps 269 | ) 270 | self.ffn = layer_factory.feed_forward( 271 | hidden_size, feed_forward_size, activation 272 | ) 273 | self.output_layer_norm = layer_factory.layer_norm(hidden_size, layer_norm_eps) 274 | 275 | def maybe_dropout(self, x: torch.Tensor) -> torch.Tensor: 276 | """Apply dropout operator in graph only if training 277 | 278 | TODO: this function could also test dropout to make sure its > 0, pruning an unnecessary op 279 | if training with no dropout 280 | 281 | :param x: The output of the sub-layer 282 | :return: A (maybe) dropped out version of the input 283 | """ 284 | return nn.functional.dropout(x, self.dropout) if self.training else x 285 | 286 | def forward( 287 | self, 288 | src: torch.Tensor, 289 | dst: torch.Tensor, 290 | src_mask: Optional[torch.Tensor] = None, 291 | dst_mask: Optional[torch.Tensor] = None, 292 | ): 293 | """Pass an x tensor and optional mask through the transformer layer 294 | 295 | :param x: A `[B, T, C]` tensor where B is batch, T is time, and C is the num hidden units 296 | :param mask: An optional attention mask. True where the input is valid, and false where it isnt 297 | :return: The output of the block 298 | """ 299 | 300 | h = self.self_attention_layer_norm(dst) 301 | y = dst + self.maybe_dropout(self.self_attention(h, dst_mask)) 302 | h = self.encoder_attention_layer_norm(y) 303 | y = y + self.maybe_dropout(self.encoder_attention(src, h, src_mask)) 304 | 305 | h = self.output_layer_norm(y) 306 | y = y + self.maybe_dropout(self.ffn(h)) 307 | return y 308 | 309 | 310 | class PreLayerNormTransformerEncoderDecoder(nn.Module): 311 | """A Pre-layer Norm Transformer Decoder (with no task heads) 312 | 313 | This encoder encapsulates the entire front-end of the Transformer from one-hots up to the final 314 | encoding. For tasks like MLM and fine-tuning we will inherit this module and provide additional 315 | functionality to the forward. 316 | 317 | This set up via inheritance to keep sub-classing and configuration params being passed to a minimum 318 | The tutorial mentions other ways that you could organize this 319 | 320 | """ 321 | 322 | def __init__( 323 | self, 324 | EmbeddingClass: Callable, 325 | vocab_size: int, 326 | padding_idx: int = 0, 327 | hidden_size: int = 768, 328 | num_heads: int = 12, 329 | num_encoder_layers: int = 6, 330 | num_decoder_layers: int = 6, 331 | dropout: float = 0.1, 332 | layer_norm_eps=1e-12, 333 | activation: nn.Module = nn.GELU(), 334 | feed_forward_size: Optional[int] = None, 335 | max_seq_len: int = 512, 336 | layer_factory=None, 337 | ): 338 | """Set up initialization for a (post-layer-norm) Transformer. Defaults to bert-base settings 339 | 340 | :param vocab_size: The size of the input vocabulary 341 | :param padding_idx: The padding index, defaults to 0 342 | :param hidden_size: The number of hidden units 343 | :param num_heads: The number of heads for multi-headed attn. Should divide evenly into hidden_size 344 | :param num_layers: The number of transformer layers (MHA+FFN) in the architecture 345 | :param dropout: The value to apply for dropout 346 | :param layer_norm_eps: The noising term for layer norm 347 | :param activation: The activation function to use throughout 348 | :param feed_forward_size: An optional value to set for the FFN MLP output size, defaults to 4*hidden_size 349 | :param layer_factory: An optional implementation of all layers, useful for specific model implementation details 350 | 351 | """ 352 | super().__init__() 353 | self.padding_idx = padding_idx 354 | if layer_factory is None: 355 | layer_factory = DefaultLayerFactory.get_instance() 356 | self.encoder_embeddings = EmbeddingClass( 357 | vocab_size, hidden_size, padding_idx=padding_idx, max_seq_len=max_seq_len 358 | ) 359 | self.decoder_embeddings = EmbeddingClass( 360 | vocab_size, hidden_size, padding_idx=padding_idx, max_seq_len=max_seq_len 361 | ) 362 | 363 | self.decoder_embeddings.word_embeddings = ( 364 | self.encoder_embeddings.word_embeddings 365 | ) 366 | 367 | self.encoder = nn.ModuleList( 368 | [ 369 | PreLayerNormTransformerEncoderLayer( 370 | hidden_size, 371 | num_heads, 372 | dropout, 373 | layer_norm_eps, 374 | activation, 375 | feed_forward_size, 376 | layer_factory, 377 | ) 378 | for _ in range(num_encoder_layers) 379 | ] 380 | ) 381 | self.decoder = nn.ModuleList( 382 | [ 383 | PreLayerNormTransformerDecoderLayer( 384 | hidden_size, 385 | num_heads, 386 | dropout, 387 | layer_norm_eps, 388 | activation, 389 | feed_forward_size, 390 | layer_factory, 391 | ) 392 | for _ in range(num_decoder_layers) 393 | ] 394 | ) 395 | 396 | self.register_buffer( 397 | "causal_mask", 398 | torch.tril( 399 | torch.ones( 400 | ( 401 | max_seq_len, 402 | max_seq_len, 403 | ), 404 | dtype=torch.uint8, 405 | ) 406 | ) 407 | .unsqueeze(0) 408 | .unsqueeze(0), 409 | ) 410 | self.LayerNormImpl = layer_factory.layer_norm 411 | self.encoder_layer_norm = self.LayerNormImpl(hidden_size, layer_norm_eps) 412 | self.decoder_layer_norm = self.LayerNormImpl(hidden_size, layer_norm_eps) 413 | 414 | @property 415 | def hidden_size(self): 416 | """Useful to see the hidden size of the arch., but we dont a member var, its going to be all over the layers 417 | :return: 418 | """ 419 | return self.encoder_embeddings.word_embeddings.weight.shape[1] 420 | 421 | @property 422 | def vocab_size(self): 423 | """Useful to see the vocab size, but we dont need to store as a member, its the first dim of word embeddings 424 | 425 | :return: 426 | """ 427 | return self.encoder_embeddings.word_embeddings.weight.shape[0] 428 | 429 | def create_pad_mask(self, x: torch.Tensor) -> torch.Tensor: 430 | """For input padded using the padding_idx, generate an attention mask for that 431 | 432 | :param x: 433 | :return: 434 | """ 435 | mask = x != self.padding_idx 436 | return mask.unsqueeze(1).unsqueeze(1).to(device=x.device) 437 | 438 | def encode( 439 | self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None 440 | ) -> torch.Tensor: 441 | src_enc = self.encoder_embeddings(src) 442 | for t in self.encoder: 443 | src_enc = t(src_enc, src_mask) 444 | src_enc = self.encoder_layer_norm(src_enc) 445 | return src_enc 446 | 447 | def decode( 448 | self, 449 | src_enc: torch.Tensor, 450 | dst: torch.Tensor, 451 | src_mask: Optional[torch.Tensor] = None, 452 | dst_mask: Optional[torch.Tensor] = None, 453 | ): 454 | futures_mask = self.causal_mask[:, :, : dst.shape[1], : dst.shape[1]] 455 | if dst_mask is not None: 456 | futures_mask = dst_mask & futures_mask.to(dtype=torch.bool) 457 | dst_enc = self.decoder_embeddings(dst) 458 | for t in self.decoder: 459 | dst_enc = t(src_enc, dst_enc, src_mask, futures_mask) 460 | dst_enc = self.decoder_layer_norm(dst_enc) 461 | return dst_enc 462 | 463 | def forward( 464 | self, 465 | src: torch.Tensor, 466 | dst: torch.Tensor, 467 | src_mask: Optional[torch.Tensor] = None, 468 | dst_mask: Optional[torch.Tensor] = None, 469 | ) -> torch.Tensor: 470 | """ 471 | 472 | :param src: A one-hot (long) tensor of shape `[B, T_k]` 473 | :param dst: A one-hot (long) tensor of shape `[B, T_q]` 474 | :param src_mask: An optional mask to take in for attention 475 | :param src_mask: An optional mask to take in for attention 476 | :return: 477 | """ 478 | 479 | src_enc = self.encode(src, src_mask) 480 | dst_enc = self.decode(src_enc, dst, src_mask, dst_mask) 481 | return dst_enc 482 | 483 | def init_layer_weights(self, module): 484 | """This not directly used on initialization. If you want to use it, call `module.apply()` on it 485 | 486 | The base classes do make use of it for MLM and pooling in their constructors 487 | :param module: 488 | :return: 489 | """ 490 | if isinstance(module, (nn.Linear, nn.Embedding, self.LayerNormImpl)): 491 | module.weight.data.normal_(mean=0.0, std=0.02) 492 | if ( 493 | isinstance(module, (nn.Linear, self.LayerNormImpl)) 494 | and module.bias is not None 495 | ): 496 | module.bias.data.zero_() 497 | 498 | 499 | class PreLayerNormTransformerSequenceGenerator(PreLayerNormTransformerEncoderDecoder): 500 | def __init__( 501 | self, 502 | EmbeddingClass: Callable, 503 | vocab_size: int, 504 | padding_idx: int = 0, 505 | hidden_size: int = 768, 506 | num_heads: int = 12, 507 | num_encoder_layers: int = 6, 508 | num_decoder_layers: int = 6, 509 | dropout: float = 0.1, 510 | layer_norm_eps=1e-12, 511 | activation: nn.Module = nn.GELU(), 512 | feed_forward_size: Optional[int] = None, 513 | max_seq_len: int = 1024, 514 | layer_factory=None, 515 | ): 516 | super().__init__( 517 | EmbeddingClass, 518 | vocab_size, 519 | padding_idx, 520 | hidden_size, 521 | num_heads, 522 | num_encoder_layers, 523 | num_decoder_layers, 524 | dropout, 525 | layer_norm_eps, 526 | activation, 527 | feed_forward_size, 528 | max_seq_len, 529 | layer_factory, 530 | ) 531 | self.output_proj = WeightTiedVocabProjection( 532 | self.decoder_embeddings.word_embeddings 533 | ) 534 | self.apply(self.init_layer_weights) 535 | 536 | def decode( 537 | self, 538 | src_enc: torch.Tensor, 539 | dst: torch.Tensor, 540 | src_mask: Optional[torch.Tensor] = None, 541 | dst_mask: Optional[torch.Tensor] = None, 542 | ) -> torch.Tensor: 543 | dst_enc = super().decode(src_enc, dst, src_mask, dst_mask) 544 | y = self.output_proj(dst_enc) 545 | return y 546 | -------------------------------------------------------------------------------- /src/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | NAME = tfs 3 | version = 0.0.1 4 | description = MinT: Minimalist Transformers Library 5 | author = dpressel 6 | author_email = dpressel@gmail.com 7 | license = Apache 2.0 8 | keywords= 9 | deep-learning 10 | transformers 11 | long_description_content_type = text/markdown 12 | long_description = file: README.md 13 | license_files = 14 | ../LICENSE 15 | ../NOTICE 16 | classifiers = 17 | Development Status :: 3 - Alpha 18 | Environment :: Console 19 | Intended Audience :: Developers 20 | Intended Audience :: Science/Research 21 | License :: OSI Approved :: Apache Software License 22 | Natural Language :: English 23 | Operating System :: OS Independent 24 | Programming Language :: Python :: 3.5 25 | Programming Language :: Python :: 3.6 26 | Programming Language :: Python :: 3.7 27 | Programming Language :: Python :: 3.8 28 | Topic :: Scientific/Engineering :: Artificial Intelligence 29 | 30 | [options] 31 | packages = find: 32 | 33 | install_requires = 34 | numpy 35 | tqdm 36 | tokenizers >= 0.10.0 37 | 38 | [options.entry_points] 39 | console_scripts = 40 | bert_completer = tfs.examples.bert_completer:main 41 | tune_bert_for_cls = tfs.examples.tune_bert_for_cls:main 42 | pretrain_bert_wiki = tfs.examples.pretrain_bert_wiki:main 43 | 44 | 45 | [options.extras_require] 46 | examples: 47 | lxml 48 | bs4 49 | prompt_toolkit >= 2.0.0 50 | matplotlib 51 | test: 52 | pytest 53 | 54 | -------------------------------------------------------------------------------- /src/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | --------------------------------------------------------------------------------