├── LICENSE ├── README.md └── checkpointing.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 davisyoshida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow 2 Gradient Checkpointing 2 | This is a simple decorator to enable gradient checkpointing (e.g. [Chen et al. (2016)](https://arxiv.org/pdf/1604.06174.pdf)) in TF2. It isn't very polished, but it's been letting me train bigger GPT-2 models on smaller hardware, so I thought I'd share it. 3 | 4 | 5 | ## Basic Usage 6 | Use the `checkpointable` decorator to allow a function (or callable object such as a Keras `Layer`) to use gradient checkpointing. If checkpointing is desired, call the decorated function with the `_checkpoint` keyword argument set to `True`. 7 | 8 | The example below shows a model with 40000 "layers", but checkpointing allows just 400 to be in memory at any point. On a GTX 1070 Ti, this code will result in an OOM error when the `_checkpoint` argument is set to `False`. 9 | 10 | ```python 11 | import tensorflow as tf 12 | 13 | from checkpointing import checkpointable 14 | 15 | @checkpointable 16 | def f(x, y, some_str, some_bool, z=None): 17 | for _ in range(200): 18 | x += y * z 19 | return x 20 | 21 | initial = tf.ones(100000, dtype=tf.float32) 22 | y = tf.ones(100000, dtype=tf.float32) + 1e-7 23 | z = tf.ones(100000, dtype=tf.float32) + 1e-7 24 | with tf.GradientTape() as g: 25 | g.watch(initial) 26 | x = initial 27 | for _ in range(200): 28 | x = f(x, y, 'a', True, z=z, _checkpoint=True) 29 | loss = tf.reduce_sum(x) 30 | print(g.gradient(loss, x)) 31 | ``` 32 | Arguments which are not float32 tensors (or nested list/tuple structures of such tensors) are allowed, but ignored for the purposes of gradient computation. 33 | 34 | ## Variables 35 | If the decorated function uses variables which are not arguments, pass a list of them via the `_watch_vars` keyword argument as shown below. 36 | 37 | ```python 38 | layer = SomeKerasLayer() 39 | wrapped_layer = checkpointable(layer) 40 | 41 | with tf.GradientTape() as g: 42 | g.watch(layer.trainable_variables) 43 | output = wrapped_layer(*args, **kwargs, _checkpoint=True, _watch_vars=layer.trainable_variables) 44 | print(g.gradient(output, layer.trainable_variables)) 45 | ``` 46 | 47 | ## Warning: Dropout 48 | Because gradient checkpoint relies on re-running the forward pass, stochastic layers such as a dropout will give different results for each pass. There is a hacky workaround available, which you can enable by passing `_force_seed=True` to the decorated function. This will use python's `random` library to get a random number, and set that as TensorFlow's random seed before each forward pass. If you have a better idea for addressing this issue, please do let me know. 49 | -------------------------------------------------------------------------------- /checkpointing.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterator 2 | from functools import wraps 3 | import random 4 | import time 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow.python.util import nest 9 | from tensorflow.python.eager import tape 10 | 11 | def checkpointable(f): 12 | @wraps(f) 13 | def inner(*args, _checkpoint=False, _watch_vars=None, _force_seed=False, **kwargs): 14 | if _force_seed: 15 | if isinstance(_force_seed, Iterator): 16 | seed = next(_force_seed) 17 | else: 18 | seed = random.randint(1, 1<<31) 19 | 20 | if _checkpoint and tape.could_possibly_record(): 21 | if _watch_vars is None: 22 | _watch_vars = [] 23 | 24 | watch_args = [] 25 | 26 | flat_inputs = nest.flatten(args) + nest.flatten(list(kwargs.values())) 27 | flat_inputs = [x for x in flat_inputs if tf.is_tensor(x)] 28 | flat_inputs = [x for x in flat_inputs if x.dtype == tf.float32] 29 | unique_inputs = [x.deref() for x in set(x.experimental_ref() for x in flat_inputs)] 30 | 31 | unique_vars = [ 32 | v.deref() for v in set(v.experimental_ref() for v in _watch_vars) 33 | if not any(v is inp for inp in flat_inputs) 34 | ] 35 | 36 | watches = unique_inputs + unique_vars 37 | tensor_watches = [tf.convert_to_tensor(x) for x in watches] 38 | 39 | with tape.stop_recording(): 40 | if _force_seed: 41 | tf.random.set_seed(seed) 42 | 43 | result = f(*args, **kwargs) 44 | flat_result = nest.flatten(result) 45 | # No idea what the point of this is but they do it in tf.custom_gradient so I'm doing it too 46 | flat_result = [tf.identity(x) for x in flat_result] 47 | output = nest.pack_sequence_as(result, flat_result) 48 | del flat_inputs 49 | del result 50 | del unique_inputs 51 | del unique_vars 52 | 53 | def grad(*output_grads): 54 | with tf.GradientTape() as g: 55 | g.watch(watches) 56 | if _force_seed: 57 | tf.random.set_seed(seed) 58 | recomputed_output = f(*args, **kwargs) 59 | recomputed_output = [tf.identity(x) for x in nest.flatten(recomputed_output)] 60 | 61 | grads = g.gradient(recomputed_output, watches, output_gradients=output_grads) 62 | del g 63 | return grads 64 | 65 | tape.record_operation(str(f), flat_result, tensor_watches, grad) 66 | 67 | return output 68 | else: 69 | if _force_seed: 70 | tf.random.set_seed(seed) 71 | return f(*args, **kwargs) 72 | return inner 73 | --------------------------------------------------------------------------------