├── CHANGELOG.md ├── README.md ├── cs336_basics ├── VERSION └── __init__.py ├── cs336_spring2024_assignment1_basics.pdf ├── pytest.ini ├── requirements-test.txt ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── adapters.py ├── common.py ├── fixtures ├── adamw_expected_params.pt ├── address.txt ├── corpus.en ├── german.txt ├── gpt2_merges.txt ├── gpt2_vocab.json ├── in_features.pt ├── in_indices.pt ├── in_indices_truncated.pt ├── positionwise_feedforward_expected_output.pt ├── positionwise_feedforward_weights.pt ├── rmsnorm_expected_output.pt ├── rmsnorm_weights.pt ├── scaled_dot_product_attention_K.pt ├── scaled_dot_product_attention_Q.pt ├── scaled_dot_product_attention_V.pt ├── scaled_dot_product_attention_expected_output.pt ├── scaled_dot_product_attention_mask.pt ├── tinystories_sample.txt ├── tinystories_sample_5M.txt ├── train-bpe-reference-merges.txt ├── train-bpe-reference-vocab.json ├── transformer_block_expected_output.pt ├── transformer_block_weights.pt ├── transformer_lm_expected_output.pt ├── transformer_lm_truncated_expected_output.pt ├── transformer_lm_weights.pt ├── unbatched_multihead_self_attention_expected_output.pt └── unbatched_multihead_self_attention_weights.pt ├── test_data.py ├── test_model.py ├── test_nn_utils.py ├── test_optimizer.py ├── test_serialization.py ├── test_tokenizer.py └── test_train_bpe.py /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All changes we make to the assignment code or PDF will be documented in this file. 4 | 5 | ## [unreleased] - yyyy-mm-dd 6 | 7 | ### Added 8 | 9 | ### Changed 10 | 11 | ### Fixed 12 | 13 | - code: fix `test_get_batch` to handle "AssertionError: Torch not compiled with CUDA enabled". 14 | - handout: clarify that gradient clipping norm is calculated over all the parameters. 15 | - code: fix gradient clipping test comparing wrong tensors 16 | - code: test skipping parameters with no gradient and properly computing norm with multiple parameters 17 | 18 | ## [0.1.6] - 2024-04-13 19 | 20 | ### Added 21 | 22 | ### Changed 23 | 24 | ### Fixed 25 | 26 | - handout: edit expected TinyStories run time to 30-40 minutes. 27 | - handout: add more details about how to use `np.memmap` or the `mmap_mode` flag 28 | to `np.load`. 29 | - code: fix `get_tokenizer()` docstring. 30 | - handout: specify that problem `main_experiment` should use the same settings 31 | as TinyStories. 32 | - code: replace mentions of layernorm with RMSNorm. 33 | 34 | ## [0.1.5] - 2024-04-06 35 | 36 | ### Added 37 | 38 | ### Changed 39 | 40 | - handout: clarify example of preferring lexicographically greater merges to 41 | specify that we want tuple comparison. 42 | 43 | ### Fixed 44 | 45 | - handout: fix expected number of training tokens for TinyStories, should be 46 | 327,680,000. 47 | - code: fix typo in `run_get_lr_cosine_schedule` return docstring. 48 | - code: fix typo in `test_tokenizer.py` 49 | 50 | ## [0.1.4] - 2024-04-04 51 | 52 | ### Added 53 | 54 | ### Changed 55 | 56 | - code: skip `Tokenizer` memory-related tests on non-Linux systems, since 57 | support for RLIMIT_AS is inconsistent. 58 | - code: reduce increase atol on end-to-end Transformer forward pass tests. 59 | - code: remove dropout in model-related tests to improve determinism across 60 | platforms. 61 | - code: add `attn_pdrop` to `run_multihead_self_attention` adapter. 62 | - code: clarify `{q,k,v}_proj` dimension orders in the adapters. 63 | - code: increase atol on cross-entropy tests 64 | - code: remove unnecessary warning in `test_get_lr_cosine_schedule` 65 | 66 | ### Fixed 67 | 68 | - handout: fix signature of `Tokenizer.__init__` to include `self`. 69 | - handout: mention that `Tokenizer.from_files` should be a class method. 70 | - handout: clarify list of model hyperparameters listed in `adamwAccounting`. 71 | - handout: clarify that `adamwAccounting` (b) considers a GPT-2 XL-shaped model 72 | (with our architecture), not necessarily the literal GPT-2 XL model. 73 | - handout: moved softmax problem to where softmax is first mentioned (Scaled Dot-Product Attention, Section 3.4.3) 74 | - handout: removed redundant initialization (t = 0) in AdamW pseudocode 75 | - handout: added resources needed for BPE training 76 | 77 | ## [0.1.3] - 2024-04-02 78 | 79 | ### Added 80 | 81 | ### Changed 82 | 83 | - handout: edit `adamWAccounting`, part (d) to define MFU and mention that the 84 | backward pass is typically assumed to have twice the FLOPS of the forward pass. 85 | - handout: provide a hint about desired behavior when a user passes in input IDs 86 | to `Tokenizer.decode` that correspond to invalid UTF-8 bytes. 87 | 88 | ### Fixed 89 | 90 | ## [0.1.2] - 2024-04-02 91 | 92 | ### Added 93 | 94 | - handout: added some more information about submitting to the leaderboard. 95 | 96 | ### Changed 97 | 98 | ### Fixed 99 | 100 | ## [0.1.1] - 2024-04-01 101 | 102 | ### Added 103 | 104 | - code: add a note to README.md that pull requests and issues are welcome and 105 | encouraged. 106 | 107 | ### Changed 108 | 109 | - handout: edit motivation for pre-tokenization to include a note about 110 | desired behavior with tokens that differ only in punctuation. 111 | - handout: remove total number of points after each section. 112 | - handout: mention that large language models (e.g., LLaMA and GPT-3) often use 113 | AdamW betas of (0.9, 0.95) (in contrast to the PyTorch defaults of (0.9, 0.999)). 114 | - handout: explicitly mention the deliverable in the `adamw` problem. 115 | - code: rename `test_serialization::test_checkpoint` to 116 | `test_serialization::test_checkpointing` to match the handout. 117 | - code: slightly relax the time limit in `test_train_bpe_speed`. 118 | 119 | ### Fixed 120 | 121 | - code: fix an issue in the `train_bpe` tests where the expected merges and vocab did 122 | not properly reflect tiebreaking with the lexicographically greatest pair. 123 | - This occurred because our reference implementation (which checks against HF) 124 | follows the GPT-2 tokenizer in remapping bytes that aren't human-readable to 125 | printable unicode strings. To match the HF code, we were erroneously tiebreaking 126 | on this remapped unicode representation instead of the original bytes. 127 | - handout: fix the expected number of non-embedding parameters for model with 128 | recommended TinyStories hyperparameters (section 7.2). 129 | - handout: replace `<|endofsequence|>` with `<|endoftext|>` in the `decoding` problem. 130 | - code: fix the setup command (`pip install -e .'[test]'`)to improve zsh compatibility. 131 | - handout: fix various trivial typos and formatting errors. 132 | 133 | ## [0.1.0] - 2024-04-01 134 | 135 | Initial release. 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CS336 Spring 2024 Assignment 1: Basics 2 | 3 | For a full description of the assignment, see the assignment handout at 4 | [cs336_spring2024_assignment1_basics.pdf](./cs336_spring2024_assignment1_basics.pdf) 5 | 6 | If you see any issues with the assignment handout or code, please feel free to 7 | raise a GitHub issue or open a pull request with a fix. 8 | 9 | ## Setup 10 | 11 | 0. Set up a conda environment and install packages: 12 | 13 | ``` sh 14 | conda create -n cs336_basics python=3.10 --yes 15 | conda activate cs336_basics 16 | pip install -e .'[test]' 17 | ``` 18 | 19 | 1. Run unit tests: 20 | 21 | ``` sh 22 | pytest 23 | ``` 24 | 25 | Initially, all tests should fail with `NotImplementedError`s. 26 | To connect your implementation to the tests, complete the 27 | functions in [./tests/adapters.py](./tests/adapters.py). 28 | 29 | 2. Download the TinyStories data and a subsample of OpenWebText: 30 | 31 | ``` sh 32 | mkdir -p data 33 | cd data 34 | 35 | wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt 36 | wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt 37 | 38 | wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz 39 | gunzip owt_train.txt.gz 40 | wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz 41 | gunzip owt_valid.txt.gz 42 | 43 | cd .. 44 | ``` 45 | 46 | -------------------------------------------------------------------------------- /cs336_basics/VERSION: -------------------------------------------------------------------------------- 1 | 0.1.7-dev 2 | -------------------------------------------------------------------------------- /cs336_basics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/cs336_basics/__init__.py -------------------------------------------------------------------------------- /cs336_spring2024_assignment1_basics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/cs336_spring2024_assignment1_basics.pdf -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_cli = True 3 | log_cli_level = WARNING 4 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | # These requirements are only for tests. 2 | pytest 3 | tiktoken 4 | psutil 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | regex 2 | torch==2.2.1 3 | numpy 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def read(*paths, **kwargs): 8 | content = "" 9 | with io.open( 10 | os.path.join(os.path.dirname(__file__), *paths), 11 | encoding=kwargs.get("encoding", "utf8"), 12 | ) as open_file: 13 | content = open_file.read().strip() 14 | return content 15 | 16 | 17 | def read_requirements(path): 18 | return [ 19 | line.strip() 20 | for line in read(path).split("\n") 21 | if not line.startswith(('"', "#", "-", "git+")) 22 | ] 23 | 24 | 25 | setup( 26 | name="cs336_basics", 27 | version=read("cs336_basics", "VERSION"), 28 | description="CS336: basics", 29 | long_description=read("README.md"), 30 | long_description_content_type="text/markdown", 31 | packages=find_packages(exclude=["tests", ".github"]), 32 | install_requires=read_requirements("requirements.txt"), 33 | extras_require={ 34 | "test": read_requirements("requirements-test.txt"), 35 | }, 36 | ) 37 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /tests/adapters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import annotations 3 | 4 | import os 5 | from typing import IO, BinaryIO, Iterable, Optional, Type 6 | 7 | import numpy.typing as npt 8 | import torch 9 | 10 | 11 | def run_positionwise_feedforward( 12 | d_model: int, 13 | d_ff: int, 14 | weights: dict[str, torch.FloatTensor], 15 | in_features: torch.FloatTensor, 16 | ) -> torch.FloatTensor: 17 | """Given the weights of a position-wise feedforward network, return 18 | the output of your implementation with these weights. 19 | 20 | Args: 21 | d_model: int 22 | Dimensionality of the feedforward input and output. 23 | d_ff: int 24 | Dimensionality of the feedforward network's inner layer. 25 | weights: dict[str, torch.FloatTensor] 26 | State dict of our reference implementation. 27 | The keys of this dictionary are `w1.weight` and `w2.weight`. 28 | `w1` is the first linear transformation, and `w2` is the second 29 | linear transformation (eq. 2 of Vaswani et al., 2017). 30 | `w1.weight` is of shape (d_ff, d_model). 31 | `w2.weight` is of shape (d_model, d_ff). 32 | ) 33 | in_features: torch.FloatTensor 34 | Tensor to run your implementation on. 35 | 36 | Returns: 37 | torch.FloatTensor with the output of running your position-wise feedforward network 38 | with the provided `weights` on the provided `in_features`. 39 | """ 40 | # Example: 41 | # If your state dict keys match, you can use `load_state_dict()` 42 | # my_ffn.load_state_dict(weights) 43 | # You can also manually assign the weights 44 | # my_ffn.w1.weight.data = weights["w1.weight"] 45 | # my_ffn.w2.weight.data = weights["w2.weight"] 46 | raise NotImplementedError 47 | 48 | 49 | def run_scaled_dot_product_attention( 50 | K: torch.FloatTensor, 51 | Q: torch.FloatTensor, 52 | V: torch.FloatTensor, 53 | mask: Optional[torch.BoolTensor] = None, 54 | pdrop: Optional[float] = None, 55 | ) -> torch.FloatTensor: 56 | """Given key (K), query (Q), and value (V) tensors, return 57 | the output of your scaled dot product attention implementation. 58 | 59 | Args: 60 | K: torch.FloatTensor 61 | Tensor with attention keys. Shape is 62 | (batch_size, ..., seq_len, key_dimension), where 63 | "..." is optional and represents any number of other 64 | batch dimensions (e.g., num_heads). 65 | Q: torch.FloatTensor 66 | Tensor with attention queries. Shape is 67 | (batch_size, ..., seq_len, key_dimension), where 68 | "..." is optional and represents any number of other 69 | batch dimensions (e.g., num_heads). 70 | V: torch.FloatTensor 71 | Tensor with attention values. Shape is 72 | (batch_size, ..., seq_len, value_dimension), where 73 | "..." is optional and represents any number of other 74 | batch dimensions (e.g., num_heads). 75 | mask: Optional[torch.BoolTensor] 76 | An (optional) mask of shape (seq_len, seq_len). 77 | Attention scores for positions with a mask value of `True` should 78 | be masked out, i.e., not affect the softmaxed attention probabilities. 79 | pdrop: Optional[float], default is None. 80 | If given, drop-out the attention probabilities (the softmax-normalized 81 | attention scores) with this rate. 82 | 83 | Returns: 84 | torch.FloatTensor of shape (batch_size, ..., seq_len, value_dimension) 85 | with the output of running your scaled dot product attention 86 | implementation with the provided key, query, and value tensors. 87 | """ 88 | raise NotImplementedError 89 | 90 | 91 | def run_multihead_self_attention( 92 | d_model: int, 93 | num_heads: int, 94 | attn_pdrop: float, 95 | weights: dict[str, torch.FloatTensor], 96 | in_features: torch.FloatTensor, 97 | ) -> torch.FloatTensor: 98 | """Given the key, query, and value projection weights of a naive unbatched 99 | implementation of multi-head attention, return the output of an optimized batched 100 | implementation. This implementation should handle the key, query, and value projections 101 | for all heads in a single matrix multiply. 102 | See section 3.2.2 of Vaswani et al., 2017. 103 | 104 | Args: 105 | d_model: int 106 | Dimensionality of the feedforward input and output. 107 | num_heads: int 108 | Number of heads to use in multi-headed attention. 109 | attn_pdrop: float 110 | Drop-out the attention probabilities (the softmax-normalized 111 | attention scores) with this rate. 112 | weights: dict[str, torch.FloatTensor] 113 | State dict of our reference implementation. 114 | The keys of this dictionary are: 115 | - `q_heads.{N}.weight`, `q_heads.{N}.weight`: 116 | Weights for the query projection heads. 117 | N is an integer from 0 to `num_heads - 1`. 118 | Shape of each tensor is (d_key, d_model). 119 | - `k_heads.{N}.weight`, `k_heads.{N}.weight`: 120 | Weights for the key projection heads. 121 | N is an integer from 0 to `num_heads - 1`. 122 | Shape of each tensor is (d_key, d_model). 123 | - `v_heads.{N}.weight`, `v_heads.{N}.weight`: 124 | Weights for the value projection heads. 125 | N is an integer from 0 to `num_heads - 1`. 126 | Shape of each tensor is (d_value, d_model). 127 | - `output_proj.weight`: 128 | Weight of the output projection 129 | (W^{O} in the original Transformer paper) 130 | Shape of (d_model, d_value * num_heads). 131 | in_features: torch.FloatTensor 132 | Tensor to run your implementation on. 133 | 134 | Returns: 135 | torch.FloatTensor with the output of running your optimized, batched multi-headed attention 136 | implementation with the given QKV projection weights and input features. 137 | """ 138 | raise NotImplementedError 139 | 140 | 141 | def run_transformer_block( 142 | d_model: int, 143 | num_heads: int, 144 | d_ff: int, 145 | attn_pdrop: float, 146 | residual_pdrop: float, 147 | weights: dict[str, torch.FloatTensor], 148 | in_features: torch.FloatTensor, 149 | ) -> torch.FloatTensor: 150 | """Given the weights of a pre-norm Transformer block and input features, 151 | return the output of running the Transformer block on the input features. 152 | 153 | Args: 154 | d_model: int 155 | The dimensionality of the Transformer block input. 156 | num_heads: int 157 | Number of heads to use in multi-headed attention. `d_model` must be 158 | evenly divisible by `num_heads`. 159 | d_ff: int 160 | Dimensionality of the feed-forward inner layer (section 3.3). 161 | attn_pdrop: float 162 | Drop-out the attention probabilities (the softmax-normalized 163 | attention scores) with this rate. 164 | residual_pdrop: float 165 | Apply dropout to the output of each sub-layer, before it 166 | is added to the sub-layer input and normalized (section 5.4). 167 | weights: dict[str, torch.FloatTensor] 168 | State dict of our reference implementation. 169 | The keys of this dictionary are: 170 | - `attn.q_proj.weight` 171 | The query projections for all `num_heads` attention heads. 172 | Shape is (num_heads * (d_model / num_heads), d_model). 173 | The rows are ordered by matrices of shape (num_heads, d_k), 174 | so `attn.q_proj.weight == torch.cat([q_heads.0.weight, ..., q_heads.N.weight], dim=0)`. 175 | - `attn.k_proj.weight` 176 | The key projections for all `num_heads` attention heads. 177 | Shape is (num_heads * (d_model / num_heads), d_model). 178 | The rows are ordered by matrices of shape (num_heads, d_k), 179 | so `attn.k_proj.weight == torch.cat([k_heads.0.weight, ..., k_heads.N.weight], dim=0)`. 180 | - `attn.v_proj.weight` 181 | The value projections for all `num_heads` attention heads. 182 | Shape is (num_heads * (d_model / num_heads), d_model). 183 | The rows are ordered by matrices of shape (num_heads, d_v), 184 | so `attn.v_proj.weight == torch.cat([v_heads.0.weight, ..., v_heads.N.weight], dim=0)`. 185 | - `attn.output_proj.weight` 186 | Weight of the multi-head self-attention output projection 187 | Shape is (d_model, (d_model / num_heads) * num_heads). 188 | - `ln1.weight` 189 | Weights of affine transform for the first RMSNorm 190 | applied in the transformer block. 191 | Shape is (d_model,). 192 | - `ffn.w1.weight` 193 | Weight of the first linear transformation in the FFN. 194 | Shape is (d_ff, d_model). 195 | - `ffn.w2.weight` 196 | Weight of the second linear transformation in the FFN. 197 | Shape is (d_model, d_ff). 198 | - `ln2.weight` 199 | Weights of affine transform for the second RMSNorm 200 | applied in the transformer block. 201 | Shape is (d_model,). 202 | in_features: torch.FloatTensor 203 | Tensor to run your implementation on. 204 | Shape is (batch_size, sequence_length, d_model). 205 | 206 | Returns: 207 | FloatTensor of shape (batch_size, sequence_length, d_model) with the output of 208 | running the Transformer block on the input features. 209 | """ 210 | raise NotImplementedError 211 | 212 | 213 | def run_transformer_lm( 214 | vocab_size: int, 215 | context_length: int, 216 | d_model: int, 217 | num_layers: int, 218 | num_heads: int, 219 | d_ff: int, 220 | attn_pdrop: float, 221 | residual_pdrop: float, 222 | weights: dict[str, torch.FloatTensor], 223 | in_indices: torch.LongTensor, 224 | ) -> torch.FloatTensor: 225 | """Given the weights of a Transformer language model and input indices, 226 | return the output of running a forward pass on the input indices. 227 | 228 | Args: 229 | vocab_size: int 230 | The number of unique items in the output vocabulary to be predicted. 231 | context_length: int, 232 | The maximum number of tokens to process at once. 233 | d_model: int 234 | The dimensionality of the model embeddings and sublayer outputs. 235 | num_layers: int 236 | The number of Transformer layers to use. 237 | num_heads: int 238 | Number of heads to use in multi-headed attention. `d_model` must be 239 | evenly divisible by `num_heads`. 240 | d_ff: int 241 | Dimensionality of the feed-forward inner layer (section 3.3). 242 | attn_pdrop: float 243 | Drop-out the attention probabilities (the softmax-normalized 244 | attention scores) with this rate. 245 | residual_pdrop: float 246 | Apply dropout to the sum of the token and position embeddings 247 | as well as the output of each sub-layer, before it is added to the 248 | sub-layer input and normalized (section 5.4). 249 | weights: dict[str, torch.FloatTensor] 250 | State dict of our reference implementation. {num_layers} refers to an 251 | integer between `0` and `num_layers - 1` (the layer index). 252 | The keys of this dictionary are: 253 | - `token_embeddings.weight` 254 | Token embedding matrix. Shape is (vocab_size, d_model). 255 | - `position_embeddings.weight` 256 | Positional embedding matrix. Shape is (context_length, d_model). 257 | - `layers.{num_layers}.attn.q_proj.weight` 258 | The query projections for all `num_heads` attention heads. 259 | Shape is (num_heads * (d_model / num_heads), d_model). 260 | The rows are ordered by matrices of shape (num_heads, d_k), 261 | so `attn.q_proj.weight == torch.cat([q_heads.0.weight, ..., q_heads.N.weight], dim=0)`. 262 | - `layers.{num_layers}.attn.k_proj.weight` 263 | The key projections for all `num_heads` attention heads. 264 | Shape is (num_heads * (d_model / num_heads), d_model). 265 | The rows are ordered by matrices of shape (num_heads, d_k), 266 | so `attn.k_proj.weight == torch.cat([k_heads.0.weight, ..., k_heads.N.weight], dim=0)`. 267 | - `layers.{num_layers}.attn.v_proj.weight` 268 | The value projections for all `num_heads` attention heads. 269 | Shape is (num_heads * (d_model / num_heads), d_model). 270 | The rows are ordered by matrices of shape (num_heads, d_v), 271 | so `attn.v_proj.weight == torch.cat([v_heads.0.weight, ..., v_heads.N.weight], dim=0)`. 272 | - `layers.{num_layers}.attn.output_proj.weight` 273 | Weight of the multi-head self-attention output projection 274 | Shape is ((d_model / num_heads) * num_heads, d_model). 275 | - `layers.{num_layers}.ln1.weight` 276 | Weights of affine transform for the first RMSNorm 277 | applied in the transformer block. 278 | Shape is (d_model,). 279 | - `layers.{num_layers}.ffn.w1.weight` 280 | Weight of the first linear transformation in the FFN. 281 | Shape is (d_ff, d_model). 282 | - `layers.{num_layers}.ffn.w2.weight` 283 | Weight of the second linear transformation in the FFN. 284 | Shape is (d_model, d_ff). 285 | - `layers.{num_layers}.ln2.weight` 286 | Weights of affine transform for the second RMSNorm 287 | applied in the transformer block. 288 | Shape is (d_model,). 289 | - `ln_final.weight` 290 | Weights of affine transform for RMSNorm applied to the output of the final transformer block. 291 | Shape is (d_model, ). 292 | - `lm_head.weight` 293 | Weights of the language model output embedding. 294 | Shape is (vocab_size, d_model). 295 | in_indices: torch.LongTensor 296 | Tensor with input indices to run the language model on. Shape is (batch_size, sequence_length), where 297 | `sequence_length` is at most `context_length`. 298 | 299 | Returns: 300 | FloatTensor of shape (batch size, sequence_length, vocab_size) with the predicted unnormalized 301 | next-word distribution for each token. 302 | """ 303 | raise NotImplementedError 304 | 305 | 306 | def run_rmsnorm( 307 | d_model: int, 308 | eps: float, 309 | weights: dict[str, torch.FloatTensor], 310 | in_features: torch.FloatTensor, 311 | ) -> torch.FloatTensor: 312 | """Given the weights of a RMSNorm affine transform, 313 | return the output of running RMSNorm on the input features. 314 | 315 | Args: 316 | d_model: int 317 | The dimensionality of the RMSNorm input. 318 | eps: float, default is 1e-5 319 | A value added to the denominator for numerical stability. 320 | weights: dict[str, torch.FloatTensor] 321 | State dict of our reference implementation. 322 | The keys of this dictionary are: 323 | - `weight` 324 | Weights of the RMSNorm affine transform. 325 | Shape is (d_model,). 326 | in_features: torch.FloatTensor 327 | Input features to run RMSNorm on. Tensor of (*, d_model), where * 328 | can be an arbitrary number of dimensions with arbitrary values. 329 | 330 | Returns: 331 | FloatTensor of with the same shape as `in_features` with the output of running 332 | RMSNorm of the `in_features`. 333 | """ 334 | raise NotImplementedError 335 | 336 | 337 | def run_gelu(in_features: torch.FloatTensor) -> torch.FloatTensor: 338 | """Given a tensor of inputs, return the output of applying GELU 339 | to each element. 340 | 341 | Args: 342 | in_features: torch.FloatTensor 343 | Input features to run GELU on. Shape is arbitrary. 344 | 345 | Returns: 346 | FloatTensor of with the same shape as `in_features` with the output of applying 347 | GELU to each element. 348 | """ 349 | raise NotImplementedError 350 | 351 | 352 | def run_get_batch( 353 | dataset: npt.NDArray, batch_size: int, context_length: int, device: str 354 | ) -> tuple[torch.Tensor, torch.Tensor]: 355 | """ 356 | Given a dataset (a 1D numpy array of integers) and a desired batch size and 357 | context length, sample language modeling input sequences and their corresponding 358 | labels from the dataset. 359 | 360 | Args: 361 | dataset: np.array 362 | 1D numpy array of integer token IDs in the dataset. 363 | batch_size: int 364 | Desired batch size to sample. 365 | context_length: int 366 | Desired context length of each sampled example. 367 | device: str 368 | PyTorch device string (e.g., 'cpu' or 'cuda:0') indicating the device 369 | to place the sampled input sequences and labels on. 370 | 371 | Returns: 372 | Tuple of torch.LongTensors of shape (batch_size, context_length). The first tuple item 373 | is the sampled input sequences, and the second tuple item is the corresponding 374 | language modeling labels. 375 | """ 376 | raise NotImplementedError 377 | 378 | 379 | def run_softmax(in_features: torch.FloatTensor, dim: int) -> torch.FloatTensor: 380 | """Given a tensor of inputs, return the output of softmaxing the given `dim` 381 | of the input. 382 | 383 | Args: 384 | in_features: torch.FloatTensor 385 | Input features to softmax. Shape is arbitrary. 386 | dim: int 387 | Dimension of the `in_features` to apply softmax to. 388 | 389 | Returns: 390 | FloatTensor of with the same shape as `in_features` with the output of 391 | softmax normalizing the specified `dim`. 392 | """ 393 | raise NotImplementedError 394 | 395 | 396 | def run_cross_entropy(inputs: torch.FloatTensor, targets: torch.LongTensor): 397 | """Given a tensor of inputs and targets, compute the average cross-entropy 398 | loss across examples. 399 | 400 | Args: 401 | inputs: torch.FloatTensor 402 | FloatTensor of shape (batch_size, num_classes). inputs[i][j] is the 403 | unnormalized logit of jth class for the ith example. 404 | targets: torch.LongTensor 405 | LongTensor of shape (batch_size, ) with the index of the correct class. 406 | Each value must be between 0 and `num_classes - 1`. 407 | 408 | Returns: 409 | Tensor of shape () with the average cross-entropy loss across examples. 410 | """ 411 | raise NotImplementedError 412 | 413 | 414 | def run_gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float): 415 | """Given a set of parameters, clip their combined gradients to have l2 norm at most max_l2_norm. 416 | 417 | Args: 418 | parameters: collection of trainable parameters. 419 | max_l2_norm: a positive value containing the maximum l2-norm. 420 | 421 | The gradients of the parameters (parameter.grad) should be modified in-place. 422 | 423 | Returns: 424 | None 425 | """ 426 | raise NotImplementedError 427 | 428 | 429 | def get_adamw_cls() -> Type[torch.optim.Optimizer]: 430 | """ 431 | Returns a torch.optim.Optimizer that implements AdamW. 432 | """ 433 | raise NotImplementedError 434 | 435 | 436 | def run_get_lr_cosine_schedule( 437 | it: int, 438 | max_learning_rate: float, 439 | min_learning_rate: float, 440 | warmup_iters: int, 441 | cosine_cycle_iters: int, 442 | ): 443 | """ 444 | Given the parameters of a cosine learning rate decay schedule (with linear 445 | warmup) and an iteration number, return the learning rate at the given 446 | iteration under the specified schedule. 447 | 448 | Args: 449 | it: int 450 | Iteration number to get learning rate for. 451 | max_learning_rate: float 452 | alpha_max, the maximum learning rate for 453 | cosine learning rate schedule (with warmup). 454 | min_learning_rate: float 455 | alpha_min, the minimum / final learning rate for 456 | the cosine learning rate schedule (with warmup). 457 | warmup_iters: int 458 | T_w, the number of iterations to linearly warm-up 459 | the learning rate. 460 | cosine_cycle_iters: int 461 | T_c, the number of cosine annealing iterations. 462 | 463 | Returns: 464 | Learning rate at the given iteration under the specified schedule. 465 | """ 466 | raise NotImplementedError 467 | 468 | 469 | def run_save_checkpoint( 470 | model: torch.nn.Module, 471 | optimizer: torch.optim.Optimizer, 472 | iteration: int, 473 | out: str | os.PathLike | BinaryIO | IO[bytes], 474 | ): 475 | """ 476 | Given a model, optimizer, and an iteration number, serialize them to disk. 477 | 478 | Args: 479 | model: torch.nn.Module 480 | Serialize the state of this model. 481 | optimizer: torch.optim.Optimizer, 482 | Serialize the state of this optimizer. 483 | iteration: int 484 | Serialize this value, which represents the number of training iterations 485 | we've completed. 486 | out: str | os.PathLike | BinaryIO | IO[bytes] 487 | Path or file-like object to serialize the model, optimizer, and iteration to. 488 | """ 489 | raise NotImplementedError 490 | 491 | 492 | def run_load_checkpoint( 493 | src: str | os.PathLike | BinaryIO | IO[bytes], 494 | model: torch.nn.Module, 495 | optimizer: torch.optim.Optimizer, 496 | ): 497 | """ 498 | Given a serialized checkpoint (path or file-like object), restore the 499 | serialized state to the given model and optimizer. 500 | Return the number of iterations that we previously serialized in 501 | the checkpoint. 502 | 503 | Args: 504 | src: str | os.PathLike | BinaryIO | IO[bytes] 505 | Path or file-like object to serialized checkpoint. 506 | model: torch.nn.Module 507 | Restore the state of this model. 508 | optimizer: torch.optim.Optimizer, 509 | Restore the state of this optimizer. 510 | Returns: 511 | int, the previously-serialized number of iterations. 512 | """ 513 | raise NotImplementedError 514 | 515 | 516 | def get_tokenizer( 517 | vocab: dict[int, bytes], 518 | merges: list[tuple[bytes, bytes]], 519 | special_tokens: Optional[list[str]] = None, 520 | ): 521 | """Given a vocabulary, a list of merges, and a list of special tokens, 522 | return a BPE tokenizer that uses the provided vocab, merges, and special tokens. 523 | 524 | Args: 525 | vocab: dict[int, bytes] 526 | The tokenizer vocabulary, a mapping from int (token ID in the vocabulary) 527 | to bytes (token bytes) 528 | merges: list[tuple[bytes, bytes]] 529 | BPE merges. Each list item is a tuple of bytes (, ), 530 | representing that was merged with . 531 | Merges are ordered by order of creation. 532 | special_tokens: Optional[list[str]] 533 | A list of string special tokens for the tokenizer. These strings will never 534 | be split into multiple tokens, and will always be kept as a single token. 535 | 536 | Returns: 537 | A BPE tokenizer that uses the provided vocab, merges, and special tokens. 538 | """ 539 | raise NotImplementedError 540 | 541 | 542 | def run_train_bpe( 543 | input_path: str | os.PathLike, 544 | vocab_size: int, 545 | special_tokens: list[str], 546 | **kwargs, 547 | ): 548 | """Given the path to an input corpus, run train a BPE tokenizer and 549 | output its vocabulary and merges. 550 | 551 | Args: 552 | input_path: str | os.PathLike 553 | Path to BPE tokenizer training data. 554 | vocab_size: int 555 | Total number of items in the tokenizer's vocabulary (including special tokens). 556 | special_tokens: list[str] 557 | A list of string special tokens to be added to the tokenizer vocabulary. 558 | These strings will never be split into multiple tokens, and will always be 559 | kept as a single token. If these special tokens occur in the `input_path`, 560 | they are treated as any other string. 561 | 562 | Returns: 563 | Tuple of (vocab, merges): 564 | vocab: dict[int, bytes] 565 | The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary) 566 | to bytes (token bytes) 567 | merges: list[tuple[bytes, bytes]] 568 | BPE merges. Each list item is a tuple of bytes (, ), 569 | representing that was merged with . 570 | Merges are ordered by order of creation. 571 | """ 572 | raise NotImplementedError 573 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import annotations 3 | 4 | import pathlib 5 | from functools import lru_cache 6 | 7 | FIXTURES_PATH = (pathlib.Path(__file__).resolve().parent) / "fixtures" 8 | 9 | 10 | @lru_cache() 11 | def gpt2_bytes_to_unicode() -> dict[int, str]: 12 | """ 13 | Returns a mapping between every possible byte (an integer from 0 to 255) to a 14 | printable unicode string character representation. This function is taken 15 | from the GPT-2 code. 16 | 17 | For example, `chr(0)` is `\x00`, which is an unprintable character: 18 | 19 | >>> chr(0) 20 | '\x00' 21 | >>> print(chr(0)) 22 | 23 | As a result, this function returns a dictionary `d` where `d[0]` returns `Ā`. 24 | The bytes that are visually printable keep their original string representation [1]. 25 | For example, `chr(33)` returns `!`, and so accordingly `d[33]` returns `!`. 26 | Note in particular that the space character `chr(32)` becomes `d[32]`, which 27 | returns 'Ġ'. 28 | 29 | For unprintable characters, the function shifts takes the integer representing 30 | the Unicode code point of that character (returned by the Python `ord`) function 31 | and shifts it by 256. For example, `ord(" ")` returns `32`, so the the space character 32 | ' ' is shifted to `256 + 32`. Since `chr(256 + 32)` returns `Ġ`, we use that as the 33 | string representation of the space. 34 | 35 | This function can simplify the BPE implementation and makes it slightly easier to 36 | manually inspect the generated merges after they're serialized to a file. 37 | """ 38 | # These 188 integers can used as-is, since they are not whitespace or control characters. 39 | # See https://www.ssec.wisc.edu/~tomw/java/unicode.html. 40 | bs = ( 41 | list(range(ord("!"), ord("~") + 1)) 42 | + list(range(ord("¡"), ord("¬") + 1)) 43 | + list(range(ord("®"), ord("ÿ") + 1)) 44 | ) 45 | cs = bs[:] 46 | # now get the representations of the other 68 integers that do need shifting 47 | # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop 48 | # Get printable representations of the remaining integers 68 integers. 49 | n = 0 50 | for b in range(2**8): 51 | if b not in bs: 52 | # If this integer isn't in our list of visually-representable 53 | # charcters, then map it to the next nice character (offset by 256) 54 | bs.append(b) 55 | cs.append(2**8 + n) 56 | n += 1 57 | characters = [chr(n) for n in cs] 58 | d = dict(zip(bs, characters)) 59 | return d 60 | -------------------------------------------------------------------------------- /tests/fixtures/adamw_expected_params.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/adamw_expected_params.pt -------------------------------------------------------------------------------- /tests/fixtures/address.txt: -------------------------------------------------------------------------------- 1 | Four score and seven years ago our fathers brought forth, on this continent, a new nation, conceived in Liberty, and dedicated to the proposition that all men are created equal. 2 | Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure. We are met on a great battle-field of that war. We have come to dedicate a portion of that field, as a final resting place for those who here gave their lives that that nation might live. It is altogether fitting and proper that we should do this. 3 | But, in a larger sense, we can not dedicate—we can not consecrate—we can not hallow—this ground. The brave men, living and dead, who struggled here, have consecrated it, far above our poor power to add or detract. The world will little note, nor long remember what we say here, but it can never forget what they did here. It is for us the living, rather, to be dedicated here to the unfinished work which they who fought here have thus far so nobly advanced. It is rather for us to be here dedicated to the great task remaining before us—that from these honored dead we take increased devotion to that cause for which they gave the last full measure of devotion—that we here highly resolve that these dead shall not have died in vain—that this nation, under God, shall have a new birth of freedom—and that government of the people, by the people, for the people, shall not perish from the earth. 4 | -------------------------------------------------------------------------------- /tests/fixtures/german.txt: -------------------------------------------------------------------------------- 1 | Die Leland Stanford Junior University (kurz Stanford University oder Stanford, Spitzname „Die Farm“) ist eine private US-amerikanische Universität in Stanford, Kalifornien. Sie liegt etwa 60 Kilometer südöstlich von San Francisco in der Nähe von Palo Alto und wurde von Leland Stanford und seiner Ehefrau Jane Stanford im Jahr 1891 im Andenken an ihren früh verstorbenen, einzigen Sohn Leland Stanford junior gegründet. 2021 waren 16.937 Studenten an der Universität eingeschrieben und studierten an einer der sieben Fakultäten. Ihr Präsident war bis 2023 Marc Tessier-Lavigne.[2] 2 | -------------------------------------------------------------------------------- /tests/fixtures/in_features.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/in_features.pt -------------------------------------------------------------------------------- /tests/fixtures/in_indices.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/in_indices.pt -------------------------------------------------------------------------------- /tests/fixtures/in_indices_truncated.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/in_indices_truncated.pt -------------------------------------------------------------------------------- /tests/fixtures/positionwise_feedforward_expected_output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/positionwise_feedforward_expected_output.pt -------------------------------------------------------------------------------- /tests/fixtures/positionwise_feedforward_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/positionwise_feedforward_weights.pt -------------------------------------------------------------------------------- /tests/fixtures/rmsnorm_expected_output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/rmsnorm_expected_output.pt -------------------------------------------------------------------------------- /tests/fixtures/rmsnorm_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/rmsnorm_weights.pt -------------------------------------------------------------------------------- /tests/fixtures/scaled_dot_product_attention_K.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/scaled_dot_product_attention_K.pt -------------------------------------------------------------------------------- /tests/fixtures/scaled_dot_product_attention_Q.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/scaled_dot_product_attention_Q.pt -------------------------------------------------------------------------------- /tests/fixtures/scaled_dot_product_attention_V.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/scaled_dot_product_attention_V.pt -------------------------------------------------------------------------------- /tests/fixtures/scaled_dot_product_attention_expected_output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/scaled_dot_product_attention_expected_output.pt -------------------------------------------------------------------------------- /tests/fixtures/scaled_dot_product_attention_mask.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/scaled_dot_product_attention_mask.pt -------------------------------------------------------------------------------- /tests/fixtures/tinystories_sample.txt: -------------------------------------------------------------------------------- 1 | 2 | Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He saw many amazing things, like beautiful vases that were on display in a store. One day, Ben was walking through the store when he came across a very special vase. When Ben saw it he was amazed! 3 | He said, “Wow, that is a really amazing vase! Can I buy it?” 4 | The shopkeeper smiled and said, “Of course you can. You can take it home and show all your friends how amazing it is!” 5 | So Ben took the vase home and he was so proud of it! He called his friends over and showed them the amazing vase. All his friends thought the vase was beautiful and couldn't believe how lucky Ben was. 6 | And that's how Ben found an amazing vase in the store! 7 | <|endoftext|> 8 | Once upon a time, there was a reliable otter named Ollie. He lived in a river with his family. They all loved to play and swim together. 9 | One day, Ollie's mom said, "Ollie, hurry and get some fish for dinner!" Ollie swam fast to catch fish. He saw his friend, the duck. "Hi, Ollie!" said the duck. "Hi, duck!" said Ollie. "I need to hurry and catch fish for my family." 10 | While Ollie was catching fish, he found a big shiny stone. He thought, "This is not a fish, but it is so pretty!" Ollie took the shiny stone home to show his family. They all looked at the shiny stone and smiled. The shiny stone made everyone happy, and they forgot about the fish for dinner. 11 | <|endoftext|> 12 | One day, a little boy named Tim went to the park. He saw a big tiger. The tiger was not mean, but very easy to play with. Tim and the tiger played all day. They had lots of fun. 13 | Then, something unexpected happened. The tiger started to shake. Tim was scared. He did not know what was going on. But then, the tiger turned into a nice dog. Tim was very surprised. 14 | Tim and the dog played together now. They were very happy. The dog was easy to play with too. At the end of the day, Tim went home with his new friend. 15 | <|endoftext|> 16 | 17 | Once upon a time there was a friendly little boy called Bob. Bob loved to pick flowers and look for birds. One day he decided to go outside with his friends to pick some more flowers. 18 | He suddenly noticed something weird on the ground. It was a big, green thumb! It was so big, Bob had never seen one before. Bob curiously leaned in to take a better look. He told his friends: "look everyone, I picked up this big thumb! What do we do with it?" 19 | His friends were very excited. They told him to pick it up and take it home to show his family. So Bob carefully picked up the friendly thumb and carried it back home. When he arrived, Bob happily showed the thumb to his family. His dad was amazed and hugged Bob to show his appreciation. 20 | From that day on Bob always kept the big, friendly thumb with him as a reminder that special things can be found anywhere. 21 | <|endoftext|> 22 | Once upon a time, in a small house, there lived a little girl named Lucy. Lucy loved the color orange. She had an orange dress, an orange ball, and even an orange cat. One day, Lucy met a new friend. This friend was not like other friends. It was a spirit. The spirit was very nice and liked to play with Lucy. 23 | One day, Lucy and the spirit were playing with her orange ball. They were having so much fun. Then, Lucy's mom called her for dinner. Lucy said to the spirit, "I have to go eat now. Will you play with me later?" The spirit nodded and smiled. 24 | At dinner, Lucy told her mom about the spirit. But her mom did not believe her. She said, "Spirits are not real, Lucy. You have a big imagination." Lucy felt sad that her mom did not believe her. After dinner, she went back to play with the spirit. They played with the orange ball and had lots of fun. Lucy knew that even if others ignore her friend, the spirit was real and they could play together. 25 | <|endoftext|> 26 | -------------------------------------------------------------------------------- /tests/fixtures/train-bpe-reference-merges.txt: -------------------------------------------------------------------------------- 1 | Ġ t 2 | Ġ a 3 | h e 4 | i n 5 | Ġt he 6 | r e 7 | Ġ o 8 | Ġ , 9 | e r 10 | Ġ s 11 | a t 12 | Ġ . 13 | n d 14 | i s 15 | o r 16 | Ġ w 17 | Ġ c 18 | o n 19 | Ġ b 20 | Ġ f 21 | o u 22 | i t 23 | e n 24 | e s 25 | Ġo f 26 | Ġ p 27 | in g 28 | Ġ in 29 | e d 30 | a l 31 | Ġ m 32 | Ġa nd 33 | Ġ d 34 | a n 35 | a r 36 | Ġt o 37 | o m 38 | Ġt h 39 | i c 40 | i on 41 | Ġ h 42 | Ġ l 43 | Ġ y 44 | Ġ e 45 | a s 46 | o t 47 | i l 48 | Ġ n 49 | Ġ u 50 | en t 51 | Ġb e 52 | Ġ & 53 | Ġ is 54 | Ġy ou 55 | o s 56 | Ġ re 57 | e t 58 | Ġf or 59 | u t 60 | e l 61 | Ġ g 62 | a y 63 | s t 64 | o w 65 | l e 66 | c e 67 | a d 68 | Ġo n 69 | Ġ I 70 | v er 71 | v e 72 | Ġ A 73 | u r 74 | o l 75 | c t 76 | q u 77 | Ġth at 78 | i m 79 | al l 80 | a m 81 | i g 82 | c h 83 | at ion 84 | Ġ P 85 | it h 86 | i r 87 | Ġ S 88 | Ġ it 89 | Ġp r 90 | a p 91 | Ġs h 92 | Ġ C 93 | t h 94 | Ġc om 95 | Ġ @ 96 | Ġw h 97 | - @ 98 | Ġa re 99 | Ġ@ -@ 100 | n t 101 | i d 102 | Ġw ith 103 | Ġa l 104 | o p 105 | Ġu s 106 | er s 107 | Ġa s 108 | t he 109 | a nd 110 | i f 111 | or d 112 | o d 113 | Ġ he 114 | is t 115 | qu ot 116 | m ent 117 | Ġ M 118 | Ġo r 119 | o re 120 | Ġ G 121 | Ġf r 122 | il l 123 | re s 124 | Ġs t 125 | es s 126 | l d 127 | Ġth is 128 | Ġ 2 129 | ar t 130 | Ġ ; 131 | Ġ L 132 | l y 133 | a in 134 | u l 135 | Ġd e 136 | Ġc on 137 | es t 138 | s e 139 | ap os 140 | a g 141 | Ġfr om 142 | Ġa n 143 | Ġw e 144 | Ġ ( 145 | 0 0 146 | t er 147 | Ġ E 148 | e m 149 | a ve 150 | Ġn ot 151 | Ġ ) 152 | Ġ 1 153 | Ġyou r 154 | o c 155 | Ġc an 156 | Ġb y 157 | Ġ D 158 | Ġn e 159 | Ġ v 160 | ig h 161 | ic h 162 | Ġal l 163 | r i 164 | Ġu p 165 | Ġ r 166 | Ġ W 167 | b le 168 | Ġthe y 169 | Ġ B 170 | u n 171 | Ġy e 172 | Ġwh ich 173 | Ġ O 174 | k e 175 | Ġw or 176 | Ġs u 177 | Ġ F 178 | Ġ H 179 | Ġh ave 180 | at e 181 | Ġsh all 182 | Ġc h 183 | e ct 184 | it y 185 | Ġs p 186 | res s 187 | igh t 188 | Ġw ill 189 | Ġcom p 190 | or t 191 | an t 192 | Ġ& # 193 | i ve 194 | a re 195 | . . 196 | Ġe x 197 | ĠA nd 198 | Ġ. .. 199 | as t 200 | 2 4 201 | Ġ T 202 | ou ld 203 | v en 204 | Ġt r 205 | u st 206 | u m 207 | ou t 208 | c om 209 | Ġu nt 210 | Ġs e 211 | f t 212 | re e 213 | os t 214 | o g 215 | is h 216 | ion s 217 | i z 218 | 1 24 219 | Ġunt o 220 | m er 221 | ing s 222 | Ġa c 223 | th is 224 | at ed 225 | a c 226 | l u 227 | e re 228 | Ġm an 229 | f or 230 | Ġm y 231 | Ġa t 232 | i es 233 | ag e 234 | r ou 235 | l o 236 | an s 237 | p p 238 | in d 239 | Ġwor k 240 | he re 241 | f ore 242 | Ġs it 243 | Ġ ver 244 | -------------------------------------------------------------------------------- /tests/fixtures/train-bpe-reference-vocab.json: -------------------------------------------------------------------------------- 1 | { 2 | "<|endoftext|>": 0, 3 | "!": 1, 4 | "\"": 2, 5 | "#": 3, 6 | "$": 4, 7 | "%": 5, 8 | "&": 6, 9 | "'": 7, 10 | "(": 8, 11 | ")": 9, 12 | "*": 10, 13 | "+": 11, 14 | ",": 12, 15 | "-": 13, 16 | ".": 14, 17 | "/": 15, 18 | "0": 16, 19 | "1": 17, 20 | "2": 18, 21 | "3": 19, 22 | "4": 20, 23 | "5": 21, 24 | "6": 22, 25 | "7": 23, 26 | "8": 24, 27 | "9": 25, 28 | ":": 26, 29 | ";": 27, 30 | "<": 28, 31 | "=": 29, 32 | ">": 30, 33 | "?": 31, 34 | "@": 32, 35 | "A": 33, 36 | "B": 34, 37 | "C": 35, 38 | "D": 36, 39 | "E": 37, 40 | "F": 38, 41 | "G": 39, 42 | "H": 40, 43 | "I": 41, 44 | "J": 42, 45 | "K": 43, 46 | "L": 44, 47 | "M": 45, 48 | "N": 46, 49 | "O": 47, 50 | "P": 48, 51 | "Q": 49, 52 | "R": 50, 53 | "S": 51, 54 | "T": 52, 55 | "U": 53, 56 | "V": 54, 57 | "W": 55, 58 | "X": 56, 59 | "Y": 57, 60 | "Z": 58, 61 | "[": 59, 62 | "\\": 60, 63 | "]": 61, 64 | "^": 62, 65 | "_": 63, 66 | "`": 64, 67 | "a": 65, 68 | "b": 66, 69 | "c": 67, 70 | "d": 68, 71 | "e": 69, 72 | "f": 70, 73 | "g": 71, 74 | "h": 72, 75 | "i": 73, 76 | "j": 74, 77 | "k": 75, 78 | "l": 76, 79 | "m": 77, 80 | "n": 78, 81 | "o": 79, 82 | "p": 80, 83 | "q": 81, 84 | "r": 82, 85 | "s": 83, 86 | "t": 84, 87 | "u": 85, 88 | "v": 86, 89 | "w": 87, 90 | "x": 88, 91 | "y": 89, 92 | "z": 90, 93 | "{": 91, 94 | "|": 92, 95 | "}": 93, 96 | "~": 94, 97 | "¡": 95, 98 | "¢": 96, 99 | "£": 97, 100 | "¤": 98, 101 | "¥": 99, 102 | "¦": 100, 103 | "§": 101, 104 | "¨": 102, 105 | "©": 103, 106 | "ª": 104, 107 | "«": 105, 108 | "¬": 106, 109 | "®": 107, 110 | "¯": 108, 111 | "°": 109, 112 | "±": 110, 113 | "²": 111, 114 | "³": 112, 115 | "´": 113, 116 | "µ": 114, 117 | "¶": 115, 118 | "·": 116, 119 | "¸": 117, 120 | "¹": 118, 121 | "º": 119, 122 | "»": 120, 123 | "¼": 121, 124 | "½": 122, 125 | "¾": 123, 126 | "¿": 124, 127 | "À": 125, 128 | "Á": 126, 129 | "Â": 127, 130 | "Ã": 128, 131 | "Ä": 129, 132 | "Å": 130, 133 | "Æ": 131, 134 | "Ç": 132, 135 | "È": 133, 136 | "É": 134, 137 | "Ê": 135, 138 | "Ë": 136, 139 | "Ì": 137, 140 | "Í": 138, 141 | "Î": 139, 142 | "Ï": 140, 143 | "Ð": 141, 144 | "Ñ": 142, 145 | "Ò": 143, 146 | "Ó": 144, 147 | "Ô": 145, 148 | "Õ": 146, 149 | "Ö": 147, 150 | "×": 148, 151 | "Ø": 149, 152 | "Ù": 150, 153 | "Ú": 151, 154 | "Û": 152, 155 | "Ü": 153, 156 | "Ý": 154, 157 | "Þ": 155, 158 | "ß": 156, 159 | "à": 157, 160 | "á": 158, 161 | "â": 159, 162 | "ã": 160, 163 | "ä": 161, 164 | "å": 162, 165 | "æ": 163, 166 | "ç": 164, 167 | "è": 165, 168 | "é": 166, 169 | "ê": 167, 170 | "ë": 168, 171 | "ì": 169, 172 | "í": 170, 173 | "î": 171, 174 | "ï": 172, 175 | "ð": 173, 176 | "ñ": 174, 177 | "ò": 175, 178 | "ó": 176, 179 | "ô": 177, 180 | "õ": 178, 181 | "ö": 179, 182 | "÷": 180, 183 | "ø": 181, 184 | "ù": 182, 185 | "ú": 183, 186 | "û": 184, 187 | "ü": 185, 188 | "ý": 186, 189 | "þ": 187, 190 | "ÿ": 188, 191 | "Ā": 189, 192 | "ā": 190, 193 | "Ă": 191, 194 | "ă": 192, 195 | "Ą": 193, 196 | "ą": 194, 197 | "Ć": 195, 198 | "ć": 196, 199 | "Ĉ": 197, 200 | "ĉ": 198, 201 | "Ċ": 199, 202 | "ċ": 200, 203 | "Č": 201, 204 | "č": 202, 205 | "Ď": 203, 206 | "ď": 204, 207 | "Đ": 205, 208 | "đ": 206, 209 | "Ē": 207, 210 | "ē": 208, 211 | "Ĕ": 209, 212 | "ĕ": 210, 213 | "Ė": 211, 214 | "ė": 212, 215 | "Ę": 213, 216 | "ę": 214, 217 | "Ě": 215, 218 | "ě": 216, 219 | "Ĝ": 217, 220 | "ĝ": 218, 221 | "Ğ": 219, 222 | "ğ": 220, 223 | "Ġ": 221, 224 | "ġ": 222, 225 | "Ģ": 223, 226 | "ģ": 224, 227 | "Ĥ": 225, 228 | "ĥ": 226, 229 | "Ħ": 227, 230 | "ħ": 228, 231 | "Ĩ": 229, 232 | "ĩ": 230, 233 | "Ī": 231, 234 | "ī": 232, 235 | "Ĭ": 233, 236 | "ĭ": 234, 237 | "Į": 235, 238 | "į": 236, 239 | "İ": 237, 240 | "ı": 238, 241 | "IJ": 239, 242 | "ij": 240, 243 | "Ĵ": 241, 244 | "ĵ": 242, 245 | "Ķ": 243, 246 | "ķ": 244, 247 | "ĸ": 245, 248 | "Ĺ": 246, 249 | "ĺ": 247, 250 | "Ļ": 248, 251 | "ļ": 249, 252 | "Ľ": 250, 253 | "ľ": 251, 254 | "Ŀ": 252, 255 | "ŀ": 253, 256 | "Ł": 254, 257 | "ł": 255, 258 | "Ń": 256, 259 | "Ġt": 257, 260 | "Ġa": 258, 261 | "he": 259, 262 | "in": 260, 263 | "Ġthe": 261, 264 | "re": 262, 265 | "Ġo": 263, 266 | "Ġ,": 264, 267 | "er": 265, 268 | "Ġs": 266, 269 | "at": 267, 270 | "Ġ.": 268, 271 | "nd": 269, 272 | "is": 270, 273 | "or": 271, 274 | "Ġw": 272, 275 | "Ġc": 273, 276 | "on": 274, 277 | "Ġb": 275, 278 | "Ġf": 276, 279 | "ou": 277, 280 | "it": 278, 281 | "en": 279, 282 | "es": 280, 283 | "Ġof": 281, 284 | "Ġp": 282, 285 | "ing": 283, 286 | "Ġin": 284, 287 | "ed": 285, 288 | "al": 286, 289 | "Ġm": 287, 290 | "Ġand": 288, 291 | "Ġd": 289, 292 | "an": 290, 293 | "ar": 291, 294 | "Ġto": 292, 295 | "om": 293, 296 | "Ġth": 294, 297 | "ic": 295, 298 | "ion": 296, 299 | "Ġh": 297, 300 | "Ġl": 298, 301 | "Ġy": 299, 302 | "Ġe": 300, 303 | "as": 301, 304 | "ot": 302, 305 | "il": 303, 306 | "Ġn": 304, 307 | "Ġu": 305, 308 | "ent": 306, 309 | "Ġbe": 307, 310 | "Ġ&": 308, 311 | "Ġis": 309, 312 | "Ġyou": 310, 313 | "os": 311, 314 | "Ġre": 312, 315 | "et": 313, 316 | "Ġfor": 314, 317 | "ut": 315, 318 | "el": 316, 319 | "Ġg": 317, 320 | "ay": 318, 321 | "st": 319, 322 | "ow": 320, 323 | "le": 321, 324 | "ce": 322, 325 | "ad": 323, 326 | "Ġon": 324, 327 | "ĠI": 325, 328 | "ver": 326, 329 | "ve": 327, 330 | "ĠA": 328, 331 | "ur": 329, 332 | "ol": 330, 333 | "ct": 331, 334 | "qu": 332, 335 | "Ġthat": 333, 336 | "im": 334, 337 | "all": 335, 338 | "am": 336, 339 | "ig": 337, 340 | "ch": 338, 341 | "ation": 339, 342 | "ĠP": 340, 343 | "ith": 341, 344 | "ir": 342, 345 | "ĠS": 343, 346 | "Ġit": 344, 347 | "Ġpr": 345, 348 | "ap": 346, 349 | "Ġsh": 347, 350 | "ĠC": 348, 351 | "th": 349, 352 | "Ġcom": 350, 353 | "Ġ@": 351, 354 | "Ġwh": 352, 355 | "-@": 353, 356 | "Ġare": 354, 357 | "Ġ@-@": 355, 358 | "nt": 356, 359 | "id": 357, 360 | "Ġwith": 358, 361 | "Ġal": 359, 362 | "op": 360, 363 | "Ġus": 361, 364 | "ers": 362, 365 | "Ġas": 363, 366 | "the": 364, 367 | "and": 365, 368 | "if": 366, 369 | "ord": 367, 370 | "od": 368, 371 | "Ġhe": 369, 372 | "ist": 370, 373 | "quot": 371, 374 | "ment": 372, 375 | "ĠM": 373, 376 | "Ġor": 374, 377 | "ore": 375, 378 | "ĠG": 376, 379 | "Ġfr": 377, 380 | "ill": 378, 381 | "res": 379, 382 | "Ġst": 380, 383 | "ess": 381, 384 | "ld": 382, 385 | "Ġthis": 383, 386 | "Ġ2": 384, 387 | "art": 385, 388 | "Ġ;": 386, 389 | "ĠL": 387, 390 | "ly": 388, 391 | "ain": 389, 392 | "ul": 390, 393 | "Ġde": 391, 394 | "Ġcon": 392, 395 | "est": 393, 396 | "se": 394, 397 | "apos": 395, 398 | "ag": 396, 399 | "Ġfrom": 397, 400 | "Ġan": 398, 401 | "Ġwe": 399, 402 | "Ġ(": 400, 403 | "00": 401, 404 | "ter": 402, 405 | "ĠE": 403, 406 | "em": 404, 407 | "ave": 405, 408 | "Ġnot": 406, 409 | "Ġ)": 407, 410 | "Ġ1": 408, 411 | "Ġyour": 409, 412 | "oc": 410, 413 | "Ġcan": 411, 414 | "Ġby": 412, 415 | "ĠD": 413, 416 | "Ġne": 414, 417 | "Ġv": 415, 418 | "igh": 416, 419 | "ich": 417, 420 | "Ġall": 418, 421 | "ri": 419, 422 | "Ġup": 420, 423 | "Ġr": 421, 424 | "ĠW": 422, 425 | "ble": 423, 426 | "Ġthey": 424, 427 | "ĠB": 425, 428 | "un": 426, 429 | "Ġye": 427, 430 | "Ġwhich": 428, 431 | "ĠO": 429, 432 | "ke": 430, 433 | "Ġwor": 431, 434 | "Ġsu": 432, 435 | "ĠF": 433, 436 | "ĠH": 434, 437 | "Ġhave": 435, 438 | "ate": 436, 439 | "Ġshall": 437, 440 | "Ġch": 438, 441 | "ect": 439, 442 | "ity": 440, 443 | "Ġsp": 441, 444 | "ress": 442, 445 | "ight": 443, 446 | "Ġwill": 444, 447 | "Ġcomp": 445, 448 | "ort": 446, 449 | "ant": 447, 450 | "Ġ&#": 448, 451 | "ive": 449, 452 | "are": 450, 453 | "..": 451, 454 | "Ġex": 452, 455 | "ĠAnd": 453, 456 | "Ġ...": 454, 457 | "ast": 455, 458 | "24": 456, 459 | "ĠT": 457, 460 | "ould": 458, 461 | "ven": 459, 462 | "Ġtr": 460, 463 | "ust": 461, 464 | "um": 462, 465 | "out": 463, 466 | "com": 464, 467 | "Ġunt": 465, 468 | "Ġse": 466, 469 | "ft": 467, 470 | "ree": 468, 471 | "ost": 469, 472 | "og": 470, 473 | "ish": 471, 474 | "ions": 472, 475 | "iz": 473, 476 | "124": 474, 477 | "Ġunto": 475, 478 | "mer": 476, 479 | "ings": 477, 480 | "Ġac": 478, 481 | "this": 479, 482 | "ated": 480, 483 | "ac": 481, 484 | "lu": 482, 485 | "ere": 483, 486 | "Ġman": 484, 487 | "for": 485, 488 | "Ġmy": 486, 489 | "Ġat": 487, 490 | "ies": 488, 491 | "age": 489, 492 | "rou": 490, 493 | "lo": 491, 494 | "ans": 492, 495 | "pp": 493, 496 | "ind": 494, 497 | "Ġwork": 495, 498 | "here": 496, 499 | "fore": 497, 500 | "Ġsit": 498, 501 | "Ġver": 499 502 | } -------------------------------------------------------------------------------- /tests/fixtures/transformer_block_expected_output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/transformer_block_expected_output.pt -------------------------------------------------------------------------------- /tests/fixtures/transformer_block_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/transformer_block_weights.pt -------------------------------------------------------------------------------- /tests/fixtures/transformer_lm_expected_output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/transformer_lm_expected_output.pt -------------------------------------------------------------------------------- /tests/fixtures/transformer_lm_truncated_expected_output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/transformer_lm_truncated_expected_output.pt -------------------------------------------------------------------------------- /tests/fixtures/transformer_lm_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/transformer_lm_weights.pt -------------------------------------------------------------------------------- /tests/fixtures/unbatched_multihead_self_attention_expected_output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/unbatched_multihead_self_attention_expected_output.pt -------------------------------------------------------------------------------- /tests/fixtures/unbatched_multihead_self_attention_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-cs336/spring2024-assignment1-basics/dae79d035bf866a71d81a02b0c023980d537b16d/tests/fixtures/unbatched_multihead_self_attention_weights.pt -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | from collections import Counter 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from .adapters import run_get_batch 9 | 10 | 11 | def test_get_batch(): 12 | dataset = np.arange(0, 100) 13 | context_length = 7 14 | batch_size = 32 15 | device = "cpu" 16 | 17 | # Sanity check to make sure that the random samples are indeed somewhat random. 18 | starting_indices = Counter() 19 | num_iters = 1000 20 | for _ in range(num_iters): 21 | x, y = run_get_batch( 22 | dataset=dataset, 23 | batch_size=batch_size, 24 | context_length=context_length, 25 | device=device, 26 | ) 27 | 28 | # Make sure the shape is correct 29 | assert x.shape == (batch_size, context_length) 30 | assert y.shape == (batch_size, context_length) 31 | 32 | # Make sure the y's are always offset by 1 33 | np.testing.assert_allclose((x + 1).detach().numpy(), y.detach().numpy()) 34 | 35 | starting_indices.update(x[:, 0].tolist()) 36 | 37 | # Make sure we never sample an invalid start index 38 | num_possible_starting_indices = len(dataset) - context_length 39 | assert max(starting_indices) == num_possible_starting_indices - 1 40 | assert min(starting_indices) == 0 41 | # Expected # of times that we see each starting index 42 | expected_count = (num_iters * batch_size) / num_possible_starting_indices 43 | standard_deviation = math.sqrt( 44 | (num_iters * batch_size) 45 | * (1 / num_possible_starting_indices) 46 | * (1 - (1 / num_possible_starting_indices)) 47 | ) 48 | # Range for expected outcomes (mu +/- 5sigma). For a given index, 49 | # this should happen 99.99994% of the time of the time. 50 | # So, in the case where we have 93 possible start indices, 51 | # the entire test should pass with 99.9944202% of the time 52 | occurrences_lower_bound = expected_count - 5 * standard_deviation 53 | occurrences_upper_bound = expected_count + 5 * standard_deviation 54 | 55 | for starting_index, count in starting_indices.items(): 56 | if count < occurrences_lower_bound: 57 | raise ValueError( 58 | f"Starting index {starting_index} occurs {count} times, but expected at least {occurrences_lower_bound}" 59 | ) 60 | if count > occurrences_upper_bound: 61 | raise ValueError( 62 | f"Starting index {starting_index} occurs {count} times, but expected at most {occurrences_upper_bound}" 63 | ) 64 | 65 | with pytest.raises((RuntimeError, AssertionError)) as excinfo: 66 | # We're assuming that cuda:99 is an invalid device ordinal. 67 | # Just adding this here to make sure that the device flag is 68 | # being handled. 69 | run_get_batch( 70 | dataset=dataset, 71 | batch_size=batch_size, 72 | context_length=context_length, 73 | device="cuda:99", 74 | ) 75 | assert "CUDA error" in str( 76 | excinfo.value 77 | ) or "Torch not compiled with CUDA enabled" in str(excinfo.value) 78 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .adapters import ( 7 | run_gelu, 8 | run_multihead_self_attention, 9 | run_positionwise_feedforward, 10 | run_rmsnorm, 11 | run_scaled_dot_product_attention, 12 | run_transformer_block, 13 | run_transformer_lm, 14 | ) 15 | from .common import FIXTURES_PATH 16 | 17 | 18 | def test_positionwise_feedforward(): 19 | reference_weights = torch.load( 20 | FIXTURES_PATH / "positionwise_feedforward_weights.pt" 21 | ) 22 | in_features = torch.load(FIXTURES_PATH / "in_features.pt") 23 | expected_output = torch.load( 24 | FIXTURES_PATH / "positionwise_feedforward_expected_output.pt" 25 | ) 26 | d_model = 64 27 | d_ff = 128 28 | 29 | actual_output = run_positionwise_feedforward( 30 | d_model=d_model, d_ff=d_ff, weights=reference_weights, in_features=in_features 31 | ) 32 | numpy.testing.assert_allclose( 33 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-6 34 | ) 35 | 36 | 37 | def test_scaled_dot_product_attention(): 38 | torch.manual_seed(42) 39 | # Take the first batch item, so we test the 3D case 40 | # (input shape (batch_size, seq_len, d_k)) for scaled dot-product attention. 41 | K = torch.load(FIXTURES_PATH / "scaled_dot_product_attention_K.pt")[0] 42 | Q = torch.load(FIXTURES_PATH / "scaled_dot_product_attention_Q.pt")[0] 43 | V = torch.load(FIXTURES_PATH / "scaled_dot_product_attention_V.pt")[0] 44 | mask = torch.load(FIXTURES_PATH / "scaled_dot_product_attention_mask.pt") 45 | pdrop = 0.0 46 | expected_output = torch.load( 47 | FIXTURES_PATH / "scaled_dot_product_attention_expected_output.pt" 48 | )[0] 49 | actual_output = run_scaled_dot_product_attention( 50 | K=K, Q=Q, V=V, mask=mask, pdrop=pdrop 51 | ) 52 | numpy.testing.assert_allclose( 53 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-6 54 | ) 55 | 56 | 57 | def test_4d_scaled_dot_product_attention(): 58 | torch.manual_seed(42) 59 | # Shape: (batch_size, num_heads, seq_len, d_k) 60 | K = torch.load(FIXTURES_PATH / "scaled_dot_product_attention_K.pt") 61 | Q = torch.load(FIXTURES_PATH / "scaled_dot_product_attention_Q.pt") 62 | V = torch.load(FIXTURES_PATH / "scaled_dot_product_attention_V.pt") 63 | mask = torch.load(FIXTURES_PATH / "scaled_dot_product_attention_mask.pt") 64 | pdrop = 0.0 65 | expected_output = torch.load( 66 | FIXTURES_PATH / "scaled_dot_product_attention_expected_output.pt" 67 | ) 68 | actual_output = run_scaled_dot_product_attention( 69 | K=K, Q=Q, V=V, mask=mask, pdrop=pdrop 70 | ) 71 | numpy.testing.assert_allclose( 72 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-6 73 | ) 74 | 75 | 76 | def test_multihead_self_attention(): 77 | reference_weights = torch.load( 78 | FIXTURES_PATH / "unbatched_multihead_self_attention_weights.pt" 79 | ) 80 | in_features = torch.load(FIXTURES_PATH / "in_features.pt") 81 | expected_output = torch.load( 82 | FIXTURES_PATH / "unbatched_multihead_self_attention_expected_output.pt" 83 | ) 84 | d_model = 64 85 | num_heads = 2 86 | attn_pdrop = 0.0 87 | actual_output = run_multihead_self_attention( 88 | d_model=d_model, 89 | num_heads=num_heads, 90 | attn_pdrop=attn_pdrop, 91 | weights=reference_weights, 92 | in_features=in_features, 93 | ) 94 | numpy.testing.assert_allclose( 95 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-6 96 | ) 97 | 98 | 99 | def test_transformer_lm(): 100 | torch.manual_seed(42) 101 | vocab_size = 100 102 | context_length = 64 103 | d_model = 128 104 | num_layers = 2 105 | num_heads = 2 106 | d_ff = d_model * 4 107 | attn_pdrop = 0.0 108 | residual_pdrop = 0.0 109 | 110 | reference_weights = torch.load(FIXTURES_PATH / "transformer_lm_weights.pt") 111 | in_indices = torch.load(FIXTURES_PATH / "in_indices.pt") 112 | expected_output = torch.load(FIXTURES_PATH / "transformer_lm_expected_output.pt") 113 | actual_output = run_transformer_lm( 114 | vocab_size=vocab_size, 115 | context_length=context_length, 116 | d_model=d_model, 117 | num_layers=num_layers, 118 | num_heads=num_heads, 119 | d_ff=d_ff, 120 | attn_pdrop=attn_pdrop, 121 | residual_pdrop=residual_pdrop, 122 | weights=reference_weights, 123 | in_indices=in_indices, 124 | ) 125 | numpy.testing.assert_allclose( 126 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-4 127 | ) 128 | 129 | 130 | def test_transformer_lm_truncated_input(): 131 | torch.manual_seed(42) 132 | vocab_size = 100 133 | context_length = 64 134 | d_model = 128 135 | num_layers = 2 136 | num_heads = 2 137 | d_ff = d_model * 4 138 | attn_pdrop = 0.0 139 | residual_pdrop = 0.0 140 | 141 | reference_weights = torch.load(FIXTURES_PATH / "transformer_lm_weights.pt") 142 | in_indices_truncated = torch.load(FIXTURES_PATH / "in_indices_truncated.pt") 143 | truncated_expected_output = torch.load( 144 | FIXTURES_PATH / "transformer_lm_truncated_expected_output.pt" 145 | ) 146 | truncated_actual_output = run_transformer_lm( 147 | vocab_size=vocab_size, 148 | context_length=context_length, 149 | d_model=d_model, 150 | num_layers=num_layers, 151 | num_heads=num_heads, 152 | d_ff=d_ff, 153 | attn_pdrop=attn_pdrop, 154 | residual_pdrop=residual_pdrop, 155 | weights=reference_weights, 156 | in_indices=in_indices_truncated, 157 | ) 158 | numpy.testing.assert_allclose( 159 | truncated_actual_output.detach().numpy(), 160 | truncated_expected_output.detach().numpy(), 161 | atol=1e-4, 162 | ) 163 | 164 | 165 | def test_transformer_block(): 166 | torch.manual_seed(42) 167 | reference_weights = torch.load(FIXTURES_PATH / "transformer_block_weights.pt") 168 | in_features = torch.load(FIXTURES_PATH / "in_features.pt") 169 | expected_output = torch.load(FIXTURES_PATH / "transformer_block_expected_output.pt") 170 | d_model = 64 171 | num_heads = 2 172 | d_ff = d_model * 4 173 | attn_pdrop = 0.0 174 | residual_pdrop = 0.0 175 | 176 | actual_output = run_transformer_block( 177 | d_model=d_model, 178 | num_heads=num_heads, 179 | d_ff=d_ff, 180 | attn_pdrop=attn_pdrop, 181 | residual_pdrop=residual_pdrop, 182 | weights=reference_weights, 183 | in_features=in_features, 184 | ) 185 | numpy.testing.assert_allclose( 186 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-6 187 | ) 188 | 189 | 190 | def test_rmsnorm(): 191 | reference_weights = torch.load(FIXTURES_PATH / "rmsnorm_weights.pt") 192 | in_features = torch.load(FIXTURES_PATH / "in_features.pt") 193 | expected_output = torch.load(FIXTURES_PATH / "rmsnorm_expected_output.pt") 194 | d_model = 64 195 | actual_output = run_rmsnorm( 196 | d_model=d_model, eps=1e-5, weights=reference_weights, in_features=in_features 197 | ) 198 | numpy.testing.assert_allclose( 199 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-6 200 | ) 201 | 202 | 203 | def test_gelu(): 204 | x = torch.tensor( 205 | [ 206 | [0.2352, 0.9259, 0.5189, 0.4725, 0.9730], 207 | [0.7581, 0.9692, 0.2129, 0.9345, 0.0149], 208 | ] 209 | ) 210 | expected_output = torch.tensor( 211 | [ 212 | [ 213 | 0.13946731388568878, 214 | 0.7617851495742798, 215 | 0.3622361421585083, 216 | 0.3221103549003601, 217 | 0.8121858239173889, 218 | ], 219 | [ 220 | 0.5881373286247253, 221 | 0.8080969452857971, 222 | 0.1243969276547432, 223 | 0.7709409594535828, 224 | 0.007538566831499338, 225 | ], 226 | ] 227 | ) 228 | actual_output = run_gelu(x) 229 | numpy.testing.assert_allclose( 230 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-6 231 | ) 232 | 233 | 234 | def test_gelu_matches_pytorch(): 235 | x = torch.tensor( 236 | [ 237 | [0.2352, 0.9259, 0.5189, 0.4725, 0.9730], 238 | [0.7581, 0.9692, 0.2129, 0.9345, 0.0149], 239 | ] 240 | ) 241 | expected_output = F.gelu(x) 242 | actual_output = run_gelu(x) 243 | numpy.testing.assert_allclose( 244 | actual_output.detach().numpy(), expected_output.detach().numpy(), atol=1e-6 245 | ) 246 | -------------------------------------------------------------------------------- /tests/test_nn_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .adapters import run_cross_entropy, run_gradient_clipping, run_softmax 7 | 8 | 9 | def test_softmax_matches_pytorch(): 10 | x = torch.tensor( 11 | [ 12 | [0.4655, 0.8303, 0.9608, 0.9656, 0.6840], 13 | [0.2583, 0.2198, 0.9334, 0.2995, 0.1722], 14 | [0.1573, 0.6860, 0.1327, 0.7284, 0.6811], 15 | ] 16 | ) 17 | expected = F.softmax(x, dim=-1) 18 | numpy.testing.assert_allclose( 19 | run_softmax(x, dim=-1).detach().numpy(), expected.detach().numpy(), atol=1e-6 20 | ) 21 | # Test that softmax handles numerical overflow issues 22 | numpy.testing.assert_allclose( 23 | run_softmax(x + 100, dim=-1).detach().numpy(), 24 | expected.detach().numpy(), 25 | atol=1e-6, 26 | ) 27 | 28 | 29 | def test_cross_entropy(): 30 | inputs = torch.tensor( 31 | [ 32 | [ 33 | [0.1088, 0.1060, 0.6683, 0.5131, 0.0645], 34 | [0.4538, 0.6852, 0.2520, 0.3792, 0.2675], 35 | [0.4578, 0.3357, 0.6384, 0.0481, 0.5612], 36 | [0.9639, 0.8864, 0.1585, 0.3038, 0.0350], 37 | ], 38 | [ 39 | [0.3356, 0.9013, 0.7052, 0.8294, 0.8334], 40 | [0.6333, 0.4434, 0.1428, 0.5739, 0.3810], 41 | [0.9476, 0.5917, 0.7037, 0.2987, 0.6208], 42 | [0.8541, 0.1803, 0.2054, 0.4775, 0.8199], 43 | ], 44 | ] 45 | ) 46 | targets = torch.tensor([[1, 0, 2, 2], [4, 1, 4, 0]]) 47 | expected = F.cross_entropy(inputs.view(-1, inputs.size(-1)), targets.view(-1)) 48 | numpy.testing.assert_allclose( 49 | run_cross_entropy(inputs.view(-1, inputs.size(-1)), targets.view(-1)) 50 | .detach() 51 | .numpy(), 52 | expected.detach().numpy(), 53 | atol=1e-4, 54 | ) 55 | 56 | # Test that cross-entropy handles numerical overflow issues 57 | large_inputs = 1000.0 * inputs 58 | large_expected_cross_entropy = F.cross_entropy( 59 | large_inputs.view(-1, large_inputs.size(-1)), targets.view(-1) 60 | ) 61 | numpy.testing.assert_allclose( 62 | run_cross_entropy( 63 | large_inputs.view(-1, large_inputs.size(-1)), targets.view(-1) 64 | ) 65 | .detach() 66 | .numpy(), 67 | large_expected_cross_entropy.detach().numpy(), 68 | atol=1e-4, 69 | ) 70 | 71 | 72 | def test_gradient_clipping(): 73 | tensors = [torch.randn((5, 5)) for _ in range(6)] 74 | max_norm = 1e-2 75 | 76 | t1 = tuple(torch.nn.Parameter(torch.clone(t)) for t in tensors) 77 | # Test freezing one parameter. 78 | t1[-1].requires_grad_(False) 79 | 80 | loss = torch.cat(t1).sum() 81 | loss.backward() 82 | torch.nn.utils.clip_grad.clip_grad_norm_(t1, max_norm) 83 | t1_grads = [torch.clone(t.grad) for t in t1 if t.grad is not None] 84 | 85 | t1_c = tuple(torch.nn.Parameter(torch.clone(t)) for t in tensors) 86 | t1_c[-1].requires_grad_(False) 87 | loss_c = torch.cat(t1_c).sum() 88 | loss_c.backward() 89 | run_gradient_clipping(t1_c, max_norm) 90 | t1_c_grads = [torch.clone(t.grad) for t in t1_c if t.grad is not None] 91 | 92 | assert len(t1_grads) == len(t1_c_grads) 93 | 94 | for t1_grad, t1_c_grad in zip(t1_grads, t1_c_grads): 95 | numpy.testing.assert_allclose( 96 | t1_grad.detach().numpy(), 97 | t1_c_grad.detach().numpy(), 98 | atol=1e-6, 99 | ) 100 | -------------------------------------------------------------------------------- /tests/test_optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy 3 | import torch 4 | 5 | from .adapters import get_adamw_cls, run_get_lr_cosine_schedule 6 | from .common import FIXTURES_PATH 7 | 8 | 9 | def _optimize(opt_class) -> torch.Tensor: 10 | torch.manual_seed(42) 11 | model = torch.nn.Linear(3, 2, bias=False) 12 | opt = opt_class( 13 | model.parameters(), 14 | lr=1e-3, 15 | weight_decay=0.01, 16 | betas=(0.9, 0.999), 17 | eps=1e-8, 18 | ) 19 | # Use 1000 optimization steps for testing 20 | for _ in range(1000): 21 | opt.zero_grad() 22 | x = torch.rand(model.in_features) 23 | y_hat = model(x) 24 | y = torch.tensor([x[0] + x[1], -x[2]]) 25 | loss = ((y - y_hat) ** 2).sum() 26 | loss.backward() 27 | opt.step() 28 | return model.weight.detach() 29 | 30 | 31 | def test_adamw(): 32 | """ 33 | Our reference implementation yields slightly different results than the 34 | PyTorch AdamW, since there are a couple different ways that you can apply 35 | weight decay that are equivalent in principle, but differ in practice due to 36 | floating point behavior. So, we test that the provided implementation matches 37 | _either_ our reference implementation's expected results or those from the PyTorch AdamW. 38 | """ 39 | expected_weights = torch.load(FIXTURES_PATH / "adamw_expected_params.pt") 40 | pytorch_weights = _optimize(torch.optim.AdamW) 41 | actual_weights = _optimize(get_adamw_cls()) 42 | 43 | matches_expected = torch.allclose(actual_weights, expected_weights, atol=1e-6) 44 | matches_pytorch = torch.allclose(actual_weights, pytorch_weights, atol=1e-6) 45 | if matches_expected or matches_pytorch: 46 | return 47 | # re-raise the error if the provided implementation doesn't 48 | # match either our reference implementation or the PyTorch implementation 49 | numpy.testing.assert_allclose( 50 | actual_weights.detach().numpy(), expected_weights.detach().numpy(), atol=1e-6 51 | ) 52 | 53 | 54 | def test_get_lr_cosine_schedule(): 55 | max_learning_rate = 1 56 | min_learning_rate = 1 * 0.1 57 | warmup_iters = 7 58 | cosine_cycle_iters = 21 59 | 60 | expected_lrs = [ 61 | 0, 62 | 0.14285714285714285, 63 | 0.2857142857142857, 64 | 0.42857142857142855, 65 | 0.5714285714285714, 66 | 0.7142857142857143, 67 | 0.8571428571428571, 68 | 1.0, 69 | 0.9887175604818206, 70 | 0.9554359905560885, 71 | 0.9018241671106134, 72 | 0.8305704108364301, 73 | 0.7452476826029011, 74 | 0.6501344202803414, 75 | 0.55, 76 | 0.44986557971965857, 77 | 0.3547523173970989, 78 | 0.26942958916356996, 79 | 0.19817583288938662, 80 | 0.14456400944391146, 81 | 0.11128243951817937, 82 | 0.1, 83 | 0.1, 84 | 0.1, 85 | 0.1, 86 | ] 87 | actual_lrs = [ 88 | run_get_lr_cosine_schedule( 89 | it=it, 90 | max_learning_rate=max_learning_rate, 91 | min_learning_rate=min_learning_rate, 92 | warmup_iters=warmup_iters, 93 | cosine_cycle_iters=cosine_cycle_iters, 94 | ) 95 | for it in range(25) 96 | ] 97 | numpy.testing.assert_allclose(numpy.array(actual_lrs), numpy.array(expected_lrs)) 98 | -------------------------------------------------------------------------------- /tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .adapters import get_adamw_cls, run_load_checkpoint, run_save_checkpoint 8 | 9 | 10 | class _TestNet(nn.Module): 11 | def __init__(self, d_input: int = 100, d_output: int = 10): 12 | super(_TestNet, self).__init__() 13 | self.fc1 = nn.Linear(d_input, 200) 14 | self.fc2 = nn.Linear(200, 100) 15 | self.fc3 = nn.Linear(100, d_output) 16 | 17 | def forward(self, x): 18 | x = F.relu(self.fc1(x)) 19 | x = F.relu(self.fc2(x)) 20 | x = self.fc3(x) 21 | return x 22 | 23 | 24 | def are_optimizers_equal( 25 | optimizer1_state_dict, optimizer2_state_dict, atol=1e-8, rtol=1e-5 26 | ): 27 | # Check if the keys of the main dictionaries are equal (e.g., 'state', 'param_groups') 28 | if set(optimizer1_state_dict.keys()) != set(optimizer2_state_dict.keys()): 29 | return False 30 | 31 | # Check parameter groups are identical 32 | if optimizer1_state_dict["param_groups"] != optimizer2_state_dict["param_groups"]: 33 | return False 34 | 35 | # Check states 36 | state1 = optimizer1_state_dict["state"] 37 | state2 = optimizer2_state_dict["state"] 38 | if set(state1.keys()) != set(state2.keys()): 39 | return False 40 | 41 | for key in state1: 42 | # Assuming state contents are also dictionaries 43 | if set(state1[key].keys()) != set(state2[key].keys()): 44 | return False 45 | 46 | for sub_key in state1[key]: 47 | item1 = state1[key][sub_key] 48 | item2 = state2[key][sub_key] 49 | 50 | # If both items are tensors, use torch.allclose 51 | if torch.is_tensor(item1) and torch.is_tensor(item2): 52 | if not torch.allclose(item1, item2, atol=atol, rtol=rtol): 53 | return False 54 | # For non-tensor items, check for direct equality 55 | elif item1 != item2: 56 | return False 57 | return True 58 | 59 | 60 | def test_checkpointing(tmp_path): 61 | torch.manual_seed(42) 62 | d_input = 100 63 | d_output = 10 64 | num_iters = 10 65 | 66 | model = _TestNet(d_input=d_input, d_output=d_output) 67 | optimizer = get_adamw_cls()( 68 | model.parameters(), 69 | lr=1e-3, 70 | weight_decay=0.01, 71 | betas=(0.9, 0.999), 72 | eps=1e-8, 73 | ) 74 | # Use 1000 optimization steps for testing 75 | it = 0 76 | for _ in range(num_iters): 77 | optimizer.zero_grad() 78 | x = torch.rand(d_input) 79 | y = torch.rand(d_output) 80 | y_hat = model(x) 81 | loss = ((y - y_hat) ** 2).sum() 82 | loss.backward() 83 | optimizer.step() 84 | it += 1 85 | 86 | serialization_path = tmp_path / "checkpoint.pt" 87 | # Save the model 88 | run_save_checkpoint( 89 | model, 90 | optimizer, 91 | iteration=it, 92 | out=serialization_path, 93 | ) 94 | 95 | # Load the model back again 96 | new_model = _TestNet(d_input=d_input, d_output=d_output) 97 | new_optimizer = get_adamw_cls()( 98 | new_model.parameters(), 99 | lr=1e-3, 100 | weight_decay=0.01, 101 | betas=(0.9, 0.999), 102 | eps=1e-8, 103 | ) 104 | loaded_iterations = run_load_checkpoint( 105 | src=serialization_path, model=new_model, optimizer=new_optimizer 106 | ) 107 | assert it == loaded_iterations 108 | 109 | # Compare the loaded model state with the original model state 110 | original_model_state = model.state_dict() 111 | original_optimizer_state = optimizer.state_dict() 112 | new_model_state = new_model.state_dict() 113 | new_optimizer_state = new_optimizer.state_dict() 114 | 115 | # Check that state dict keys match 116 | assert set(original_model_state.keys()) == set(new_model_state.keys()) 117 | assert set(original_optimizer_state.keys()) == set(new_optimizer_state.keys()) 118 | 119 | # compare the model state dicts 120 | for key in original_model_state.keys(): 121 | numpy.testing.assert_allclose( 122 | original_model_state[key].detach().numpy(), 123 | new_model_state[key].detach().numpy(), 124 | ) 125 | # compare the optimizer state dicts 126 | assert are_optimizers_equal(original_optimizer_state, new_optimizer_state) 127 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | from __future__ import annotations 4 | 5 | import json 6 | import os 7 | import resource 8 | import sys 9 | from typing import Optional 10 | 11 | import psutil 12 | import pytest 13 | import tiktoken 14 | 15 | from .adapters import get_tokenizer 16 | from .common import FIXTURES_PATH, gpt2_bytes_to_unicode 17 | 18 | VOCAB_PATH = FIXTURES_PATH / "gpt2_vocab.json" 19 | MERGES_PATH = FIXTURES_PATH / "gpt2_merges.txt" 20 | 21 | 22 | def memory_limit(max_mem): 23 | def decorator(f): 24 | def wrapper(*args, **kwargs): 25 | process = psutil.Process(os.getpid()) 26 | prev_limits = resource.getrlimit(resource.RLIMIT_AS) 27 | resource.setrlimit( 28 | resource.RLIMIT_AS, (process.memory_info().rss + max_mem, -1) 29 | ) 30 | try: 31 | result = f(*args, **kwargs) 32 | return result 33 | finally: 34 | # Even if the function above fails (e.g., it exceeds the 35 | # memory limit), reset the memory limit back to the 36 | # previous limit so other tests aren't affected. 37 | resource.setrlimit(resource.RLIMIT_AS, prev_limits) 38 | 39 | return wrapper 40 | 41 | return decorator 42 | 43 | 44 | def get_tokenizer_from_vocab_merges_path( 45 | vocab_path: str | os.PathLike, 46 | merges_path: str | os.PathLike, 47 | special_tokens: Optional[list[str]] = None, 48 | ): 49 | gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()} 50 | with open(vocab_path) as vocab_f: 51 | gpt2_vocab = json.load(vocab_f) 52 | gpt2_bpe_merges = [] 53 | with open(merges_path) as f: 54 | for line in f: 55 | cleaned_line = line.rstrip() 56 | if cleaned_line and len(cleaned_line.split(" ")) == 2: 57 | gpt2_bpe_merges.append(tuple(cleaned_line.split(" "))) 58 | # The GPT-2 tokenizer uses a remapped unicode encoding for bytes. Let's 59 | # just return the original bytes, so we don't force students to use 60 | # any particular encoding scheme. 61 | vocab = { 62 | gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item]) 63 | for gpt2_vocab_item, gpt2_vocab_index in gpt2_vocab.items() 64 | } 65 | # If any of the special tokens don't exist in the vocab, append them to the vocab. 66 | if special_tokens: 67 | for special_token in special_tokens: 68 | byte_encoded_special_token = special_token.encode("utf-8") 69 | if byte_encoded_special_token not in set(vocab.values()): 70 | vocab[len(vocab)] = byte_encoded_special_token 71 | 72 | merges = [ 73 | ( 74 | bytes([gpt2_byte_decoder[token] for token in merge_token_1]), 75 | bytes([gpt2_byte_decoder[token] for token in merge_token_2]), 76 | ) 77 | for merge_token_1, merge_token_2 in gpt2_bpe_merges 78 | ] 79 | return get_tokenizer(vocab, merges, special_tokens) 80 | 81 | 82 | def test_roundtrip_empty(): 83 | tokenizer = get_tokenizer_from_vocab_merges_path( 84 | vocab_path=VOCAB_PATH, 85 | merges_path=MERGES_PATH, 86 | ) 87 | test_string = "" 88 | encoded_ids = tokenizer.encode(test_string) 89 | decoded_string = tokenizer.decode(encoded_ids) 90 | assert test_string == decoded_string 91 | 92 | 93 | def test_empty_matches_tiktoken(): 94 | reference_tokenizer = tiktoken.get_encoding("gpt2") 95 | tokenizer = get_tokenizer_from_vocab_merges_path( 96 | vocab_path=VOCAB_PATH, 97 | merges_path=MERGES_PATH, 98 | ) 99 | test_string = "" 100 | 101 | reference_ids = reference_tokenizer.encode(test_string) 102 | ids = tokenizer.encode(test_string) 103 | assert ids == reference_ids 104 | 105 | tokenized_string = [tokenizer.decode([x]) for x in ids] 106 | assert tokenized_string == [] 107 | 108 | assert tokenizer.decode(ids) == test_string 109 | assert reference_tokenizer.decode(reference_ids) == test_string 110 | 111 | 112 | def test_roundtrip_single_character(): 113 | tokenizer = get_tokenizer_from_vocab_merges_path( 114 | vocab_path=VOCAB_PATH, 115 | merges_path=MERGES_PATH, 116 | ) 117 | test_string = "s" 118 | encoded_ids = tokenizer.encode(test_string) 119 | decoded_string = tokenizer.decode(encoded_ids) 120 | assert test_string == decoded_string 121 | 122 | 123 | def test_single_character_matches_tiktoken(): 124 | reference_tokenizer = tiktoken.get_encoding("gpt2") 125 | tokenizer = get_tokenizer_from_vocab_merges_path( 126 | vocab_path=VOCAB_PATH, 127 | merges_path=MERGES_PATH, 128 | ) 129 | test_string = "s" 130 | 131 | reference_ids = reference_tokenizer.encode(test_string) 132 | ids = tokenizer.encode(test_string) 133 | assert ids == reference_ids 134 | 135 | tokenized_string = [tokenizer.decode([x]) for x in ids] 136 | assert tokenized_string == ["s"] 137 | 138 | assert tokenizer.decode(ids) == test_string 139 | assert reference_tokenizer.decode(reference_ids) == test_string 140 | 141 | 142 | def test_roundtrip_single_unicode_character(): 143 | tokenizer = get_tokenizer_from_vocab_merges_path( 144 | vocab_path=VOCAB_PATH, 145 | merges_path=MERGES_PATH, 146 | ) 147 | test_string = "🙃" 148 | encoded_ids = tokenizer.encode(test_string) 149 | decoded_string = tokenizer.decode(encoded_ids) 150 | assert test_string == decoded_string 151 | 152 | 153 | def test_single_unicode_character_matches_tiktoken(): 154 | reference_tokenizer = tiktoken.get_encoding("gpt2") 155 | tokenizer = get_tokenizer_from_vocab_merges_path( 156 | vocab_path=VOCAB_PATH, 157 | merges_path=MERGES_PATH, 158 | ) 159 | test_string = "🙃" 160 | 161 | reference_ids = reference_tokenizer.encode(test_string) 162 | ids = tokenizer.encode(test_string) 163 | assert ids == reference_ids 164 | 165 | assert tokenizer.decode(ids) == test_string 166 | assert reference_tokenizer.decode(reference_ids) == test_string 167 | 168 | 169 | def test_roundtrip_ascii_string(): 170 | tokenizer = get_tokenizer_from_vocab_merges_path( 171 | vocab_path=VOCAB_PATH, 172 | merges_path=MERGES_PATH, 173 | ) 174 | test_string = "Hello, how are you?" 175 | encoded_ids = tokenizer.encode(test_string) 176 | decoded_string = tokenizer.decode(encoded_ids) 177 | assert test_string == decoded_string 178 | 179 | 180 | def test_ascii_string_matches_tiktoken(): 181 | reference_tokenizer = tiktoken.get_encoding("gpt2") 182 | tokenizer = get_tokenizer_from_vocab_merges_path( 183 | vocab_path=VOCAB_PATH, merges_path=MERGES_PATH, special_tokens=["<|endoftext|>"] 184 | ) 185 | test_string = "Hello, how are you?" 186 | 187 | reference_ids = reference_tokenizer.encode(test_string) 188 | ids = tokenizer.encode(test_string) 189 | # assert ids == reference_ids 190 | 191 | tokenized_string = [tokenizer.decode([x]) for x in ids] 192 | assert tokenized_string == ["Hello", ",", " how", " are", " you", "?"] 193 | 194 | assert tokenizer.decode(ids) == test_string 195 | assert reference_tokenizer.decode(reference_ids) == test_string 196 | 197 | 198 | def test_roundtrip_unicode_string(): 199 | tokenizer = get_tokenizer_from_vocab_merges_path( 200 | vocab_path=VOCAB_PATH, 201 | merges_path=MERGES_PATH, 202 | ) 203 | test_string = "Héllò hôw are ü? 🙃" 204 | encoded_ids = tokenizer.encode(test_string) 205 | decoded_string = tokenizer.decode(encoded_ids) 206 | assert test_string == decoded_string 207 | 208 | 209 | def test_unicode_string_matches_tiktoken(): 210 | reference_tokenizer = tiktoken.get_encoding("gpt2") 211 | tokenizer = get_tokenizer_from_vocab_merges_path( 212 | vocab_path=VOCAB_PATH, merges_path=MERGES_PATH, special_tokens=["<|endoftext|>"] 213 | ) 214 | test_string = "Héllò hôw are ü? 🙃" 215 | 216 | reference_ids = reference_tokenizer.encode(test_string) 217 | ids = tokenizer.encode(test_string) 218 | assert ids == reference_ids 219 | 220 | assert tokenizer.decode(ids) == test_string 221 | assert reference_tokenizer.decode(reference_ids) == test_string 222 | 223 | 224 | def test_roundtrip_unicode_string_with_special_tokens(): 225 | tokenizer = get_tokenizer_from_vocab_merges_path( 226 | vocab_path=VOCAB_PATH, merges_path=MERGES_PATH, special_tokens=["<|endoftext|>"] 227 | ) 228 | test_string = "Héllò hôw <|endoftext|><|endoftext|> are ü? 🙃<|endoftext|>" 229 | encoded_ids = tokenizer.encode(test_string) 230 | tokenized_string = [tokenizer.decode([x]) for x in encoded_ids] 231 | # Ensure the special <|endoftext|> token is preserved 232 | assert tokenized_string.count("<|endoftext|>") == 3 233 | 234 | decoded_string = tokenizer.decode(encoded_ids) 235 | assert test_string == decoded_string 236 | 237 | 238 | def test_unicode_string_with_special_tokens_matches_tiktoken(): 239 | reference_tokenizer = tiktoken.get_encoding("gpt2") 240 | tokenizer = get_tokenizer_from_vocab_merges_path( 241 | vocab_path=VOCAB_PATH, merges_path=MERGES_PATH, special_tokens=["<|endoftext|>"] 242 | ) 243 | test_string = "Héllò hôw <|endoftext|><|endoftext|> are ü? 🙃<|endoftext|>" 244 | 245 | reference_ids = reference_tokenizer.encode( 246 | test_string, allowed_special={"<|endoftext|>"} 247 | ) 248 | ids = tokenizer.encode(test_string) 249 | assert ids == reference_ids 250 | 251 | assert tokenizer.decode(ids) == test_string 252 | assert reference_tokenizer.decode(reference_ids) == test_string 253 | 254 | 255 | def test_overlapping_special_tokens(): 256 | tokenizer = get_tokenizer_from_vocab_merges_path( 257 | vocab_path=VOCAB_PATH, 258 | merges_path=MERGES_PATH, 259 | special_tokens=["<|endoftext|>", "<|endoftext|><|endoftext|>"], 260 | ) 261 | test_string = "Hello, how <|endoftext|><|endoftext|> are you?<|endoftext|>" 262 | 263 | ids = tokenizer.encode(test_string) 264 | tokenized_string = [tokenizer.decode([x]) for x in ids] 265 | # Ensure the double <|endoftext|><|endoftext|> is preserved as a single token 266 | assert tokenized_string.count("<|endoftext|>") == 1 267 | assert tokenized_string.count("<|endoftext|><|endoftext|>") == 1 268 | # Test roundtrip 269 | assert tokenizer.decode(ids) == test_string 270 | 271 | 272 | def test_address_roundtrip(): 273 | tokenizer = get_tokenizer_from_vocab_merges_path( 274 | vocab_path=VOCAB_PATH, 275 | merges_path=MERGES_PATH, 276 | ) 277 | with open(FIXTURES_PATH / "address.txt") as f: 278 | corpus_contents = f.read() 279 | 280 | ids = tokenizer.encode(corpus_contents) 281 | assert tokenizer.decode(ids) == corpus_contents 282 | 283 | 284 | def test_address_matches_tiktoken(): 285 | reference_tokenizer = tiktoken.get_encoding("gpt2") 286 | tokenizer = get_tokenizer_from_vocab_merges_path( 287 | vocab_path=VOCAB_PATH, 288 | merges_path=MERGES_PATH, 289 | ) 290 | corpus_path = FIXTURES_PATH / "address.txt" 291 | with open(corpus_path) as f: 292 | corpus_contents = f.read() 293 | reference_ids = reference_tokenizer.encode(corpus_contents) 294 | ids = tokenizer.encode(corpus_contents) 295 | assert ids == reference_ids 296 | 297 | assert tokenizer.decode(ids) == corpus_contents 298 | assert reference_tokenizer.decode(reference_ids) == corpus_contents 299 | 300 | 301 | def test_german_roundtrip(): 302 | tokenizer = get_tokenizer_from_vocab_merges_path( 303 | vocab_path=VOCAB_PATH, 304 | merges_path=MERGES_PATH, 305 | ) 306 | with open(FIXTURES_PATH / "german.txt") as f: 307 | corpus_contents = f.read() 308 | 309 | ids = tokenizer.encode(corpus_contents) 310 | assert tokenizer.decode(ids) == corpus_contents 311 | 312 | 313 | def test_german_matches_tiktoken(): 314 | reference_tokenizer = tiktoken.get_encoding("gpt2") 315 | tokenizer = get_tokenizer_from_vocab_merges_path( 316 | vocab_path=VOCAB_PATH, 317 | merges_path=MERGES_PATH, 318 | ) 319 | corpus_path = FIXTURES_PATH / "german.txt" 320 | with open(corpus_path) as f: 321 | corpus_contents = f.read() 322 | reference_ids = reference_tokenizer.encode(corpus_contents) 323 | ids = tokenizer.encode(corpus_contents) 324 | assert ids == reference_ids 325 | 326 | assert tokenizer.decode(ids) == corpus_contents 327 | assert reference_tokenizer.decode(reference_ids) == corpus_contents 328 | 329 | 330 | def test_tinystories_sample_roundtrip(): 331 | tokenizer = get_tokenizer_from_vocab_merges_path( 332 | vocab_path=VOCAB_PATH, 333 | merges_path=MERGES_PATH, 334 | ) 335 | with open(FIXTURES_PATH / "tinystories_sample.txt") as f: 336 | corpus_contents = f.read() 337 | 338 | ids = tokenizer.encode(corpus_contents) 339 | assert tokenizer.decode(ids) == corpus_contents 340 | 341 | 342 | def test_tinystories_matches_tiktoken(): 343 | reference_tokenizer = tiktoken.get_encoding("gpt2") 344 | tokenizer = get_tokenizer_from_vocab_merges_path( 345 | vocab_path=VOCAB_PATH, merges_path=MERGES_PATH, special_tokens=["<|endoftext|>"] 346 | ) 347 | corpus_path = FIXTURES_PATH / "tinystories_sample.txt" 348 | with open(corpus_path) as f: 349 | corpus_contents = f.read() 350 | reference_ids = reference_tokenizer.encode( 351 | corpus_contents, allowed_special={"<|endoftext|>"} 352 | ) 353 | ids = tokenizer.encode(corpus_contents) 354 | assert ids == reference_ids 355 | 356 | assert tokenizer.decode(ids) == corpus_contents 357 | assert reference_tokenizer.decode(reference_ids) == corpus_contents 358 | 359 | 360 | def test_encode_iterable_tinystories_sample_roundtrip(): 361 | tokenizer = get_tokenizer_from_vocab_merges_path( 362 | vocab_path=VOCAB_PATH, 363 | merges_path=MERGES_PATH, 364 | ) 365 | all_ids = [] 366 | with open(FIXTURES_PATH / "tinystories_sample.txt") as f: 367 | for _id in tokenizer.encode_iterable(f): 368 | all_ids.append(_id) 369 | with open(FIXTURES_PATH / "tinystories_sample.txt") as f: 370 | corpus_contents = f.read() 371 | assert tokenizer.decode(all_ids) == corpus_contents 372 | 373 | 374 | def test_encode_iterable_tinystories_matches_tiktoken(): 375 | reference_tokenizer = tiktoken.get_encoding("gpt2") 376 | tokenizer = get_tokenizer_from_vocab_merges_path( 377 | vocab_path=VOCAB_PATH, merges_path=MERGES_PATH, special_tokens=["<|endoftext|>"] 378 | ) 379 | corpus_path = FIXTURES_PATH / "tinystories_sample.txt" 380 | with open(corpus_path) as f: 381 | corpus_contents = f.read() 382 | reference_ids = reference_tokenizer.encode( 383 | corpus_contents, allowed_special={"<|endoftext|>"} 384 | ) 385 | all_ids = [] 386 | with open(FIXTURES_PATH / "tinystories_sample.txt") as f: 387 | for _id in tokenizer.encode_iterable(f): 388 | all_ids.append(_id) 389 | assert all_ids == reference_ids 390 | 391 | assert tokenizer.decode(all_ids) == corpus_contents 392 | assert reference_tokenizer.decode(reference_ids) == corpus_contents 393 | 394 | 395 | @pytest.mark.skipif( 396 | not sys.platform.startswith("linux"), 397 | reason="rlimit support for non-linux systems is spotty.", 398 | ) 399 | def test_encode_iterable_memory_usage(): 400 | tokenizer = get_tokenizer_from_vocab_merges_path( 401 | vocab_path=VOCAB_PATH, 402 | merges_path=MERGES_PATH, 403 | ) 404 | with open(FIXTURES_PATH / "tinystories_sample_5M.txt") as f: 405 | ids = [] 406 | for _id in _encode_iterable(tokenizer, f): 407 | ids.append(_id) 408 | 409 | 410 | @pytest.mark.skipif( 411 | not sys.platform.startswith("linux"), 412 | reason="rlimit support for non-linux systems is spotty.", 413 | ) 414 | @pytest.mark.xfail( 415 | reason="Tokenizer.encode is expected to take more memory than allotted (1MB)." 416 | ) 417 | def test_encode_memory_usage(): 418 | """ 419 | We expect this test to fail, since Tokenizer.encode is not expected to be memory efficient. 420 | """ 421 | tokenizer = get_tokenizer_from_vocab_merges_path( 422 | vocab_path=VOCAB_PATH, 423 | merges_path=MERGES_PATH, 424 | ) 425 | with open(FIXTURES_PATH / "tinystories_sample_5M.txt") as f: 426 | contents = f.read() 427 | _ = _encode(tokenizer, contents) 428 | 429 | 430 | @memory_limit(int(1e6)) 431 | def _encode_iterable(tokenizer, iterable): 432 | """ 433 | We place tokenizer.encode_iterable into a separate function so we can limit memory 434 | for just this function. We set the memory limit to 1MB. 435 | """ 436 | yield from tokenizer.encode_iterable(iterable) 437 | 438 | 439 | @memory_limit(int(1e6)) 440 | def _encode(tokenizer, text): 441 | """ 442 | We place tokenizer.encode into a separate function so we can limit memory 443 | for just this function. We set the memory limit to 1MB. 444 | """ 445 | return tokenizer.encode(text) 446 | -------------------------------------------------------------------------------- /tests/test_train_bpe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import json 3 | import time 4 | 5 | from .adapters import run_train_bpe 6 | from .common import FIXTURES_PATH, gpt2_bytes_to_unicode 7 | 8 | 9 | def test_train_bpe_speed(): 10 | """ 11 | Ensure that BPE training is relatively efficient by measuring training 12 | time on this small dataset and throwing an error if it takes more than 1.5 seconds. 13 | This is a pretty generous upper-bound, it takes 0.38 seconds with the 14 | reference implementation on my laptop. In contrast, the toy implementation 15 | takes around 3 seconds. 16 | """ 17 | input_path = FIXTURES_PATH / "corpus.en" 18 | start_time = time.time() 19 | _, _ = run_train_bpe( 20 | input_path=input_path, 21 | vocab_size=500, 22 | special_tokens=["<|endoftext|>"], 23 | ) 24 | end_time = time.time() 25 | assert end_time - start_time < 1.5 26 | 27 | 28 | def test_train_bpe(): 29 | input_path = FIXTURES_PATH / "corpus.en" 30 | vocab, merges = run_train_bpe( 31 | input_path=input_path, 32 | vocab_size=500, 33 | special_tokens=["<|endoftext|>"], 34 | ) 35 | 36 | # Path to the reference tokenizer vocab and merges 37 | reference_vocab_path = FIXTURES_PATH / "train-bpe-reference-vocab.json" 38 | reference_merges_path = FIXTURES_PATH / "train-bpe-reference-merges.txt" 39 | 40 | # Compare the learned merges to the expected output merges 41 | gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()} 42 | with open(reference_merges_path) as f: 43 | gpt2_reference_merges = [tuple(line.rstrip().split(" ")) for line in f] 44 | reference_merges = [ 45 | ( 46 | bytes([gpt2_byte_decoder[token] for token in merge_token_1]), 47 | bytes([gpt2_byte_decoder[token] for token in merge_token_2]), 48 | ) 49 | for merge_token_1, merge_token_2 in gpt2_reference_merges 50 | ] 51 | assert merges == reference_merges 52 | 53 | # Compare the vocab to the expected output vocab 54 | with open(reference_vocab_path) as f: 55 | gpt2_reference_vocab = json.load(f) 56 | reference_vocab = { 57 | gpt2_vocab_index: bytes( 58 | [gpt2_byte_decoder[token] for token in gpt2_vocab_item] 59 | ) 60 | for gpt2_vocab_item, gpt2_vocab_index in gpt2_reference_vocab.items() 61 | } 62 | # Rather than checking that the vocabs exactly match (since they could 63 | # have been constructed differently, we'll make sure that the vocab keys and values match) 64 | assert set(vocab.keys()) == set(reference_vocab.keys()) 65 | assert set(vocab.values()) == set(reference_vocab.values()) 66 | --------------------------------------------------------------------------------