├── .gitignore ├── LICENSE ├── README.md ├── benchmark.py ├── benchmark_data ├── dim.csv ├── n_classes.csv ├── n_tokens.csv └── plots.png ├── modules.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 mgmalek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Memory-Efficient Cross Entropy Loss 2 | 3 | ## TL;DR 4 | 5 | This repo contains an implementation of a linear projection + cross-entropy loss PyTorch module that has substantially lower memory consumption compared to a standard implementation, with almost no additional compute cost. The memory savings come from two optimizations: 1) overwriting the logits with their gradients in-place and 2) not materializing the entire logits tensor. 6 | 7 | ## Overview 8 | 9 | In networks trained to perform classification tasks, such as language models, the final layer is generally a linear projection from `dim` channels to `n_classes` channels to compute the logits, followed by cross-entropy loss. When `n_classes` is large relative to `dim`, the logits consume a large amount of GPU memory compared to other activations in the network. For example, [Mistral 7B](https://arxiv.org/abs/2310.06825) has a vocabulary size of 32,000 compared to much lower hidden dimension of 4096, so the logits take up roughly 8x as much GPU memory as the preceding activations. 10 | 11 | This repo contains two optimizations to reduce the memory usage of a linear projection followed by cross-entropy loss, implemented in PyTorch + Triton. These optimizations primarily focus on reducing the memory usage of the logits tensor and its gradient since these tensors can dominate overall memory usage: 12 | * **Optimization 1**: Overwrite the logits tensor with its gradient in-place to avoid allocating more memory for the gradients 13 | * **Optimization 2**: Compute the loss and gradients in a loop of $K$ micro-batches **in the forward pass** so that we only materialize $\frac{1}{K}$ of the full logits tensor 14 | 15 | These optimizations can reduce peak memory usage of a linear projection + cross-entropy loss by several times with almost no additional compute cost. 16 | 17 | 18 | ## Performance Analysis 19 | 20 | Figure 1 plots the peak memory usage (top row) and median wall clock time (bottom row) before and after applying these optimizations. 21 | 22 | ![Figure 1](benchmark_data/plots.png) 23 | **Figure 1 (generated by running `python ./benchmark.py`)** 24 | 25 | ### Optimization 1: Overwrite logits with their gradients 26 | 27 | During the backward pass of a linear projection + cross-entropy loss module, we no longer need to keep the logits in memory after computing their gradients. So, we overwrite the logits in-place with their gradients in the backward pass to avoid allocating any new memory for the gradients. 28 | 29 | The memory savings from this optimization are (approximately) represented by the difference between the blue line (without this optimization) and orange line (with this optimization) in Figure 1, above. 30 | 31 | ### Optimization 2: Avoid materializing full logits tensor 32 | 33 | To avoid materializing the full logits tensor, we split the batch into $K$ micro-batches. Then, we can compute both the loss and the logit gradients (up to a scale factor) during the forward pass. With the logit gradients, we can compute the gradients w.r.t the input features and linear projection weights. This way, we do not need to materialize the entire logits tensor - we materialize $\frac{1}{K}$ of the logits at a time, compute the gradients we need, then discard those logits. Note that this requires no additional recomputation since this is all done in the forward pass. 34 | 35 | The reason we can compute the logit gradients in the forward pass is that the output of this module is a scalar (since we assume we will do either a `mean` or `sum` reduction on the loss). Therefore, to get the correct gradients in the backward pass, we can simply multiply the gradients we computed in the forward pass by the `grad_output` scalar. 36 | 37 | The top row in Figure 1 shows the memory savings from this optimization for different values of $K$ (`n_loop_iters` in Figure 1 refers to the number of microbatches $K$). In the bottom row of Figure 1, we can see that the median wall clock time is nearly identical before vs after these optimizations. 38 | 39 | Note that we see diminishing returns in peak memory usage as we scale the hidden dim (`dim`). This is because peak memory usage becomes determined by the size of the linear projection's weights & gradients, rather than the logits, once the hidden dim is sufficiently large (right column in Figure 1). 40 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from copy import copy 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import torch 8 | import triton 9 | from tqdm.auto import tqdm 10 | 11 | from modules import ( 12 | FusedProjectionPlusCrossEntropyLoss, 13 | PyTorchProjectionPlusCrossEntropyLoss, 14 | ) 15 | 16 | 17 | def _get_peak_memory_consumed( 18 | n_tokens, dim, n_classes, init_fn, dtype, device="cuda" 19 | ): 20 | torch.cuda.reset_peak_memory_stats() 21 | initial_peak_mem_bytes = torch.cuda.max_memory_allocated(device=device) 22 | 23 | feat = torch.randn(n_tokens, dim, dtype=dtype, requires_grad=True, device=device) 24 | targ = torch.randint(0, n_classes, (n_tokens,), device=device) 25 | 26 | forward_fn = init_fn() # initialize module here so that we capture the memory consumption of the weights 27 | 28 | _ = forward_fn(feat, targ).mean().backward() 29 | 30 | final_peak_mem_bytes = torch.cuda.max_memory_allocated(device=device) 31 | 32 | peak_mem_consumed_bytes = final_peak_mem_bytes - initial_peak_mem_bytes 33 | peak_mem_consumed_gb = peak_mem_consumed_bytes / 1e9 34 | 35 | median_time_ms = triton.testing.do_bench( 36 | lambda: forward_fn(feat, targ).mean().backward(), return_mode="median" 37 | ) 38 | 39 | return peak_mem_consumed_gb, median_time_ms 40 | 41 | 42 | def run_test( 43 | line_configs, test_constant_kwargs, test_key, test_vals, mem_ax, time_ax, debug 44 | ) -> pd.DataFrame: 45 | if debug: 46 | test_vals = test_vals[:2] 47 | line_configs = line_configs[:2] 48 | 49 | constant_kwargs_label = ", ".join(f"{k}={v}" for k, v in test_constant_kwargs.items()) 50 | 51 | mem_ax.set_title(f"Peak Memory Usage\n({constant_kwargs_label})") 52 | mem_ax.set_xlabel(test_key) 53 | 54 | time_ax.set_title(f"Median Wall Clock Time\n({constant_kwargs_label})") 55 | time_ax.set_xlabel(test_key) 56 | 57 | mem_ax.set_ylabel("Peak Memory Usage (GB)") 58 | time_ax.set_ylabel("Median Wall Clock Time (ms)") 59 | 60 | df_rows = [] 61 | for line_config in tqdm(line_configs): 62 | line_label = ", ".join(f"{k}={v}" for k, v in line_config.items()) 63 | 64 | line_config = copy(line_config) 65 | fn = line_config.pop("fn") 66 | 67 | peak_mems = [] 68 | median_times = [] 69 | 70 | for test_val in test_vals: 71 | params = {**test_constant_kwargs, test_key: test_val} 72 | dim = params["dim"] 73 | n_classes = params["n_classes"] 74 | n_tokens = params["n_tokens"] 75 | dtype = params["dtype"] 76 | 77 | if fn == "torch": 78 | module_cls = PyTorchProjectionPlusCrossEntropyLoss 79 | elif fn == "triton": 80 | module_cls = FusedProjectionPlusCrossEntropyLoss 81 | else: 82 | raise ValueError(f"Unknown {fn=}") 83 | 84 | init_fn = lambda: module_cls(dim, n_classes, **line_config).cuda().to(dtype) 85 | peak_mem, median_time = _get_peak_memory_consumed( 86 | n_tokens, dim, n_classes, init_fn=init_fn, dtype=dtype 87 | ) 88 | peak_mems.append(peak_mem) 89 | median_times.append(median_time) 90 | df_rows.append( 91 | dict( 92 | fn=fn, 93 | dim=dim, 94 | n_classes=n_classes, 95 | n_tokens=n_tokens, 96 | dtype=str(dtype), 97 | peak_mem=peak_mem, 98 | median_time=median_time, 99 | n_loop_iters=line_config.get("n_loop_iters", None), 100 | ) 101 | ) 102 | 103 | mem_ax.plot(test_vals, peak_mems, "--o", label=line_label) 104 | time_ax.plot(test_vals, median_times, "--o", label=line_label) 105 | 106 | for ax in (mem_ax, time_ax): 107 | ax.legend(loc="upper left") 108 | 109 | return pd.DataFrame(df_rows) 110 | 111 | 112 | def benchmark(line_configs, output_dir, debug): 113 | output_dir = Path(output_dir) 114 | output_dir.mkdir(exist_ok=True, parents=True) 115 | 116 | _, (mem_axs, time_axs) = plt.subplots(2, 3, figsize=(20, 10), dpi=200) 117 | ax_mem_nclasses, ax_mem_tokens, ax_mem_dim = mem_axs 118 | ax_time_nclasses, ax_time_tokens, ax_time_dim = time_axs 119 | 120 | # Benchmark performance wrt n_classes 121 | test_key = "n_classes" 122 | test_vals = [2048, 4096, 8192, 16384, 32768, 65536] 123 | test_constant_kwargs = dict(dim=2048, n_tokens=8192, dtype=torch.float32) 124 | df_n_classes = run_test( 125 | line_configs, 126 | test_constant_kwargs, 127 | test_key, 128 | test_vals, 129 | ax_mem_nclasses, 130 | ax_time_nclasses, 131 | debug=debug, 132 | ) 133 | df_n_classes.to_csv(output_dir / "n_classes.csv", index=False) 134 | 135 | # Benchmark performance wrt n_tokens 136 | test_key = "n_tokens" 137 | test_vals = [1024, 2048, 4096, 8192, 16384] 138 | test_constant_kwargs = dict(n_classes=32768, dim=2048, dtype=torch.float32) 139 | df_n_tokens = run_test( 140 | line_configs, 141 | test_constant_kwargs, 142 | test_key, 143 | test_vals, 144 | ax_mem_tokens, 145 | ax_time_tokens, 146 | debug=debug, 147 | ) 148 | df_n_tokens.to_csv(output_dir / "n_tokens.csv", index=False) 149 | 150 | # Benchmark performance wrt dim 151 | test_key = "dim" 152 | test_vals = [1024, 2048, 4096, 8192] 153 | test_constant_kwargs = dict(n_classes=32768, n_tokens=8192, dtype=torch.float32) 154 | df_dim = run_test( 155 | line_configs, 156 | test_constant_kwargs, 157 | test_key, 158 | test_vals, 159 | ax_mem_dim, 160 | ax_time_dim, 161 | debug=debug, 162 | ) 163 | df_dim.to_csv(output_dir / "dim.csv", index=False) 164 | 165 | plt.tight_layout() 166 | plt.savefig(output_dir / "plots.png") 167 | plt.close() 168 | 169 | 170 | def main(): 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--debug", action="store_true") 173 | args = parser.parse_args() 174 | 175 | # Benchmark torch vs custom kernel 176 | line_configs = [ 177 | dict(fn="torch"), 178 | dict(fn="triton", n_loop_iters=1), 179 | dict(fn="triton", n_loop_iters=2), 180 | dict(fn="triton", n_loop_iters=4), 181 | dict(fn="triton", n_loop_iters=8), 182 | ] 183 | benchmark( 184 | line_configs=line_configs, 185 | output_dir="./benchmark_data", 186 | debug=args.debug, 187 | ) 188 | 189 | 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /benchmark_data/dim.csv: -------------------------------------------------------------------------------- 1 | fn,dim,n_classes,n_tokens,dtype,peak_mem,median_time,n_loop_iters 2 | torch,1024,32768,8192,torch.float32,3.389064192,98.735107421875, 3 | torch,2048,32768,8192,torch.float32,3.556836352,188.45184326171875, 4 | torch,4096,32768,8192,torch.float32,3.892380672,376.32818603515625, 5 | torch,8192,32768,8192,torch.float32,4.563469312,754.8159790039062, 6 | triton,1024,32768,8192,torch.float32,1.442939392,94.4906234741211,1.0 7 | triton,2048,32768,8192,torch.float32,1.812038144,184.9047088623047,1.0 8 | triton,4096,32768,8192,torch.float32,2.550235648,375.2806396484375,1.0 9 | triton,8192,32768,8192,torch.float32,4.026630656,747.92138671875,1.0 10 | triton,1024,32768,8192,torch.float32,0.889291264,95.13369750976562,2.0 11 | triton,2048,32768,8192,torch.float32,1.2416128,183.87149047851562,2.0 12 | triton,4096,32768,8192,torch.float32,1.946255872,365.42156982421875,2.0 13 | triton,8192,32768,8192,torch.float32,3.355542016,744.4859008789062,2.0 14 | triton,1024,32768,8192,torch.float32,0.6124672,93.00479888916016,4.0 15 | triton,2048,32768,8192,torch.float32,0.956400128,188.79385375976562,4.0 16 | triton,4096,32768,8192,torch.float32,1.644265984,365.6519775390625,4.0 17 | triton,8192,32768,8192,torch.float32,3.019997696,745.1135864257812,4.0 18 | triton,1024,32768,8192,torch.float32,0.474055168,92.41600036621094,8.0 19 | triton,2048,32768,8192,torch.float32,0.813793792,186.0474853515625,8.0 20 | triton,4096,32768,8192,torch.float32,1.49327104,373.6524658203125,8.0 21 | triton,8192,32768,8192,torch.float32,2.852225536,745.8744506835938,8.0 22 | -------------------------------------------------------------------------------- /benchmark_data/n_classes.csv: -------------------------------------------------------------------------------- 1 | fn,dim,n_classes,n_tokens,dtype,peak_mem,median_time,n_loop_iters 2 | torch,2048,2048,8192,torch.float32,0.293798912,11.92959976196289, 3 | torch,2048,4096,8192,torch.float32,0.50338304,23.48646354675293, 4 | torch,2048,8192,8192,torch.float32,0.939590656,46.320640563964844, 5 | torch,2048,16384,8192,torch.float32,1.812005888,94.18956756591797, 6 | torch,2048,32768,8192,torch.float32,3.556836352,182.78707885742188, 7 | torch,2048,65536,8192,torch.float32,7.04649728,397.3591003417969, 8 | triton,2048,2048,8192,torch.float32,0.302088704,11.811840057373047,1.0 9 | triton,2048,4096,8192,torch.float32,0.402752,23.0512638092041,1.0 10 | triton,2048,8192,8192,torch.float32,0.604078592,46.302207946777344,1.0 11 | triton,2048,16384,8192,torch.float32,1.006731776,92.71807861328125,1.0 12 | triton,2048,32768,8192,torch.float32,1.812038144,181.0339813232422,1.0 13 | triton,2048,65536,8192,torch.float32,3.42265088,408.7091064453125,1.0 14 | triton,2048,2048,8192,torch.float32,0.23497984,12.84607982635498,2.0 15 | triton,2048,4096,8192,torch.float32,0.302088704,23.48953628540039,2.0 16 | triton,2048,8192,8192,torch.float32,0.436306432,46.282752990722656,2.0 17 | triton,2048,16384,8192,torch.float32,0.704741888,91.109375,2.0 18 | triton,2048,32768,8192,torch.float32,1.2416128,181.41798400878906,2.0 19 | triton,2048,65536,8192,torch.float32,2.315354624,410.96600341796875,2.0 20 | triton,2048,2048,8192,torch.float32,0.201425408,12.398591995239258,4.0 21 | triton,2048,4096,8192,torch.float32,0.251757056,24.31283187866211,4.0 22 | triton,2048,8192,8192,torch.float32,0.352420352,46.60940933227539,4.0 23 | triton,2048,16384,8192,torch.float32,0.553746944,93.11539459228516,4.0 24 | triton,2048,32768,8192,torch.float32,0.956400128,184.9681854248047,4.0 25 | triton,2048,65536,8192,torch.float32,1.761706496,399.52178955078125,4.0 26 | triton,2048,2048,8192,torch.float32,0.184648192,13.105152130126953,8.0 27 | triton,2048,4096,8192,torch.float32,0.226591232,24.52992057800293,8.0 28 | triton,2048,8192,8192,torch.float32,0.310477312,49.33427047729492,8.0 29 | triton,2048,16384,8192,torch.float32,0.478249472,92.16000366210938,8.0 30 | triton,2048,32768,8192,torch.float32,0.813793792,183.77420043945312,8.0 31 | triton,2048,65536,8192,torch.float32,1.484882432,385.5206298828125,8.0 32 | -------------------------------------------------------------------------------- /benchmark_data/n_tokens.csv: -------------------------------------------------------------------------------- 1 | fn,dim,n_classes,n_tokens,dtype,peak_mem,median_time,n_loop_iters 2 | torch,2048,32768,1024,torch.float32,0.687875072,23.853055953979492, 3 | torch,2048,32768,2048,torch.float32,1.090536448,47.36716842651367, 4 | torch,2048,32768,4096,torch.float32,1.912636416,92.95769500732422, 5 | torch,2048,32768,8192,torch.float32,3.556836352,187.725830078125, 6 | torch,2048,32768,16384,torch.float32,6.845236224,372.83123779296875, 7 | triton,2048,32768,1024,torch.float32,0.696267264,23.89299201965332,1.0 8 | triton,2048,32768,2048,torch.float32,0.855663104,47.528961181640625,1.0 9 | triton,2048,32768,4096,torch.float32,1.174454784,92.35865783691406,1.0 10 | triton,2048,32768,8192,torch.float32,1.812038144,183.9411163330078,1.0 11 | triton,2048,32768,16384,torch.float32,3.087204864,367.54022216796875,1.0 12 | triton,2048,32768,1024,torch.float32,0.624964096,25.177087783813477,2.0 13 | triton,2048,32768,2048,torch.float32,0.713056768,47.15315246582031,2.0 14 | triton,2048,32768,4096,torch.float32,0.889242112,94.12198638916016,2.0 15 | triton,2048,32768,8192,torch.float32,1.2416128,183.43423461914062,2.0 16 | triton,2048,32768,16384,torch.float32,1.946354176,368.1321105957031,2.0 17 | triton,2048,32768,1024,torch.float32,0.589312512,25.584640502929688,4.0 18 | triton,2048,32768,2048,torch.float32,0.6417536,49.3199348449707,4.0 19 | triton,2048,32768,4096,torch.float32,0.746635776,93.23725128173828,4.0 20 | triton,2048,32768,8192,torch.float32,0.956400128,187.01210021972656,4.0 21 | triton,2048,32768,16384,torch.float32,1.375928832,368.1167297363281,4.0 22 | triton,2048,32768,1024,torch.float32,0.57148672,25.878528594970703,8.0 23 | triton,2048,32768,2048,torch.float32,0.606102016,50.14118576049805,8.0 24 | triton,2048,32768,4096,torch.float32,0.675332608,98.6798095703125,8.0 25 | triton,2048,32768,8192,torch.float32,0.813793792,185.58770751953125,8.0 26 | triton,2048,32768,16384,torch.float32,1.09071616,374.0549011230469,8.0 27 | -------------------------------------------------------------------------------- /benchmark_data/plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mgmalek/efficient_cross_entropy/049d44460051a82f58f7ff49a2ad0653ecf026d8/benchmark_data/plots.png -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | import torch.nn as nn 5 | import triton 6 | 7 | import triton 8 | import triton.language as tl 9 | 10 | 11 | @triton.jit 12 | def fused_cross_entropy_fwd_bwd_kernel( 13 | output_loss_ptr, 14 | output_logit_grad_ptr, 15 | input_logit_ptr, 16 | input_targ_ptr, 17 | input_divisor_ptr, 18 | output_loss_stride, 19 | output_logit_grad_stride, 20 | input_logit_stride, 21 | input_targ_stride, 22 | n_cols, 23 | ignore_index, 24 | BLOCK_SIZE: tl.constexpr, 25 | ): 26 | # Get pointers to current row for all inputs/outputs 27 | row_idx = tl.program_id(0) 28 | logit_grad_row_start_ptr = output_logit_grad_ptr + row_idx * output_logit_grad_stride 29 | logit_row_start_ptr = input_logit_ptr + row_idx * input_logit_stride 30 | targ_ptr = input_targ_ptr + row_idx * input_targ_stride 31 | loss_ptr = output_loss_ptr + row_idx * output_loss_stride 32 | 33 | col_offsets = tl.arange(0, BLOCK_SIZE) 34 | logit_row_ptrs = logit_row_start_ptr + col_offsets 35 | logit_grad_row_ptrs = logit_grad_row_start_ptr + col_offsets 36 | 37 | # Load data into SRAM 38 | logit_row_unnormalized = tl.load( 39 | logit_row_ptrs, mask=col_offsets < n_cols, other=float("-Inf") 40 | ) 41 | targ = tl.load(targ_ptr) 42 | divisor = tl.load(input_divisor_ptr) 43 | 44 | # Normalize logits and compute some useful intermediate values 45 | logit_row = logit_row_unnormalized - tl.max( 46 | logit_row_unnormalized, axis=0 47 | ) # Subtract max value for numerical stability 48 | exp_logit_row = tl.exp(logit_row) 49 | sum_exp_logit_row = tl.sum(exp_logit_row, axis=0) 50 | 51 | # Compute loss 52 | log_sum_exp_logit_row = tl.log(sum_exp_logit_row) 53 | logit_gt_logit = tl.sum(tl.where(targ == col_offsets, logit_row, 0.0)) 54 | loss = log_sum_exp_logit_row - logit_gt_logit 55 | loss = loss / divisor 56 | loss = tl.where(targ == ignore_index, 0.0, loss) 57 | tl.store(loss_ptr, loss) 58 | 59 | # Compute gradients 60 | targ_one_hot = tl.where(targ == col_offsets, 1.0, 0.0) 61 | grad = (exp_logit_row / sum_exp_logit_row - targ_one_hot) 62 | grad = grad / divisor 63 | grad = tl.where(targ == ignore_index, 0.0, grad) 64 | tl.store(logit_grad_row_ptrs, grad, mask=col_offsets < n_cols) 65 | 66 | 67 | class FusedCrossEntropyLossFunction(torch.autograd.Function): 68 | # NOTE: We put the linear projection in the same autograd Function as the loss computation 69 | # because we overwrite the logits with their gradients inplace to avoid allocating more 70 | # memory for the gradients, and so we keep the logits completely contained within this 71 | # Functionto avoid possible side-effects if they were exposed. 72 | 73 | @staticmethod 74 | def forward( 75 | ctx, 76 | in_feat: torch.Tensor, 77 | proj_weight: torch.Tensor, 78 | targ: torch.Tensor, 79 | n_loop_iters: int, 80 | ignore_index: int, 81 | reduction: str, 82 | ): 83 | n_tokens = in_feat.shape[0] 84 | n_classes = proj_weight.shape[0] 85 | 86 | assert in_feat.ndim == 2, in_feat.ndim 87 | assert proj_weight.ndim == 2, proj_weight.ndim 88 | assert targ.ndim == 1, targ.shape 89 | assert in_feat.shape[0] == targ.shape[0], f"Number of tokens in in_feat and targ is not equal: {(in_feat.shape, targ.shape) = }" 90 | assert reduction in ("mean", "sum"), reduction 91 | assert n_loop_iters > 0, n_loop_iters 92 | assert n_tokens % n_loop_iters == 0, (n_tokens, n_loop_iters) 93 | 94 | NUM_WARPS = 16 95 | 96 | BLOCK_SIZE = triton.next_power_of_2(n_classes) 97 | 98 | loss = torch.empty(n_tokens, dtype=in_feat.dtype, device=in_feat.device) 99 | dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else in_feat.dtype 100 | 101 | if proj_weight.requires_grad: 102 | grad_proj_weight = torch.zeros_like(proj_weight, dtype=dtype) 103 | else: 104 | grad_proj_weight = None 105 | 106 | if in_feat.requires_grad: 107 | grad_in_feat = torch.zeros_like(in_feat) 108 | else: 109 | grad_in_feat = None 110 | 111 | divisor = (targ != ignore_index).sum().to(dtype) if reduction == "mean" else torch.ones(1, dtype=dtype, device=in_feat.device) 112 | 113 | # Divide the input into chunks of size num_tokens // n_loop_iters, then compute the loss for each of these groups 114 | proj_weight_cast = proj_weight.to(dtype) 115 | 116 | loop_chunk_size = triton.cdiv(n_tokens, n_loop_iters) 117 | logits_chunk_cast = torch.zeros((loop_chunk_size, n_classes), dtype=dtype, device=in_feat.device) 118 | for i, in_feat_chunk in enumerate(torch.split(in_feat, loop_chunk_size)): 119 | token_start_idx = i * loop_chunk_size 120 | token_end_idx = (i + 1) * loop_chunk_size 121 | 122 | in_feat_chunk = in_feat_chunk.to(dtype) 123 | 124 | # Compute logits 125 | torch.matmul(in_feat_chunk, proj_weight_cast.T, out=logits_chunk_cast) 126 | logits_chunk = logits_chunk_cast.float() 127 | 128 | # Compute loss 129 | loss_chunk = loss[token_start_idx:token_end_idx] 130 | targ_chunk = targ[token_start_idx:token_end_idx] 131 | 132 | n_tokens_chunk = logits_chunk.shape[0] 133 | grad_logits_chunk = logits_chunk # NOTE: we override the logits with their gradients 134 | fused_cross_entropy_fwd_bwd_kernel[(n_tokens_chunk,)]( 135 | loss_chunk, 136 | grad_logits_chunk, 137 | logits_chunk, 138 | targ_chunk, 139 | divisor, 140 | loss_chunk.stride(0), 141 | grad_logits_chunk.stride(0), 142 | logits_chunk.stride(0), 143 | targ_chunk.stride(0), 144 | n_classes, 145 | ignore_index, 146 | num_warps=NUM_WARPS, 147 | BLOCK_SIZE=BLOCK_SIZE, 148 | ) 149 | 150 | grad_logits_chunk = grad_logits_chunk.to(dtype) 151 | 152 | if in_feat.requires_grad: 153 | grad_in_feat[token_start_idx:token_end_idx] = grad_logits_chunk @ proj_weight_cast 154 | 155 | if proj_weight.requires_grad: 156 | torch.addmm( 157 | grad_proj_weight, 158 | grad_logits_chunk.T, 159 | in_feat_chunk, 160 | out=grad_proj_weight, 161 | ) 162 | 163 | # NOTE: if reduction == "mean" we already divide by an appropriate normalization factor in the kernel so we can alway sum here 164 | loss = loss.sum() 165 | 166 | # Save data for backward 167 | ctx.in_feat_requires_grad = in_feat.requires_grad 168 | ctx.proj_weight_requires_grad = proj_weight.requires_grad 169 | 170 | if proj_weight.requires_grad and in_feat.requires_grad: 171 | ctx.save_for_backward(grad_in_feat, grad_proj_weight) 172 | elif proj_weight.requires_grad and not in_feat.requires_grad: 173 | ctx.save_for_backward(grad_proj_weight) 174 | elif not proj_weight.requires_grad and in_feat.requires_grad: 175 | ctx.save_for_backward(grad_in_feat) 176 | 177 | return loss 178 | 179 | @staticmethod 180 | def backward(ctx, grad_output): 181 | if ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad: 182 | grad_in_feat, grad_proj_weight = ctx.saved_tensors 183 | elif not ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad: 184 | grad_proj_weight, = ctx.saved_tensors 185 | elif ctx.in_feat_requires_grad and not ctx.proj_weight_requires_grad: 186 | grad_in_feat, = ctx.saved_tensors 187 | 188 | assert grad_output.shape == tuple(), grad_output.shape 189 | grad_in_feat *= grad_output 190 | grad_proj_weight *= grad_output 191 | 192 | return grad_in_feat, grad_proj_weight, None, None, None, None 193 | 194 | 195 | class FusedProjectionPlusCrossEntropyLoss(nn.Module): 196 | """Fused implementation of linear projection + cross entropy loss""" 197 | 198 | def __init__( 199 | self, 200 | dim: int, 201 | n_classes: int, 202 | n_loop_iters: int = 1, 203 | ignore_index: int = -100, 204 | reduction: str = "mean", 205 | ): 206 | super().__init__() 207 | self.n_loop_iters = n_loop_iters 208 | self.ignore_index = ignore_index 209 | self.reduction = reduction 210 | self.proj_weight = nn.Parameter(torch.empty(n_classes, dim)) 211 | self.reset_parameters() 212 | 213 | def reset_parameters(self): 214 | nn.init.kaiming_uniform_(self.proj_weight, a=sqrt(5)) 215 | 216 | def forward(self, x, targ): 217 | return FusedCrossEntropyLossFunction.apply( 218 | x, 219 | self.proj_weight, 220 | targ, 221 | self.n_loop_iters, 222 | self.ignore_index, 223 | self.reduction, 224 | ) 225 | 226 | 227 | class PyTorchProjectionPlusCrossEntropyLoss(nn.Module): 228 | """Simple PyTorch implementation of linear projection + cross entropy loss. Intended only for testing and benchmarking.""" 229 | 230 | def __init__(self, dim: int, n_classes: int, ignore_index: int = -100, reduction: str = "mean"): 231 | super().__init__() 232 | self.proj = nn.Linear(dim, n_classes, bias=False) 233 | self.loss = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) 234 | 235 | def forward(self, x, targ): 236 | logits = self.proj(x) 237 | return self.loss(logits, targ) 238 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | 7 | from modules import ( 8 | FusedProjectionPlusCrossEntropyLoss, 9 | PyTorchProjectionPlusCrossEntropyLoss, 10 | ) 11 | 12 | 13 | def fwd_bwd(module: nn.Module, x: torch.Tensor, targ: torch.Tensor): 14 | x = x.clone() 15 | x.requires_grad_(True) 16 | loss = module(x, targ) 17 | loss.mean().backward() 18 | grad = x.grad 19 | if isinstance(module, FusedProjectionPlusCrossEntropyLoss): 20 | proj_weight_grad = module.proj_weight.grad 21 | else: 22 | proj_weight_grad = module.proj.weight.grad 23 | return loss, grad, proj_weight_grad 24 | 25 | 26 | @pytest.mark.parametrize("n_tokens", [8, 1536]) 27 | @pytest.mark.parametrize("n_classes", [8, 2048]) 28 | @pytest.mark.parametrize("dim", [8, 2048]) 29 | @pytest.mark.parametrize("n_loop_iters", [1, 2, 4]) 30 | @pytest.mark.parametrize("reduction", ["sum", "mean"]) 31 | @pytest.mark.parametrize("use_ignore_index", [False]) 32 | @pytest.mark.parametrize("autocast", [True, False]) 33 | @pytest.mark.parametrize("dtype", [torch.float32]) 34 | @pytest.mark.parametrize("device", ["cuda"]) 35 | def test_correctness( 36 | n_tokens, n_classes, dim, n_loop_iters, reduction, use_ignore_index, autocast, dtype, device 37 | ): 38 | torch.manual_seed(0) 39 | 40 | x = torch.randn(n_tokens, dim, device=device, dtype=dtype) 41 | targ = torch.randint(low=0, high=n_classes, size=(n_tokens,), device=device) 42 | 43 | ignore_index = -1 44 | if use_ignore_index: 45 | targ = torch.where( 46 | torch.rand(targ.shape, device=targ.device) < 0.1, targ, ignore_index 47 | ) 48 | 49 | torch_module = PyTorchProjectionPlusCrossEntropyLoss( 50 | dim, n_classes, ignore_index=ignore_index, reduction=reduction, 51 | ).to(device, dtype=dtype) 52 | 53 | triton_module = FusedProjectionPlusCrossEntropyLoss( 54 | dim, n_classes, n_loop_iters, ignore_index=ignore_index, reduction=reduction 55 | ).to(device, dtype=dtype) 56 | 57 | torch_fp32_module = deepcopy(torch_module) 58 | 59 | assert triton_module.proj_weight.data.shape == torch_module.proj.weight.data.shape 60 | triton_module.proj_weight.data = torch_module.proj.weight.data 61 | 62 | with torch.cuda.amp.autocast(enabled=autocast, dtype=torch.bfloat16): 63 | torch_loss, torch_grad, torch_proj_weight_grad = fwd_bwd(torch_module, x, targ) 64 | triton_loss, triton_grad, triton_proj_weight_grad = fwd_bwd(triton_module, x, targ) 65 | 66 | assert torch_grad is not None 67 | assert torch_loss.dtype == triton_loss.dtype, (torch_loss.dtype, triton_loss.dtype) 68 | assert torch_grad.dtype == triton_grad.dtype, (torch_grad.dtype, triton_grad.dtype) 69 | assert torch_proj_weight_grad.dtype == triton_proj_weight_grad.dtype, (torch_proj_weight_grad.dtype, triton_proj_weight_grad.dtype) 70 | 71 | if autocast: 72 | # autocast correctness is validated by checking that the norm of the loss and gradients 73 | # between pytorch fp32 and pytorch autocast is similar to the norm of the loss and gradients 74 | # between pytorch fp32 and triton autocast 75 | torch_fp32_loss, torch_fp32_grad, torch_fp32_proj_weight_grad = fwd_bwd(torch_fp32_module, x, targ) 76 | 77 | torch_loss_norm = torch.linalg.norm(torch_fp32_loss - torch_loss).item() 78 | torch_grad_norm = torch.linalg.norm(torch_fp32_grad - torch_grad).item() 79 | torch_proj_weight_grad_norm = torch.linalg.norm(torch_fp32_proj_weight_grad - torch_proj_weight_grad).item() 80 | 81 | triton_loss_norm = torch.linalg.norm(torch_fp32_loss - triton_loss).item() 82 | triton_grad_norm = torch.linalg.norm(torch_fp32_grad - triton_grad).item() 83 | triton_proj_weight_grad_norm = torch.linalg.norm(torch_fp32_proj_weight_grad - triton_proj_weight_grad).item() 84 | 85 | assert triton_loss_norm < 2 * torch_loss_norm, (triton_loss_norm, torch_loss_norm) 86 | assert triton_grad_norm < 2 * torch_grad_norm, (triton_loss_norm, torch_loss_norm) 87 | assert triton_proj_weight_grad_norm < 2 * torch_proj_weight_grad_norm, (triton_loss_norm, torch_loss_norm) 88 | 89 | else: 90 | assert torch.allclose(torch_loss, triton_loss, rtol=1e-4) 91 | assert torch.allclose(torch_grad, triton_grad, atol=1e-3, rtol=1e-4) 92 | assert torch.allclose(torch_proj_weight_grad, triton_proj_weight_grad, atol=1e-2, rtol=1e-2) 93 | --------------------------------------------------------------------------------