├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── CoordCheck.ipynb ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── coord_checks ├── sp_cnn_adam_lr0.001_nseeds5_bn0_coord.png ├── sp_cnn_adam_lr0.001_nseeds5_bn1_coord.png ├── sp_cnn_sgd_lr0.1_nseeds5_bn0_coord.png ├── sp_cnn_sgd_lr0.1_nseeds5_bn1_coord.png ├── sp_mlp_adam_lr0.001_nseeds5_bn0_coord.png ├── sp_mlp_adam_lr0.001_nseeds5_bn1_coord.png ├── sp_mlp_sgd_lr0.1_nseeds5_bn0_coord.png ├── sp_mlp_sgd_lr0.1_nseeds5_bn1_coord.png ├── μp_cnn_adam_lr0.001_nseeds5_bn0_coord.png ├── μp_cnn_adam_lr0.001_nseeds5_bn1_coord.png ├── μp_cnn_sgd_lr0.1_nseeds5_bn0_coord.png ├── μp_cnn_sgd_lr0.1_nseeds5_bn1_coord.png ├── μp_mlp_adam_lr0.001_nseeds5_bn0_coord.png ├── μp_mlp_adam_lr0.001_nseeds5_bn1_coord.png ├── μp_mlp_sgd_lr0.1_nseeds5_bn0_coord.png └── μp_mlp_sgd_lr0.1_nseeds5_bn1_coord.png ├── examples ├── .gitignore ├── MLP │ ├── README.md │ ├── coord_checks │ │ ├── sp_mlp_sgd_coord.png │ │ └── μp_mlp_sgd_coord.png │ ├── demo.ipynb │ ├── main.py │ └── width64.bsh ├── ResNet │ ├── CoordCheck.ipynb │ ├── README.md │ ├── coord_checks │ │ ├── sp_resnet18_adam_coord.png │ │ ├── sp_resnet18_sgd_coord.png │ │ ├── μp_resnet18_adam_coord.png │ │ └── μp_resnet18_sgd_coord.png │ ├── main.py │ ├── resnet.py │ ├── resnet18.bsh │ └── utils.py └── Transformer │ ├── CoordCheck.ipynb │ ├── README.md │ ├── _overrides.py │ ├── coord_checks │ ├── sp_trsfmr_adam_coord.png │ ├── sp_trsfmr_sgd_coord.png │ ├── μp_trsfmr_adam_coord.png │ └── μp_trsfmr_sgd_coord.png │ ├── data.py │ ├── data │ └── wikitext-2 │ │ ├── README │ │ ├── dict.pt │ │ ├── test.pt │ │ ├── train.pt │ │ └── valid.pt │ ├── generate.py │ ├── main.py │ ├── model.py │ └── width256.bsh ├── figures ├── parametrizations.gif ├── sp_vs_mup_dashed.png └── widerbetter.png ├── mup ├── __init__.py ├── coord_check.py ├── infshape.py ├── init.py ├── layer.py ├── optim.py ├── shape.py └── test │ ├── __main__.py │ └── models.py ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Datasets 2 | ResNet*/data/ 3 | Transformer*/data/* 4 | dataset/* 5 | 6 | # jupyter checkpoints 7 | **/.ipynb_checkpoints 8 | 9 | # Compiled python modules. 10 | *.pyc 11 | 12 | # Byte-compiled 13 | _pycache__/ 14 | .cache/ 15 | 16 | # Python egg metadata, regenerated from source files by setuptools. 17 | *.egg-info 18 | .eggs/ 19 | 20 | # PyPI distribution artifacts. 21 | build/ 22 | dist/ 23 | 24 | # Environments 25 | .env 26 | .venv 27 | env/ 28 | venv/ 29 | ENV/ 30 | env.bak/ 31 | venv.bak/ 32 | 33 | # PyCharm/vscode 34 | .idea 35 | .vscode 36 | 37 | # Vim 38 | .*.swp 39 | 40 | # Other 41 | *.DS_Store -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "examples/mutransformers"] 2 | path = examples/mutransformers 3 | url = https://github.com/microsoft/mutransformers 4 | branch = main 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Maximal Update Parametrization (μP) and Hyperparameter Transfer (μTransfer) 2 | 3 | [Paper link](https://arxiv.org/abs/2203.03466) 4 | | 5 | [Blog link](https://www.microsoft.com/en-us/research/blog/%C2%B5transfer-a-technique-for-hyperparameter-tuning-of-enormous-neural-networks/) 6 | | 7 | [YouTube link](https://www.youtube.com/watch?v=z8-C42mAwBc) 8 | 9 | In [*Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer*](https://arxiv.org/abs/2203.03466), we show that optimal hyperparameters become stable across neural network sizes when we parametrize the model in [maximal update parametrization (μP)](http://arxiv.org/abs/2011.14522). 10 | This can be used to tune extremely large neural networks such as large pretrained transformers, as we have done in our work. 11 | More generally, μP reduces the fragility and uncertainty when transitioning from exploration to scaling up, which are not often talked about explicitly in the deep learning literature. 12 | 13 | ![](figures/sp_vs_mup_dashed.png) 14 | *Figure above: Training loss against learning rate on Transformers of varying `d_model` trained with Adam.* 15 | 16 | 17 | μP turns out to be the *unique* "natural" parametrization that has this hyperparameter stability property across width, as empirically verified in the gif below on MLPs trained with SGD. Here, across time, we interpolate between PyTorch default and μP's learning rate and initialization scalings (right), and we scale up the width-256 model (log2(width)=8) to width 2^13 = 8192 using this interpolated scaling rule (left). 18 | 19 | ![](figures/parametrizations.gif) 20 | 21 | This repo contains the source code for the `mup` package, our tool that makes the implementation of μP in Pytorch models effortless and less error-prone. 22 | 23 | ## Table of Contents 24 | 25 | 26 | - [Installation](#installation) 27 | - [Install From Source](#install-from-source) 28 | - [Basic Usage](#basic-usage) 29 | - [How `mup` Works Under the Hood](#how-mup-works-under-the-hood) 30 | - [Current Limitations](#current-limitations) 31 | - [Checking Correctness of Parametrization](#checking-correctness-of-parametrization) 32 | - [Coord Check](#coord-check) 33 | - [Making Your Own Coord Check Plots](#making-your-own-coord-check-plots) 34 | - [Wider is Always Better](#wider-is-always-better) 35 | - [Examples](#examples) 36 | - [Running Tests](#running-tests) 37 | - [The Basic Math](#the-basic-math) 38 | - [Contributing](#contributing) 39 | - [Trademarks](#trademarks) 40 | 41 | ## Installation 42 | 43 | ``` 44 | pip install mup 45 | ``` 46 | 47 | ### Install From Source 48 | 49 | Clone this repo, change to its directory, and do 50 | ``` 51 | pip install -r requirements.txt 52 | pip install -e . 53 | ``` 54 | 55 | ## Basic Usage 56 | 57 | ```Python 58 | from mup import MuReadout, make_base_shapes, set_base_shapes, MuSGD, MuAdam 59 | 60 | class MyModel(nn.Module): 61 | def __init__(self, width, ...): 62 | ... 63 | ### In model definition, replace output layer with MuReadout 64 | # readout = nn.Linear(width, d_out) 65 | readout = MuReadout(width, d_out) 66 | ### If tying weights with an input nn.Embedding layer, do 67 | # readout = MuSharedReadout(input_layer.weight) 68 | ... 69 | def forward(self, ...): 70 | ... 71 | ### If using a transformer, make sure to use 72 | ### 1/d instead of 1/sqrt(d) attention scaling 73 | # attention_scores = query @ key.T / d**0.5 74 | attention_scores = query @ key.T * 8 / d 75 | ### We use 8/d instead of 1/d here to be backward compatible 76 | ### with 1/d**0.5 when d=64, a common head dimension. 77 | ... 78 | 79 | ### Instantiate a base model 80 | base_model = MyModel(width=1) 81 | ### Optionally, use `torchdistx.deferred_init.deferred_init` to avoid instantiating the parameters 82 | ### Simply install `torchdistx` and use 83 | # base_model = torchdistx.deferred_init.deferred_init(MyModel, width=1) 84 | ### Instantiate a "delta" model that differs from the base model 85 | ### in all dimensions ("widths") that one wishes to scale. 86 | ### Here it's simple, but e.g., in a Transformer, you may want to scale 87 | ### both nhead and dhead, so the delta model should differ in both. 88 | delta_model = MyModel(width=2) # Optionally use `torchdistx` to avoid instantiating 89 | 90 | ### Instantiate the target model (the model you actually want to train). 91 | ### This should be the same as the base model except 92 | ### the widths could be potentially different. 93 | ### In particular, base_model and model should have the same depth. 94 | model = MyModel(width=100) 95 | 96 | ### Set base shapes 97 | ### When `model` has same parameter shapes as `base_model`, 98 | ### `model` behaves exactly the same as `base_model` 99 | ### (which is in PyTorch's default parametrization). 100 | ### This provides backward compatibility at this particular model size. 101 | ### Otherwise, `model`'s init and LR are scaled by μP. 102 | ### IMPORTANT: this should be called as soon as possible, 103 | ### before re-initialization and optimizer definition. 104 | set_base_shapes(model, base_model, delta=delta_model) 105 | 106 | ### Alternatively, one can save the base model shapes in a file 107 | # make_base_shapes(base_model, delta_model, filename) 108 | ### and later set base shapes directly from the filename 109 | # set_base_shapes(model, filename) 110 | ### This is useful when one cannot fit both 111 | ### base_model and model in memory at the same time 112 | 113 | ### Replace your custom init, if any 114 | for param in model.parameters(): 115 | ### If initializing manually with fixed std or bounds, 116 | ### then replace with same function from mup.init 117 | # torch.nn.init.uniform_(param, -0.1, 0.1) 118 | mup.init.uniform_(param, -0.1, 0.1) 119 | ### Likewise, if using 120 | ### `xavier_uniform_, xavier_normal_, kaiming_uniform_, kaiming_normal_` 121 | ### from `torch.nn.init`, replace with the same functions from `mup.init` 122 | 123 | ### Use the optimizers from `mup.optim` instead of `torch.optim` 124 | # optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 125 | optimizer = MuSGD(model.parameters(), lr=0.1) 126 | 127 | ### Then just train normally 128 | ``` 129 | 130 | Note the base and delta models *do not need to be trained* --- we are only extracting parameter shape information from them. 131 | Therefore, optionally, we can avoid instantiating these potentially large models by using the `deferred_init` function in `torchdistx`. 132 | After installing [`torchdistx`](https://github.com/pytorch/torchdistx), use `torchdistx.deferred_init.deferred_init(MyModel, **args)` instead of `MyModel(**args)`. See [this page](https://pytorch.org/torchdistx/latest/deferred_init.html) for more detail. 133 | In the MLP and Transformer examples (not `mutransformers`) we provided, you can activate this feature by passing `--deferred_init`. 134 | 135 | 136 | ## How `mup` Works Under the Hood 137 | 138 | 139 | By invoking `set_base_shapes(model, ...)`, each parameter tensor `p` of `model` gets a `p.infshape` attribute that stores, for each of its dimensions, the corresponding base dimension and whether that dimension should be considered `infinite` (i.e. will be scaled up/down, e.g., `d_model` of a Transformer) or `finite` (i.e. will be fixed, e.g., vocabulary size). 140 | This information is used in the initializers and optimizers to automatically scale the parameters or learning rates to be compliant with μP. 141 | For example, the Adam learning rate of hidden weights `p` is calculated as `globalLR / p.infshape.width_mult()`, where `p.infshape.width_mult()` essentially calculates `fan_in / base_fan_in`. 142 | 143 | 144 | ## Current Limitations 145 | 146 | - `set_base_shapes(model, ...)` assumes that `model` has just been randomly initialized in the standard way and rescales its parameters using the base shape information so the model is in μP. 147 | - If you want data parallelism, please use `torch.nn.parallel.DistributedDataParallel` instead of `torch.nn.DataParallel`. This is because the latter removes the attributes the `mup` package adds to each parameter tensor of the model. Also, for performance, `pytorch` [recommends the former anyway](https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead). 148 | - We scale the learning rate according to μP explicitly by creating refined parameter groups from what is passed to the `mup` optimizer and by manipulating the `lr` attribute in those groups. This is compatible with PyTorch's learning rate schedulers. However, if you roll your own, make sure the scheduler sets the learning rate relative to what is currently in the refined parameter groups. The following is an example of what *not* to do and what is OK: 149 | ```python 150 | optimizer = mup.MuAdam(model.parameters(), lr=1e-3) 151 | for pg in optimizer.param_groups: 152 | # what NOT to do: setting learning rate absolutely 153 | # pg['lr'] = 1e-3 * 2 154 | # what is an OK alternative: setting it relatively 155 | pg['lr'] *= 2 156 | ``` 157 | - By default, any parameter matrix that has 2 "infinite" dimensions (i.e. dimensions that are different from base dimensions) are considered by `mup` to have shape (fan_out, fan_in), i.e., in the forward pass, this matrix multiplies its input on the right. This is the case with all `nn.Linear` weights from pytorch. If you have a custom parameter, say `W`, that violates this convention, you can manually set `W.infshape.main_idx = 0; W.infshape.main = W.infshape[0]` to let `mup` know that its shape corresponds to (fan_in, fan_out). A similar discussion applies if you have a parameter *tensor* with many dimensions but exactly 2 "infinite" dimensions, for which the first is fan_in and the second is fan_out. 158 | - Currently, [`torch.save` does not save the `infshape` objects attached to each parameter tensor](https://github.com/pytorch/pytorch/issues/72129). Before this is fixed, you would have to set base shape manually after loading a model checkpoint like so: 159 | ```python 160 | model = torch.load('my/model/path.pt') 161 | # Important: note the flag `rescale_params=False`! 162 | set_base_shapes(model, 'my/base/shape/path.bsh', rescale_params=False) 163 | ``` 164 | (`set_base_shapes` by default rescales the parameters of `model`, assuming it's freshly initialized by PyTorch, to be consistent with μP. 165 | The `rescale_params=False` flag turns off this behavior.) 166 | 167 | 168 | ## Checking Correctness of Parametrization 169 | 170 | 171 | ### Coord Check 172 | 173 | Just like gradient checking is a simple way of verifying the correctness of an autograd implementation, *coordinate checking* is a simple way to verify you have implemented μP correctly: calculate the average size (which we denote in the y-axis below by `l1`) of the coordinates of each activation vector in, and output of, the model, for a few steps of training and a few different widths. 174 | If implemented correctly, then we shall see this `l1` stable over many widths; otherwise, the `l1` can blow up or shrink to 0 with width. 175 | (We are essentially checking desideratum 1 described below.) 176 | (The `l1` calculates `x.abs().mean()` for each activation vector `x` and is just one measure of the "average size" of `x`'s entries; one can also use analogously defined `l2`, `l4`, etc, though they may exhibit greater fluctuation with random seeds.) 177 | 178 | For example, in the following, we plot `width` vs `l1` for 2 steps of training, where t=1 means at initialization, before any gradient update. 179 | Each curve corresponds to an (pre-)activation vector of a layer or the output of the network. 180 | The first set of 3 plots shows an MLP in standard parametrization (SP), trained by adam. 181 | We see after 1 step of update, activation/output `l1` are exploding with width. 182 | This means SP is "incorrect." 183 | ![](coord_checks/sp_mlp_adam_lr0.001_nseeds5_bn0_coord.png) 184 | We now do the same for an MLP in maximal update parametrization (μP) (including using `mup.optim.MuAdam` instead of `torch.optim.Adam`). 185 | In contrast to the above, all curves stay horizontal, indicating that μP is implemented correctly. 186 | ![](coord_checks/μp_mlp_adam_lr0.001_nseeds5_bn0_coord.png) 187 | We call this way of checking implementation correctness a *coord check*, short for "coordinate check." 188 | 189 | ### Making Your Own Coord Check Plots 190 | We provide an easy way to implement this check via functions in the `mup.coord_check` module. 191 | The workflow typically looks like the following. 192 | 193 | ```Python 194 | from mup.coord_check import get_coord_data, plot_coord_data 195 | # construct a dictionary of lazy μP models with differing widths 196 | def lazy_model(width): 197 | # `set_base_shapes` returns the model 198 | return lambda: set_base_shapes(MyMuModel(width), 'my/base/shape/path.bsh') 199 | # Note: any custom initialization with `mup.init` would need to 200 | # be done inside the lambda as well 201 | models = {64: lazy_model(64), ..., 1024: lazy_model(1024)} 202 | # make a dataloader with small batch size/seq len 203 | # just for testing 204 | dataloader = ... 205 | # record data from the model activations over a few steps of training 206 | # this returns a pandas dataframe 207 | df = get_coord_data(models, dataloader) 208 | # This saves the coord check plots to filename. 209 | plot_coord_data(df, save_to=filename) 210 | # If you are in jupyter notebook, you can also do 211 | # `plt.show()` 212 | # to show the plot 213 | ``` 214 | For example, the `mup.coord_check.example_plot_coord_check` function is implemented this way for toy MLP and CNN models. 215 | 216 | If you see the curves blow up or shrink to 0 with width after a few steps of training, then there's a bug in your μP implementation (did you forget to vary some dimension, like `d_ffn`, in the delta model?). 217 | If instead you see the curves converge to the right, then most likely your implementation is correct. 218 | However, there are two typical exceptions to this; 219 | the following can shrink to 0 at initialization in μP (at a 1/sqrt(width) rate): 220 | - the network output 221 | - the attention logits in a Transformer 222 | 223 | These are transient, and after a few steps their curves should be roughly flat. 224 | Nevertheless, to remove the discrepancy at init, we recommend 225 | - initializing the output layer 226 | (should be a `MuReadout` instance) weights to be 0 via 227 | the `readout_zero_init=True` option and 228 | - initializing the query matrix in a Transformer to 0 229 | (this has to be done manually). If symmetry-breaking is desired in the attention logits at init, initialize the (relative) position biases with nonzero variance. 230 | 231 | #### Tips for Coord Check 232 | 233 | - Use a large learning rate (larger than you'd use for actual training). This would emphasize any potential exploding coordinates issue, which could be hidden by the initialization if the learning rate is too small. 234 | - If you reuse a module multiple times in the forward pass, then `mup.get_coord_data` will only record the statistics from the last usage. In this case, for testing purposes, one can wrap different usages with `nn.Identity` modules of different names to distinguish them. 235 | 236 | ### Wider is Always Better 237 | 238 | ![](figures/widerbetter.png) 239 | 240 | Another sign that μP has not been implemented correctly is if going wider does worse (on training loss) after some width, at some point during training. 241 | The figure above illustrates this in a collection of training curves: (left) the correct implementation should always see performance improve with width, at any point in training; (middle) if you used standard parametrization (SP), sometimes you may see performance improve with width up to some point and then suddenly it becomes worse with wider models; (right) or you may immediately see worsening performance even for narrow models. 242 | 243 | ## Examples 244 | See the `MLP`, `Transformer`, and `ResNet` folders inside `examples/` as well as the tests in `mup/test` for examples. 245 | People familiar with [Huggingface Transformers](https://github.com/huggingface/transformers) may also find the `examples/mutransformers` submodule instructive (obtained via `git submodule update --init`), which is also available standalone at [https://github.com/microsoft/mutransformers](https://github.com/microsoft/mutransformers). 246 | 247 | ## Native Integration With Huggingface 248 | 249 | Frustrated that your [Huggingface Transformer](https://github.com/huggingface/transformers) breaks when you scale up? Want to tune hyperparameters for your large mult-GPU [Huggingface Transformer](https://github.com/huggingface/transformers) on a single GPU, right out the box? If so, please upvote [this github issue](https://github.com/huggingface/transformers/issues/16157)! 250 | 251 | 252 | ## Running Tests 253 | To run tests, do 254 | ```bash 255 | python -m mup.test 256 | ``` 257 | 258 | 259 | ## The Basic Math 260 | 261 | μP is designed so as to satisfy the following desiderata: 262 | 263 | > At any time during training 264 | > 1. Every (pre)activation vector in a network should have Θ(1)-sized coordinates 265 | > 2. Neural network output should be O(1). 266 | > 3. All parameters should be updated as much as possible (in terms of scaling in width) without leading to divergence 267 | 268 | It turns out these desiderata uniquely single out μP. 269 | To derive μP from them, one needs to carefully consider how the *coordinate size* of a vector Av, resulting from a square matrix A multiplying vector v, depends on those of A and v, when A and v are "correlated". 270 | Here you can think of A as weights and v as an activation vector. 271 | This in turn depends on what kind of matrix is A and what kind of vector is v. 272 | In the context of training a wide neural network, it turns out we only need to consider vectors that has approximately iid coordinates, and two kinds of matrices: 1) those that look like outer products of such vectors, and 2) random iid matrices. 273 | Those of type 1 cover things like weight gradients; those of type 2 cover things like weight initialization. 274 | Then, if A and v both have entry size Θ(1) and they are correlated in ways that arise naturally during training, then we have the following table. 275 | 276 | | | outer product A (type 1) | iid A (type 2) | 277 | |------------------|--------------------------|--------------------| 278 | | Entry size of Av | Θ(n) | Θ(sqrt(n)) | 279 | 280 | Given this table, one can then trace the forward and backward computation of a network to derive μP straightforwardly. 281 | 282 | See [our blog post](https://www.microsoft.com/en-us/research/blog/%C2%B5transfer-a-technique-for-hyperparameter-tuning-of-enormous-neural-networks/) for a gentle primer and [our paper](https://arxiv.org/abs/2203.03466) for details. 283 | 284 | 285 | ## Contributing 286 | 287 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 288 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 289 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 290 | 291 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 292 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 293 | provided by the bot. You will only need to do this once across all repos using our CLA. 294 | 295 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 296 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 297 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 298 | 299 | ## Trademarks 300 | 301 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 302 | trademarks or logos is subject to and must follow 303 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 304 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 305 | Any use of third-party trademarks or logos are subject to those third-party's policies. 306 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please use Github Discussions in this repo. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this project is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /coord_checks/sp_cnn_adam_lr0.001_nseeds5_bn0_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/sp_cnn_adam_lr0.001_nseeds5_bn0_coord.png -------------------------------------------------------------------------------- /coord_checks/sp_cnn_adam_lr0.001_nseeds5_bn1_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/sp_cnn_adam_lr0.001_nseeds5_bn1_coord.png -------------------------------------------------------------------------------- /coord_checks/sp_cnn_sgd_lr0.1_nseeds5_bn0_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/sp_cnn_sgd_lr0.1_nseeds5_bn0_coord.png -------------------------------------------------------------------------------- /coord_checks/sp_cnn_sgd_lr0.1_nseeds5_bn1_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/sp_cnn_sgd_lr0.1_nseeds5_bn1_coord.png -------------------------------------------------------------------------------- /coord_checks/sp_mlp_adam_lr0.001_nseeds5_bn0_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/sp_mlp_adam_lr0.001_nseeds5_bn0_coord.png -------------------------------------------------------------------------------- /coord_checks/sp_mlp_adam_lr0.001_nseeds5_bn1_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/sp_mlp_adam_lr0.001_nseeds5_bn1_coord.png -------------------------------------------------------------------------------- /coord_checks/sp_mlp_sgd_lr0.1_nseeds5_bn0_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/sp_mlp_sgd_lr0.1_nseeds5_bn0_coord.png -------------------------------------------------------------------------------- /coord_checks/sp_mlp_sgd_lr0.1_nseeds5_bn1_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/sp_mlp_sgd_lr0.1_nseeds5_bn1_coord.png -------------------------------------------------------------------------------- /coord_checks/μp_cnn_adam_lr0.001_nseeds5_bn0_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/μp_cnn_adam_lr0.001_nseeds5_bn0_coord.png -------------------------------------------------------------------------------- /coord_checks/μp_cnn_adam_lr0.001_nseeds5_bn1_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/μp_cnn_adam_lr0.001_nseeds5_bn1_coord.png -------------------------------------------------------------------------------- /coord_checks/μp_cnn_sgd_lr0.1_nseeds5_bn0_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/μp_cnn_sgd_lr0.1_nseeds5_bn0_coord.png -------------------------------------------------------------------------------- /coord_checks/μp_cnn_sgd_lr0.1_nseeds5_bn1_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/μp_cnn_sgd_lr0.1_nseeds5_bn1_coord.png -------------------------------------------------------------------------------- /coord_checks/μp_mlp_adam_lr0.001_nseeds5_bn0_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/μp_mlp_adam_lr0.001_nseeds5_bn0_coord.png -------------------------------------------------------------------------------- /coord_checks/μp_mlp_adam_lr0.001_nseeds5_bn1_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/μp_mlp_adam_lr0.001_nseeds5_bn1_coord.png -------------------------------------------------------------------------------- /coord_checks/μp_mlp_sgd_lr0.1_nseeds5_bn0_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/μp_mlp_sgd_lr0.1_nseeds5_bn0_coord.png -------------------------------------------------------------------------------- /coord_checks/μp_mlp_sgd_lr0.1_nseeds5_bn1_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/coord_checks/μp_mlp_sgd_lr0.1_nseeds5_bn1_coord.png -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | dataset/* -------------------------------------------------------------------------------- /examples/MLP/README.md: -------------------------------------------------------------------------------- 1 | # μP MLP 2 | This folder contains the source code for our experiment on MLP, which also serves as an example usage of `mup`. 3 | The script trains a series of MLPs with increasing hidden sizes from 64 to 8192. 4 | 5 | ## Save Model Base Shapes 6 | To train a μP model, one needs to first specify the base shapes. To save base shapes info of the narrowest model, run, 7 | ``` 8 | python main.py --save_base_shapes width64.bsh 9 | ``` 10 | 11 | ## Verify Implementation with Coordinate Check 12 | Before we scale up and start training, it is recommended to check the size of activation coordinates as model width increases. We have integrated such a test in this example using the helper functions in `mup`; you can simply run: 13 | 14 | ```bash 15 | python main.py --load_base_shapes width64.bsh --coord_check 16 | ``` 17 | You should find the generated plots under `./coord_checks`, which show stable coordinate sizes under μP, e.g., 18 | 19 | ![](coord_checks/μp_mlp_sgd_coord.png) 20 | 21 | and growing sizes under SP, e.g., 22 | 23 | ![](coord_checks/sp_mlp_sgd_coord.png) 24 | 25 | 26 | ## Start Training 27 | Having verified our implementation of μP, we can scale up our model and train using the same hyperparameters used for the small model and expect that the wider model performs better on the training data and that the optimal hyperparameters transfer. 28 | ``` 29 | python main.py --load_base_shapes width64.bsh 30 | ``` 31 | 32 | Note that if you do not specify `--load_base_shapes`, the script will default to training a SP model. -------------------------------------------------------------------------------- /examples/MLP/coord_checks/sp_mlp_sgd_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/MLP/coord_checks/sp_mlp_sgd_coord.png -------------------------------------------------------------------------------- /examples/MLP/coord_checks/μp_mlp_sgd_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/MLP/coord_checks/μp_mlp_sgd_coord.png -------------------------------------------------------------------------------- /examples/MLP/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import pandas as pd 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torchvision import datasets, transforms 7 | import torch 8 | from torch import nn 9 | import torch.optim as optim 10 | import argparse 11 | import math 12 | 13 | from mup.coord_check import get_coord_data, plot_coord_data 14 | from mup import MuSGD, get_shapes, set_base_shapes, make_base_shapes, MuReadout 15 | 16 | def coord_check(mup, lr, train_loader, nsteps, nseeds, args, plotdir='', legend=False): 17 | 18 | def gen(w, standparam=False): 19 | def f(): 20 | model = MLP(width=w, nonlin=torch.tanh, output_mult=args.output_mult, input_mult=args.input_mult).to(device) 21 | if standparam: 22 | set_base_shapes(model, None) 23 | else: 24 | assert args.load_base_shapes, 'load_base_shapes needs to be nonempty' 25 | set_base_shapes(model, args.load_base_shapes) 26 | return model 27 | return f 28 | 29 | widths = 2**np.arange(7, 14) 30 | models = {w: gen(w, standparam=not mup) for w in widths} 31 | 32 | df = get_coord_data(models, train_loader, mup=mup, lr=lr, optimizer='sgd', flatten_input=True, nseeds=nseeds, nsteps=nsteps, lossfn='nll') 33 | 34 | prm = 'μP' if mup else 'SP' 35 | return plot_coord_data(df, legend=legend, 36 | save_to=os.path.join(plotdir, f'{prm.lower()}_mlp_sgd_coord.png'), 37 | suptitle=f'{prm} MLP SGD lr={lr} nseeds={nseeds}', 38 | face_color='xkcd:light grey' if not mup else None) 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser(description=''' 43 | PyTorch MLP on CIFAR-10, with μP. 44 | 45 | This is the script we use in the MLP experiment in our paper. 46 | 47 | To train a μP model, one needs to first specify the base shapes. To save base shapes info, run, for example, 48 | 49 | python main.py --save_base_shapes width64.bsh 50 | 51 | To train using MuSGD, run 52 | 53 | python main.py --load_base_shapes width64.bsh 54 | 55 | To perform coord check, run 56 | 57 | python main.py --load_base_shapes width64.bsh --coord_check 58 | 59 | If you don't specify a base shape file, then you are using standard parametrization 60 | 61 | python main.py 62 | 63 | We provide below some optimal hyperparameters for different activation/loss function combos: 64 | if nonlin == torch.relu and criterion == F.cross_entropy: 65 | args.input_mult = 0.00390625 66 | args.output_mult = 32 67 | elif nonlin == torch.tanh and criterion == F.cross_entropy: 68 | args.input_mult = 0.125 69 | args.output_mult = 32 70 | elif nonlin == torch.relu and criterion == MSE_label: 71 | args.input_mult = 0.03125 72 | args.output_mult = 32 73 | elif nonlin == torch.tanh and criterion == MSE_label: 74 | args.input_mult = 8 75 | args.output_mult = 0.125 76 | ''', formatter_class=argparse.RawTextHelpFormatter) 77 | parser.add_argument('--save_base_shapes', type=str, default='', 78 | help='file location to save base shapes at') 79 | parser.add_argument('--load_base_shapes', type=str, default='', 80 | help='file location to load base shapes from') 81 | parser.add_argument('--seed', type=int, default=1) 82 | parser.add_argument('--batch_size', type=int, default=64) 83 | parser.add_argument('--epochs', type=int, default=20) 84 | parser.add_argument('--momentum', type=float, default=0.9) 85 | parser.add_argument('--lr', type=float, default=0.1) 86 | parser.add_argument('--output_mult', type=float, default=1.0) 87 | parser.add_argument('--input_mult', type=float, default=1.0) 88 | parser.add_argument('--init_std', type=float, default=1.0) 89 | parser.add_argument('--no_shuffle', action='store_true') 90 | parser.add_argument('--log_interval', type=int, default=300) 91 | parser.add_argument('--log_dir', type=str, default='.') 92 | parser.add_argument('--data_dir', type=str, default='/tmp') 93 | parser.add_argument('--coord_check', action='store_true', 94 | help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.') 95 | parser.add_argument('--coord_check_nsteps', type=int, default=3, 96 | help='Do coord check with this many steps.') 97 | parser.add_argument('--coord_check_nseeds', type=int, default=5, 98 | help='number of seeds for testing correctness of μ parametrization') 99 | parser.add_argument('--deferred_init', action='store_true', help='Skip instantiating the base and delta models for mup. Requires torchdistx.') 100 | 101 | args = parser.parse_args() 102 | 103 | torch.manual_seed(args.seed) 104 | 105 | device = torch.device("cuda") 106 | 107 | kwargs = {'num_workers': 1, 'pin_memory': True} 108 | 109 | transform = transforms.Compose( 110 | [transforms.ToTensor(), 111 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 112 | 113 | trainset = datasets.CIFAR10(root=args.data_dir, train=True, 114 | download=True, transform=transform) 115 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 116 | shuffle=not args.no_shuffle, num_workers=2) 117 | 118 | testset = datasets.CIFAR10(root=args.data_dir, train=False, 119 | download=True, transform=transform) 120 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, 121 | shuffle=False, num_workers=2) 122 | 123 | classes = ('plane', 'car', 'bird', 'cat', 124 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 125 | 126 | 127 | class MLP(nn.Module): 128 | def __init__(self, width=128, num_classes=10, nonlin=F.relu, output_mult=1.0, input_mult=1.0): 129 | super(MLP, self).__init__() 130 | self.nonlin = nonlin 131 | self.input_mult = input_mult 132 | self.output_mult = output_mult 133 | self.fc_1 = nn.Linear(3072, width, bias=False) 134 | self.fc_2 = nn.Linear(width, width, bias=False) 135 | self.fc_3 = MuReadout(width, num_classes, bias=False, output_mult=args.output_mult) 136 | self.reset_parameters() 137 | 138 | def reset_parameters(self): 139 | nn.init.kaiming_normal_(self.fc_1.weight, a=1, mode='fan_in') 140 | self.fc_1.weight.data /= self.input_mult**0.5 141 | self.fc_1.weight.data *= args.init_std 142 | nn.init.kaiming_normal_(self.fc_2.weight, a=1, mode='fan_in') 143 | self.fc_2.weight.data *= args.init_std 144 | nn.init.zeros_(self.fc_3.weight) 145 | 146 | def forward(self, x): 147 | out = self.nonlin(self.fc_1(x) * self.input_mult**0.5) 148 | out = self.nonlin(self.fc_2(out)) 149 | return self.fc_3(out) 150 | 151 | 152 | def train(args, model, device, train_loader, optimizer, epoch, 153 | scheduler=None, criterion=F.cross_entropy): 154 | model.train() 155 | train_loss = 0 156 | correct = 0 157 | start_time = time.time() 158 | for batch_idx, (data, target) in enumerate(train_loader): 159 | data, target = data.to(device), target.to(device) 160 | optimizer.zero_grad() 161 | output = model(data.view(data.size(0), -1)) 162 | 163 | loss = criterion(output, target) 164 | loss.backward() 165 | train_loss += loss.item() * data.shape[0] # sum up batch loss 166 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 167 | correct += pred.eq(target.view_as(pred)).sum().item() 168 | optimizer.step() 169 | if batch_idx % args.log_interval == 0: 170 | elapsed = time.time() - start_time 171 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} | ms/batch {:5.2f}'.format( 172 | epoch, batch_idx * len(data), len(train_loader.dataset), 173 | 100. * batch_idx / len(train_loader), loss.item(), 174 | elapsed * 1000 / args.log_interval)) 175 | start_time = time.time() 176 | if scheduler is not None: 177 | scheduler.step() 178 | train_loss /= len(train_loader.dataset) 179 | train_acc = correct / len(train_loader.dataset) 180 | print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 181 | train_loss, correct, len(train_loader.dataset), 182 | 100. * correct / len(train_loader.dataset))) 183 | return train_loss, train_acc 184 | 185 | def test(args, model, device, test_loader, 186 | evalmode=True, criterion=F.cross_entropy): 187 | if evalmode: 188 | model.eval() 189 | test_loss = 0 190 | correct = 0 191 | with torch.no_grad(): 192 | for data, target in test_loader: 193 | data, target = data.to(device), target.to(device) 194 | output = model(data.view(data.size(0), -1)) 195 | test_loss += criterion(output, target, reduction='sum').item() # sum up batch loss 196 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 197 | correct += pred.eq(target.view_as(pred)).sum().item() 198 | 199 | test_loss /= len(test_loader.dataset) 200 | 201 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 202 | test_loss, correct, len(test_loader.dataset), 203 | 100. * correct / len(test_loader.dataset))) 204 | return test_loss, correct / len(test_loader.dataset) 205 | 206 | 207 | def MSE_label(output, target): 208 | y_onehot = output.new_zeros(output.size(0), 10) 209 | y_onehot.scatter_(1, target.unsqueeze(-1), 1) 210 | y_onehot -= 1/10 211 | return F.mse_loss(output, y_onehot) 212 | 213 | if args.coord_check: 214 | print('testing parametrization') 215 | import os 216 | os.makedirs('coord_checks', exist_ok=True) 217 | plotdir = 'coord_checks' 218 | coord_check(mup=True, lr=args.lr, train_loader=train_loader, nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, args=args, plotdir=plotdir, legend=False) 219 | coord_check(mup=False, lr=args.lr, train_loader=train_loader, nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, args=args, plotdir=plotdir, legend=False) 220 | import sys; sys.exit() 221 | 222 | logs = [] 223 | for nonlin in [torch.relu, torch.tanh]: 224 | for criterion in [F.cross_entropy, MSE_label]: 225 | 226 | for width in [64, 128, 256, 512, 1024, 2048, 4096, 8192]: 227 | # print(f'{nonlin.__name__}_{criterion.__name__}_{str(width)}') 228 | if args.save_base_shapes: 229 | print(f'saving base shapes at {args.save_base_shapes}') 230 | if args.deferred_init: 231 | from torchdistx.deferred_init import deferred_init 232 | # We don't need to instantiate the base and delta models 233 | # Note: this only works with torch nightly since unsqueeze isn't supported for fake tensors in stable 234 | base_shapes = get_shapes(deferred_init(MLP, width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult)) 235 | delta_shapes = get_shapes( 236 | # just need to change whatever dimension(s) we are scaling 237 | deferred_init(MLP, width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult) 238 | ) 239 | else: 240 | base_shapes = get_shapes(MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult)) 241 | delta_shapes = get_shapes( 242 | # just need to change whatever dimension(s) we are scaling 243 | MLP(width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult) 244 | ) 245 | make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes) 246 | print('done and exit') 247 | import sys; sys.exit() 248 | mynet = MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult).to(device) 249 | if args.load_base_shapes: 250 | print(f'loading base shapes from {args.load_base_shapes}') 251 | set_base_shapes(mynet, args.load_base_shapes) 252 | print('done') 253 | else: 254 | print(f'using own shapes') 255 | set_base_shapes(mynet, None) 256 | print('done') 257 | optimizer = MuSGD(mynet.parameters(), lr=args.lr, momentum=args.momentum) 258 | for epoch in range(1, args.epochs+1): 259 | train_loss, train_acc, = train(args, mynet, device, train_loader, optimizer, epoch, criterion=criterion) 260 | test_loss, test_acc = test(args, mynet, device, test_loader) 261 | logs.append(dict( 262 | epoch=epoch, 263 | train_loss=train_loss, 264 | train_acc=train_acc, 265 | test_loss=test_loss, 266 | test_acc=test_acc, 267 | width=width, 268 | nonlin=nonlin.__name__, 269 | criterion='xent' if criterion.__name__=='cross_entropy' else 'mse', 270 | )) 271 | if math.isnan(train_loss): 272 | break 273 | 274 | with open(os.path.join(os.path.expanduser(args.log_dir), 'logs.tsv'), 'w') as f: 275 | logdf = pd.DataFrame(logs) 276 | print(os.path.join(os.path.expanduser(args.log_dir), 'logs.tsv')) 277 | f.write(logdf.to_csv(sep='\t', float_format='%.4f')) 278 | -------------------------------------------------------------------------------- /examples/MLP/width64.bsh: -------------------------------------------------------------------------------- 1 | # This is a base shape file encoded in yaml 2 | # - `null` indicates a dimension is "finite", i.e. a non-"width" dimension 3 | # - a number indicates the base dimension of an "infinite" dimension, i.e. some notion of "width" 4 | fc_1.weight: 5 | - 64 6 | - null 7 | fc_2.weight: 8 | - 64 9 | - 64 10 | fc_3.weight: 11 | - null 12 | - 64 13 | -------------------------------------------------------------------------------- /examples/ResNet/README.md: -------------------------------------------------------------------------------- 1 | # μP ResNet 2 | This folder contains the source code for our experiment on ResNet on CIFAR10, which also serves as an example usage of `mup`. 3 | 4 | ## Save Model Base Shapes 5 | To train a μP model, one needs to first specify the base shapes. To save base shapes info, run, for example, 6 | ``` 7 | python main.py --save_base_shapes resnet18.bsh --width_mult 1 8 | ``` 9 | 10 | ## Verify Implementation with Coordinate Check 11 | Before we scale up and start training, it is recommended to check the size of activation coordinates as model width increases. We have integrated such a test in this example using the helper functions in `mup`; you can simply run: 12 | 13 | ```bash 14 | # for SGD 15 | python main.py --load_base_shapes resnet18.bsh --optimizer sgd --lr 0.1 --coord_check 16 | # for Adam 17 | python main.py --load_base_shapes resnet18.bsh --optimizer adam --lr 0.001 --coord_check 18 | ``` 19 | You should find the generated plots under `./coord_checks`, which show stable coordinate sizes under μP, e.g., 20 | 21 | ![](coord_checks/μp_resnet18_adam_coord.png) 22 | 23 | and growing sizes under SP, e.g., 24 | 25 | ![](coord_checks/sp_resnet18_adam_coord.png) 26 | 27 | 28 | ## Start Training 29 | Having verified our implementation of μP, we can scale up our model and train using the same hyperparameters used for the small model and expect that the wider model performs better on the training data and that the optimal hyperparameters transfer. 30 | ```bash 31 | # for SGD 32 | python main.py --width_mult 2 --optimizer musgd 33 | # for Adam 34 | python main.py --width_mult 2 --optimizer muadam 35 | ``` 36 | 37 | Note that if you do not specify `--load_base_shapes`, the script will default to training a SP model. -------------------------------------------------------------------------------- /examples/ResNet/coord_checks/sp_resnet18_adam_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/ResNet/coord_checks/sp_resnet18_adam_coord.png -------------------------------------------------------------------------------- /examples/ResNet/coord_checks/sp_resnet18_sgd_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/ResNet/coord_checks/sp_resnet18_sgd_coord.png -------------------------------------------------------------------------------- /examples/ResNet/coord_checks/μp_resnet18_adam_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/ResNet/coord_checks/μp_resnet18_adam_coord.png -------------------------------------------------------------------------------- /examples/ResNet/coord_checks/μp_resnet18_sgd_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/ResNet/coord_checks/μp_resnet18_sgd_coord.png -------------------------------------------------------------------------------- /examples/ResNet/main.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | import argparse 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | from mup.coord_check import get_coord_data, plot_coord_data 12 | from mup import MuAdam, MuSGD, get_shapes, make_base_shapes, set_base_shapes 13 | 14 | import resnet 15 | 16 | 17 | def coord_check(mup, lr, optimizer, nsteps, arch, base_shapes, nseeds, device='cuda', plotdir='', legend=False): 18 | 19 | optimizer = optimizer.replace('mu', '') 20 | 21 | def gen(w, standparam=False): 22 | def f(): 23 | model = getattr(resnet, arch)(wm=w).to(device) 24 | if standparam: 25 | set_base_shapes(model, None) 26 | else: 27 | set_base_shapes(model, base_shapes) 28 | return model 29 | return f 30 | 31 | transform_train = transforms.Compose([ 32 | transforms.RandomCrop(32, padding=4), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 36 | ]) 37 | trainset = torchvision.datasets.CIFAR10( 38 | root='../dataset', train=True, download=True, transform=transform_train) 39 | dataloader = torch.utils.data.DataLoader( 40 | trainset, batch_size=1, shuffle=False) 41 | 42 | widths = 2**np.arange(-2., 2) 43 | models = {w: gen(w, standparam=not mup) for w in widths} 44 | df = get_coord_data(models, dataloader, mup=mup, lr=lr, optimizer=optimizer, nseeds=nseeds, nsteps=nsteps) 45 | 46 | prm = 'μP' if mup else 'SP' 47 | plot_coord_data(df, legend=legend, 48 | save_to=os.path.join(plotdir, f'{prm.lower()}_{arch}_{optimizer}_coord.png'), 49 | suptitle=f'{prm} {arch} {optimizer} lr={lr} nseeds={nseeds}', 50 | face_color='xkcd:light grey' if not mup else None) 51 | 52 | 53 | # Training 54 | def train(epoch, net): 55 | from utils import progress_bar 56 | print('\nEpoch: %d' % epoch) 57 | net.train() 58 | train_loss = 0 59 | correct = 0 60 | total = 0 61 | for batch_idx, (inputs, targets) in enumerate(trainloader): 62 | inputs, targets = inputs.to(device), targets.to(device) 63 | optimizer.zero_grad() 64 | outputs = net(inputs) 65 | loss = criterion(outputs, targets) 66 | loss.backward() 67 | optimizer.step() 68 | 69 | train_loss += loss.item() 70 | _, predicted = outputs.max(1) 71 | total += targets.size(0) 72 | correct += predicted.eq(targets).sum().item() 73 | 74 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 75 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 76 | 77 | 78 | def test(epoch, net): 79 | from utils import progress_bar 80 | global best_acc 81 | net.eval() 82 | test_loss = 0 83 | correct = 0 84 | total = 0 85 | with torch.no_grad(): 86 | for batch_idx, (inputs, targets) in enumerate(testloader): 87 | inputs, targets = inputs.to(device), targets.to(device) 88 | outputs = net(inputs) 89 | loss = criterion(outputs, targets) 90 | 91 | test_loss += loss.item() 92 | _, predicted = outputs.max(1) 93 | total += targets.size(0) 94 | correct += predicted.eq(targets).sum().item() 95 | 96 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 97 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 98 | 99 | # Save checkpoint. 100 | acc = 100.*correct/total 101 | if acc > best_acc: 102 | print('Saving..') 103 | state = { 104 | 'net': net.state_dict(), 105 | 'acc': acc, 106 | 'epoch': epoch, 107 | } 108 | if not os.path.isdir('checkpoint'): 109 | os.mkdir('checkpoint') 110 | torch.save(state, './checkpoint/ckpt.pth') 111 | best_acc = acc 112 | 113 | 114 | 115 | if __name__ == '__main__': 116 | 117 | parser = argparse.ArgumentParser(description='' 118 | ''' 119 | PyTorch CIFAR10 Training, with μP. 120 | 121 | To save base shapes info, run e.g. 122 | 123 | python main.py --save_base_shapes resnet18.bsh --width_mult 1 124 | 125 | To train using MuAdam (or MuSGD), run 126 | 127 | python main.py --width_mult 2 --load_base_shapes resnet18.bsh --optimizer {muadam,musgd} 128 | 129 | To test coords, run 130 | 131 | python main.py --load_base_shapes resnet18.bsh --optimizer sgd --lr 0.1 --coord_check 132 | 133 | python main.py --load_base_shapes resnet18.bsh --optimizer adam --lr 0.001 --coord_check 134 | 135 | If you don't specify a base shape file, then you are using standard parametrization, e.g. 136 | 137 | python main.py --width_mult 2 --optimizer {muadam,musgd} 138 | 139 | Here muadam (resp. musgd) would have the same result as adam (resp. sgd). 140 | 141 | Note that models of different depths need separate `.bsh` files. 142 | ''', formatter_class=argparse.RawTextHelpFormatter) 143 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 144 | parser.add_argument('--resume', '-r', action='store_true', 145 | help='resume from checkpoint') 146 | parser.add_argument('--arch', type=str, default='resnet18') 147 | parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'adam', 'musgd', 'muadam']) 148 | parser.add_argument('--epochs', type=int, default=150) 149 | parser.add_argument('--width_mult', type=float, default=1) 150 | parser.add_argument('--save_base_shapes', type=str, default='', 151 | help='file location to save base shapes at') 152 | parser.add_argument('--load_base_shapes', type=str, default='', 153 | help='file location to load base shapes from') 154 | parser.add_argument('--batch_size', type=int, default=128) 155 | parser.add_argument('--test_batch_size', type=int, default=128) 156 | parser.add_argument('--weight_decay', type=float, default=5e-4) 157 | parser.add_argument('--num_workers', type=int, default=2) 158 | parser.add_argument('--test_num_workers', type=int, default=2) 159 | parser.add_argument('--momentum', type=float, default=0.9) 160 | parser.add_argument('--coord_check', action='store_true', 161 | help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.') 162 | parser.add_argument('--coord_check_nsteps', type=int, default=3, 163 | help='Do coord check with this many steps.') 164 | parser.add_argument('--coord_check_nseeds', type=int, default=1, 165 | help='number of seeds for coord check') 166 | parser.add_argument('--seed', type=int, default=1111, 167 | help='random seed') 168 | args = parser.parse_args() 169 | 170 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 171 | best_acc = 0 # best test accuracy 172 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 173 | 174 | # Set the random seed manually for reproducibility. 175 | torch.manual_seed(args.seed) 176 | 177 | # Data 178 | if not args.save_base_shapes: 179 | print('==> Preparing data..') 180 | transform_train = transforms.Compose([ 181 | transforms.RandomCrop(32, padding=4), 182 | transforms.RandomHorizontalFlip(), 183 | transforms.ToTensor(), 184 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 185 | ]) 186 | 187 | transform_test = transforms.Compose([ 188 | transforms.ToTensor(), 189 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 190 | ]) 191 | 192 | trainset = torchvision.datasets.CIFAR10( 193 | root='../dataset', train=True, download=True, transform=transform_train) 194 | trainloader = torch.utils.data.DataLoader( 195 | trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 196 | 197 | testset = torchvision.datasets.CIFAR10( 198 | root='../dataset', train=False, download=True, transform=transform_test) 199 | testloader = torch.utils.data.DataLoader( 200 | testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers) 201 | 202 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 203 | 'dog', 'frog', 'horse', 'ship', 'truck') 204 | 205 | if args.coord_check: 206 | print('testing parametrization') 207 | import os 208 | os.makedirs('coord_checks', exist_ok=True) 209 | plotdir = 'coord_checks' 210 | coord_check(mup=True, 211 | lr=args.lr, optimizer=args.optimizer, nsteps=args.coord_check_nsteps, arch=args.arch, base_shapes=args.load_base_shapes, nseeds=args.coord_check_nseeds, device=device, plotdir=plotdir, legend=False) 212 | coord_check(mup=False, 213 | lr=args.lr, optimizer=args.optimizer, nsteps=args.coord_check_nsteps, arch=args.arch, base_shapes=args.load_base_shapes, nseeds=args.coord_check_nseeds, device=device,plotdir=plotdir, legend=False) 214 | import sys; sys.exit() 215 | 216 | 217 | # Model 218 | print('==> Building model..') 219 | net = getattr(resnet, args.arch)(wm=args.width_mult) 220 | if args.save_base_shapes: 221 | print(f'saving base shapes at {args.save_base_shapes}') 222 | base_shapes = get_shapes(net) 223 | delta_shapes = get_shapes(getattr(resnet, args.arch)(wm=args.width_mult/2)) 224 | make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes) 225 | # save_shapes(net, args.save_base_shapes) 226 | print('done and exit') 227 | import sys; sys.exit() 228 | 229 | net = net.to(device) 230 | 231 | if args.load_base_shapes: 232 | print(f'loading base shapes from {args.load_base_shapes}') 233 | set_base_shapes(net, args.load_base_shapes) 234 | print('done') 235 | else: 236 | print(f'using standard parametrization') 237 | set_base_shapes(net, None) 238 | print('done') 239 | 240 | if args.resume: 241 | # Load checkpoint. 242 | print('==> Resuming from checkpoint..') 243 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 244 | checkpoint = torch.load('./checkpoint/ckpt.pth') 245 | net.load_state_dict(checkpoint['net']) 246 | best_acc = checkpoint['acc'] 247 | start_epoch = checkpoint['epoch'] 248 | 249 | criterion = nn.CrossEntropyLoss() 250 | if args.optimizer == 'musgd': 251 | optimizer = MuSGD(net.parameters(), lr=args.lr, 252 | momentum=args.momentum, 253 | weight_decay=args.weight_decay) 254 | elif args.optimizer == 'muadam': 255 | optimizer = MuAdam(net.parameters(), lr=args.lr) 256 | elif args.optimizer == 'sgd': 257 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 258 | elif args.optimizer == 'adam': 259 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 260 | else: 261 | raise ValueError() 262 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 263 | 264 | 265 | for epoch in range(start_epoch, start_epoch+args.epochs): 266 | train(epoch, net) 267 | test(epoch, net) 268 | scheduler.step() -------------------------------------------------------------------------------- /examples/ResNet/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | Reference: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | ''' 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | 10 | from mup import MuReadout 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 21 | padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, 28 | stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion*planes)) 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self) -> None: 34 | layers = [self.conv1, self.conv2] 35 | if len(self.shortcut) > 1: 36 | layers.append(self.shortcut[0]) 37 | for layer in layers: 38 | init.kaiming_normal_(layer.weight, a=1) 39 | if layer.bias is not None: 40 | init.zeros_(layer.bias) 41 | 42 | def forward(self, x): 43 | out = F.relu(self.bn1(self.conv1(x))) 44 | out = self.bn2(self.conv2(out)) 45 | out += self.shortcut(x) 46 | return F.relu(out) 47 | 48 | 49 | class Bottleneck(nn.Module): 50 | expansion = 4 51 | 52 | def __init__(self, in_planes, planes, stride=1): 53 | super(Bottleneck, self).__init__() 54 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 57 | padding=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 61 | 62 | self.shortcut = nn.Sequential() 63 | if stride != 1 or in_planes != self.expansion*planes: 64 | self.shortcut = nn.Sequential( 65 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 66 | nn.BatchNorm2d(self.expansion*planes) 67 | ) 68 | 69 | self.reset_parameters() 70 | 71 | def reset_parameters(self) -> None: 72 | layers = [self.conv1, self.conv2, self.conv3] 73 | if len(self.shortcut) > 1: 74 | layers.append(self.shortcut[0]) 75 | for layer in layers: 76 | init.kaiming_normal_(layer.weight, a=1) 77 | if layer.bias is not None: 78 | init.zeros_(layer.bias) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out = F.relu(self.bn2(self.conv2(out))) 83 | out = self.bn3(self.conv3(out)) 84 | out += self.shortcut(x) 85 | return F.relu(out) 86 | 87 | 88 | class ResNet(nn.Module): 89 | # feat_scale lets us deal with CelebA, other non-32x32 datasets 90 | def __init__(self, block, num_blocks, num_classes=10, feat_scale=1, wm=1): 91 | super(ResNet, self).__init__() 92 | 93 | base_widths = [64, 128, 256, 512] 94 | widths = [int(w * wm) for w in base_widths] 95 | 96 | self.in_planes = widths[0] 97 | self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, 98 | padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(self.in_planes) 100 | self.layer1 = self._make_layer(block, widths[0], num_blocks[0], stride=1) 101 | self.layer2 = self._make_layer(block, widths[1], num_blocks[1], stride=2) 102 | self.layer3 = self._make_layer(block, widths[2], num_blocks[2], stride=2) 103 | self.layer4 = self._make_layer(block, widths[3], num_blocks[3], stride=2) 104 | ### This is the only μP related change ### 105 | self.linear = MuReadout(feat_scale*widths[3]*block.expansion, num_classes, readout_zero_init=True) 106 | ########################################### 107 | 108 | def _make_layer(self, block, planes, num_blocks, stride): 109 | strides = [stride] + [1]*(num_blocks-1) 110 | layers = [] 111 | for stride in strides: 112 | layers.append(block(self.in_planes, planes, stride=stride)) 113 | self.in_planes = planes * block.expansion 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | out = F.relu(self.bn1(self.conv1(x))) 118 | out = self.layer1(out) 119 | out = self.layer2(out) 120 | out = self.layer3(out) 121 | out = self.layer4(out) 122 | out = F.avg_pool2d(out, 4) 123 | 124 | pre_out = out.view(out.size(0), -1) 125 | final = self.linear(pre_out) 126 | return final 127 | 128 | def ResNet18(**kwargs): 129 | return ResNet(BasicBlock, [2,2,2,2], **kwargs) 130 | 131 | def ResNet18Wide(**kwargs): 132 | return ResNet(BasicBlock, [2,2,2,2], wm=5, **kwargs) 133 | 134 | def ResNet18Thin(**kwargs): 135 | return ResNet(BasicBlock, [2,2,2,2], wm=.75, **kwargs) 136 | 137 | def ResNet34(**kwargs): 138 | return ResNet(BasicBlock, [3,4,6,3], **kwargs) 139 | 140 | def ResNet50(**kwargs): 141 | return ResNet(Bottleneck, [3,4,6,3], **kwargs) 142 | 143 | def ResNet101(**kwargs): 144 | return ResNet(Bottleneck, [3,4,23,3], **kwargs) 145 | 146 | def ResNet152(**kwargs): 147 | return ResNet(Bottleneck, [3,8,36,3], **kwargs) 148 | 149 | resnet50 = ResNet50 150 | resnet18 = ResNet18 151 | resnet101 = ResNet101 152 | resnet152 = ResNet152 153 | resnet18wide = ResNet18Wide -------------------------------------------------------------------------------- /examples/ResNet/resnet18.bsh: -------------------------------------------------------------------------------- 1 | # This is a base shape file encoded in yaml 2 | # - `null` indicates a dimension is "finite", i.e. a non-"width" dimension 3 | # - a number indicates the base dimension of an "infinite" dimension, i.e. some notion of "width" 4 | bn1.bias: 5 | - 64 6 | bn1.weight: 7 | - 64 8 | conv1.weight: 9 | - 64 10 | - null 11 | - null 12 | - null 13 | layer1.0.bn1.bias: 14 | - 64 15 | layer1.0.bn1.weight: 16 | - 64 17 | layer1.0.bn2.bias: 18 | - 64 19 | layer1.0.bn2.weight: 20 | - 64 21 | layer1.0.conv1.weight: 22 | - 64 23 | - 64 24 | - null 25 | - null 26 | layer1.0.conv2.weight: 27 | - 64 28 | - 64 29 | - null 30 | - null 31 | layer1.1.bn1.bias: 32 | - 64 33 | layer1.1.bn1.weight: 34 | - 64 35 | layer1.1.bn2.bias: 36 | - 64 37 | layer1.1.bn2.weight: 38 | - 64 39 | layer1.1.conv1.weight: 40 | - 64 41 | - 64 42 | - null 43 | - null 44 | layer1.1.conv2.weight: 45 | - 64 46 | - 64 47 | - null 48 | - null 49 | layer2.0.bn1.bias: 50 | - 128 51 | layer2.0.bn1.weight: 52 | - 128 53 | layer2.0.bn2.bias: 54 | - 128 55 | layer2.0.bn2.weight: 56 | - 128 57 | layer2.0.conv1.weight: 58 | - 128 59 | - 64 60 | - null 61 | - null 62 | layer2.0.conv2.weight: 63 | - 128 64 | - 128 65 | - null 66 | - null 67 | layer2.0.shortcut.0.weight: 68 | - 128 69 | - 64 70 | - null 71 | - null 72 | layer2.0.shortcut.1.bias: 73 | - 128 74 | layer2.0.shortcut.1.weight: 75 | - 128 76 | layer2.1.bn1.bias: 77 | - 128 78 | layer2.1.bn1.weight: 79 | - 128 80 | layer2.1.bn2.bias: 81 | - 128 82 | layer2.1.bn2.weight: 83 | - 128 84 | layer2.1.conv1.weight: 85 | - 128 86 | - 128 87 | - null 88 | - null 89 | layer2.1.conv2.weight: 90 | - 128 91 | - 128 92 | - null 93 | - null 94 | layer3.0.bn1.bias: 95 | - 256 96 | layer3.0.bn1.weight: 97 | - 256 98 | layer3.0.bn2.bias: 99 | - 256 100 | layer3.0.bn2.weight: 101 | - 256 102 | layer3.0.conv1.weight: 103 | - 256 104 | - 128 105 | - null 106 | - null 107 | layer3.0.conv2.weight: 108 | - 256 109 | - 256 110 | - null 111 | - null 112 | layer3.0.shortcut.0.weight: 113 | - 256 114 | - 128 115 | - null 116 | - null 117 | layer3.0.shortcut.1.bias: 118 | - 256 119 | layer3.0.shortcut.1.weight: 120 | - 256 121 | layer3.1.bn1.bias: 122 | - 256 123 | layer3.1.bn1.weight: 124 | - 256 125 | layer3.1.bn2.bias: 126 | - 256 127 | layer3.1.bn2.weight: 128 | - 256 129 | layer3.1.conv1.weight: 130 | - 256 131 | - 256 132 | - null 133 | - null 134 | layer3.1.conv2.weight: 135 | - 256 136 | - 256 137 | - null 138 | - null 139 | layer4.0.bn1.bias: 140 | - 512 141 | layer4.0.bn1.weight: 142 | - 512 143 | layer4.0.bn2.bias: 144 | - 512 145 | layer4.0.bn2.weight: 146 | - 512 147 | layer4.0.conv1.weight: 148 | - 512 149 | - 256 150 | - null 151 | - null 152 | layer4.0.conv2.weight: 153 | - 512 154 | - 512 155 | - null 156 | - null 157 | layer4.0.shortcut.0.weight: 158 | - 512 159 | - 256 160 | - null 161 | - null 162 | layer4.0.shortcut.1.bias: 163 | - 512 164 | layer4.0.shortcut.1.weight: 165 | - 512 166 | layer4.1.bn1.bias: 167 | - 512 168 | layer4.1.bn1.weight: 169 | - 512 170 | layer4.1.bn2.bias: 171 | - 512 172 | layer4.1.bn2.weight: 173 | - 512 174 | layer4.1.conv1.weight: 175 | - 512 176 | - 512 177 | - null 178 | - null 179 | layer4.1.conv2.weight: 180 | - 512 181 | - 512 182 | - null 183 | - null 184 | linear.bias: 185 | - null 186 | linear.weight: 187 | - null 188 | - 512 189 | -------------------------------------------------------------------------------- /examples/ResNet/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f -------------------------------------------------------------------------------- /examples/Transformer/README.md: -------------------------------------------------------------------------------- 1 | # μP Transformer 2 | This folder contains the source code for our experiment on small Transformers, which also serves as an example usage of `mup`. 3 | 4 | ## Save Model Base Shapes 5 | To train a μP model, one needs to first specify the base shapes. To save base shapes info, run, for example, 6 | ``` 7 | python main.py --d_model 256 --save_base_shapes width256.bsh 8 | ``` 9 | 10 | ## Verify Implementation with Coordinate Check 11 | Before we scale up and start training, it is recommended to check the size of activation coordinates as model width increases. We have integrated such a test in this example using the helper functions in `mup`; you can simply run: 12 | 13 | ```bash 14 | # for SGD 15 | python main.py --load_base_shapes width256.bsh --optimizer sgd --lr 0.5 --cuda --coord_check 16 | # for Adam 17 | python main.py --load_base_shapes width256.bsh --optimizer adam --lr 0.01 --cuda --coord_check 18 | ``` 19 | You should find the generated plots under `./coord_checks`, which show stable coordinate sizes under μP, e.g., 20 | 21 | ![](coord_checks/μp_trsfmr_adam_coord.png) 22 | 23 | and growing sizes under SP, e.g., 24 | 25 | ![](coord_checks/sp_trsfmr_adam_coord.png) 26 | 27 | 28 | ## Start Training 29 | Having verified our implementation of μP, we can scale up our model and train using the same hyperparameters used for the small model and expect that the wider model performs better on the training data and that the optimal hyperparameters transfer. 30 | ```bash 31 | # for SGD 32 | python main.py --d_model 4096 --load_base_shapes width256.bsh --optimizer musgd --lr 0.5 --cuda 33 | # for Adam 34 | python main.py --d_model 4096 --load_base_shapes width256.bsh --optimizer muadam --lr 0.01 --cuda 35 | ``` 36 | 37 | Note that if you do not specify `--load_base_shapes`, the script will default to training a SP model. 38 | -------------------------------------------------------------------------------- /examples/Transformer/coord_checks/sp_trsfmr_adam_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/Transformer/coord_checks/sp_trsfmr_adam_coord.png -------------------------------------------------------------------------------- /examples/Transformer/coord_checks/sp_trsfmr_sgd_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/Transformer/coord_checks/sp_trsfmr_sgd_coord.png -------------------------------------------------------------------------------- /examples/Transformer/coord_checks/μp_trsfmr_adam_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/Transformer/coord_checks/μp_trsfmr_adam_coord.png -------------------------------------------------------------------------------- /examples/Transformer/coord_checks/μp_trsfmr_sgd_coord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/Transformer/coord_checks/μp_trsfmr_sgd_coord.png -------------------------------------------------------------------------------- /examples/Transformer/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import open 3 | import torch 4 | 5 | class Dictionary(object): 6 | def __init__(self): 7 | self.word2idx = {} 8 | self.idx2word = [] 9 | 10 | def add_word(self, word): 11 | if word not in self.word2idx: 12 | self.idx2word.append(word) 13 | self.word2idx[word] = len(self.idx2word) - 1 14 | return self.word2idx[word] 15 | 16 | def __len__(self): 17 | return len(self.idx2word) 18 | 19 | 20 | class Corpus(object): 21 | def __init__(self, path): 22 | self.dictionary = Dictionary() 23 | self.train = None 24 | self.valid = None 25 | self.test = None 26 | if not self.load_cache(path): 27 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 28 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 29 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 30 | self.save_cache(path) 31 | 32 | def load_cache(self, path): 33 | for cache in ['dict.pt', 'train.pt', 'valid.pt', 'test.pt']: 34 | cache_path = os.path.join(path, cache) 35 | if not os.path.exists(cache_path): 36 | return False 37 | self.dictionary = torch.load(os.path.join(path, 'dict.pt')) 38 | self.train = torch.load(os.path.join(path, 'train.pt')) 39 | self.valid = torch.load(os.path.join(path, 'valid.pt')) 40 | self.test = torch.load(os.path.join(path, 'test.pt')) 41 | return True 42 | 43 | def save_cache(self, path): 44 | torch.save(self.dictionary, os.path.join(path, 'dict.pt')) 45 | torch.save(self.train, os.path.join(path, 'train.pt')) 46 | torch.save(self.valid, os.path.join(path, 'valid.pt')) 47 | torch.save(self.test, os.path.join(path, 'test.pt')) 48 | 49 | def tokenize(self, path): 50 | """Tokenizes a text file.""" 51 | assert os.path.exists(path) 52 | # Add words to the dictionary 53 | with open(path, 'r', encoding="utf8") as f: 54 | for line in f: 55 | words = line.split() + [''] 56 | for word in words: 57 | self.dictionary.add_word(word) 58 | 59 | # Tokenize file content 60 | with open(path, 'r', encoding="utf8") as f: 61 | idss = [] 62 | for line in f: 63 | words = line.split() + [''] 64 | ids = [] 65 | for word in words: 66 | ids.append(self.dictionary.word2idx[word]) 67 | idss.append(torch.tensor(ids).type(torch.int64)) 68 | ids = torch.cat(idss) 69 | 70 | return ids 71 | -------------------------------------------------------------------------------- /examples/Transformer/data/wikitext-2/README: -------------------------------------------------------------------------------- 1 | This is raw data from the wikitext-2 dataset. 2 | 3 | See https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ 4 | -------------------------------------------------------------------------------- /examples/Transformer/data/wikitext-2/dict.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/Transformer/data/wikitext-2/dict.pt -------------------------------------------------------------------------------- /examples/Transformer/data/wikitext-2/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/Transformer/data/wikitext-2/test.pt -------------------------------------------------------------------------------- /examples/Transformer/data/wikitext-2/train.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/Transformer/data/wikitext-2/train.pt -------------------------------------------------------------------------------- /examples/Transformer/data/wikitext-2/valid.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/examples/Transformer/data/wikitext-2/valid.pt -------------------------------------------------------------------------------- /examples/Transformer/generate.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Language Modeling on Wikitext-2 3 | # 4 | # This file generates new sentences sampled from the language model 5 | # 6 | ############################################################################### 7 | 8 | import argparse 9 | 10 | import torch 11 | 12 | import data 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Language Model') 15 | 16 | # Model parameters. 17 | parser.add_argument('--data', type=str, default='./data/wikitext-2', 18 | help='location of the data corpus') 19 | parser.add_argument('--checkpoint', type=str, default='./model.pt', 20 | help='model checkpoint to use') 21 | parser.add_argument('--outf', type=str, default='generated.txt', 22 | help='output file for generated text') 23 | parser.add_argument('--words', type=int, default='1000', 24 | help='number of words to generate') 25 | parser.add_argument('--seed', type=int, default=1111, 26 | help='random seed') 27 | parser.add_argument('--cuda', action='store_true', 28 | help='use CUDA') 29 | parser.add_argument('--temperature', type=float, default=1.0, 30 | help='temperature - higher will increase diversity') 31 | parser.add_argument('--log-interval', type=int, default=100, 32 | help='reporting interval') 33 | args = parser.parse_args() 34 | 35 | # Set the random seed manually for reproducibility. 36 | torch.manual_seed(args.seed) 37 | if torch.cuda.is_available(): 38 | if not args.cuda: 39 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 40 | 41 | device = torch.device("cuda" if args.cuda else "cpu") 42 | 43 | if args.temperature < 1e-3: 44 | parser.error("--temperature has to be greater or equal 1e-3") 45 | 46 | with open(args.checkpoint, 'rb') as f: 47 | model = torch.load(f).to(device) 48 | model.eval() 49 | 50 | corpus = data.Corpus(args.data) 51 | ntokens = len(corpus.dictionary) 52 | 53 | is_transformer_model = hasattr(model, 'model_type') and model.model_type == 'Transformer' 54 | if not is_transformer_model: 55 | hidden = model.init_hidden(1) 56 | input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) 57 | 58 | with open(args.outf, 'w') as outf: 59 | with torch.no_grad(): # no tracking history 60 | for i in range(args.words): 61 | if is_transformer_model: 62 | output = model(input, False) 63 | word_weights = output[-1].squeeze().div(args.temperature).exp().cpu() 64 | word_idx = torch.multinomial(word_weights, 1)[0] 65 | word_tensor = torch.Tensor([[word_idx]]).long().to(device) 66 | input = torch.cat([input, word_tensor], 0) 67 | else: 68 | output, hidden = model(input, hidden) 69 | word_weights = output.squeeze().div(args.temperature).exp().cpu() 70 | word_idx = torch.multinomial(word_weights, 1)[0] 71 | input.fill_(word_idx) 72 | 73 | word = corpus.dictionary.idx2word[word_idx] 74 | 75 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 76 | 77 | if i % args.log_interval == 0: 78 | print('| Generated {}/{} words'.format(i, args.words)) 79 | -------------------------------------------------------------------------------- /examples/Transformer/main.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | try: 12 | from apex import amp 13 | except: 14 | print('Failed to import apex. You can still train with --precision {float|double}.') 15 | 16 | from mup.coord_check import get_coord_data, plot_coord_data 17 | from mup import MuAdam, MuSGD, get_shapes, make_base_shapes, set_base_shapes 18 | 19 | import data 20 | import model as mdl 21 | 22 | 23 | ############################################################################### 24 | # Training code 25 | ############################################################################### 26 | 27 | # get_batch subdivides the source data into chunks of length args.bptt. 28 | # If source is equal to the example output of the batchify function, with 29 | # a bptt-limit of 2, we'd get the following two Variables for i = 0: 30 | # ┌ a g m s ┐ ┌ b h n t ┐ 31 | # └ b h n t ┘ └ c i o u ┘ 32 | # Note that despite the name of the function, the subdivison of data is not 33 | # done along the batch dimension (i.e. dimension 1), since that was handled 34 | # by the batchify function. The chunks are along dimension 0, corresponding 35 | # to the seq_len dimension in the LSTM. 36 | 37 | def get_batch(source, i, bptt): 38 | seq_len = min(bptt, len(source) - 1 - i) 39 | data = source[i:i+seq_len] 40 | target = source[i+1:i+1+seq_len].view(-1) 41 | return data, target 42 | 43 | def batchloader(train_data, bptt): 44 | for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): 45 | yield get_batch(train_data, i, bptt) 46 | 47 | def batchify(data, bsz, device): 48 | # Work out how cleanly we can divide the dataset into bsz parts. 49 | nbatch = data.size(0) // bsz 50 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 51 | data = data.narrow(0, 0, nbatch * bsz) 52 | # Evenly divide the data across the bsz batches. 53 | data = data.view(bsz, -1).t().contiguous() 54 | return data.to(device) 55 | 56 | def setprec(t, precision): 57 | if precision == 'half': 58 | # do nothing since this is handled by AMP 59 | return t 60 | elif precision == 'float': 61 | return t.float() 62 | elif precision == 'double': 63 | return t.double() 64 | else: 65 | raise ValueError(f'invalid precision string {args.precision}') 66 | 67 | def coord_check(mup, lr, optimizer, batch_size, nsteps, nseeds, data_dir, args, plotdir='', legend=False): 68 | 69 | corpus = data.Corpus(data_dir) 70 | ntokens = len(corpus.dictionary) 71 | 72 | def gen(w, standparam=False): 73 | import model as _model 74 | def f(): 75 | model = _model.TransformerModel(args, ntokens, ninp=w, nhead=args.nhead, nhid=w*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout, 76 | tied=args.tied, bias=args.bias, encoder_var=args.init_var, 77 | decoder_var=args.init_var, standparam=standparam).to(args.device) 78 | model = setprec(model, args.precision) 79 | if standparam: 80 | set_base_shapes(model, None) 81 | else: 82 | assert args.load_base_shapes, 'load_base_shapes needs to be nonempty' 83 | set_base_shapes(model, args.load_base_shapes) 84 | return model 85 | return f 86 | 87 | optimizer = optimizer.replace('mu', '') 88 | widths = 2**np.arange(7, 14 if optimizer=='sgd' else 12) 89 | models = {w: gen(w, standparam=not mup) for w in widths} 90 | 91 | 92 | train_data = batchify(corpus.train, batch_size, device=args.device) 93 | df = get_coord_data(models, batchloader(train_data, args.bptt), mup=mup, lr=lr, optimizer=optimizer, flatten_output=True, nseeds=nseeds, nsteps=nsteps, lossfn='nll') 94 | 95 | prm = 'μP' if mup else 'SP' 96 | return plot_coord_data(df, legend=legend, 97 | save_to=os.path.join(plotdir, f'{prm.lower()}_trsfmr_{optimizer}_coord.png'), 98 | suptitle=f'{prm} Transformer {optimizer} lr={lr} nseeds={nseeds}', 99 | face_color='xkcd:light grey' if not mup else None) 100 | 101 | 102 | if __name__ == '__main__': 103 | 104 | parser = argparse.ArgumentParser(description= 105 | ''' 106 | PyTorch Wikitext-2 Transformer Language Model, with μP. 107 | 108 | To train a μP model, one needs to first specify the base shapes. To save base shapes info, run, for example, 109 | 110 | python main.py --d_model 256 --save_base_shapes width256.bsh 111 | 112 | To train using MuAdam, run 113 | 114 | python main.py --d_model 256 --load_base_shapes width256.bsh --cuda --optimizer muadam 115 | 116 | To perform coord check, run 117 | 118 | python main.py --load_base_shapes width256.bsh --optimizer sgd --lr 0.5 --cuda --coord_check 119 | 120 | python main.py --load_base_shapes width256.bsh --optimizer adam --lr 0.01 --cuda --coord_check 121 | 122 | If you don't specify a base shape file, then you are using standard parametrization 123 | 124 | python main.py --d_model 256 --cuda --optimizer muadam 125 | 126 | Note that models of different depths need separate `.bsh` files. 127 | ''', formatter_class=argparse.RawTextHelpFormatter) 128 | parser.add_argument('--data', type=str, default='./data/wikitext-2', 129 | help='location of the data corpus') 130 | parser.add_argument('--bias', action='store_true', 131 | help='use bias') 132 | parser.add_argument('--save_base_shapes', type=str, default='', 133 | help='file location to save base shapes at') 134 | parser.add_argument('--load_base_shapes', type=str, default='', 135 | help='file location to load base shapes from') 136 | parser.add_argument('--d_model', type=int, default=256, 137 | help='width of the model') 138 | parser.add_argument('--ffn_ratio', type=int, default=1, 139 | help='the ratio of d_ffn to d_model') 140 | parser.add_argument('--nlayers', type=int, default=2, 141 | help='number of layers') 142 | parser.add_argument('--nhead', type=int, default=2, 143 | help='the number of heads in the encoder/decoder of the transformer model') 144 | parser.add_argument('--lr', type=float, default=0.001, 145 | help='initial learning rate') 146 | parser.add_argument('--momentum', type=float, default=0, 147 | help='momentum') 148 | parser.add_argument('--output_mult', type=float, default=1, 149 | help='output is multiplied by sqrt(output_mult/d_model)') 150 | parser.add_argument('--input_mult', type=float, default=1, 151 | help='input is multiplied by sqrt(input_mult*d_model)') 152 | parser.add_argument('--attn_mult', type=float, default=1, 153 | help='attn is multiplied by sqrt(attn_mult)/head_dim') 154 | parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'musgd', 'adam', 'muadam']) 155 | parser.add_argument('--init_var', type=float, default=1, 156 | help='weights are initialized with variance init_var/ninp') 157 | parser.add_argument('--clip', type=float, default=0.25, 158 | help='gradient clipping') 159 | parser.add_argument('--epochs', type=int, default=40, 160 | help='upper epoch limit') 161 | parser.add_argument('--batch_size', type=int, default=20, metavar='N', 162 | help='batch size') 163 | parser.add_argument('--bptt', type=int, default=35, 164 | help='sequence length') 165 | parser.add_argument('--dropout', type=float, default=0.2, 166 | help='dropout applied to layers (0 = no dropout)') 167 | parser.add_argument('--tied', action='store_true', 168 | help='tie the word embedding and softmax weights') 169 | parser.add_argument('--seed', type=int, default=1111, 170 | help='random seed') 171 | parser.add_argument('--cuda', action='store_true', 172 | help='use CUDA') 173 | parser.add_argument('--precision', type=str, default='float', 174 | help='float | double | half') 175 | parser.add_argument('--log_interval', type=int, default=200, metavar='N', 176 | help='report interval') 177 | parser.add_argument('--save_dir', type=str, default=None, 178 | help='path to save the final model') 179 | parser.add_argument('--resume_dir', type=str, default=None, 180 | help='path to resume training') 181 | parser.add_argument('--log_dir', type=str, default='.', 182 | help='path to save logs') 183 | parser.add_argument('--coord_check', action='store_true', 184 | help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.') 185 | parser.add_argument('--coord_check_nsteps', type=int, default=3, 186 | help='Do coord check with this many steps.') 187 | parser.add_argument('--coord_check_nseeds', type=int, default=3, 188 | help='number of seeds for testing correctness of μ parametrization') 189 | parser.add_argument('--deferred_init', action='store_true', help='Skip instantiating the base and delta models for mup. Requires torchdistx.') 190 | 191 | args = parser.parse_args() 192 | 193 | print(args) 194 | 195 | # Set the random seed manually for reproducibility. 196 | torch.manual_seed(args.seed) 197 | if torch.cuda.is_available(): 198 | if not args.cuda: 199 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 200 | 201 | device = args.device = torch.device("cuda" if args.cuda else "cpu") 202 | 203 | ############################################################################### 204 | # Load data 205 | ############################################################################### 206 | 207 | corpus = data.Corpus(args.data) 208 | 209 | # Starting from sequential data, batchify arranges the dataset into columns. 210 | # For instance, with the alphabet as the sequence and batch size 4, we'd get 211 | # ┌ a g m s ┐ 212 | # │ b h n t │ 213 | # │ c i o u │ 214 | # │ d j p v │ 215 | # │ e k q w │ 216 | # └ f l r x ┘. 217 | # These columns are treated as independent by the model, which means that the 218 | # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient 219 | # batch processing. 220 | 221 | eval_batch_size = 10 222 | train_data = batchify(corpus.train, args.batch_size, device) 223 | val_data = batchify(corpus.valid, eval_batch_size, device) 224 | test_data = batchify(corpus.test, eval_batch_size, device) 225 | 226 | ############################################################################### 227 | # Build the model 228 | ############################################################################### 229 | 230 | 231 | ntokens = len(corpus.dictionary) 232 | 233 | 234 | 235 | def evaluate(data_source): 236 | # Turn on evaluation mode which disables dropout. 237 | model.eval() 238 | total_loss = 0. 239 | ntokens = len(corpus.dictionary) 240 | with torch.no_grad(): 241 | for i in range(0, data_source.size(0) - 1, args.bptt): 242 | data, targets = get_batch(data_source, i, args.bptt) 243 | output = model(data) 244 | output = output.view(-1, ntokens) 245 | total_loss += len(data) * criterion(output, targets).item() 246 | return total_loss / (len(data_source) - 1) 247 | 248 | 249 | def train(optimizer, epoch): 250 | # Turn on training mode which enables dropout. 251 | model.train() 252 | total_loss = 0. 253 | epoch_loss = 0. 254 | start_time = time.time() 255 | ntokens = len(corpus.dictionary) 256 | first_loss = None 257 | for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): 258 | data, targets = get_batch(train_data, i, args.bptt) 259 | # Starting each batch, we detach the hidden state from how it was previously produced. 260 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 261 | 262 | optimizer.zero_grad() 263 | output = model(data) 264 | output = output.view(-1, ntokens) 265 | loss = criterion(output, targets) 266 | if torch.isnan(loss): 267 | exit(0) 268 | if args.precision == 'half': 269 | with amp.scale_loss(loss, optimizer) as scaled_loss: 270 | scaled_loss.backward() 271 | else: 272 | loss.backward() 273 | 274 | if args.clip > 0: 275 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 276 | if args.precision == 'half': 277 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip) 278 | else: 279 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 280 | 281 | optimizer.step() 282 | 283 | total_loss += loss.item() 284 | epoch_loss += len(data) * loss.item() 285 | 286 | if batch % args.log_interval == 0 and batch > 0: 287 | cur_loss = total_loss / args.log_interval 288 | elapsed = time.time() - start_time 289 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | ' 290 | 'loss {:5.2f} | ppl {:8.2f}'.format( 291 | epoch, batch, len(train_data) // args.bptt, lr, 292 | elapsed * 1000 / args.log_interval, cur_loss, np.exp(cur_loss))) 293 | total_loss = 0 294 | start_time = time.time() 295 | if first_loss is None: 296 | first_loss = cur_loss 297 | 298 | return epoch_loss / (len(train_data) - 1), first_loss 299 | 300 | if args.coord_check: 301 | print('testing parametrization') 302 | import os 303 | os.makedirs('coord_checks', exist_ok=True) 304 | plotdir = 'coord_checks' 305 | coord_check(mup=True, lr=args.lr, optimizer=args.optimizer, batch_size=args.batch_size, nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, data_dir=args.data, args=args, plotdir=plotdir, legend=False) 306 | coord_check(mup=False, lr=args.lr, optimizer=args.optimizer, batch_size=args.batch_size, nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, data_dir=args.data, args=args, plotdir=plotdir, legend=False) 307 | import sys; sys.exit() 308 | 309 | 310 | if args.save_base_shapes: 311 | print(f'saving base shapes at {args.save_base_shapes}') 312 | if args.deferred_init: 313 | from torchdistx.deferred_init import deferred_init 314 | # We don't need to instantiate the base and delta models 315 | base_shapes = get_shapes( 316 | deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout, 317 | tied=args.tied, bias=args.bias, encoder_var=args.init_var, 318 | decoder_var=args.init_var, standparam=args.load_base_shapes=='') 319 | ) 320 | delta_shapes = get_shapes( 321 | # just need to change whatever dimension(s) we are scaling 322 | deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2, 323 | nlayers=args.nlayers, dropout=args.dropout, 324 | tied=args.tied, bias=args.bias, encoder_var=args.init_var, 325 | decoder_var=args.init_var, standparam=args.load_base_shapes=='') 326 | ) 327 | else: 328 | base_shapes = get_shapes( 329 | mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout, 330 | tied=args.tied, bias=args.bias, encoder_var=args.init_var, 331 | decoder_var=args.init_var, standparam=args.load_base_shapes=='') 332 | ) 333 | delta_shapes = get_shapes( 334 | # just need to change whatever dimension(s) we are scaling 335 | mdl.TransformerModel(args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2, 336 | nlayers=args.nlayers, dropout=args.dropout, 337 | tied=args.tied, bias=args.bias, encoder_var=args.init_var, 338 | decoder_var=args.init_var, standparam=args.load_base_shapes=='') 339 | ) 340 | make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes) 341 | print('done and exit') 342 | import sys; sys.exit() 343 | model = mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout, 344 | tied=args.tied, bias=args.bias, encoder_var=args.init_var, 345 | decoder_var=args.init_var, standparam=args.load_base_shapes=='') 346 | if args.load_base_shapes: 347 | print(f'loading base shapes from {args.load_base_shapes}') 348 | set_base_shapes(model, args.load_base_shapes) 349 | print('done') 350 | else: 351 | print(f'using own shapes') 352 | set_base_shapes(model, None) 353 | print('done') 354 | 355 | model = model.to(device) 356 | model = setprec(model, args.precision) 357 | 358 | criterion = nn.NLLLoss() 359 | 360 | if args.save_dir is not None: 361 | os.makedirs(args.save_dir, exist_ok=True) 362 | 363 | # Loop over epochs. 364 | lr = args.lr 365 | best_val_loss = float('inf') 366 | 367 | if args.optimizer == 'sgd': 368 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 369 | elif args.optimizer == 'musgd': 370 | optimizer = MuSGD(model.parameters(), lr=args.lr, momentum=args.momentum) 371 | elif args.optimizer == 'adam': 372 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 373 | elif args.optimizer == 'muadam': 374 | optimizer = MuAdam(model.parameters(), lr=args.lr) 375 | else: 376 | raise ValueError() 377 | 378 | # half-precision black magic 379 | if args.precision == 'half': 380 | model, optimizer = amp.initialize( 381 | model, 382 | optimizer, 383 | opt_level='O1', 384 | min_loss_scale=0.0001, 385 | verbosity=0 386 | ) 387 | 388 | logs = [] 389 | start_epoch = 0 390 | if args.resume_dir and os.path.exists(os.path.join(args.resume_dir, 'checkpoint_last.pt')): 391 | checkpoint = torch.load(os.path.join(args.resume_dir, 'checkpoint_last.pt')) 392 | model.load_state_dict(checkpoint['model']) 393 | optimizer.load_state_dict(checkpoint['optimizer']) 394 | if args.precision == 'half': 395 | amp.load_state_dict(checkpoint['amp']) 396 | start_epoch = checkpoint['epoch'] 397 | best_val_loss = checkpoint['best_val_loss'] 398 | logs = checkpoint['logs'] 399 | 400 | # At any point you can hit Ctrl + C to break out of training early. 401 | try: 402 | for epoch in range(start_epoch+1, args.epochs+1): 403 | epoch_start_time = time.time() 404 | train_loss, first_loss = train(optimizer, epoch) 405 | # print(first_loss) 406 | val_loss = evaluate(val_data) 407 | print('-' * 89) 408 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 409 | 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 410 | val_loss, np.exp(val_loss))) 411 | print('-' * 89) 412 | logs.append(dict( 413 | epoch=epoch, 414 | train_loss=train_loss, 415 | val_loss=val_loss, 416 | first_loss=first_loss 417 | )) 418 | # Save the model if the validation loss is the best we've seen so far. 419 | if args.save_dir is not None: 420 | if val_loss < best_val_loss: 421 | checkpoint = { 422 | 'model': model.state_dict(), 423 | 'optimizer': optimizer.state_dict(), 424 | 'epoch': epoch, 425 | 'best_val_loss': best_val_loss, 426 | 'logs': logs 427 | } 428 | if args.precision == 'half': 429 | checkpoint['amp'] = amp.state_dict(), 430 | with open(os.path.join(args.save_dir, 'checkpoint_best.pt'), 'wb') as f: 431 | torch.save(checkpoint, f) 432 | best_val_loss = val_loss 433 | else: 434 | checkpoint = { 435 | 'model': model.state_dict(), 436 | 'optimizer': optimizer.state_dict(), 437 | 'epoch': epoch, 438 | 'best_val_loss': best_val_loss, 439 | 'logs': logs 440 | } 441 | if args.precision == 'half': 442 | checkpoint['amp'] = amp.state_dict() 443 | with open(os.path.join(args.save_dir, 'checkpoint_last.pt'), 'wb') as f: 444 | torch.save(checkpoint, f) 445 | 446 | except KeyboardInterrupt: 447 | print('-' * 89) 448 | print('Exiting from training early') 449 | 450 | # Load the best saved model. 451 | if args.save_dir is not None: 452 | with open(os.path.join(args.save_dir, 'checkpoint_best.pt'), 'rb') as f: 453 | checkpoint = torch.load(f) 454 | model.load_state_dict(checkpoint['model']) 455 | optimizer.load_state_dict(checkpoint['optimizer']) 456 | if args.precision == 'half': 457 | amp.load_state_dict(checkpoint['amp'][0]) 458 | # Run on test data. 459 | test_loss = evaluate(test_data) 460 | print('=' * 89) 461 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 462 | test_loss, np.exp(test_loss))) 463 | print('=' * 89) 464 | logs.append(dict( 465 | epoch='-1', 466 | test_loss=test_loss 467 | )) 468 | 469 | 470 | with open(os.path.join(os.path.expanduser(args.log_dir), 'logs.tsv'), 'w') as f: 471 | logdf = pd.DataFrame(logs) 472 | print(os.path.join(os.path.expanduser(args.log_dir), 'logs.tsv')) 473 | f.write(logdf.to_csv(sep='\t', float_format='%.4f')) 474 | -------------------------------------------------------------------------------- /examples/Transformer/width256.bsh: -------------------------------------------------------------------------------- 1 | # This is a base shape file encoded in yaml 2 | # - `null` indicates a dimension is "finite", i.e. a non-"width" dimension 3 | # - a number indicates the base dimension of an "infinite" dimension, i.e. some notion of "width" 4 | decoder.weight: 5 | - null 6 | - 256 7 | encoder.weight: 8 | - null 9 | - 256 10 | transformer_encoder.layers.0.linear1.weight: 11 | - 256 12 | - 256 13 | transformer_encoder.layers.0.linear2.weight: 14 | - 256 15 | - 256 16 | transformer_encoder.layers.0.norm1.bias: 17 | - 256 18 | transformer_encoder.layers.0.norm1.weight: 19 | - 256 20 | transformer_encoder.layers.0.norm2.bias: 21 | - 256 22 | transformer_encoder.layers.0.norm2.weight: 23 | - 256 24 | transformer_encoder.layers.0.self_attn.in_proj_weight: 25 | - 768 26 | - 256 27 | transformer_encoder.layers.0.self_attn.out_proj.weight: 28 | - 256 29 | - 256 30 | transformer_encoder.layers.1.linear1.weight: 31 | - 256 32 | - 256 33 | transformer_encoder.layers.1.linear2.weight: 34 | - 256 35 | - 256 36 | transformer_encoder.layers.1.norm1.bias: 37 | - 256 38 | transformer_encoder.layers.1.norm1.weight: 39 | - 256 40 | transformer_encoder.layers.1.norm2.bias: 41 | - 256 42 | transformer_encoder.layers.1.norm2.weight: 43 | - 256 44 | transformer_encoder.layers.1.self_attn.in_proj_weight: 45 | - 768 46 | - 256 47 | transformer_encoder.layers.1.self_attn.out_proj.weight: 48 | - 256 49 | - 256 50 | -------------------------------------------------------------------------------- /figures/parametrizations.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/figures/parametrizations.gif -------------------------------------------------------------------------------- /figures/sp_vs_mup_dashed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/figures/sp_vs_mup_dashed.png -------------------------------------------------------------------------------- /figures/widerbetter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mup/19814971934ef91dd546f88e913fc963e096d11c/figures/widerbetter.png -------------------------------------------------------------------------------- /mup/__init__.py: -------------------------------------------------------------------------------- 1 | name = "mup" 2 | 3 | from mup.shape import * 4 | from mup.infshape import * 5 | from mup.init import * 6 | from mup.layer import * 7 | from mup.optim import * -------------------------------------------------------------------------------- /mup/coord_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Microsoft Corporation. 2 | ''' 3 | Helper functions for performing coord check. 4 | ''' 5 | import os 6 | from copy import copy 7 | from itertools import product 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | def cov(x): 16 | '''Treat `x` as a collection of vectors and its Gram matrix. 17 | Input: 18 | x: If it has shape [..., d], then it's treated as 19 | a collection of d-dimensional vectors 20 | Output: 21 | cov: a matrix of size N x N where N is the product of 22 | the non-last dimensions of `x`. 23 | ''' 24 | if x.nelement() == 1: 25 | width = 1 26 | xx = x.reshape(1, 1) 27 | else: 28 | width = x.shape[-1] 29 | xx = x.reshape(-1, x.shape[-1]) 30 | return xx @ xx.T / width 31 | 32 | def covoffdiag(x): 33 | '''Get off-diagonal entries of `cov(x)` in a vector. 34 | Input: 35 | x: If it has shape [..., d], then it's treated as 36 | a collection of d-dimensional vectors 37 | Output: 38 | Off-diagonal entries of `cov(x)` in a vector.''' 39 | c = cov(x) 40 | return c[~torch.eye(c.shape[0], dtype=bool)] 41 | 42 | #: dict of provided functions for use in coord check 43 | FDICT = { 44 | 'l1': lambda x: torch.abs(x).mean(dtype=torch.float32), 45 | 'l2': lambda x: (x**2).mean(dtype=torch.float32)**0.5, 46 | 'mean': lambda x: x.mean(dtype=torch.float32), 47 | 'std': lambda x: x.std(dtype=torch.float32), 48 | 'covl1': lambda x: torch.abs(cov(x)).mean(dtype=torch.float32), 49 | 'covl2': lambda x: (cov(x)**2).mean(dtype=torch.float32)**0.5, 50 | 'covoffdiagl1': lambda x: torch.abs(covoffdiag(x)).mean(dtype=torch.float32), 51 | 'covoffdiagl2': lambda x: (covoffdiag(x)**2).mean(dtype=torch.float32)**0.5 52 | } 53 | 54 | def convert_fdict(d): 55 | '''convert a dict `d` with string values to function values. 56 | Input: 57 | d: a dict whose values are either strings or functions 58 | Output: 59 | a new dict, with the same keys as `d`, but the string values are 60 | converted to functions using `FDICT`. 61 | ''' 62 | return dict([ 63 | ((k, FDICT[v]) if isinstance(v, str) else (k, v)) 64 | for k, v in d.items()]) 65 | 66 | def _record_coords(records, width, modulename, t, 67 | output_fdict=None, input_fdict=None, param_fdict=None): 68 | '''Returns a forward hook that records coordinate statistics. 69 | 70 | Returns a forward hook that records statistics regarding the output, input, 71 | and/or parameters of a `nn.Module`. This hook is intended to run only once, 72 | on the timestep specified by `t`. 73 | 74 | On forward pass, the returned hook calculates statistics specified in 75 | `output_fdict`, `input_fdict`, and `param_fdict`, such as the normalized l1 76 | norm, of output, input, and/or parameters of the module. The statistics are 77 | recorded along with the `width`, `modulename`, and `t` (the time step) as a 78 | dict and inserted into `records` (which should be a list). More precisely, 79 | for each output, input, and/or parameter, the inserted dict is of the form 80 | 81 | { 82 | 'width': width, 'module': modified_modulename, 't': t, 83 | # keys are keys in fdict 84 | 'l1': 0.241, 'l2': 0.420, 'mean': 0.0, ... 85 | } 86 | 87 | where `modified_modulename` is a string that combines the `modulename` with 88 | an indicator of which output, input, or parameter tensor is the statistics 89 | computed over. 90 | 91 | The `*_fdict` inputs should be dictionaries with string keys and whose 92 | values can either be functions or strings. The string values are converted 93 | to functions via `convert_fdict`. The default values of `*_dict` inputs are 94 | converted to `output_fdict = dict(l1=FDICT['l1'])`, `input_fdict = {}`, 95 | `param_fdict = {}`, i.e., only the average coordinate size (`l1`) of the 96 | output activations are recorded. 97 | 98 | Inputs: 99 | records: 100 | list to append coordinate data to 101 | width: 102 | width of the model. This is used only for plotting coord check later 103 | on, so it can be any notion of width. 104 | modulename: 105 | string name of the module. This is used only for plotting coord check. 106 | t: 107 | timestep of training. This is used only for plotting coord check. 108 | output_fdict, input_fdict, param_fdict: 109 | dicts with string keys and whose values can either be functions or 110 | strings. The string values are converted to functions via 111 | `convert_fdict` 112 | Output: 113 | a forward hook that records statistics regarding the output, input, 114 | and/or parameters of a `nn.Module`, as discussed above. 115 | ''' 116 | if output_fdict is None: 117 | output_fdict = dict(l1=FDICT['l1']) 118 | else: 119 | output_fdict = convert_fdict(output_fdict) 120 | if input_fdict is None: 121 | input_fdict = {} 122 | else: 123 | input_fdict = convert_fdict(input_fdict) 124 | if param_fdict is None: 125 | param_fdict = {} 126 | else: 127 | param_fdict = convert_fdict(param_fdict) 128 | def f(module, input, output): 129 | def get_stat(d, x, fdict): 130 | if isinstance(x, (tuple, list)): 131 | for i, _x in enumerate(x): 132 | _d = copy(d) 133 | _d['module'] += f'[{i}]' 134 | get_stat(_d, _x, fdict) 135 | elif isinstance(x, dict): 136 | for name, _x in x.items(): 137 | _d = copy(d) 138 | _d['module'] += f'[{name}]' 139 | get_stat(_d, _x, fdict) 140 | elif isinstance(x, torch.Tensor): 141 | _d = copy(d) 142 | for fname, f in fdict.items(): 143 | _d[fname] = f(x).item() 144 | records.append(_d) 145 | elif x is None: 146 | pass 147 | else: 148 | raise NotImplementedError(f'Unexpected output type: {type(x)}') 149 | with torch.no_grad(): 150 | ret = { 151 | 'width': width, 152 | 'module': modulename, 153 | 't': t 154 | } 155 | 156 | # output stats 157 | if isinstance(output, (tuple, list)): 158 | for i, out in enumerate(output): 159 | _ret = copy(ret) 160 | _ret['module'] += f':out[{i}]' 161 | get_stat(_ret, out, output_fdict) 162 | elif isinstance(output, dict): 163 | for name, out in output.items(): 164 | _ret = copy(ret) 165 | _ret['module'] += f':out[{name}]' 166 | get_stat(_ret, out, output_fdict) 167 | elif isinstance(output, torch.Tensor): 168 | _ret = copy(ret) 169 | for fname, f in output_fdict.items(): 170 | _ret[fname] = f(output).item() 171 | records.append(_ret) 172 | else: 173 | raise NotImplementedError(f'Unexpected output type: {type(output)}') 174 | 175 | # input stats 176 | if input_fdict: 177 | if isinstance(input, (tuple, list)): 178 | for i, out in enumerate(input): 179 | _ret = copy(ret) 180 | _ret['module'] += f':in[{i}]' 181 | get_stat(_ret, out, input_fdict) 182 | elif isinstance(input, dict): 183 | for name, out in input.items(): 184 | _ret = copy(ret) 185 | _ret['module'] += f':in[{name}]' 186 | get_stat(_ret, out, input_fdict) 187 | elif isinstance(input, torch.Tensor): 188 | _ret = copy(ret) 189 | for fname, f in input_fdict.items(): 190 | _ret[fname] = f(input).item() 191 | records.append(_ret) 192 | else: 193 | raise NotImplementedError(f'Unexpected output type: {type(input)}') 194 | 195 | # param stats 196 | if param_fdict: 197 | for name, p in module.named_parameters(): 198 | _ret = copy(ret) 199 | _ret['module'] += f':param[{name}]' 200 | for fname, f in param_fdict.items(): 201 | _ret[fname] = f(p).item() 202 | records.append(_ret) 203 | 204 | return f 205 | 206 | def _get_coord_data(models, dataloader, optcls, nsteps=3, 207 | dict_in_out=False, flatten_input=False, flatten_output=False, 208 | output_name='loss', lossfn='xent', filter_module_by_name=None, 209 | fix_data=True, cuda=True, nseeds=1, 210 | output_fdict=None, input_fdict=None, param_fdict=None, 211 | show_progress=True, one_hot_target=False): 212 | '''Inner method for `get_coord_data`. 213 | 214 | Train the models in `models` with optimizer given by `optcls` and data from 215 | `dataloader` for `nsteps` steps, and record coordinate statistics specified 216 | by `output_fdict`, `input_fdict`, `param_fdict`. By default, only `l1` is 217 | computed for output activations of each module. 218 | 219 | Inputs: 220 | models: 221 | a dict of lazy models, where the keys are numbers indicating width. 222 | Each entry of `models` is a function that instantiates a model given 223 | nothing. 224 | dataloader: 225 | an iterator whose elements are either Huggingface style dicts, if 226 | `dict_in_out` is True, or (input, label). If `fix_data` is True 227 | (which is the default), then only the first element of `dataloader` 228 | is used in a loop and the rest of `dataloder` is ignored. 229 | optcls: 230 | a function so that `optcls(model)` gives an optimizer used to train 231 | the model. 232 | nsteps: 233 | number of steps to train the model 234 | dict_in_out: 235 | whether the data loader contains Huggingface-style dict input and 236 | output. Default: False 237 | flatten_input: 238 | if not `dict_in_out`, reshape the input to be 239 | `input.view(input.shape[0], -1)`. Typically used for testing MLPs. 240 | flatten_output: 241 | if not `dict_in_out`, reshape the label to be `label.view(-1, 242 | input.shape[-1])`. 243 | output_name: 244 | if `dict_in_out`, this is the key for the loss value if the output 245 | is a dict. If the output is not a dict, then we assume the first 246 | element of the output is the loss. 247 | lossfn: 248 | loss function to use if not `dict_in_out`. Can be either a string from 249 | [`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that 250 | `lossfn(output, target)` returns the loss value. Examples of valid 251 | `callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is 252 | `torch.nn.functional`. Default: 'xent' 253 | filter_module_by_name: 254 | a function that returns a bool given module names (from 255 | `model.named_modules()`), or None. If not None, then only modules 256 | whose name yields True will be recorded. 257 | cuda: 258 | whether to use cuda or not. Default: True 259 | nseeds: 260 | number of times to repeat the training, each with different seeds. 261 | output_fdict, input_fdict, param_fdict: 262 | function dicts to be used in `_record_coords`. By default, only `l1` 263 | is computed for output activations of each module. 264 | show_progress: 265 | show progress using tqdm. Default: True 266 | one_hot_target: 267 | convert target label into a one-hot vector. This typically is only 268 | used for `'mse'` or `'l1'` losses in classification tasks. 269 | Default: False 270 | Output: 271 | a pandas DataFrame containing recorded results. The column names are 272 | `'width', 'module', 't'` as well as names of statistics recorded, such 273 | as `'l1'` (see `FDICT` for other premade statistics that can be 274 | collected). 275 | 276 | Breaking Changes: 277 | In v1.0.0, when `lossfn=='mse'`, the target is automatically converted 278 | to a one hot vector before loss computation. Starting in v1.1.0, this 279 | behavior is turned off, and the user needs to explicitly turn on this 280 | behavior by setting `one_hot_target=True`. 281 | 282 | ''' 283 | df = [] 284 | if fix_data: 285 | batch = next(iter(dataloader)) 286 | dataloader = [batch] * nsteps 287 | if show_progress: 288 | from tqdm import tqdm 289 | pbar = tqdm(total=nseeds * len(models)) 290 | 291 | for i in range(nseeds): 292 | torch.manual_seed(i) 293 | for width, model in models.items(): 294 | model = model() 295 | model = model.train() 296 | if cuda: 297 | model = model.cuda() 298 | optimizer = optcls(model) 299 | for batch_idx, batch in enumerate(dataloader, 1): 300 | remove_hooks = [] 301 | # add hooks 302 | for name, module in model.named_modules(): 303 | if filter_module_by_name and not filter_module_by_name(name): 304 | continue 305 | remove_hooks.append(module.register_forward_hook( 306 | _record_coords(df, width, name, batch_idx, 307 | output_fdict=output_fdict, 308 | input_fdict=input_fdict, 309 | param_fdict=param_fdict))) 310 | if dict_in_out: 311 | if cuda: 312 | for k, v in batch.items(): 313 | if isinstance(v, torch.Tensor): 314 | batch[k] = v.cuda() 315 | outputs = model(**batch) 316 | loss = outputs[output_name] if isinstance(outputs, dict) else outputs[0] 317 | else: 318 | (data, target) = batch 319 | if cuda: 320 | data, target = data.cuda(), target.cuda() 321 | if flatten_input: 322 | data = data.view(data.size(0), -1) 323 | output = model(data) 324 | if flatten_output: 325 | output = output.view(-1, output.shape[-1]) 326 | if one_hot_target: 327 | target = F.one_hot(target, 328 | num_classes=output.size(-1)).float() 329 | if lossfn == 'xent': 330 | loss = F.cross_entropy(output, target) 331 | elif lossfn == 'mse': 332 | loss = F.mse_loss(output, target) 333 | elif lossfn == 'nll': 334 | loss = F.nll_loss(output, target) 335 | elif lossfn == 'l1': 336 | loss = F.l1_loss(output, target) 337 | elif callable(lossfn): 338 | loss = lossfn(output, target) 339 | else: 340 | raise NotImplementedError(f'unknown `lossfn`: {lossfn}') 341 | optimizer.zero_grad() 342 | loss.backward() 343 | optimizer.step() 344 | 345 | # remove hooks 346 | for handle in remove_hooks: 347 | handle.remove() 348 | 349 | if batch_idx == nsteps: break 350 | if show_progress: 351 | pbar.update(1) 352 | if show_progress: 353 | pbar.close() 354 | return pd.DataFrame(df) 355 | 356 | 357 | def get_coord_data(models, dataloader, optimizer='sgd', lr=None, mup=True, 358 | filter_trainable_by_name=None, 359 | **kwargs): 360 | '''Get coord data for coord check. 361 | 362 | Train the models in `models` with data from `dataloader` and optimizer 363 | specified by `optimizer` and `lr` for `nsteps` steps, and record coordinate 364 | statistics specified by `output_fdict`, `input_fdict`, `param_fdict`. By 365 | default, only `l1` is computed for output activations of each module. 366 | 367 | This function wraps around `_get_coord_data`, with the main difference being 368 | user can specify common optimizers via a more convenient interface. 369 | 370 | Inputs: 371 | models: 372 | a dict of lazy models, where the keys are numbers indicating width. 373 | Each entry of `models` is a function that instantiates a model given 374 | nothing. 375 | dataloader: 376 | an iterator whose elements are either Huggingface style dicts, if 377 | `dict_in_out` is True, or (input, label). If `fix_data` is True 378 | (which is the default), then only the first element of `dataloader` 379 | is used in a loop and the rest of `dataloder` is ignored. 380 | optimizer: 381 | a string in `['sgd', 'adam', 'adamw']`, with default being `'sgd'`. 382 | lr: 383 | learning rate. By default is 0.1 for `'sgd'` and 1e-3 for others. 384 | mup: 385 | If True, then use the optimizer from `mup.optim`; otherwise, use the 386 | one from `torch.optim`. 387 | filter_trainable_by_name: 388 | a function that returns a bool given module names (from 389 | `model.named_modules()`), or None. If not None, then only modules 390 | whose name yields True will be trained. 391 | nsteps: 392 | number of steps to train the model 393 | dict_in_out: 394 | whether the data loader contains Huggingface-style dict input and 395 | output. Default: False 396 | flatten_input: 397 | if not `dict_in_out`, reshape the input to be 398 | `input.view(input.shape[0], -1)`. Typically used for testing MLPs. 399 | flatten_output: 400 | if not `dict_in_out`, reshape the label to be `label.view(-1, 401 | input.shape[-1])`. 402 | output_name: 403 | if `dict_in_out`, this is the key for the loss value if the output 404 | is a dict. If the output is not a dict, then we assume the first 405 | element of the output is the loss. 406 | lossfn: 407 | loss function to use if not `dict_in_out`. Can be either a string from 408 | [`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that 409 | `lossfn(output, target)` returns the loss value. Examples of valid 410 | `callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is 411 | `torch.nn.functional`. Default: 'xent' 412 | filter_module_by_name: 413 | a function that returns a bool given module names (from 414 | `model.named_modules()`), or None. If not None, then only modules 415 | whose name yields True will be recorded. 416 | cuda: 417 | whether to use cuda or not. Default: True 418 | nseeds: 419 | number of times to repeat the training, each with different seeds. 420 | output_fdict, input_fdict, param_fdict: 421 | function dicts to be used in `_record_coords`. By default, only `l1` 422 | is computed for output activations of each module. 423 | show_progress: 424 | show progress using tqdm. Default: True 425 | one_hot_target: 426 | convert target label into a one-hot vector. This typically is only 427 | used for `'mse'` or `'l1'` losses in classification tasks. 428 | Default: False 429 | Output: 430 | a pandas DataFrame containing recorded results. The column names are 431 | `'width', 'module', 't'` as well as names of statistics recorded, such 432 | as `'l1'` (see `FDICT` for other premade statistics that can be 433 | collected). 434 | 435 | Breaking Changes: 436 | In v1.0.0, when `lossfn=='mse'`, the target is automatically converted 437 | to a one hot vector before loss computation. Starting in v1.1.0, this 438 | behavior is turned off, and the user needs to explicitly turn on this 439 | behavior by setting `one_hot_target=True`. 440 | ''' 441 | if lr is None: 442 | lr = 0.1 if optimizer == 'sgd' else 1e-3 443 | if mup: 444 | from mup.optim import MuAdam as Adam 445 | from mup.optim import MuAdamW as AdamW 446 | from mup.optim import MuSGD as SGD 447 | else: 448 | from torch.optim import SGD, Adam, AdamW 449 | def get_trainable(model): 450 | params = model.parameters() 451 | if filter_trainable_by_name is not None: 452 | params = [] 453 | for name, p in model.named_parameters(): 454 | if filter_trainable_by_name(name): 455 | params.append(p) 456 | return params 457 | if optimizer == 'sgd': 458 | optcls = lambda model: SGD(get_trainable(model), lr=lr) 459 | elif optimizer == 'adam': 460 | optcls = lambda model: Adam(get_trainable(model), lr=lr) 461 | elif optimizer == 'adamw': 462 | optcls = lambda model: AdamW(get_trainable(model), lr=lr) 463 | elif optimizer is None: 464 | raise ValueError('optimizer should be sgd|adam|adamw or a custom function') 465 | 466 | data = _get_coord_data(models, dataloader, optcls, **kwargs) 467 | data['optimizer'] = optimizer 468 | data['lr'] = lr 469 | return data 470 | 471 | 472 | def plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module', 473 | legend='full', name_contains=None, name_not_contains=None, module_list=None, 474 | loglog=True, logbase=2, face_color=None, subplot_width=5, 475 | subplot_height=4): 476 | '''Plot coord check data `df` obtained from `get_coord_data`. 477 | 478 | Input: 479 | df: 480 | a pandas DataFrame obtained from `get_coord_data` 481 | y: 482 | the column of `df` to plot on the y-axis. Default: `'l1'` 483 | save_to: 484 | path to save the resulting figure, or None. Default: None. 485 | suptitle: 486 | The title of the entire figure. 487 | x: 488 | the column of `df` to plot on the x-axis. Default: `'width'` 489 | hue: 490 | the column of `df` to represent as color. Default: `'module'` 491 | legend: 492 | 'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`. 493 | name_contains, name_not_contains: 494 | only plot modules whose name contains `name_contains` and does not contain `name_not_contains` 495 | module_list: 496 | only plot modules that are given in the list, overrides `name_contains` and `name_not_contains` 497 | loglog: 498 | whether to use loglog scale. Default: True 499 | logbase: 500 | the log base, if using loglog scale. Default: 2 501 | face_color: 502 | background color of the plot. Default: None (which means white) 503 | subplot_width, subplot_height: 504 | The width and height for each timestep's subplot. More precisely, 505 | the figure size will be 506 | `(subplot_width*number_of_time_steps, subplot_height)`. 507 | Default: 5, 4 508 | 509 | Output: 510 | the `matplotlib` figure object 511 | ''' 512 | ### preprocessing 513 | df = copy(df) 514 | # nn.Sequential has name '', which duplicates the output layer 515 | df = df[df.module != ''] 516 | if module_list is not None: 517 | df = df[df['module'].isin(module_list)] 518 | else: 519 | if name_contains is not None: 520 | df = df[df['module'].str.contains(name_contains)] 521 | if name_not_contains is not None: 522 | df = df[~(df['module'].str.contains(name_not_contains))] 523 | # for nn.Sequential, module names are numerical 524 | try: 525 | df['module'] = pd.to_numeric(df['module']) 526 | except ValueError: 527 | pass 528 | 529 | ts = df.t.unique() 530 | 531 | import matplotlib.pyplot as plt 532 | import seaborn as sns 533 | sns.set() 534 | 535 | def tight_layout(plt): 536 | plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 537 | 538 | ### plot 539 | fig = plt.figure(figsize=(subplot_width * len(ts), subplot_height)) 540 | hue_order = sorted(set(df['module'])) 541 | if face_color is not None: 542 | fig.patch.set_facecolor(face_color) 543 | ymin, ymax = min(df[y]), max(df[y]) 544 | for t in ts: 545 | t = int(t) 546 | plt.subplot(1, len(ts), t) 547 | sns.lineplot(x=x, y=y, data=df[df.t == t], hue=hue, hue_order=hue_order, legend=legend if t == 1 else None) 548 | plt.title(f't={t}') 549 | if t != 1: 550 | plt.ylabel('') 551 | if loglog: 552 | plt.loglog(base=logbase) 553 | ax = plt.gca() 554 | ax.set_ylim([ymin, ymax]) 555 | if suptitle: 556 | plt.suptitle(suptitle) 557 | tight_layout(plt) 558 | if save_to is not None: 559 | plt.savefig(save_to) 560 | print(f'coord check plot saved to {save_to}') 561 | 562 | return fig 563 | 564 | # example of how to plot coord check results 565 | # for the CNN and MLP models in mup.test 566 | def example_plot_coord_check( 567 | arch='mlp', optimizer='sgd', lr=None, widths=None, mup=True, 568 | nsteps=3, nseeds=10, plotdir='', batchnorm=False, batch_size=1, 569 | init='kaiming_fan_in_normal', download_cifar=True, legend='full', 570 | dict_in_out=False, name_contains=None, name_not_contains=None): 571 | 572 | from mup.test.models import get_lazy_models, get_train_loader 573 | if batchnorm: 574 | batch_size = 5 575 | train_loader = get_train_loader(batch_size=batch_size, download=download_cifar) 576 | 577 | if widths is None: 578 | widths = 2**np.arange(7, 14) if arch == 'mlp' else 2**np.arange(3, 10) 579 | models = get_lazy_models(arch, widths, mup=mup, batchnorm=batchnorm, init=init, readout_zero_init=True) 580 | df = get_coord_data(models, train_loader, mup=mup, lr=lr, optimizer=optimizer, flatten_input=arch == 'mlp', nseeds=nseeds, nsteps=nsteps, dict_in_out=dict_in_out) 581 | 582 | prm = 'μP' if mup else 'SP' 583 | bn = 'on' if batchnorm else 'off' 584 | if lr is None: 585 | lr = 0.1 if optimizer == 'sgd' else 1e-3 586 | return plot_coord_data(df, legend=legend, 587 | name_contains=name_contains, name_not_contains=name_not_contains, 588 | save_to=os.path.join(plotdir, f'{prm.lower()}_{arch}_{optimizer}_lr{lr}_nseeds{nseeds}_bn{int(batchnorm)}_coord.png'), 589 | suptitle=f'{prm} {arch.upper()} {optimizer} lr={lr} bn={bn} nseeds={nseeds}', 590 | face_color='xkcd:light grey' if not mup else None) 591 | 592 | 593 | if __name__ == '__main__': 594 | import os 595 | os.makedirs('coord_checks', exist_ok=True) 596 | plotdir = 'coord_checks' 597 | 598 | nseeds = 5 599 | 600 | for arch, opt, bn, mup in product(['mlp', 'cnn'], ['sgd', 'adam'], [False, True], [False, True]): 601 | example_plot_coord_check(arch, opt, batchnorm=bn, mup=mup, nseeds=nseeds, download_cifar=True, legend=None, plotdir=plotdir) 602 | 603 | 604 | -------------------------------------------------------------------------------- /mup/infshape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Microsoft Corporation. 2 | 3 | from copy import copy 4 | 5 | 6 | class InfDim: 7 | '''A dimension with a base dimension, used for calculating μP scaling. 8 | 9 | An `InfDim` object is made up of 2 numbers: a dimension and a base 10 | dimension. If the base dimension is None, then this object represents a 11 | "finite", or "non-width" dimension. Otherwise, it represents an "infinite", 12 | or "width" dimension. 13 | ''' 14 | 15 | def __init__(self, base_dim, dim): 16 | self.base_dim = base_dim 17 | self.dim = dim 18 | 19 | def isinf(self): 20 | return self.base_dim is not None 21 | 22 | def width_mult(self): 23 | '''Width multiplier used for calculating μP scaling. 24 | 25 | If finite, return 1. 26 | If infinite, return dim / base_dim. 27 | ''' 28 | if self.isinf(): 29 | return self.dim / self.base_dim 30 | return 1 31 | 32 | def __repr__(self): 33 | return f'InfDim({self.base_dim}, {self.dim})' 34 | 35 | def __str__(self): 36 | if self.isinf(): 37 | return repr(self) 38 | return f'FinDim({self.dim})' 39 | 40 | def __eq__(self, other: object) -> bool: 41 | if not isinstance(other, InfDim): 42 | return False 43 | return self.base_dim == other.base_dim and \ 44 | self.dim == other.dim 45 | 46 | 47 | class InfShape(tuple): 48 | '''A tuple of `InfDim`s. 49 | 50 | This is intended to be attached to each parameter tensor `p` as `p.infshape`. 51 | ''' 52 | 53 | def __init__(self, *args, **kwargs): 54 | tuple.__init__(*args, **kwargs) 55 | for dim in self: 56 | if not isinstance(dim, InfDim): 57 | raise ValueError('Elements of InfShape needs to be of class InfDim') 58 | # set main to be the last dimension that is infinite 59 | # for inf x inf this is fanin 60 | # for inf x fin or fin x inf it's the unique inf dim 61 | # user can set this manually if necessary 62 | self.main_idx = self.main = None 63 | for i, dim in list(enumerate(self))[::-1]: 64 | if dim.isinf(): 65 | self.main_idx = i 66 | self.main = dim 67 | break 68 | 69 | def fanin_fanout(self): 70 | assert len(self) >= 2, 'fanin, fanout undefined for 1-dimensional weights' 71 | return self[1], self[0] 72 | 73 | def fanin_fanout_mult_ratio(self): 74 | fanin, fanout = self.fanin_fanout() 75 | return fanin.width_mult() / fanout.width_mult() 76 | 77 | def ninf(self): 78 | return sum(1 for dim in self if dim.isinf()) 79 | 80 | def width_mult(self): 81 | if self.main is not None: 82 | return self.main.width_mult() 83 | return 1 84 | 85 | def base_shape(self): 86 | return [d.base_dim for d in self] 87 | 88 | def shape(self): 89 | return [d.dim for d in self] 90 | 91 | def __repr__(self): 92 | r = tuple.__repr__(self)[1:-1] 93 | return f'InfShape([{r}])' 94 | 95 | def serialize(self): 96 | d = {'base_shape': [], 'shape': []} 97 | for infdim in self: 98 | d['shape'].append(infdim.dim) 99 | d['base_shape'].append(infdim.base_dim) 100 | return d 101 | 102 | def __eq__(self, other: object) -> bool: 103 | if not isinstance(other, InfShape): 104 | return False 105 | return all(d == dd for d, dd in zip(self, other)) 106 | 107 | @classmethod 108 | def deserialize(cls, d): 109 | infshape = [] 110 | for base_dim, dim in zip(d['base_shape'], d['shape']): 111 | infshape.append(InfDim(base_dim, dim)) 112 | return InfShape(infshape) 113 | 114 | @classmethod 115 | def from_base_shape(cls, bsh): 116 | return InfShape([InfDim(bd, None) for bd in bsh]) 117 | 118 | def zip_infshape(base_dims, dims, fin_if_same=True): 119 | infshape = [] 120 | for bd, d in zip(base_dims, dims): 121 | if isinstance(bd, InfDim): 122 | # retain bd's base_dim but overwrite dim 123 | infdim = copy(bd) 124 | infdim.dim = d 125 | infshape.append(infdim) 126 | elif isinstance(bd, int): 127 | if bd == d and fin_if_same: 128 | infshape.append(InfDim(None, d)) 129 | else: 130 | infshape.append(InfDim(bd, d)) 131 | else: 132 | raise ValueError(f'unhandled base_dim type: {type(bd)}') 133 | return InfShape(infshape) 134 | 135 | if __name__ == '__main__': 136 | infshape = InfShape([InfDim(None, 100), InfDim(128, 1024), InfDim(128, 128)]) 137 | print(infshape) 138 | print(f'{infshape.ninf()} dims are inf') 139 | print(f'width_mult {infshape.width_mult()}') 140 | 141 | print(zip_infshape([64, 128, 1024], [32, 128, 2048])) -------------------------------------------------------------------------------- /mup/init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Microsoft Corporation. 2 | ''' 3 | Initializer functions mirroring those of `torch.nn.init`. They serve as 4 | drop-in replacements after the user has called `set_base_shapes` on their 5 | model. 6 | 7 | All of the initializers here are designed to 1) behave exactly the same 8 | as the torch versions when the model shapes are equal to their base shapes, 9 | and 2) to scale with width correctly (according to μP), when the model shapes 10 | differ from the base shapes. In general, this means deviating from the 11 | torch version behaviors. 12 | ''' 13 | import math 14 | import warnings 15 | 16 | import torch 17 | from torch.nn.init import (_calculate_correct_fan, 18 | _calculate_fan_in_and_fan_out, _no_grad_fill_, 19 | _no_grad_normal_, _no_grad_uniform_, calculate_gain) 20 | 21 | 22 | def constant_std_init_(tensor, sampler_): 23 | assert hasattr(tensor, 'infshape'), 'Please call set_base_shapes(...)' 24 | if tensor.infshape.ninf() <= 1: 25 | sampler_(tensor) 26 | elif tensor.infshape.ninf() == 2: 27 | sampler_(tensor, scale=tensor.infshape.width_mult()**-0.5) 28 | else: 29 | raise NotImplementedError() 30 | return tensor 31 | 32 | def uniform_(tensor, a=0, b=1): 33 | '''Drop-in replacement of `torch.nn.init.uniform_`. 34 | Note: 35 | - if using this function, ensure `a` and `b` do not depend on fan-in, 36 | fan-out, or other notions of width, e.g. if a = 0, b = 1. 37 | - `tensor` should have `infshape` attribute set by `set_base_shapes`. 38 | ''' 39 | assert hasattr(tensor, 'infshape'), 'Please call set_base_shapes(...)' 40 | if a != -b: 41 | assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' 42 | def sampler_(tensor, scale=1): 43 | _no_grad_uniform_(tensor, a * scale, b * scale) 44 | return constant_std_init_(tensor, sampler_) 45 | 46 | def normal_(tensor, mean=0, std=1): 47 | '''Drop-in replacement of `torch.nn.init.normal_`. 48 | Note: 49 | - if using this function, ensure `mean` and `std` do not depend on 50 | fan-in, fan-out, or other notions of width, e.g. if mean = 0, std = 51 | 1. 52 | - `tensor` should have `infshape` attribute set by `set_base_shapes`. 53 | ''' 54 | if mean != 0: 55 | assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' 56 | def sampler_(tensor, scale=1): 57 | _no_grad_normal_(tensor, mean=mean*scale, std=std*scale) 58 | return constant_std_init_(tensor, sampler_) 59 | 60 | def ones_(tensor): 61 | '''Same as `torch.nn.init.ones_`. 62 | Note: 63 | - `tensor` should have `infshape` attribute set by `set_base_shapes`. 64 | ''' 65 | assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' 66 | def sampler_(tensor, scale=1): 67 | _no_grad_fill_(tensor, scale) 68 | return constant_std_init_(tensor, sampler_) 69 | 70 | def eye_(tensor): 71 | '''Same as `torch.nn.init.eye_`. 72 | Note: 73 | - `tensor` should have `infshape` attribute set by `set_base_shapes`. 74 | ''' 75 | assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' 76 | return torch.nn.init.eye_(tensor) 77 | 78 | 79 | def _inf_fan_adjust_xavier(scale, tensor): 80 | fan_out, fan_in = tensor.infshape[:2] 81 | # following are needed to accomodate SP models where all infshapes are finite so base_dims are Nones 82 | fan_out_base_dim = fan_out.base_dim or fan_out.dim 83 | fan_in_base_dim = fan_in.base_dim or fan_in.dim 84 | scale *= math.sqrt( 85 | (fan_out.dim + fan_in.dim) 86 | / (fan_out_base_dim + fan_in_base_dim)) 87 | if tensor.infshape.ninf() <= 1: 88 | # should have fixed scale 89 | pass 90 | elif tensor.infshape.ninf() == 2: 91 | # should scale like fanin 92 | assert fan_out.isinf() and fan_in.isinf() 93 | scale /= math.sqrt(fan_in.width_mult()) 94 | else: 95 | raise NotImplementedError('can only handle 2 inf dimensions currently') 96 | return scale 97 | 98 | 99 | def xavier_uniform_(tensor, gain=1.): 100 | '''Drop-in replacement of `torch.nn.init.xavier_uniform_`. 101 | Note: 102 | - if using this function, ensure `gain` does not depend on fan-in, 103 | fan-out, or other notions of width, e.g. if gain = 1. 104 | - `tensor` should have `infshape` attribute set by `set_base_shapes`. 105 | ''' 106 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 107 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 108 | std = _inf_fan_adjust_xavier(std, tensor) 109 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 110 | return _no_grad_uniform_(tensor, -a, a) 111 | 112 | 113 | def xavier_normal_(tensor, gain=1.): 114 | '''Drop-in replacement of `torch.nn.init.xavier_normal_`. 115 | Note: 116 | - if using this function, ensure `gain` does not depend on fan-in, 117 | fan-out, or other notions of width, e.g. if gain = 1. 118 | - `tensor` should have `infshape` attribute set by `set_base_shapes`. 119 | ''' 120 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 121 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 122 | std = _inf_fan_adjust_xavier(std, tensor) 123 | return _no_grad_normal_(tensor, 0., std) 124 | 125 | 126 | def _inf_fan_adjust_kaiming(scale, tensor, mode): 127 | fan_out, fan_in = tensor.infshape[:2] 128 | if tensor.infshape.ninf() == 0: 129 | return scale 130 | elif tensor.infshape.ninf() == 1: 131 | # should have fixed scale 132 | if mode == 'fan_in' and fan_in.isinf(): 133 | scale *= fan_in.width_mult()**0.5 134 | elif mode == 'fan_out' and fan_out.isinf(): 135 | scale *= fan_out.width_mult()**0.5 136 | elif tensor.infshape.ninf() == 2: 137 | # should scale like fanin 138 | assert fan_out.isinf() and fan_in.isinf() 139 | if mode == 'fan_out': 140 | scale *= math.sqrt(fan_out.width_mult() / fan_in.width_mult()) 141 | else: 142 | raise NotImplementedError('can only handle <=2 inf dimensions currently') 143 | return scale 144 | 145 | def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): 146 | '''Drop-in replacement of `torch.nn.init.kaiming_normal_`. 147 | Note: 148 | - if using this function, ensure `a` does not depend on fan-in, 149 | fan-out, or other notions of width, e.g. if a = 0. 150 | - `tensor` should have `infshape` attribute set by `set_base_shapes`. 151 | ''' 152 | if 0 in tensor.shape: 153 | warnings.warn("Initializing zero-element tensors is a no-op") 154 | return tensor 155 | fan = _calculate_correct_fan(tensor, mode) 156 | gain = calculate_gain(nonlinearity, a) 157 | std = _inf_fan_adjust_kaiming(gain / math.sqrt(fan), tensor, mode) 158 | with torch.no_grad(): 159 | return tensor.normal_(0, std) 160 | 161 | 162 | def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): 163 | '''Drop-in replacement of `torch.nn.init.kaiming_uniform_`. 164 | Note: 165 | - if using this function, ensure `a` does not depend on fan-in, 166 | fan-out, or other notions of width, e.g. if a = 0. 167 | - `tensor` should have `infshape` attribute set by `set_base_shapes`. 168 | ''' 169 | if 0 in tensor.shape: 170 | warnings.warn("Initializing zero-element tensors is a no-op") 171 | return tensor 172 | fan = _calculate_correct_fan(tensor, mode) 173 | gain = calculate_gain(nonlinearity, a) 174 | std = _inf_fan_adjust_kaiming(gain / math.sqrt(fan), tensor, mode) 175 | bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 176 | with torch.no_grad(): 177 | return tensor.uniform_(-bound, bound) 178 | 179 | 180 | try: 181 | from torch.nn.init import _no_grad_trunc_normal_ 182 | def trunc_normal_(tensor, mean=0, std=1, a=-2, b=2): 183 | '''Drop-in replacement of `torch.nn.init.trunc_normal_`. 184 | Note: 185 | - if using this function, ensure `mean`, `std`, `a`, `b` do not 186 | depend on fan-in, fan-out, or other notions of width, e.g. if 187 | mean = 0, std = 1, a = -2, b = 2. 188 | - `tensor` should have `infshape` attribute set by 189 | `set_base_shapes`. 190 | ''' 191 | if mean != 0 or a != -b: 192 | assert tensor.infshape.ninf() == 1, 'Sampler for (inf, inf) tensors should have mean 0' 193 | def sampler_(tensor, scale=1): 194 | _no_grad_trunc_normal_(tensor, mean=mean*scale, std=std*scale, a=a*scale, b=b*scale) 195 | return constant_std_init_(tensor, sampler_) 196 | except: 197 | warnings.warn( 198 | 'Failed to import _no_grad_trunc_normal_ from torch.nn.init; ' 199 | 'you might be running an older version of torch. trunc_normal_ will not work.') 200 | def trunc_normal_(tensor, mean=0, std=1, a=-2, b=2): 201 | warnings.warn('Please upgrade your Pytorch version before using truncated normal.') 202 | pass 203 | -------------------------------------------------------------------------------- /mup/layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Microsoft Corporation. 2 | from torch.nn import Linear 3 | 4 | 5 | class MuReadout(Linear): 6 | '''Drop-in replacement for all output linear layers. 7 | 8 | An "output" linear layer is one that maps from a width dimension (e.g., 9 | `d_model` in a Transformer) to a non-width dimension (e.g., vocab size). 10 | 11 | This layer implements the version of μP with a 1/width multiplier and a 12 | constant variance initialization for both weights and biases. 13 | ''' 14 | def __init__(self, *args, readout_zero_init=False, output_mult=1.0, **kwargs): 15 | self.output_mult = output_mult 16 | self.readout_zero_init = readout_zero_init 17 | super().__init__(*args, **kwargs) 18 | 19 | def reset_parameters(self) -> None: 20 | if self.readout_zero_init: 21 | self.weight.data[:] = 0 22 | if self.bias is not None: 23 | self.bias.data[:] = 0 24 | else: 25 | super().reset_parameters() 26 | 27 | def width_mult(self): 28 | assert hasattr(self.weight, 'infshape'), ( 29 | 'Please call set_base_shapes(...). If using torch.nn.DataParallel, ' 30 | 'switch to distributed training with ' 31 | 'torch.nn.parallel.DistributedDataParallel instead' 32 | ) 33 | return self.weight.infshape.width_mult() 34 | 35 | def _rescale_parameters(self): 36 | '''Rescale parameters to convert SP initialization to μP initialization. 37 | 38 | Warning: This method is NOT idempotent and should be called only once 39 | unless you know what you are doing. 40 | ''' 41 | if hasattr(self, '_has_rescaled_params') and self._has_rescaled_params: 42 | raise RuntimeError( 43 | "`_rescale_parameters` has been called once before already. " 44 | "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" 45 | "If you called `set_base_shapes` on a model loaded from a checkpoint, " 46 | "or just want to re-set the base shapes of an existing model, " 47 | "make sure to set the flag `rescale_params=False`.\n" 48 | "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call.") 49 | if self.bias is not None: 50 | self.bias.data *= self.width_mult()**0.5 51 | self.weight.data *= self.width_mult()**0.5 52 | self._has_rescaled_params = True 53 | 54 | def forward(self, x): 55 | return super().forward( 56 | self.output_mult * x / self.width_mult()) 57 | 58 | 59 | class MuSharedReadout(MuReadout): 60 | '''`MuReadout` with weights shared with an `nn.Embedding` layer. 61 | 62 | Inputs: 63 | weight: should be weight of an `nn.Embedding` layer 64 | other inputs are fed to `MuReadout` 65 | ''' 66 | def __init__(self, weight, bias=True, **kwargs): 67 | super().__init__(*weight.shape, bias=bias, **kwargs) 68 | self.weight = weight 69 | 70 | def rescale_linear_bias(linear): 71 | '''Rescale bias in nn.Linear layers to convert SP initialization to μP initialization. 72 | 73 | Warning: This method is NOT idempotent and should be called only once 74 | unless you know what you are doing. 75 | ''' 76 | if hasattr(linear, '_has_rescaled_params') and linear._has_rescaled_params: 77 | raise RuntimeError("`rescale_linear_bias` has been called once before already. Unless you know what you are doing, usually you should not be calling `rescale_linear_bias` more than once.\n" 78 | "If you called `set_base_shapes` on a model loaded from a checkpoint, or just want to re-set the base shapes of an existing model, make sure to set the flag `rescale_params=False`.\n" 79 | "To bypass this error and *still rescale biases*, set `linear._has_rescaled_params=False` before this call.") 80 | if linear.bias is None: 81 | return 82 | fanin_mult = linear.weight.infshape[1].width_mult() 83 | linear.bias.data *= fanin_mult**0.5 84 | linear._has_rescaled_params = True 85 | -------------------------------------------------------------------------------- /mup/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Microsoft Corporation. 2 | ''' 3 | Optimizers with μP scaling. 4 | 5 | Here we provide 3 ready-to-go optimizers MuAdam, MuAdamW, and MuSGD. 6 | However, the user can easily convert their own optimizer to a μP 7 | optimizer: if your `optimizer` is "Adam-like", such as RMSProp and Adagrad, 8 | that involves normalizing the gradient entrywise, then the following creates 9 | the desired μP optimizer: 10 | 11 | def MuOptimizer(params, **kwargs): 12 | return MuAdam(params, impl=optimizer, **kwargs) 13 | 14 | On the other hand, if your `optimizer` is "SGD-like", such as ASGD, then 15 | the following creates the desired μP optimizer: 16 | 17 | def MuOptimizer(params, **kwargs): 18 | return MuSGD(params, impl=optimizer, **kwargs) 19 | 20 | See Appendix B in our paper for discussions of other optimizers. 21 | ''' 22 | from collections import defaultdict 23 | 24 | from torch.optim import SGD, Adam, AdamW 25 | 26 | 27 | def process_param_groups(params, **kwargs): 28 | param_groups = list(params) 29 | if not isinstance(param_groups[0], dict): 30 | param_groups = [{'params': param_groups}] 31 | for param_group in param_groups: 32 | if 'lr' not in param_group: 33 | param_group['lr'] = kwargs['lr'] 34 | if 'weight_decay' not in param_group: 35 | param_group['weight_decay'] = kwargs.get('weight_decay', 0.) 36 | return param_groups 37 | 38 | def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs): 39 | '''Adam with μP scaling. 40 | 41 | Note for this to work properly, your model needs to have its base shapes set 42 | already using `mup.set_base_shapes`. 43 | 44 | Inputs: 45 | impl: the specific Adam-like optimizer implementation from torch.optim or 46 | elsewhere 47 | decoupled_wd: if True, skips the mup scaling for weight decay, which should 48 | be used for optimizer implementations that decouple weight decay from 49 | learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. 50 | Outputs: 51 | An instance of `impl` with refined parameter groups, each of which has the correctly 52 | scaled learning rate according to mup. 53 | ''' 54 | new_param_groups = [] 55 | for param_group in process_param_groups(params, **kwargs): 56 | # For every existing param group, we split into several new groups 57 | def new_group(): 58 | new_g = {k:v for k, v in param_group.items() if k != 'params'} 59 | new_g['params'] = [] 60 | return new_g 61 | # The matrix-like weights might need multiple groups since weights 62 | # might have different width multipliers 63 | matrix_like_p = defaultdict(new_group) # key is width_mult 64 | vector_like_p = new_group() 65 | for p in param_group['params']: 66 | assert hasattr(p, 'infshape'), ( 67 | f'A parameter with shape {p.shape} does not have `infshape` attribute. ' 68 | 'Did you forget to call `mup.set_base_shapes` on the model?') 69 | if p.infshape.ninf() == 2: 70 | matrix_like_p[p.infshape.width_mult()]['params'].append(p) 71 | elif p.infshape.ninf() > 2: 72 | raise NotImplementedError('more than 2 inf dimensions') 73 | else: 74 | vector_like_p['params'].append(p) 75 | for width_mult, group in matrix_like_p.items(): 76 | # Scale learning rate and weight decay accordingly 77 | group['lr'] /= width_mult 78 | if not decoupled_wd: 79 | group['weight_decay'] *= width_mult 80 | new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p]) 81 | return impl(new_param_groups, **kwargs) 82 | 83 | def MuAdamW(params, **kwargs): 84 | '''AdamW with μP scaling. 85 | 86 | Note for this to work properly, your model needs to have its base shapes set 87 | already using `mup.set_base_shapes`. 88 | ''' 89 | return MuAdam(params, impl=AdamW, **kwargs) 90 | 91 | def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs): 92 | '''SGD with μP scaling. 93 | 94 | Note for this to work properly, your model needs to have its base shapes set 95 | already using `mup.set_base_shapes`. 96 | 97 | Inputs: 98 | impl: the specific SGD-like optimizer implementation from torch.optim or 99 | elsewhere 100 | decoupled_wd: if True, skips the mup scaling for weight decay, which should 101 | be used for optimizer implementations that decouple weight decay from 102 | learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. 103 | Outputs: 104 | An instance of `impl` with refined parameter groups, each of which has the correctly 105 | scaled learning rate according to mup. 106 | ''' 107 | new_param_groups = [] 108 | for param_group in process_param_groups(params, **kwargs): 109 | # For every existing param group, we split into several new groups 110 | def new_group(): 111 | new_g = {k:v for k, v in param_group.items() if k != 'params'} 112 | new_g['params'] = [] 113 | return new_g 114 | # The matrix-like weights might need multiple groups since weights 115 | # might have different width multipliers 116 | vector_like_p = defaultdict(new_group) # key is width mult 117 | matrix_like_p = defaultdict(new_group) # key is fan_in/out ratio 118 | fixed_p = new_group() 119 | for p in param_group['params']: 120 | assert hasattr(p, 'infshape'), ( 121 | f'A parameter with shape {p.shape} does not have `infshape` attribute. ' 122 | 'Did you forget to call `mup.set_base_shapes` on the model?') 123 | if p.infshape.ninf() == 1: 124 | vector_like_p[p.infshape.width_mult()]['params'].append(p) 125 | elif p.infshape.ninf() == 2: 126 | matrix_like_p[p.infshape.fanin_fanout_mult_ratio()]['params'].append(p) 127 | elif p.infshape.ninf() > 2: 128 | raise NotImplementedError('more than 2 inf dimensions') 129 | else: 130 | fixed_p['params'].append(p) 131 | for width_mult, group in vector_like_p.items(): 132 | # Scale learning rate and weight decay accordingly 133 | group['lr'] *= width_mult 134 | if not decoupled_wd: 135 | group['weight_decay'] /= width_mult 136 | for shape_ratio, group in matrix_like_p.items(): 137 | group['lr'] /= shape_ratio 138 | if not decoupled_wd: 139 | group['weight_decay'] *= shape_ratio 140 | new_param_groups.extend(list(matrix_like_p.values()) + \ 141 | list(vector_like_p.values()) + [fixed_p]) 142 | return impl(new_param_groups, **kwargs) 143 | -------------------------------------------------------------------------------- /mup/shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Microsoft Corporation. 2 | from copy import deepcopy 3 | 4 | import yaml 5 | from torch import nn 6 | from torch.nn import Linear 7 | from torch.nn.modules.conv import _ConvNd 8 | 9 | from mup.infshape import InfShape, zip_infshape 10 | from mup.layer import MuReadout, rescale_linear_bias 11 | 12 | __BSH_COMMENT__ = '''\ 13 | # This is a base shape file encoded in yaml 14 | # - `null` indicates a dimension is "finite", i.e. a non-"width" dimension 15 | # - a number indicates the base dimension of an "infinite" dimension, i.e. some notion of "width" 16 | ''' 17 | 18 | def get_shapes(model): 19 | # If you want to implement a custom shapes function, you can use this name 20 | if hasattr(model, "get_shapes"): 21 | return model.get_shapes() 22 | return {name: param.shape for name, param in model.named_parameters()} 23 | 24 | def get_infshapes(model): 25 | return {name: param.infshape for name, param in model.named_parameters()} 26 | 27 | def save_base_shapes(model_or_shapes, file): 28 | if isinstance(model_or_shapes, nn.Module): 29 | sh = get_infshapes(model_or_shapes) 30 | elif isinstance(model_or_shapes, dict): 31 | sh = deepcopy(model_or_shapes) 32 | else: 33 | raise ValueError() 34 | sh = {k: s.base_shape() for k, s in sh.items()} 35 | s = yaml.dump(sh, None, indent=4) 36 | s = __BSH_COMMENT__ + s 37 | with open(file, 'w') as f: 38 | f.write(s) 39 | 40 | def load_base_shapes(filename): 41 | '''Get a dict of `InfShape` from a filename.''' 42 | with open(filename, 'r') as f: 43 | d = yaml.safe_load(f) 44 | return {k: InfShape.from_base_shape(v) for k, v in d.items()} 45 | 46 | def _dataparallel_hack(base_shapes, shapes): 47 | '''Fix module name discrepancy caused by (Distributed)DataParallel module. 48 | 49 | The parameters of a (Distributed)DataParallel module all have names that 50 | start with 'module'. This causes a mismatch from non-DataParallel modules. 51 | This function tries to match `base_shapes` to `shapes`: if the latter starts 52 | with 'module', then make the former too; likewise if not. 53 | ''' 54 | if all(k.startswith('module.') for k in shapes) and \ 55 | all(not k.startswith('module.') for k in base_shapes): 56 | return {'module.' + k: v for k, v in base_shapes.items()}, shapes 57 | if all(not k.startswith('module.') for k in shapes) and \ 58 | all(k.startswith('module.') for k in base_shapes): 59 | return {k.strip('module.'): v for k, v in base_shapes.items()}, shapes 60 | return base_shapes, shapes 61 | 62 | 63 | def _extract_shapes(x): 64 | ''' 65 | Input: 66 | x: can be any of the following: 67 | - `nn.Module` 68 | - dict of shapes 69 | - dict of `InfShape` 70 | - str of path to a base shapes (.bsh) file 71 | Output: 72 | If `x` is dict of `InfShape`, then output itself. 73 | If `x` is path, then output a dict of `InfShapes` loaded from `x`. 74 | Else, output the shapes (not `InfShape`) associated to `x` 75 | ''' 76 | if isinstance(x, nn.Module): 77 | x_shapes = get_shapes(x) 78 | elif isinstance(x, dict): 79 | x_shapes = deepcopy(x) 80 | elif isinstance(x, str): 81 | # x is file name 82 | x_shapes = load_base_shapes(x) 83 | else: 84 | raise ValueError(f'unhandled x type: {type(x)}') 85 | return x_shapes 86 | 87 | def _zip_infshape_dict(base_shapes, shapes): 88 | '''make a dict of `InfShape` from two dicts of shapes. 89 | Inputs: 90 | base_shapes: dict of base shapes or InfShape objects 91 | shapes: dict of shapes 92 | Output: 93 | dict of `InfShape` using `zip_infshape` 94 | ''' 95 | base_shapes, shapes = _dataparallel_hack(base_shapes, shapes) 96 | basenames = set(base_shapes.keys()) 97 | names = set(shapes.keys()) 98 | assert basenames == names, ( 99 | f'`base_shapes` has extra names {basenames - names}. ' 100 | f'`shapes` has extra names {names - basenames}.' 101 | ) 102 | infshapes = {} 103 | for name, bsh in base_shapes.items(): 104 | infshapes[name] = zip_infshape(bsh, shapes[name]) 105 | return infshapes 106 | 107 | def zip_infshapes(base, target): 108 | '''make a dict of `InfShape` from models or dicts. 109 | Inputs: 110 | base: a base `nn.Module` or a dict of shapes 111 | target: a target `nn.Module` or a dict of shapes 112 | Output: 113 | dict of `InfShape` using `zip_infshape` 114 | ''' 115 | base_shapes = _extract_shapes(base) 116 | target_shapes = _extract_shapes(target) 117 | return _zip_infshape_dict(base_shapes, target_shapes) 118 | 119 | def clear_dims(infshape_dict): 120 | ''' 121 | Input: 122 | infshape_dict: dict of `InfShape` 123 | Output: 124 | the same dict but where all `InfDim` in all `InfShape` 125 | have their `dim` attribute set to None 126 | ''' 127 | d = deepcopy(infshape_dict) 128 | for _, v in d.items(): 129 | for infdim in v: 130 | infdim.dim = None 131 | return d 132 | 133 | def make_base_shapes(base_shapes, delta_shapes, savefile=None): 134 | '''Make a base shape object from a base model/shapes and a delta model/shapes. 135 | 136 | Inputs: 137 | base: 138 | a base `nn.Module` or a dict of shapes 139 | delta: 140 | a "delta" model or a dict of shapes, for the sole purpose of 141 | determining which dimensions are "width" and will be scaled up and 142 | down in the target model. 143 | savefile: 144 | if a string, then the resulting base shape object is serialized to 145 | this location via yaml encoding. 146 | Outputs: 147 | base infshapes 148 | ''' 149 | bsh = clear_dims(zip_infshapes(base_shapes, delta_shapes)) 150 | if savefile is not None: 151 | save_base_shapes(bsh, savefile) 152 | return bsh 153 | 154 | 155 | def apply_infshapes(model, infshapes): 156 | for name, p in model.named_parameters(): 157 | p.infshape = infshapes[name] 158 | 159 | def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None, do_assert=True): 160 | '''Sets the `p.infshape` attribute for each parameter `p` of `model`. 161 | 162 | Inputs: 163 | model: nn.Module instance 164 | base: The base model. 165 | Can be nn.Module, a dict of shapes, a str, or None. 166 | If None, then defaults to `model` 167 | If str, then treated as filename for yaml encoding of a dict of base shapes. 168 | rescale_params: 169 | assuming the model is initialized using the default pytorch init (or 170 | He initialization etc that scale the same way with fanin): If True 171 | (default), rescales parameters to have the correct (μP) variances. 172 | do_assert: 173 | Output: 174 | same object as `model`, after setting the `infshape` attribute of each parameter. 175 | ''' 176 | if base is None: 177 | base = model 178 | base_shapes = _extract_shapes(base) 179 | if delta is not None: 180 | delta_shapes = _extract_shapes(delta) 181 | base_shapes = _zip_infshape_dict(base_shapes, delta_shapes) 182 | shapes = get_shapes(model) 183 | infshapes = _zip_infshape_dict(base_shapes, shapes) 184 | if savefile is not None: 185 | save_base_shapes(infshapes, savefile) 186 | apply_infshapes(model, infshapes) 187 | if do_assert: 188 | assert_hidden_size_inf(model) 189 | if rescale_params: 190 | for name, module in model.named_modules(): 191 | if isinstance(module, MuReadout): 192 | module._rescale_parameters() 193 | elif isinstance(module, (Linear, _ConvNd)): 194 | rescale_linear_bias(module) 195 | return model 196 | 197 | def assert_hidden_size_inf(model): 198 | ''' 199 | This tests for any `nn.Linear` whose output dimension is finite but input 200 | dimension is infinite and is not of type `MuReadout`. Such `nn.Linear` 201 | modules should not exist in a correctly parametrized models. 202 | ''' 203 | for name, module in model.named_modules(): 204 | if isinstance(module, Linear) and not isinstance(module, MuReadout): 205 | if not module.weight.infshape[0].isinf() and module.weight.infshape[1].isinf(): 206 | assert False, ( 207 | f'{name} has infinite fan-in and finite fan-out dimensions but is not type `MuReadout`. ' 208 | 'To resolve this, either change the module to `MuReadout` or change the fan-out to an infinite dimension.' 209 | ) 210 | -------------------------------------------------------------------------------- /mup/test/__main__.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import unittest 3 | from functools import partial 4 | from itertools import cycle 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import torch.nn.functional as F 10 | from mup.coord_check import get_coord_data 11 | from mup.optim import MuAdam, MuSGD 12 | from mup.shape import get_infshapes, get_shapes, make_base_shapes, set_base_shapes 13 | from mup.test.models import (generate_CNN, generate_MLP, _generate_MLP, get_lazy_models, 14 | get_train_loader, init_methods) 15 | 16 | train_loader = get_train_loader(batch_size=32, num_workers=4, download=True) 17 | 18 | def reset_seed(): 19 | torch.manual_seed(0) 20 | 21 | class SetBaseShapeCase(unittest.TestCase): 22 | mlp_base_shapes_file = 'mlp64.bsh.test' 23 | 24 | def get_mlp_infshapes1(self): 25 | base_model = _generate_MLP(64, True, True, True) 26 | delta_model = _generate_MLP(65, True, True, True) 27 | target_model = _generate_MLP(128, True, True, True) 28 | set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file) 29 | return get_infshapes(target_model) 30 | 31 | def get_mlp_infshapes1meta(self): 32 | base_model = _generate_MLP(64, True, True, True, device='meta') 33 | delta_model = _generate_MLP(65, True, True, True, device='meta') 34 | target_model = _generate_MLP(128, True, True, True) 35 | set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file) 36 | return get_infshapes(target_model) 37 | 38 | def get_mlp_infshapes2(self): 39 | target_model = _generate_MLP(128, True, True, True) 40 | set_base_shapes(target_model, self.mlp_base_shapes_file) 41 | return get_infshapes(target_model) 42 | 43 | def get_mlp_infshapes3(self): 44 | base_model = _generate_MLP(64, True, True, True) 45 | delta_model = _generate_MLP(65, True, True, True) 46 | base_infshapes = make_base_shapes(base_model, delta_model) 47 | target_model = _generate_MLP(128, True, True, True) 48 | set_base_shapes(target_model, base_infshapes) 49 | return get_infshapes(target_model) 50 | 51 | def get_mlp_infshapes3meta(self): 52 | base_model = _generate_MLP(64, True, True, True, device='meta') 53 | delta_model = _generate_MLP(65, True, True, True, device='meta') 54 | base_infshapes = make_base_shapes(base_model, delta_model) 55 | target_model = _generate_MLP(128, True, True, True) 56 | set_base_shapes(target_model, base_infshapes) 57 | return get_infshapes(target_model) 58 | 59 | def get_mlp_infshapes4(self): 60 | base_model = _generate_MLP(64, True, True, True) 61 | delta_model = _generate_MLP(65, True, True, True) 62 | target_model = _generate_MLP(128, True, True, True) 63 | set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model)) 64 | return get_infshapes(target_model) 65 | 66 | def get_mlp_infshapes4meta(self): 67 | base_model = _generate_MLP(64, True, True, True) 68 | delta_model = _generate_MLP(65, True, True, True, device='meta') 69 | target_model = _generate_MLP(128, True, True, True, device='meta') 70 | set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model)) 71 | return get_infshapes(target_model) 72 | 73 | def get_mlp_infshapes5(self): 74 | delta_model = _generate_MLP(65, True, True, True) 75 | target_model = _generate_MLP(128, True, True, True) 76 | # `delta` here doesn't do anything because of base shape file 77 | set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model)) 78 | return get_infshapes(target_model) 79 | 80 | def get_mlp_infshapes5meta(self): 81 | delta_model = _generate_MLP(65, True, True, True, device='meta') 82 | target_model = _generate_MLP(128, True, True, True) 83 | # `delta` here doesn't do anything because of base shape file 84 | set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model)) 85 | return get_infshapes(target_model) 86 | 87 | def get_mlp_infshapes_bad(self): 88 | base_model = _generate_MLP(64, True, True, True) 89 | target_model = _generate_MLP(128, True, True, True) 90 | set_base_shapes(target_model, base_model, delta=base_model) 91 | return get_infshapes(target_model) 92 | 93 | def test_set_base_shape(self): 94 | self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes1meta()) 95 | self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes2()) 96 | self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes2()) 97 | self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes4()) 98 | self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes3meta()) 99 | self.assertEqual(self.get_mlp_infshapes4(), self.get_mlp_infshapes4meta()) 100 | self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes4()) 101 | self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes5meta()) 102 | self.assertNotEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes_bad()) 103 | 104 | 105 | class BackwardCompatibleCase(unittest.TestCase): 106 | 107 | def gen_model(self, arch, width, batchnorm=False, mup=True): 108 | if arch == 'mlp': 109 | return generate_MLP(width=width, batchnorm=batchnorm, readout_zero_init=False, base_width=256, mup=mup) 110 | elif arch == 'cnn': 111 | return generate_CNN(width=width, batchnorm=batchnorm, readout_zero_init=False, base_width=8, mup=mup) 112 | else: 113 | raise ValueError() 114 | 115 | def test_MLP_CNN_at_base_width(self): 116 | for arch, batchnorm in itertools.product(['mlp', 'cnn'], [False, True]): 117 | for init_name, init in init_methods.items(): 118 | reset_seed() 119 | mup_model = self.gen_model('mlp', 256, mup=True, batchnorm=batchnorm) 120 | reset_seed() 121 | init(mup_model) 122 | reset_seed() 123 | SP_model = self.gen_model('mlp', 256, mup=False, batchnorm=batchnorm) 124 | reset_seed() 125 | init(SP_model) 126 | for (name, mup_param), (_, SP_param) in zip( 127 | mup_model.named_parameters(), SP_model.named_parameters()): 128 | with self.subTest(name=f'{arch}, {name}, {init_name}, bn={batchnorm}'): 129 | self.assertEqual((mup_param.data - SP_param.data).abs().sum().item(), 0) 130 | 131 | def test_MLP_at_diff_width_init(self): 132 | for init_name, init in init_methods.items(): 133 | reset_seed() 134 | mup_model = self.gen_model('mlp', 128, mup=True) 135 | reset_seed() 136 | init(mup_model) 137 | reset_seed() 138 | SP_model = self.gen_model('mlp', 128, mup=False) 139 | reset_seed() 140 | init(SP_model) 141 | 142 | mup_params = dict(mup_model.named_parameters()) 143 | SP_params = dict(SP_model.named_parameters()) 144 | 145 | if init_name == 'default' or 'fan_in' in init_name: 146 | diff_names = ['2.bias', '4.bias', '4.weight'] 147 | same_names = ['0.weight', '0.bias', '2.weight'] 148 | elif 'fan_out' in init_name: 149 | diff_names = ['2.bias', '4.bias', '0.weight'] 150 | same_names = ['4.weight', '0.bias', '2.weight'] 151 | elif 'xavier' in init_name: 152 | diff_names = ['2.bias', '4.bias', '0.weight', '4.weight'] 153 | same_names = ['0.bias', '2.weight'] 154 | elif 'const' in init_name: 155 | diff_names = ['2.bias', '4.bias', '2.weight'] 156 | same_names = ['0.weight', '0.bias', '4.weight'] 157 | else: 158 | raise ValueError() 159 | 160 | for name in diff_names: 161 | with self.subTest(name=f'{name}, {init_name}'): 162 | self.assertNotEqual( 163 | (mup_params[name] - SP_params[name]).abs().sum().item(), 0) 164 | for name in same_names: 165 | with self.subTest(name=f'{name}, {init_name}'): 166 | self.assertEqual( 167 | (mup_params[name] - SP_params[name]).abs().sum().item(), 0) 168 | 169 | def test_CNN_at_diff_width_init(self): 170 | for init_name, init in init_methods.items(): 171 | reset_seed() 172 | mup_model = self.gen_model('cnn', 16, mup=True) 173 | reset_seed() 174 | init(mup_model) 175 | reset_seed() 176 | SP_model = self.gen_model('cnn', 16, mup=False) 177 | reset_seed() 178 | init(SP_model) 179 | 180 | mup_params = dict(mup_model.named_parameters()) 181 | SP_params = dict(SP_model.named_parameters()) 182 | 183 | if init_name == 'default' or 'fan_in' in init_name: 184 | diff_names = ['3.bias', '7.bias', '9.bias', '11.bias', '11.weight'] 185 | same_names = ['0.bias', '0.weight', '3.weight', '7.weight', '9.weight'] 186 | elif 'fan_out' in init_name: 187 | diff_names = ['3.bias', '7.bias', '9.bias', '11.bias', '0.weight'] 188 | same_names = ['0.bias', '3.weight', '7.weight', '9.weight', '11.weight'] 189 | elif 'xavier' in init_name: 190 | diff_names = ['3.bias', '7.bias', '9.bias', '11.bias', '0.weight', '11.weight'] 191 | same_names = ['0.bias', '3.weight', '7.weight', '9.weight'] 192 | elif 'const' in init_name: 193 | diff_names = ['3.bias', '7.bias', '9.bias', '11.bias', '3.weight', '7.weight', '9.weight'] 194 | same_names = ['0.bias', '0.weight', '11.weight'] 195 | else: 196 | raise ValueError() 197 | 198 | for name in diff_names: 199 | with self.subTest(name=f'{name}, {init_name}'): 200 | self.assertNotEqual( 201 | (mup_params[name] - SP_params[name]).abs().sum().item(), 0) 202 | for name in same_names: 203 | with self.subTest(name=f'{name}, {init_name}'): 204 | self.assertEqual( 205 | (mup_params[name] - SP_params[name]).abs().sum().item(), 0) 206 | 207 | def train_model(model, train_loader, step=-1, optcls=MuSGD, lr=0.1, flatten_input=False, cuda=True): 208 | model.train() 209 | train_loss = 0 210 | train_losses = [] 211 | optimizer = optcls(model.parameters(), lr=lr) 212 | for batch_idx, (data, target) in enumerate(cycle(iter(train_loader)), 1): 213 | if cuda: 214 | data, target = data.cuda(), target.cuda() 215 | optimizer.zero_grad() 216 | if flatten_input: 217 | data = data.view(data.size(0), -1) 218 | output = model(data) 219 | loss = F.cross_entropy(output, target) 220 | loss.backward() 221 | train_loss += loss.item() 222 | train_losses.append(train_loss / batch_idx) 223 | optimizer.step() 224 | if batch_idx == step: break 225 | # train_loss /= batch_idx 226 | return train_losses 227 | 228 | train_model_MuSGD = partial(train_model, optcls=MuSGD, lr=0.1) 229 | train_model_MuAdam = partial(train_model, optcls=MuAdam, lr=1e-3) 230 | 231 | class CoordCheckCase(unittest.TestCase): 232 | 233 | def test_MLP_CNN(self): 234 | combos = list(itertools.product(['mlp', 'cnn'], [True], [False, True], ['sgd', 'adam'], init_methods.keys())) 235 | # comment out the following 2 lines to do all tests 236 | idx = np.random.choice(np.arange(len(combos)), size=10) 237 | combos = np.array(combos)[idx] 238 | for arch, mup, batchnorm, optimizer, init in combos: 239 | widths = [128, 512] if arch == 'cnn' else [1000, 4000] 240 | models = get_lazy_models(arch, widths, mup=mup, batchnorm=batchnorm, init=init) 241 | df = get_coord_data(models, train_loader, mup=mup, optimizer=optimizer, flatten_input=arch == 'mlp') 242 | df = df[df.module != ''] 243 | df['module'] = pd.to_numeric(df['module']) 244 | for t, module in itertools.product([1, 2, 3], df['module'].unique()): 245 | with self.subTest( 246 | name=f'{arch}, mup={mup}, bn={batchnorm}, {optimizer}, {init}, t={t}, module={module}'): 247 | data = df[(df['module'] == module) & (df['t'] == t)] 248 | std0 = data[data.width==widths[0]]['l1'].unique()[0] 249 | std1 = data[data.width==widths[1]]['l1'].unique()[0] 250 | if t == 1 and module == df['module'].max(): 251 | self.assertTrue(std0 == std1 == 0, 252 | f'output should be 0 due to readout_zero_init: {std0}, {std1}') 253 | else: 254 | tol = 1.2 255 | self.assertGreater(std1/std0, 1/tol, f'{std0}, {std1}') 256 | self.assertLess(std1/std0, tol, f'{std0}, {std1}') 257 | 258 | 259 | class MLPTrainCase(unittest.TestCase): 260 | 261 | def train_adam(self, model, step): 262 | return train_model_MuAdam(model, train_loader, step=step, flatten_input=True) 263 | 264 | def train_sgd(self, model, step): 265 | return train_model_MuSGD(model, train_loader, step=step, flatten_input=True) 266 | 267 | def setUp(self): 268 | self.models = {w: generate_MLP(w, bias=True, readout_zero_init=True, base_width=256, init='kaiming_fan_in_normal', bias_zero_init=True).cuda() for w in [64, 256, 1024]} 269 | 270 | def test_init(self): 271 | stds = {} 272 | for w, model in self.models.items(): 273 | for i, module in enumerate(list(model.modules())[1::2]): 274 | stds[(w, i+1, 'weight')] = module.weight.data.std() 275 | stds[(w, i+1, 'bias')] = module.bias.data.std() 276 | 277 | for w in [64, 256]: 278 | self.assertLess( 279 | torch.abs( 280 | stds[(1024, 1, 'weight')] - stds[(w, 1, 'weight')] 281 | ) / stds[(1024, 1, 'weight')], 3e-3) 282 | # for l in [1, 2]: 283 | # self.assertLess( 284 | # torch.abs( 285 | # stds[(1024, l, 'bias')] - stds[(w, l, 'bias')] 286 | # ) / stds[(1024, l, 'bias')], 1e-1) 287 | self.assertTrue( 288 | stds[(1024, 2, 'weight')] < stds[(256, 2, 'weight')] < stds[(64, 2, 'weight')]) 289 | for w in [64, 256, 1024]: 290 | self.assertEqual(stds[(w, 3, 'weight')], 0) 291 | self.assertEqual(stds[(w, 3, 'bias')], 0) 292 | 293 | def _test_train(self, opt): 294 | loss = {w: getattr(self, f'train_{opt}')(model, 201) for w, model in self.models.items()} 295 | with self.subTest(name=f'{opt}, step 1'): 296 | self.assertTrue( 297 | loss[64][0] == loss[256][0] == loss[1024][0], 298 | {k: v[0] for k, v in loss.items()}) 299 | for t in [100, 200]: 300 | with self.subTest(name=f'{opt}, step {t+1}'): 301 | self.assertTrue( 302 | loss[64][t] > loss[256][t] > loss[1024][t], 303 | {k: v[t] for k, v in loss.items()}) 304 | 305 | def test_sgd(self): 306 | self._test_train('sgd') 307 | 308 | def test_adam(self): 309 | self._test_train('adam') 310 | 311 | class CNNTrainCase(unittest.TestCase): 312 | 313 | def train_adam(self, model, step): 314 | return train_model_MuAdam(model, train_loader, step=step, flatten_input=False) 315 | 316 | def train_sgd(self, model, step): 317 | return train_model_MuSGD(model, train_loader, step=step, flatten_input=False) 318 | 319 | def setUp(self): 320 | self.models = {w: generate_CNN(w, mup=True, bias=True, readout_zero_init=True, base_width=8, init='kaiming_fan_in_normal', bias_zero_init=False).cuda() for w in [8, 32, 128]} 321 | 322 | def test_init(self): 323 | stds = {} 324 | names = [0, 3, 7, 9, 11] 325 | for w, model in self.models.items(): 326 | for i, module in enumerate(model): 327 | if i in names: 328 | stds[(w, i, 'weight')] = module.weight.data.std() 329 | stds[(w, i, 'bias')] = module.bias.data.std() 330 | 331 | for w in [8, 32]: 332 | self.assertLess( 333 | torch.abs( 334 | stds[(128, 0, 'weight')] - stds[(128, 0, 'weight')] 335 | ) / stds[(128, 0, 'weight')], 3e-3) 336 | for name in names[:-1]: 337 | self.assertLess( 338 | torch.abs( 339 | stds[(128, 0, 'bias')] - stds[(w, 0, 'bias')] 340 | ) / stds[(128, 0, 'bias')], 2e-1) 341 | for name in names[1:-1]: 342 | self.assertTrue( 343 | stds[(128, name, 'weight')] < stds[(32, name, 'weight')] < stds[(8, name, 'weight')]) 344 | for w in [8, 32, 128]: 345 | self.assertEqual(stds[(w, 11, 'weight')], 0) 346 | self.assertEqual(stds[(w, 11, 'bias')], 0) 347 | 348 | def _test_train(self, opt): 349 | loss = {w: getattr(self, f'train_{opt}')(model, 201) for w, model in self.models.items()} 350 | with self.subTest(name=f'{opt}, step 1'): 351 | self.assertTrue( 352 | loss[8][0] == loss[32][0] == loss[128][0], 353 | {k: v[0] for k, v in loss.items()}) 354 | for t in [200]: 355 | with self.subTest(name=f'{opt}, step {t+1}'): 356 | losses = {k: v[t] for k, v in loss.items()} 357 | # print(losses) 358 | self.assertTrue( 359 | loss[8][t] > loss[32][t] > loss[128][t], 360 | losses) 361 | 362 | def test_sgd(self): 363 | self._test_train('sgd') 364 | 365 | def test_adam(self): 366 | self._test_train('adam') 367 | 368 | def suite(): 369 | suite = unittest.TestSuite() 370 | suite.addTests(unittest.makeSuite(BackwardCompatibleCase)) 371 | suite.addTests(unittest.makeSuite(MLPTrainCase)) 372 | suite.addTests(unittest.makeSuite(CNNTrainCase)) 373 | suite.addTests(unittest.makeSuite(CoordCheckCase)) 374 | suite.addTests(unittest.makeSuite(SetBaseShapeCase)) 375 | return suite 376 | 377 | if __name__ == '__main__': 378 | runner = unittest.TextTestRunner(failfast=False) 379 | runner.run(suite()) 380 | -------------------------------------------------------------------------------- /mup/test/models.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torchvision import transforms, datasets 4 | from mup.shape import set_base_shapes 5 | from torch import nn 6 | from torch.nn import Linear 7 | from mup.layer import MuReadout 8 | from functools import partial 9 | from mup.init import (kaiming_normal_, kaiming_uniform_, normal_, 10 | trunc_normal_, uniform_, xavier_normal_, 11 | xavier_uniform_) 12 | from torch.nn.modules.conv import _ConvNd 13 | 14 | samplers = { 15 | 'default': lambda x: x, 16 | 'const_uniform': partial(uniform_, a=-0.1, b=0.1), 17 | 'const_normal': partial(normal_, std=0.1), 18 | 'const_trunc_normal': partial(trunc_normal_, std=0.1, a=-0.2, b=0.2), 19 | 'xavier_uniform': xavier_uniform_, 20 | 'xavier_normal': xavier_normal_, 21 | 'kaiming_fan_in_uniform': partial(kaiming_uniform_, mode='fan_in'), 22 | 'kaiming_fan_in_normal': partial(kaiming_normal_, mode='fan_in'), 23 | 'kaiming_fan_out_uniform': partial(kaiming_uniform_, mode='fan_out'), 24 | 'kaiming_fan_out_normal': partial(kaiming_normal_, mode='fan_out') 25 | } 26 | 27 | 28 | def init_model(model, sampler): 29 | for param in model.parameters(): 30 | if len(param.shape) >= 2: 31 | sampler(param) 32 | return model 33 | 34 | init_methods = { 35 | k: partial(init_model, sampler=s) for k, s in samplers.items() 36 | } 37 | 38 | def _generate_MLP(width, bias=True, mup=True, batchnorm=False, device='cpu'): 39 | mods = [Linear(3072, width, bias=bias, device=device), 40 | nn.ReLU(), 41 | Linear(width, width, bias=bias, device=device), 42 | nn.ReLU() 43 | ] 44 | if mup: 45 | mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False, device=device)) 46 | else: 47 | mods.append(Linear(width, 10, bias=bias, device=device)) 48 | if batchnorm: 49 | mods.insert(1, nn.BatchNorm1d(width, device=device)) 50 | mods.insert(4, nn.BatchNorm1d(width, device=device)) 51 | model = nn.Sequential(*mods) 52 | return model 53 | 54 | def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=False, init='default', bias_zero_init=False, base_width=256): 55 | if not mup: 56 | model = _generate_MLP(width, bias, mup, batchnorm) 57 | # set base shapes to model's own shapes, so we get SP 58 | return set_base_shapes(model, None) 59 | # it's important we make `model` first, because of random seed 60 | model = _generate_MLP(width, bias, mup, batchnorm) 61 | base_model = _generate_MLP(base_width, bias, mup, batchnorm, device='meta') 62 | set_base_shapes(model, base_model) 63 | init_methods[init](model) 64 | if readout_zero_init: 65 | readout = list(model.modules())[-1] 66 | readout.weight.data.zero_() 67 | if readout.bias is not None: 68 | readout.bias.data.zero_() 69 | if bias_zero_init: 70 | for module in model.modules(): 71 | if isinstance(module, nn.Linear) and module.bias is not None: 72 | module.bias.data.zero_() 73 | return model 74 | 75 | 76 | def _generate_CNN(width, bias=True, mup=True, batchnorm=False, device='cpu'): 77 | mods = [ 78 | nn.Conv2d(3, width, kernel_size=5, bias=bias, device=device), 79 | nn.ReLU(inplace=True), 80 | nn.MaxPool2d(kernel_size=2, stride=2), 81 | nn.Conv2d(width, 2*width, kernel_size=5, bias=bias, device=device), 82 | nn.ReLU(inplace=True), 83 | nn.MaxPool2d(kernel_size=2, stride=2), 84 | nn.Flatten(), 85 | nn.Linear(2*width*25, width*16, bias=bias, device=device), 86 | nn.ReLU(inplace=True), 87 | nn.Linear(width*16, width*10, bias=bias, device=device), 88 | nn.ReLU(inplace=True), 89 | ] 90 | if mup: 91 | mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False, device=device)) 92 | else: 93 | mods.append(nn.Linear(width*10, 10, bias=bias, device=device)) 94 | if batchnorm: 95 | mods.insert(1, nn.BatchNorm2d(width, device=device)) 96 | mods.insert(5, nn.BatchNorm2d(2*width, device=device)) 97 | mods.insert(10, nn.BatchNorm1d(16*width, device=device)) 98 | mods.insert(13, nn.BatchNorm1d(10*width, device=device)) 99 | return nn.Sequential(*mods) 100 | 101 | def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=False, init='default', bias_zero_init=False, base_width=8): 102 | if not mup: 103 | model = _generate_CNN(width, bias, mup, batchnorm) 104 | # set base shapes to model's own shapes, so we get SP 105 | return set_base_shapes(model, None) 106 | # it's important we make `model` first, because of random seed 107 | model = _generate_CNN(width, bias, mup, batchnorm) 108 | base_model = _generate_CNN(base_width, bias, mup, batchnorm, device='meta') 109 | set_base_shapes(model, base_model) 110 | init_methods[init](model) 111 | if readout_zero_init: 112 | readout = list(model.modules())[-1] 113 | readout.weight.data.zero_() 114 | if readout.bias is not None: 115 | readout.bias.data.zero_() 116 | if bias_zero_init: 117 | for module in model.modules(): 118 | if isinstance(module, (nn.Linear, _ConvNd)) and module.bias is not None: 119 | module.bias.data.zero_() 120 | return model 121 | 122 | def get_lazy_models(arch, widths, mup=True, init='kaiming_fan_in_normal', readout_zero_init=True, batchnorm=True, base_width=None): 123 | '''if mup is False, then `init`, `readout_zero_init`, `base_width` don't matter.''' 124 | if arch == 'mlp': 125 | base_width = base_width or 256 126 | generate = generate_MLP 127 | elif arch == 'cnn': 128 | base_width = base_width or 8 129 | generate = generate_CNN 130 | def gen(w): 131 | def f(): 132 | model = generate(w, mup=mup, init=init, readout_zero_init=readout_zero_init, batchnorm=batchnorm, base_width=base_width) 133 | return model 134 | return f 135 | return {w: gen(w) for w in widths} 136 | 137 | 138 | def get_train_loader(batch_size, num_workers=0, shuffle=False, train=True, download=False): 139 | 140 | transform = transforms.Compose( 141 | [transforms.ToTensor(), 142 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 143 | trainset = datasets.CIFAR10(root='dataset', train=train, 144 | download=download, transform=transform) 145 | return torch.utils.data.DataLoader(trainset, batch_size=batch_size, 146 | shuffle=shuffle, num_workers=num_workers) 147 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.5 2 | pandas>=1.1.2 3 | torch>=1.6.0 4 | torchvision>=0.7.0 5 | seaborn>=0.11.2 6 | tqdm 7 | pyyaml -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | license_files=LICENSE -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="mup", 8 | version="1.0.0", 9 | author="Edward J Hu, Greg Yang", 10 | author_email="edwardjhu@edwardjhu.com, gregyang@microsoft.com", 11 | description="Maximal Update Parametrization", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/microsoft/mup", 15 | download_url="https://github.com/microsoft/mup/archive/refs/tags/v1.0.0.tar.gz", 16 | install_requires=[ 17 | 'numpy', 18 | 'pandas', 19 | 'torch', 20 | 'torchvision', 21 | 'seaborn', 22 | 'tqdm', 23 | 'pyyaml' 24 | ], 25 | packages=setuptools.find_packages(), 26 | classifiers=[ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: MIT License", 29 | "Operating System :: OS Independent", 30 | ], 31 | python_requires='>=3.6', 32 | ) --------------------------------------------------------------------------------