├── .gitignore ├── LICENSE ├── README.md ├── setup.py └── src └── grad_cache ├── __init__.py ├── cachex ├── __init__.py ├── functional.py ├── training.py └── tree_utils.py ├── context_managers.py ├── functional.py ├── grad_cache.py ├── loss.py └── pytorch_lightning ├── pl_example.py ├── pl_gradcache.py ├── readme.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # Gradient Cache 6 | Gradient Cache is a simple technique for unlimitedly scaling contrastive learning batch far beyond GPU/TPU memory constraint. This means training that used to take heavy hardware, e.g. 8 V100 GPU, can be done on a single GPU. In addition, Gradient Cache allow users to replace big RAM GPU/TPU with much more cost efficient high FLOP low RAM systems. 7 | 8 | This repo holds a generic implementation of Gradient Cache described in our paper [Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup 9 | ](https://arxiv.org/abs/2101.06983). Both Pytorch and JAX frameworks are supported. 10 | ``` 11 | @inproceedings{gao2021scaling, 12 | title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup}, 13 | author={Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan}, 14 | booktitle ={Proceedings of the 6th Workshop on Representation Learning for NLP}, 15 | year={2021}, 16 | } 17 | ``` 18 | 19 | **NEW: We now support JAX and TPU!** 20 | 21 | Gradient Cache has also been integrated into dense passage retrieval (DPR). Checkout our [GC-DPR toolkit](https://github.com/luyug/GC-DPR). 22 | ## Installation 23 | First install your desired deep learning backend, either Pytorch or JAX. To install GradCache, clone this repo and run pip. 24 | ``` 25 | git clone https://github.com/luyug/GradCache 26 | cd GradCache 27 | pip install . 28 | ``` 29 | For development, 30 | ``` 31 | pip install --editable . 32 | ``` 33 | 34 | ## Usage 35 | Gradient caching functionalities are implemented in `GradCache` class. If you are developing a **new project** instead of patching an old one, also checkout our [functional approach](#functional-approach) for a effort reduced approach. 36 | 37 | For JAX/Flax user, take a look at a simple train function [here](https://github.com/luyug/GradCache/blob/8463340a15a2395fc33b9a1f40f5f4946b7cbad8/src/grad_cache/cachex/training.py#L9). 38 | 39 | ### Initialization 40 | The class's `__init__` method defines the cache and has several functional parameters `*_fn` for easy adjust of model behaviors. Alternatively you can also sub-class GradCache. 41 | ``` 42 | grad_cache.GradCache( 43 | models: List[nn.Module], 44 | chunk_sizes: Union[int, List[int]], 45 | loss_fn: Callable[..., Tensor], 46 | split_input_fn: Callable[[Any, int], Any] = None, 47 | get_rep_fn: Callable[..., Tensor] = None, 48 | fp16: bool = False, 49 | scaler: GradScaler = None, 50 | ) 51 | ``` 52 | **models** - A list of encoder models to be updated with with the Gradient Cache. 53 | 54 | **chunk_sizes** - An integer indicating chunk size. Or a list of integers of chunk size for each model. This controls for each model the sub-batch size to run forward-backward pass and should be set based on available GPU memory. A value too small will leave the GPU under utilized. 55 | 56 | **loss_fn** - A loss function that takes representation tensors of number equal to number of models in `models` and arbitrary numbers of keyword arguments. It should compute loss based on the input tensors, and in no case modify the input tensors' relations in the autograd graph, which are later relied upon to create the gradient cache. 57 | 58 | **split_input_fn** - An optional function that split generic model input into chunks based on defined chunk_sizes. If not provided, this class will try its best to split the inputs of supported types. See `split_inputs` function. 59 | 60 | **get_rep_fn** - An optional function that takes generic model output and return representation tensors. If not provided, the generic output is assumed to be the representation tensor. 61 | 62 | **fp16** - If True, run mixed precision training, which requires scaler to also be set. 63 | 64 | **scaler** - A GradScaler object for automatic mixed precision training. 65 | 66 | ### Cache Gradient Step 67 | To run a cached gradient computatoin step, call `cache_step` function, 68 | 69 | ``` 70 | cache_step( 71 | *model_inputs, 72 | no_sync_except_last: bool = False, 73 | **loss_kwargs 74 | ) 75 | ``` 76 | Run a single gradient cache step. Upon function return, updates are computed for each model in `self.models` with gradient populated on the weights, as if the `model_inputs` are run as a huge single batch on sufficiently large hardware. Calling an GradCache object with `__call__` will also invoke this function. 77 | 78 | **model_inputs** - List of inputs to each encoder model. Should be in similar order as `self.models`. 79 | 80 | **no_sync_except_last** - If True, under distributed setup, for each model, only trigger gradient reduction across processes for the last sub-batch's forward-backward pass. This could come in handy when dealing with a) large model, and/or b) non trivial number of sub-batches. 81 | 82 | **loss_kwargs** - Additional keyword arguments to the loss function `loss_fn`. This is intended to enable flexible loss computation (thanks to dynamic graph in Pytorch) such as reduction, weighting, etc. Potentially, using `loss_kwargs` you can incorporate outputs from those encoder models not tracked by the cache. 83 | 84 | **Return** - loss, the current steps loss scaler tensor (detached from the graph). 85 | 86 | ### Natively Supported Input Types 87 | - x: Tensor - will be passed in as `model(x)` 88 | - x: List[Tensor] - will be passed in as `model(*x)` 89 | - x: Dict[str, Tensor] (or UserDict[str, Tensor]) - will be passed in as `model(**x)` 90 | - x: Tuple[List[Tensor], Dict[str, Tensor]] - will be passed in as `model(*x[0], **x[1])` 91 | 92 | Other generic input are not fully supported, we perform model call using the following heuristics, 93 | 94 | - x: List[Any] - will be passed in as `model(*x)` 95 | - x: Dict[str, Any] - will be passed in as `model(**x)` 96 | - x: Tuple[List[Any], Dict[str, Any]] - will be passed in as `model(*x[0], **x[1])` 97 | 98 | To run with them, `split_input_fn` should be specified during cache initialization to break these inputs into smaller batches. In some rare cases, you may also need to override `get_input_tensors` when its heuristic can not grab enough tensors that covers all cuda devices that hold some tensors in the input. 99 | 100 | 101 | ## Example Usage with Huggingface Transformers 102 | ### Learning a Bi-encoder 103 | Say we want to learn a embedding space of labels and text. Consider the following four pairs. (In practice, you will have many more and much longer text entries.) 104 | ``` 105 | labels = ['fruit', 'meat', 'school', 'company'] 106 | texts = [ 107 | 'this is an apple', 108 | 'steak should be cooked medium rare', 109 | 'cmu is pittsburgh', 110 | 'apple sells laptop' 111 | ] 112 | ``` 113 | 114 | Initialize our encoder models, 115 | ``` 116 | from transformers import AutoTokenizer, AutoModel 117 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 118 | encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda() 119 | encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda() 120 | ``` 121 | Initialize the GradCache object, 122 | ``` 123 | from grad_cache import GradCache 124 | from grad_cache.loss import SimpleContrastiveLoss 125 | 126 | loss_fn = SimpleContrastiveLoss() 127 | gc = GradCache( 128 | models=[encoder1, encoder2], 129 | chunk_sizes=2, 130 | loss_fn=loss_fn, 131 | get_rep_fn=lambda v: v.pooler_output 132 | ) 133 | ``` 134 | Here we use the **get_rep_fn** argument to specify a function that takes generic Huggingface model output and return the actual representation tensor. 135 | 136 | Create model input, 137 | ``` 138 | xx = tokenizer(tt, return_tensors='pt', padding=True) 139 | yy = tokenizer(tt2, return_tensors='pt', padding=True) 140 | ``` 141 | Run a cache step, 142 | ``` 143 | gc(xx, yy, reduction='mean') 144 | ``` 145 | Here we use `reduction='mean'` as a **loss_kwargs** to control loss behavior. With a defined `optimizer`, the full gradient update can be done as, 146 | ``` 147 | optimizer.zero_grad() 148 | gc(xx, yy, reduction='mean') 149 | optimizer.step() 150 | ``` 151 | 152 | ### Use Tied Encoder? 153 | This is naturally handled by the (magic of) dynamic graph. You pass shallow copies of the same encoder model to the GradCache init method. 154 | ``` 155 | tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda() 156 | gc = GradCache( 157 | models=[tied_encoder , tied_encoder], 158 | chunk_sizes=2, 159 | loss_fn=loss_fn, 160 | get_rep_fn=lambda v: v.pooler_output 161 | ) 162 | ``` 163 | Under the hood, distinct hooks will be registered to make correct gradient computation. 164 | ### Distributed Training with Multiple GPUs? 165 | We expect cross process communication of representations to be handled by the `loss_fn`. 166 | ``` 167 | from grad_cache.loss import DistributedContrastiveLoss 168 | loss_fn_dist = DistributedContrastiveLoss() 169 | ``` 170 | Properly wrap the the encoder models for gradient reduction, 171 | ``` 172 | encoder1_ddp = DistributedDataParallel( 173 | encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 174 | encoder2_ddp = DistributedDataParallel( 175 | encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 176 | ``` 177 | You can initialize the cache use the distributed loss and the DDP models, 178 | ``` 179 | gc = GradCache( 180 | models=[encoder1_ddp, encoder2_ddp], 181 | chunk_sizes=2, 182 | loss_fn=loss_fn_dist, 183 | get_rep_fn=lambda v: v.pooler_output 184 | ) 185 | ``` 186 | Run a cache step, 187 | ``` 188 | gc(xx, yy, no_sync_except_last=True, reduction='mean') 189 | ``` 190 | Set `no_sync_except_last=True` to avoid unnecessary gradient reduction. 191 | 192 | ## Functional Approach 193 | ### Decorators 194 | If you are developing a new project, we recommend also checking out the decorators we have provided to create higher order functions for cache. 195 | ``` 196 | grad_cache.functional.cached(func: Callable[..., Tensor]) 197 | ``` 198 | A decorator that takes a model call function into a cached compatible version. 199 | 200 | **func** - A function that calls the model and return representation tensor. 201 | 202 | **Return** - A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor. 203 | ``` 204 | grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor]) 205 | ``` 206 | A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor on the 0th dimension. This can come in handy dealing with results of representation tensors from multiple cached forward. 207 | 208 | **func** - A loss function 209 | 210 | **Return** - Decorated loss function for cached results. 211 | 212 | ``` 213 | grad_cache.functional.gather_input_tensor(func: Callable[..., Tensor], axis=0) 214 | ``` 215 | A decorator that all-gather positional and keyword arguments of type Tensor and concatenate them on axis. Intended to be used to create distributed contrastive learning loss. 216 | 217 | **func** - A loss function 218 | 219 | **Return** - Decorated loss function for distributed training. 220 | ### Usage 221 | The functional decorators are particular useful if your data loader is emitting small batches, from which you can construct the big batch. Say you also want to do automatic mixed precision, we first define the model call function and loss function, 222 | ``` 223 | from grad_cache.functional import cached, cat_input_tensor 224 | 225 | import torch 226 | import torch.nn.functional as F 227 | from torch.cuda.amp import autocast 228 | 229 | @cached 230 | @autocast() 231 | def call_model(model, input): 232 | return model(**input).pooler_output 233 | 234 | @cat_input_tensor 235 | @autocast() 236 | def contrastive_loss(x, y): 237 | target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device) 238 | scores = torch.matmul(x, y.transpose(0, 1)) 239 | return F.cross_entropy(scores, target=target) 240 | ``` 241 | Say you have a DataLoader `loader` emitting small batches of tuple `(xx, yy)` of size (M * N) and that you want to train by aggregating 16 small batches to get a batch of (16M * 16N), 242 | 243 | ``` 244 | cache_x = [] 245 | cache_y = [] 246 | closures_x = [] 247 | closures_y = [] 248 | 249 | for step, sub_batch in enumerate(loader): 250 | xx, yy = sub_batch 251 | rx, cx = call_model(bert, xx) 252 | ry, cy = call_model(bert, yy) 253 | 254 | cache_x.append(rx) 255 | cache_y.append(ry) 256 | closuresx.append(cx) 257 | closuresy.append(cy) 258 | 259 | if (step + 1) % 16 == 0: 260 | loss = contrastive_loss(cache_x, cache_y) 261 | scaler.scale(loss).backward() 262 | 263 | for f, r in zip(closuresx, cache_x): 264 | f(r) 265 | for f, r in zip(closuresy, cache_y): 266 | f(r) 267 | 268 | cache_x = [] 269 | cache_y = [] 270 | closures_x = [] 271 | closures_y = [] 272 | 273 | scaler.step(optimizer) 274 | scaler.update() 275 | optimizer.zero_grad() 276 | ``` 277 | ### Distributed Training 278 | Running distributed multi-process training requires: 1) (all-)gather representations across devices and 2) (all-reduce) gradients across devices. Both steps will happen **outside** the cached decorated funtions. 279 | 280 | The latter is easy to achieve by wrapping encoders, e.g. a `bert`, in `DistributedDataParallel`. 281 | ``` 282 | bert = DistributedDataParallel( 283 | bert, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 284 | ``` 285 | 286 | The former requires extra distributed ops in the loss function, which should be done according the original loss definition. For example, 287 | ``` 288 | from torch import distributed as dist 289 | from grad_cache.functional import cat_input_tensor, gather_input_tensor 290 | 291 | @cat_input_tensor 292 | @gather_input_tensor 293 | @autocast() 294 | def contrastive_loss(x, y): 295 | target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device) 296 | scores = torch.matmul(x, y.transpose(0, 1)) 297 | # scale the loss as DistributedDataParallel will do mean reduce 298 | return F.cross_entropy(scores, target=target) * dist.get_world_size() 299 | ``` 300 | ## Code Structure 301 | [grad_cache/grad_cache.py](src/grad_cache/grad_cache.py) - Define the GradCache class. The code is under 300 lines including comments. For development, we encourage you to read through it. 302 | 303 | [grad_cache/functional.py](src/grad_cache/functional.py) - Define decorators to create higher order function for gradient caching from ordinary model call functions and loss functions. 304 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='GradCache', 5 | version='0.1.0', 6 | packages=['grad_cache', 'grad_cache.cachex', 'grad_cache.pytorch_lightning'], 7 | package_dir={'': 'src', 'grad_cache': 'src/grad_cache'}, 8 | url='https://github.com/luyug/GradCache', 9 | license='Apache-2.0', 10 | author='Luyu Gao', 11 | author_email='luyug@cs.cmu.edu', 12 | description='' 13 | ) 14 | -------------------------------------------------------------------------------- /src/grad_cache/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .grad_cache import GradCache 3 | except ModuleNotFoundError: 4 | pass 5 | -------------------------------------------------------------------------------- /src/grad_cache/cachex/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import chunk_encode, cache_grad, unchunk_args, grad_cached 2 | from .tree_utils import tree_chunk, tree_unchunk 3 | 4 | -------------------------------------------------------------------------------- /src/grad_cache/cachex/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Any, Callable 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from .tree_utils import tree_unchunk 8 | 9 | Array = jax.Array 10 | 11 | 12 | def grad_with_cache(f, **grad_kwargs): 13 | def cache_f(params, cache, *args, **kwargs): 14 | return jnp.sum(f(params, *args, **kwargs) * cache) 15 | return jax.grad(cache_f, **grad_kwargs) 16 | 17 | 18 | def encode_scan_fn(f, carry, x): 19 | return carry, f(**x) 20 | 21 | 22 | def cache_grad_scan_fn(f, params, acc, x): 23 | cached_grad, kwargs = x 24 | 25 | def fwd_fn(w): 26 | return f(params=w, **kwargs) 27 | 28 | chunk_grad = grad_with_cache(fwd_fn)(params, cached_grad) 29 | acc = jax.tree_multimap(lambda u, v: u + v, acc, chunk_grad) 30 | return acc, None 31 | 32 | 33 | def chunk_encode(encode_fn): 34 | def f(**xx): 35 | _, hh = jax.lax.scan(partial(encode_scan_fn, encode_fn), 0, xx) 36 | return hh 37 | return f 38 | 39 | 40 | def cache_grad(encode_fn): 41 | def f(params, grad_accumulator, cached_grad, **xx): 42 | grads, _ = jax.lax.scan( 43 | partial(cache_grad_scan_fn, encode_fn, params), grad_accumulator, [cached_grad, xx] 44 | ) 45 | return grads 46 | return f 47 | 48 | 49 | def unchunk_args(axis: int = 0, argnums: Iterable[int] = ()): 50 | def decorator_unchunk(f): 51 | def g(*args, **kwargs): 52 | new_args = list(args) 53 | for i in argnums: 54 | new_args[i] = tree_unchunk(args[i], axis) 55 | return f(*new_args, **kwargs) 56 | 57 | return g 58 | 59 | return decorator_unchunk 60 | 61 | def grad_cached( 62 | f: Callable[..., Array], 63 | policy: Callable[..., bool] = jax.checkpoint_policies.nothing_saveable, 64 | prevent_cse: bool = True 65 | ): 66 | """ 67 | Single-decorator version of grad cache that uses XLA to infer backward pass. 68 | 69 | The forward pass is manually split into chunks and performed sequentially with lax.scan. 70 | We rely on XLA to infer the backward pass and run it in a similar fashion. 71 | 72 | Args: 73 | f: Function to be differentiated. It should take in a single argument and return a jax array of representations. 74 | policy: The sub-batch rematerialization policy. 75 | prevent_cse: Whether to prevent common subexpression elimination. 76 | 77 | Returns: 78 | Decorated gradient cached `f` that expects input to have an extra leading sub-batch dimension, potentially produced by `tree_chunk`. 79 | 80 | A example of usage on a apply function that takes multiple arguments: 81 | 82 | >>> @cachex.grad_cached 83 | ... def fwd(params, batch): 84 | ... return apply(params, **batch) 85 | 86 | >>> src = cachex.tree_chunk(src, 8) 87 | >>> tgt = cachex.tree_chunk(tgt, 8) 88 | 89 | >>> def compute_loss(params, src, tgt): 90 | ... h_src = fwd(params, src) 91 | ... h_tgt = fwd(params, tgt) 92 | ... return loss(h_src, h_tgt) 93 | 94 | >>> grads = jax.grad(compute_loss)(params, src, tgt) 95 | 96 | Here the `compute_loss` function can typically be dropped into a larger training step function. 97 | """ 98 | def cached_f(params, batch): 99 | def scan_f(_, sub_batch): 100 | return None, f(params, sub_batch) 101 | _, reps = jax.lax.scan(jax.checkpoint(scan_f, policy=policy, prevent_cse=prevent_cse), None, batch) 102 | return jnp.reshape(reps, (-1,) + reps.shape[2:]) 103 | return cached_f -------------------------------------------------------------------------------- /src/grad_cache/cachex/training.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from .functional import chunk_encode, cache_grad, unchunk_args 7 | 8 | 9 | def cache_train_step(loss_fn, state, ss, tt, axis='device'): 10 | def encode_with_params(params, **kwargs): 11 | return state.apply_fn(params=params, **kwargs) 12 | 13 | encode_fn = chunk_encode(partial(encode_with_params, state.params)) 14 | grad_fn = cache_grad(encode_with_params) 15 | 16 | s_reps = encode_fn(**ss) 17 | t_reps = encode_fn(**tt) 18 | 19 | @unchunk_args(axis=0, argnums=(0, 1)) 20 | def grad_cache_fn(xx, yy): 21 | return jnp.mean(loss_fn(xx, yy, axis=axis)) 22 | loss, (s_grads, t_grads) = jax.value_and_grad(grad_cache_fn, argnums=(0, 1))(s_reps, t_reps) 23 | 24 | grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params) 25 | grads = grad_fn(state.params, grads, s_grads, **ss) 26 | grads = grad_fn(state.params, grads, t_grads, **tt) 27 | 28 | loss, grads = jax.lax.pmean([loss, grads], axis) 29 | new_state = state.apply_gradients(grads=grads) 30 | return loss, new_state 31 | -------------------------------------------------------------------------------- /src/grad_cache/cachex/tree_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import jax 4 | 5 | 6 | def tree_chunk(tree: Any, n_chunk: int, axis: int = 0) -> Any: 7 | return jax.tree_map( 8 | lambda v: v.reshape(v.shape[:axis] + (n_chunk, -1) + v.shape[axis + 1:]), 9 | tree 10 | ) 11 | 12 | 13 | def tree_unchunk(tree: Any, axis: int = 0) -> Any: 14 | return jax.tree_map( 15 | lambda x: x.reshape(x.shape[:axis] + (-1,) + x.shape[axis + 2:]), 16 | tree 17 | ) 18 | -------------------------------------------------------------------------------- /src/grad_cache/context_managers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.checkpoint import get_device_states, set_device_states 3 | 4 | 5 | class RandContext: 6 | def __init__(self, *tensors): 7 | self.fwd_cpu_state = torch.get_rng_state() 8 | self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) 9 | 10 | def __enter__(self): 11 | self._fork = torch.random.fork_rng( 12 | devices=self.fwd_gpu_devices, 13 | enabled=True 14 | ) 15 | self._fork.__enter__() 16 | torch.set_rng_state(self.fwd_cpu_state) 17 | set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) 18 | 19 | def __exit__(self, exc_type, exc_val, exc_tb): 20 | self._fork.__exit__(exc_type, exc_val, exc_tb) 21 | self._fork = None -------------------------------------------------------------------------------- /src/grad_cache/functional.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Callable, Union, Tuple, Any 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch import distributed as dist 7 | 8 | from .context_managers import RandContext 9 | 10 | 11 | def cached(func: Callable[..., Tensor]): 12 | """ 13 | A decorator that takes a model call function into a cached compatible version. 14 | :param func: A function that calls the model and return representation tensor. 15 | :return: A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for 16 | the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor. 17 | """ 18 | @wraps(func) 19 | def cache_func(*args, **kwargs): 20 | rnd_state = RandContext() 21 | with torch.no_grad(): 22 | reps_no_grad = func(*args, **kwargs) 23 | if isinstance(reps_no_grad, Tensor): 24 | reps_no_grad = (reps_no_grad, ) 25 | else: 26 | assert all(isinstance(v, Tensor) for v in reps_no_grad) 27 | leaf_reps = tuple(t.detach().requires_grad_() for t in reps_no_grad) 28 | 29 | @wraps(func) 30 | def forward_backward_func(cache_reps: Union[Tensor, Tuple[Tensor]]): 31 | with rnd_state: 32 | reps = func(*args, **kwargs) 33 | if isinstance(reps, Tensor): 34 | reps = (reps,) 35 | if isinstance(cache_reps, Tensor): 36 | cache_reps = (cache_reps,) 37 | assert len(reps) == len(cache_reps) 38 | 39 | surrogate = sum(map(lambda u, v: torch.dot(u.flatten(), v.grad.flatten()), reps, cache_reps), 0) 40 | surrogate.backward() 41 | 42 | return leaf_reps + (forward_backward_func,) 43 | return cache_func 44 | 45 | 46 | def _cat_tensor_list(xx): 47 | if isinstance(xx, list) and len(xx) > 0 and all(isinstance(x, Tensor) for x in xx): 48 | return torch.cat(xx) 49 | else: 50 | return xx 51 | 52 | 53 | def cat_input_tensor(func: Callable[..., Tensor]): 54 | """ 55 | A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor 56 | on the 0 dimension. This can come in handy dealing with results of representation tensors from multiple 57 | cached forward. 58 | :param func: A loss function 59 | :return: Decorated loss function for cached results. 60 | """ 61 | @wraps(func) 62 | def cat_f(*args, **kwargs): 63 | args_cat = [_cat_tensor_list(x) for x in args] 64 | kwargs_cat = dict((k, _cat_tensor_list(v)) for k, v in kwargs.values()) 65 | return func(*args_cat, **kwargs_cat) 66 | return cat_f 67 | 68 | 69 | def _maybe_gather_tensor(t: Any, axis: int): 70 | if not isinstance(t, Tensor): 71 | return t 72 | gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())] 73 | dist.all_gather(gathered, t) 74 | gathered[dist.get_rank()] = t 75 | return torch.cat(gathered, dim=axis) 76 | 77 | 78 | def gather_input_tensor(func: Callable[..., Tensor], axis=0): 79 | """ 80 | A decorator that all-gather positional and keyword arguments of type Tensor and concatenate them on axis. 81 | Intended to be used with distributed contrastive learning loss. 82 | :param func: A loss function 83 | :param axis: The axis the gathered tensors are concatenated. 84 | :return: Decorated loss function for distributed training. 85 | """ 86 | @wraps(func) 87 | def f(*args, **kwargs): 88 | args_gathered = [_maybe_gather_tensor(x, axis=axis) for x in args] 89 | kwargs_gathered = dict((k, _maybe_gather_tensor(v, axis=axis)) for k, v in kwargs.values()) 90 | return func(*args_gathered, **kwargs_gathered) 91 | return f 92 | -------------------------------------------------------------------------------- /src/grad_cache/grad_cache.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Callable, Any 2 | from contextlib import nullcontext 3 | from itertools import repeat 4 | from collections import UserDict 5 | import logging 6 | 7 | import torch 8 | from torch import nn, Tensor 9 | from torch.cuda.amp import GradScaler, autocast 10 | 11 | from grad_cache.context_managers import RandContext 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class GradCache: 17 | """ 18 | Gradient Cache class. Implements input chunking, first graph-less forward pass, Gradient Cache creation, second 19 | forward & backward gradient computation. Optimizer step is not included. Native torch automatic mixed precision is 20 | supported. User needs to handle gradient unscaling and scaler update after a gradeitn cache step. 21 | """ 22 | def __init__( 23 | self, 24 | models: List[nn.Module], 25 | chunk_sizes: Union[int, List[int]], 26 | loss_fn: Callable[..., Tensor], 27 | split_input_fn: Callable[[Any, int], Any] = None, 28 | get_rep_fn: Callable[..., Tensor] = None, 29 | fp16: bool = False, 30 | scaler: GradScaler = None, 31 | ): 32 | """ 33 | Initialize the Gradient Cache class instance. 34 | :param models: A list of all encoder models to be updated by the current cache. 35 | :param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model. 36 | :param loss_fn: A loss function that takes arbitrary numbers of representation tensors and 37 | arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations 38 | in the autograd graph, which are later relied upon to create the gradient cache. 39 | :param split_input_fn: An optional function that split generic model input into chunks. If not provided, this 40 | class will try its best to split the inputs of supported types. See `split_inputs` function. 41 | :param get_rep_fn: An optional function that takes generic model output and return representation tensors. If 42 | not provided, the generic output is assumed to be the representation tensor. 43 | :param fp16: If True, run mixed precision training, which requires scaler to also be set. 44 | :param scaler: A GradScaler object for automatic mixed precision training. 45 | """ 46 | self.models = models 47 | 48 | if isinstance(chunk_sizes, int): 49 | self.chunk_sizes = [chunk_sizes for _ in range(len(models))] 50 | else: 51 | self.chunk_sizes = chunk_sizes 52 | 53 | self.split_input_fn = split_input_fn 54 | self.get_rep_fn = get_rep_fn 55 | self.loss_fn = loss_fn 56 | 57 | if fp16: 58 | assert scaler is not None, "mixed precision training requires a gradient scaler passed in" 59 | 60 | self.fp16 = fp16 61 | self.scaler = scaler 62 | 63 | self._get_input_tensors_strict = False 64 | 65 | def __call__(self, *args, **kwargs): 66 | """ 67 | Call the cache_step function. 68 | :return: Current step loss. 69 | """ 70 | return self.cache_step(*args, **kwargs) 71 | 72 | def split_inputs(self, model_input, chunk_size: int) -> List: 73 | """ 74 | Split input into chunks. Will call user provided `split_input_fn` if specified. Otherwise, 75 | it can handle input types of tensor, list of tensors and dictionary of tensors. 76 | :param model_input: Generic model input. 77 | :param chunk_size: Size of each chunk. 78 | :return: A list of chunked model input. 79 | """ 80 | # delegate splitting to user provided function 81 | if self.split_input_fn is not None: 82 | return self.split_input_fn(model_input, chunk_size) 83 | 84 | if isinstance(model_input, (dict, UserDict)) and all(isinstance(x, Tensor) for x in model_input.values()): 85 | keys = list(model_input.keys()) 86 | chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] 87 | return [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] 88 | 89 | elif isinstance(model_input, list) and all(isinstance(x, Tensor) for x in model_input): 90 | chunked_x = [t.split(chunk_size, dim=0) for t in model_input] 91 | return [list(s) for s in zip(*chunked_x)] 92 | 93 | elif isinstance(model_input, Tensor): 94 | return list(model_input.split(chunk_size, dim=0)) 95 | 96 | elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]: 97 | args_chunks = self.split_inputs(model_input[0], chunk_size) 98 | kwargs_chunks = self.split_inputs(model_input[1], chunk_size) 99 | return list(zip(args_chunks, kwargs_chunks)) 100 | 101 | else: 102 | raise NotImplementedError(f'Model input split not implemented for type {type(model_input)}') 103 | 104 | def get_input_tensors(self, model_input) -> List[Tensor]: 105 | """ 106 | Recursively go through model input and grab all tensors, which are then used to record current device random 107 | states. This method will do its best to parse types of Tensor, tuple, list, dict and UserDict. Other types will 108 | be ignored unless self._get_input_tensors_strict is set to True, in which case an exception will be raised. 109 | :param model_input: input to model 110 | :return: all torch tensors in model_input 111 | """ 112 | if isinstance(model_input, Tensor): 113 | return [model_input] 114 | 115 | elif isinstance(model_input, (list, tuple)): 116 | return sum((self.get_input_tensors(x) for x in model_input), []) 117 | 118 | elif isinstance(model_input, (dict, UserDict)): 119 | return sum((self.get_input_tensors(x) for x in model_input.values()), []) 120 | 121 | elif self._get_input_tensors_strict: 122 | raise NotImplementedError(f'get_input_tensors not implemented for type {type(model_input)}') 123 | 124 | else: 125 | return [] 126 | 127 | def model_call(self, model: nn.Module, model_input): 128 | """ 129 | Literally call the model's __call__ method. 130 | :param model: model to be called 131 | :param model_input: input to the model call 132 | :return: model output 133 | """ 134 | with autocast() if self.fp16 else nullcontext(): 135 | if isinstance(model_input, Tensor): 136 | return model(model_input) 137 | elif isinstance(model_input, list): 138 | return model(*model_input) 139 | elif isinstance(model_input, (dict, UserDict)): 140 | return model(**model_input) 141 | elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]: 142 | model_args, model_kwargs = model_input 143 | return model(*model_args, **model_kwargs) 144 | else: 145 | raise NotImplementedError 146 | 147 | def get_reps(self, model_out) -> Tensor: 148 | """ 149 | Return representation tensor from generic model output 150 | :param model_out: generic model output 151 | :return: a single tensor corresponding to the model representation output 152 | """ 153 | if self.get_rep_fn is not None: 154 | return self.get_rep_fn(model_out) 155 | else: 156 | return model_out 157 | 158 | def compute_loss(self, *reps: Tensor, **loss_kwargs) -> Tensor: 159 | """ 160 | Compute the loss based on the representation tensors. The tensors should be ordered same as the list of models 161 | registered in this GradCache class instance. 162 | :param reps: Representations for computing the loss. 163 | :param loss_kwargs: Keyword arguments input to the loss function. 164 | :return: the loss tensor. 165 | """ 166 | loss = self.loss_fn(*reps, **loss_kwargs) 167 | return loss 168 | 169 | def forward_no_grad( 170 | self, 171 | model: nn.Module, 172 | model_inputs, 173 | ) -> [Tensor, List[RandContext]]: 174 | """ 175 | The first forward pass without gradient computation. 176 | :param model: Encoder model. 177 | :param model_inputs: Model input already broken into chunks. 178 | :return: A tuple of a) representations and b) recorded random states. 179 | """ 180 | rnd_states = [] 181 | model_reps = [] 182 | 183 | with torch.no_grad(): 184 | for x in model_inputs: 185 | rnd_states.append(RandContext(*self.get_input_tensors(x))) 186 | y = self.model_call(model, x) 187 | model_reps.append(self.get_reps(y)) 188 | 189 | # concatenate all sub-batch representations 190 | model_reps = torch.cat(model_reps, dim=0) 191 | return model_reps, rnd_states 192 | 193 | def build_cache(self, *reps: Tensor, **loss_kwargs) -> [List[Tensor], Tensor]: 194 | """ 195 | Compute the gradient cache 196 | :param reps: Computed representations from all encoder models 197 | :param loss_kwargs: Extra keyword arguments to the loss function 198 | :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor 199 | """ 200 | reps = [r.detach().requires_grad_() for r in reps] 201 | with autocast() if self.fp16 else nullcontext(): 202 | loss = self.compute_loss(*reps, **loss_kwargs) 203 | 204 | if self.fp16: 205 | self.scaler.scale(loss).backward() 206 | else: 207 | loss.backward() 208 | 209 | cache = [r.grad for r in reps] 210 | 211 | return cache, loss.detach() 212 | 213 | def forward_backward( 214 | self, 215 | model: nn.Module, 216 | model_inputs, 217 | cached_gradients: List[Tensor], 218 | random_states: List[RandContext], 219 | no_sync_except_last: bool = False 220 | ): 221 | """ 222 | Run the second forward and the backward pass to compute gradient for a model. 223 | :param model: Encoder model. 224 | :param model_inputs: Chunked input to the encoder model. 225 | :param cached_gradients: Chunked gradient cache tensor for each input. 226 | :param random_states: Each input's device random state during the first forward. 227 | :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes 228 | for the last sub-batch's forward-backward pass. 229 | """ 230 | if no_sync_except_last: 231 | sync_contexts = [model.no_sync for _ in range(len(model_inputs) - 1)] + [nullcontext] 232 | else: 233 | sync_contexts = [nullcontext for _ in range(len(model_inputs))] 234 | 235 | for x, state, gradient, sync_context in zip(model_inputs, random_states, cached_gradients, sync_contexts): 236 | with sync_context(): 237 | with state: 238 | y = self.model_call(model, x) 239 | reps = self.get_reps(y) 240 | 241 | surrogate = torch.dot(reps.flatten(), gradient.flatten()) 242 | surrogate.backward() 243 | 244 | def cache_step( 245 | self, 246 | *model_inputs, 247 | no_sync_except_last: bool = False, 248 | **loss_kwargs 249 | ) -> Tensor: 250 | """ 251 | Run a cached step to compute gradient over the inputs. 252 | :param model_inputs: Input to each encoder model. Should be in similar order as the class's model. 253 | :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction 254 | across processes for the last sub-batch's forward-backward pass. 255 | :param loss_kwargs: Additional keyword arguments to the loss function. 256 | :return: The current's loss. 257 | """ 258 | all_reps = [] 259 | all_rnd_states = [] 260 | 261 | if no_sync_except_last: 262 | assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \ 263 | 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \ 264 | 'proper initializations.' 265 | 266 | model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)] 267 | 268 | for model, x in zip(self.models, model_inputs): 269 | model_reps, rnd_states = self.forward_no_grad(model, x) 270 | all_reps.append(model_reps) 271 | all_rnd_states.append(rnd_states) 272 | 273 | cache, loss = self.build_cache(*all_reps, **loss_kwargs) 274 | cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)] 275 | 276 | for model, x, model_cache, rnd_states in zip( 277 | self.models, model_inputs, cache, all_rnd_states): 278 | self.forward_backward(model, x, model_cache, rnd_states, no_sync_except_last=no_sync_except_last) 279 | 280 | return loss 281 | -------------------------------------------------------------------------------- /src/grad_cache/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | from torch import distributed as dist 7 | 8 | 9 | class SimpleContrastiveLoss: 10 | def __init__(self, n_hard_negatives: int = 0): 11 | self.target_per_qry = n_hard_negatives + 1 12 | 13 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'): 14 | if target is None: 15 | assert x.size(0) * self.target_per_qry == y.size(0) 16 | target = torch.arange(0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device) 17 | 18 | logits = torch.matmul(x, y.transpose(0, 1)) 19 | return F.cross_entropy(logits, target, reduction=reduction) 20 | 21 | 22 | class DistributedContrastiveLoss(SimpleContrastiveLoss): 23 | def __init__(self, n_hard_negatives: int = 0): 24 | assert dist.is_initialized(), "Distributed training has not been properly initialized." 25 | 26 | super().__init__(n_hard_negatives=n_hard_negatives) 27 | self.word_size = dist.get_world_size() 28 | self.rank = dist.get_rank() 29 | 30 | def __call__(self, x: Tensor, y: Tensor, **kwargs): 31 | dist_x = self.gather_tensor(x) 32 | dist_y = self.gather_tensor(y) 33 | 34 | return super().__call__(dist_x, dist_y, **kwargs) 35 | 36 | def gather_tensor(self, t): 37 | gathered = [torch.empty_like(t) for _ in range(self.word_size)] 38 | dist.all_gather(gathered, t) 39 | gathered[self.rank] = t 40 | return torch.cat(gathered, dim=0) 41 | 42 | 43 | class ContrastiveLossWithQueryClosure(SimpleContrastiveLoss): 44 | def __call__( 45 | self, 46 | *reps: Tensor, 47 | query_closure: Callable[[], Tensor] = None, 48 | target: Tensor = None, 49 | reduction: str = 'mean' 50 | ): 51 | if len(reps) == 0 or len(reps) > 2: 52 | raise ValueError(f'Expecting 1 or 2 tensor input, got {len(reps)} tensors') 53 | 54 | # no closure evaluation 55 | if len(reps) == 2: 56 | assert query_closure is None, 'received 2 representation tensors while query_closure is also set' 57 | return super().__call__(*reps, target=target, reduction=reduction) 58 | 59 | # run the closure 60 | assert query_closure is not None 61 | x = query_closure() 62 | y = reps[0] 63 | return super().__call__(x, y, target=target, reduction=reduction) 64 | -------------------------------------------------------------------------------- /src/grad_cache/pytorch_lightning/pl_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch Lightning Example of using Grad Cache, tested on PyTorch Lightning version '2.2.0.post0' with Multi-GPUs and Mix-Precision (fp-16). 3 | Required to install Pytorch Metric Learning as well for contrastive loss calculation. 4 | """ 5 | 6 | import os 7 | import argparse 8 | import torch 9 | import lightning as pl 10 | from contextlib import nullcontext 11 | from lightning.pytorch.loggers import WandbLogger 12 | from lightning.pytorch.strategies import DDPStrategy 13 | from pytorch_metric_learning.utils import distributed as pml_dist 14 | from pytorch_metric_learning.losses import SupConLoss 15 | 16 | from grad_cache.pytorch_lightning.pl_gradcache import PLGradCache 17 | 18 | 19 | class RandomDataset(torch.utils.data.Dataset): 20 | def __init__(self, params): 21 | self.params = params 22 | 23 | def __len__(self): 24 | return self.params.data_size 25 | 26 | def __getitem__(self, idx): 27 | # Generate random float inputs with shape [2, input_dim] for contrastive learning 28 | input_data = torch.randn(2, self.params.input_dim) 29 | # Generate a random integer label for binary classification (0 or 1), replicate it to have shape [2] 30 | label = torch.randint(0, 2, (1,), dtype=torch.long) 31 | label = torch.tensor([label, label], dtype=torch.long) 32 | return input_data, label 33 | 34 | 35 | class SimpleLitModel(pl.LightningModule): 36 | def __init__(self, params): 37 | super().__init__() 38 | self.params = params 39 | self.loss = SupConLoss(temperature=params.temperature) 40 | if params.gpus > 1: 41 | self.loss = pml_dist.DistributedLossWrapper(self.loss) 42 | self.automatic_optimization = (not self.params.use_gc) # needed when use_gc is on 43 | self.fp16 = (self.params.precision == 16) 44 | self.linear = torch.nn.Linear(params.input_dim, params.embed_dim) # our simple model 45 | 46 | def init_gc(self, scaler, ddp_module): 47 | """Sets up the required components of GradCache. This method is called after the model is initialized.""" 48 | assert self.params.use_gc 49 | if self.fp16 and self.params.use_gc: 50 | # pytorch lightning autocast wraps everything in it 51 | # it needs to be disabled in gradcache because we do forward twice, and one with no grad 52 | # then we do autocast manually in gradcache when we need to 53 | # original post: https://discuss.pytorch.org/t/autocast-and-torch-no-grad-unexpected-behaviour/93475/3 54 | # pl source code: your_venv_name/lib/python3.8/site-packages/lightning/pytorch/plugins/precision/amp.py::forward_context 55 | self.trainer.strategy.precision_plugin.forward_context = nullcontext 56 | 57 | print(f"*** initializing gradcache with ddp_module={type(ddp_module)}, minibatch_size={self.params.gc_minibatch_size}") 58 | self.gc = PLGradCache( 59 | models=[ddp_module], 60 | chunk_sizes=self.params.gc_minibatch_size, 61 | loss_fn=self.calculate_loss, 62 | fp16=self.fp16, 63 | scaler=(scaler if self.fp16 else None), # needed when using automatic_optimization is off and fp16 is on 64 | backward_fn=self.manual_backward, # needed when automatic_optimization is off 65 | ) 66 | 67 | def train_dataloader(self): 68 | train_dataset = RandomDataset(params) 69 | train_loader = torch.utils.data.DataLoader( 70 | train_dataset, 71 | batch_size=params.batch_size, 72 | num_workers=params.num_workers, 73 | drop_last=True, 74 | ) 75 | return train_loader 76 | 77 | def configure_optimizers(self): 78 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 79 | return optimizer 80 | 81 | def calculate_loss(self, embeddings, labels): 82 | # embeddings shape [batch_size, 2, embed_dim] 83 | # labels shape [batch_size, 2] 84 | embeddings = embeddings.flatten(0, 1) 85 | labels = labels.flatten() 86 | return self.loss(embeddings, labels) 87 | 88 | def forward(self, inputs): # needed for grad cache 89 | return self.linear(inputs) 90 | 91 | def on_train_start(self): # initialize grad cache here 92 | if self.params.use_gc: 93 | self.init_gc(self.trainer.scaler, self.trainer.strategy.model) 94 | # self.init_gc(self.trainer.scaler, self.trainer.lightning_module) # we can use this if nccl strategy is available 95 | 96 | def training_step(self, batch, batch_idx): 97 | # inputs shape [batch_size, 2, input_dim] 98 | # labels shape [batch_size, 2] 99 | inputs, labels = batch 100 | if self.params.use_gc: 101 | assert self.gc is not None 102 | optimizer = self.optimizers() 103 | optimizer.zero_grad() 104 | loss = self.gc( 105 | inputs, 106 | no_sync_except_last=(self.params.gpus > 1), 107 | labels=labels.flatten(), 108 | ) 109 | loss /= max(1, self.params.gpus) # needed when automatic_optimization is off 110 | log_loss = loss 111 | optimizer.step() 112 | else: 113 | outputs = self.linear(inputs) 114 | loss = self.calculate_loss(outputs, labels) 115 | log_loss = loss / max(1, self.params.gpus) 116 | self.log( 117 | "train_loss", 118 | log_loss, 119 | on_step=True, 120 | on_epoch=True, 121 | sync_dist=self.params.use_gc, # needed when automatic_optimization is off 122 | ) 123 | print(f"batch_idx={batch_idx}, loss={loss}") 124 | return loss 125 | 126 | 127 | def get_argument_parser(): 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument("--random_seed", type=int, default=42) 130 | parser.add_argument("--num_workers", type=int, default=4) 131 | parser.add_argument("--gpus", type=int, default=0) 132 | parser.add_argument("--precision", type=int, default=32) 133 | parser.add_argument("--ddp_backend", type=str, default="nccl", help="torch distributed backend (Default: nccl), use 'gloo' if nccl doesn't work") 134 | parser.add_argument("--project_name", type=str, default="debug_gradcache") 135 | 136 | # training params 137 | parser.add_argument("--data_size", type=int, default=100) 138 | parser.add_argument("--epochs", type=int, default=10) 139 | parser.add_argument("--batch_size", type=int, default=16) 140 | parser.add_argument("--temperature", type=float, default=0.1) 141 | 142 | # model hyperparams 143 | parser.add_argument("--input_dim", type=int, default=784) 144 | parser.add_argument("--embed_dim", type=int, default=512) 145 | 146 | # grad cache params 147 | parser.add_argument("--use_gc", action="store_true", default=False, help="whether to use grad cache") 148 | parser.add_argument("--gc_minibatch_size", type=int, default=2, help="mini batch size of grad cache, must be provided if use_gc is on") 149 | 150 | return parser 151 | 152 | 153 | def main(params): 154 | # set random seeds reproduceability 155 | torch.backends.cudnn.deterministic = True 156 | torch.backends.cudnn.benchmark = False 157 | 158 | # set different random seeds for each worker 159 | pl.seed_everything(seed=params.random_seed, workers=True) 160 | 161 | # weirdness with HuggingFace tokenizer when processing things in parallel 162 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 163 | torch.multiprocessing.set_sharing_strategy("file_system") 164 | 165 | # load model 166 | model = SimpleLitModel(params) 167 | 168 | # load trainer 169 | experiment_id = f"gpus-{params.gpus}_precision-{params.precision}" 170 | if params.use_gc: 171 | experiment_id += "_gc" 172 | experiment_id += "_pl" 173 | wandb_logger = WandbLogger( 174 | project=params.project_name, 175 | name=experiment_id, 176 | ) 177 | ddp = DDPStrategy(process_group_backend=params.ddp_backend) 178 | trainer = pl.Trainer( 179 | accelerator="gpu" if params.gpus > 0 else "cpu", 180 | strategy=ddp if params.gpus > 1 else "auto", 181 | devices=params.gpus if params.gpus > 0 else "auto", 182 | precision=params.precision, 183 | logger=wandb_logger, 184 | max_epochs=params.epochs, 185 | log_every_n_steps=1, 186 | ) 187 | trainer.fit(model) 188 | 189 | 190 | if __name__ == "__main__": 191 | params = get_argument_parser().parse_args() 192 | main(params) 193 | -------------------------------------------------------------------------------- /src/grad_cache/pytorch_lightning/pl_gradcache.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | from typing import Any, Callable, List, Tuple, Union 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.cuda.amp import GradScaler, autocast 7 | 8 | from ..grad_cache import GradCache, RandContext 9 | 10 | class PLGradCache(GradCache): 11 | """ 12 | Gradient Cache class with PyTorch Lightning Support. 13 | Implements input chunking, first graph-less forward pass, Gradient Cache creation, second forward & backward gradient computation. 14 | Optimizer step is not included. Native torch automatic mixed precision is supported. 15 | Gradient unscaling and scaler update are handled internally. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | models: List[nn.Module], 21 | chunk_sizes: Union[int, List[int]], 22 | loss_fn: Callable[..., Tensor], 23 | split_input_fn: Callable[[Any, int], Any] = None, 24 | get_rep_fn: Callable[..., Tensor] = None, 25 | fp16: bool = False, 26 | scaler: GradScaler = None, 27 | backward_fn=None, # [added] 28 | ): 29 | """ 30 | Initialize the Gradient Cache class instance. 31 | :param models: A list of all encoder models to be updated by the current cache. 32 | :param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model. 33 | :param loss_fn: A loss function that takes arbitrary numbers of representation tensors and 34 | arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations 35 | in the autograd graph, which are later relied upon to create the gradient cache. 36 | :param split_input_fn: An optional function that split generic model input into chunks. If not provided, this 37 | class will try its best to split the inputs of supported types. See `split_inputs` function. 38 | :param get_rep_fn: An optional function that takes generic model output and return representation tensors. If 39 | not provided, the generic output is assumed to be the representation tensor. 40 | :param fp16: If True, run mixed precision training, which requires scaler to also be set. 41 | :param scaler: A GradScaler object for automatic mixed precision training. 42 | :[added] param backward_fn: The `manual_backward` function of pytorch lightning trainer when automatic_optimization is disabled. 43 | """ 44 | super().__init__(models, chunk_sizes, loss_fn, split_input_fn, get_rep_fn, fp16, scaler) 45 | self.backward_fn = backward_fn 46 | 47 | def build_cache(self, *reps: Tensor, **loss_kwargs) -> Union[List[Tensor], Tensor]: 48 | """ 49 | Compute the gradient cache 50 | :param reps: Computed representations from all encoder models 51 | :param loss_kwargs: Extra keyword arguments to the loss function 52 | :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor 53 | """ 54 | reps = [r.detach().requires_grad_() for r in reps] 55 | with autocast() if self.fp16 else nullcontext(): 56 | loss = self.compute_loss(*reps, **loss_kwargs) 57 | 58 | self.backward_fn(loss) # [modified] 59 | 60 | cache = [r.grad for r in reps] 61 | 62 | return cache, loss.detach() 63 | 64 | def forward_backward( 65 | self, 66 | model: nn.Module, 67 | model_inputs, 68 | cached_gradients: List[Tensor], 69 | random_states: List[RandContext], 70 | no_sync_except_last: bool = False, 71 | ): 72 | """ 73 | Run the second forward and the backward pass to compute gradient for a model. 74 | :param model: Encoder model. 75 | :param model_inputs: Chunked input to the encoder model. 76 | :param cached_gradients: Chunked gradient cache tensor for each input. 77 | :param random_states: Each input's device random state during the first forward. 78 | :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes 79 | for the last sub-batch's forward-backward pass. 80 | """ 81 | if isinstance( 82 | model, nn.parallel.DistributedDataParallel 83 | ): # [use ddp_model] 84 | 85 | if no_sync_except_last: 86 | sync_contexts = [ 87 | model.no_sync for _ in range(len(model_inputs) - 1) 88 | ] + [nullcontext] 89 | sync_flags = [True] * (len(model_inputs)) # [added] 90 | else: 91 | sync_contexts = [nullcontext for _ in range(len(model_inputs))] 92 | sync_flags = [False] * (len(model_inputs)) # [added] 93 | 94 | # [modified] 95 | for x, state, gradient, sync_context, sync_flag in zip( 96 | model_inputs, random_states, cached_gradients, sync_contexts, sync_flags 97 | ): 98 | with sync_context(): 99 | with state: 100 | y = self.model_call(model, x) 101 | reps = self.get_reps(y) 102 | surrogate = torch.dot(reps.flatten(), gradient.flatten()) 103 | if sync_flag: 104 | model.require_backward_grad_sync = True 105 | if self.fp16: # [added] 106 | self.scaler._enabled = False 107 | self.backward_fn(surrogate) 108 | self.scaler._enabled = True 109 | else: 110 | self.backward_fn(surrogate) # [modified] 111 | else: # [use base model (i.e. SimpleLitModel)] 112 | 113 | # [remove no_sync_except_last: pytorch lightning would handle gradient sync automatically] 114 | for x, state, gradient in zip( 115 | model_inputs, random_states, cached_gradients 116 | ): 117 | with state: 118 | y = self.model_call(model, x) 119 | reps = self.get_reps(y) 120 | surrogate = torch.dot(reps.flatten(), gradient.flatten()) 121 | if self.fp16: # [added] 122 | self.scaler._enabled = False 123 | self.backward_fn(surrogate) 124 | self.scaler._enabled = True 125 | else: 126 | self.backward_fn(surrogate) # [added] 127 | 128 | def cache_step( 129 | self, *model_inputs, no_sync_except_last: bool = False, **loss_kwargs 130 | ) -> Tuple[Tensor, Tensor]: 131 | """ 132 | Run a cached step to compute gradient over the inputs. 133 | :param model_inputs: Input to each encoder model. Should be in similar order as the class's model. 134 | :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction 135 | across processes for the last sub-batch's forward-backward pass. 136 | :param loss_kwargs: Additional keyword arguments to the loss function. 137 | :return: A tuple of the current's loss and the model's representation. 138 | """ 139 | all_reps = [] 140 | all_rnd_states = [] 141 | 142 | # [removed: we check it in forward_backward(.)] 143 | # if no_sync_except_last: 144 | # assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \ 145 | # 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \ 146 | # 'proper initializations.' 147 | 148 | model_inputs = [ 149 | self.split_inputs(x, chunk_size) 150 | for x, chunk_size in zip(model_inputs, self.chunk_sizes) 151 | ] 152 | 153 | for model, x in zip(self.models, model_inputs): 154 | model_reps, rnd_states = self.forward_no_grad(model, x) 155 | all_reps.append(model_reps) 156 | all_rnd_states.append(rnd_states) 157 | 158 | # all_reps: len(self.models) x [batch_size, 2, embed_dim] 159 | # cache: len(self.models) x gc_minibatch x [(batch_size / gc_minibatch, 2, embed_dim] 160 | 161 | cache, loss = self.build_cache(*all_reps, **loss_kwargs) 162 | cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)] 163 | 164 | for model, x, model_cache, rnd_states in zip( 165 | self.models, model_inputs, cache, all_rnd_states 166 | ): 167 | self.forward_backward( 168 | model, 169 | x, 170 | model_cache, 171 | rnd_states, 172 | no_sync_except_last=no_sync_except_last, 173 | ) 174 | 175 | return loss 176 | -------------------------------------------------------------------------------- /src/grad_cache/pytorch_lightning/readme.md: -------------------------------------------------------------------------------- 1 | # PL_GradCache 2 | 3 | This is an experimental folder to provide example of using Grad Cache with PyTorch Lightning (pl), tested on pl version '2.2.0.post0' with Multi-GPUs and Mix-Precision (fp-16). Pytorch Metric Learning is required to install as well for contrastive loss calculation. 4 | 5 | - [Wandb Logging Experiments for Sanity Test](https://api.wandb.ai/links/xyznlp/nmf8d551) 6 | 7 | ### Installation 8 | 9 | After GradCache is installed, do 10 | 11 | ``` 12 | cd GradCache/src/grad_cache/pytorch_lightning 13 | python -m venv plgc 14 | . ./plgc/bin/activate 15 | pip3 install -U pip 16 | pip3 install -r requirements.txt 17 | ``` 18 | 19 | ### Reproducing Wandb Experiments 20 | 21 | ``` 22 | # 1-gpu 23 | python pl_example.py --gpus 1 --batch_size 16 24 | # 2-gpus 25 | python pl_example.py --gpus 2 --batch_size 8 26 | # 1-gpu, gradcache 27 | python pl_example.py --gpus 1 --batch_size 16 --use_gc --gc_minibatch_size 2 28 | # 2-gpus, gradcache 29 | python pl_example.py --gpus 2 --batch_size 8 --use_gc --gc_minibatch_size 2 30 | ``` 31 | 32 | Optionally, do mix-precision training with `--precision 16`, run different ddp_backend with `--ddp_backend {gloo/nccl/etc.}` 33 | 34 | ### Example 35 | 36 | Run `python pl_example.py` with the following flags. 37 | 38 | * `--use_gc` activates GradCache. 39 | * `--gc_minibatch_size {minibatch_size}` defines the batch size that each GPU needs to hold its memory into. If we specify `--gpus 2 --batch_size 8 --gc_minibatch 2`, for example, the model would be trained with batch size 8 * 2 = 16, the trainer would split each batch on each GPU (8 data samples) into 4 chunks of mini batches (2 data samples per mini batch). Set this to 1 gives the minimal possible gpu memory usage. 40 | 41 | ### Summary 42 | 43 | - Add `pl_gradcache.py` as customized GradCache on PyTorch Lightning. 44 | - Use manual backward in gradcache by calling `lightning_trainer.manual_backward(loss)` instead of using `loss.backward()` (this requires changing gradcache). 45 | - Set gradcache `no_sync_except_last=True` in multi-GPU case. 46 | 47 | ### Changes to the original GradCache 48 | 49 | #### File Change 50 | - `pl_gradcache.py` is the GradCache we will run on PyTorch Lightning (pl) with Distributed Data Parallel (ddp). 51 | 52 | #### Change in Optimization 53 | - In pt ddp setting, we need to first set `lightning_trainer.automatic_optimization=False` for us to customize calling backward. 54 | - See [the pl optimization doc](https://lightning.ai/docs/pytorch/stable/common/optimization.html) for implementation details, make sure that we are calling `self.optimizers()` instead of creating one by ourselves: if we do `self.optimizer = optimizer` in `self.configure_optimizers()`, this is not correct as it initializes a base optimizer in pt, but `self.optimizers()` is a wrapper for that. The base optimizer does not have the correct access to ddp and logging. 55 | - Then, replace all `loss.backward()` in GradCache with `lightning_trainer.manual_backward(loss)`. 56 | 57 | #### Change in GradCache 58 | - Set `no_sync_except_last=True` in Multi-GPU case to avoid unnecessary gradient reduction in the last step of gradcache. 59 | 60 | #### If you want to run GradCache in PyTorch Lightning with Multi-GPUs 61 | - In short, you are good to go by not worrying about this part. But here are the key changes in the original gradcache that are necessary for this to work. 62 | - we have two options to use gradcache in pl ddp setting and call `self.init_gc(scaler, gc_model)`. 63 | - We can set `gc_model=pytorch lightning trainer`. 64 | - PyTorch Lightning would then wrap the base model (transformer) by their implementation of DDP. 65 | - In this case, just set `no_sync_except_last = False`, because lightning will handle gradient sync before `optimizer.step()`. 66 | - Set `no_sync_except_last = True` in this case does not work as the base model in gradcache is the transformer, which causes gradcache assertion and `model.no_sync` not available error. 67 | - Or, we can just change gradcache instead (remove assert DDP and `model.no_sync`). 68 | - The only downside of this approach is that the training may take a little longer, because gradient sync is done on the full batch size (expected by pytorch lightning) instead of the last minibatch (expected by gradcache). But based on some sanity runs, it is ok (less than a 10% runtime increase). 69 | - We can set `gc_model=pytorch lightning trainer.strategy.model`, i.e. the wrapped base model by PyTorch DDP. 70 | - This is tricky as PyTorch Lightning uses a parameter `require_backward_grad_sync` to determine whether gradients would be synced across GPUs. 71 | - Firstly, Pytorch Lightning overrides the PyTorch DDP by their own implementation and set `require_backward_grad_sync=False` before each training step (when `automatic optimization=False`). Then, it is set it to True **after** each training step. 72 | - The issue here is that gradcache needs the gradient to be synced in the last backward step, which happens inside the training step hook of pytorch lightning. Thus, what we can only do is to set this variable manually before the last backward step in gradcache - we cannot set it outside of gradcache either, because the first backward of gradcache to do gradient checkpointing should NOT sync gradient (this is the point of gradcache essentially). 73 | - Thus, we do `model.require_backward_grad_sync=True` at the very end of gradcache - before the backward of the last minibatch surrogate. 74 | - The advantage of this is that we can do `no_sync_except_last` as what gradcache hopes us to do (no runtime increase). The downside is that we need to modify gradcache in a very hacky way. This is the default setup. 75 | -------------------------------------------------------------------------------- /src/grad_cache/pytorch_lightning/requirements.txt: -------------------------------------------------------------------------------- 1 | lightning 2 | pytorch_metric_learning 3 | --------------------------------------------------------------------------------