├── README.md └── recompute.py /README.md: -------------------------------------------------------------------------------- 1 | # recompute 2 | 3 | 通过重计算来节省显存,参考论文[《Training Deep Nets with Sublinear Memory Cost》](https://arxiv.org/abs/1604.06174)。 4 | 5 | 本程序已经内置在[bert4keras](https://github.com/bojone/bert4keras)中 6 | 7 | ## 使用方法 8 | 9 | 首先,确保环境变量加上`RECOMPUTE=1`。 10 | 11 | 然后,在自定义层的时候,用`recompute_grad`装饰call函数即可: 12 | ```python 13 | from recompute import recompute_grad 14 | 15 | class MyLayer(Layer): 16 | @recompute_grad 17 | def call(self, inputs): 18 | return inputs * 2 19 | ``` 20 | 21 | 如果是现成的层,可以通过继承的方式来装饰: 22 | ```python 23 | from recompute import recompute_grad 24 | 25 | class MyDense(Dense): 26 | @recompute_grad 27 | def call(self, inputs): 28 | return super(MyDense, self).call(inputs) 29 | ``` 30 | 31 | ## 环境依赖 32 | 33 | 在下面的环境下测试通过: 34 | ``` 35 | tensorflow 1.14 + keras 2.3.1 36 | tensorflow 1.15 + keras 2.3.1 37 | tensorflow 2.0 + keras 2.3.1 38 | tensorflow 2.1 + keras 2.3.1 39 | tensorflow 2.0 + 自带tf.keras 40 | tensorflow 2.1 + 自带tf.keras 41 | ``` 42 | 43 | 确认不支持的环境: 44 | ``` 45 | tensorflow 1.x + 自带tf.keras 46 | ``` 47 | 48 | 欢迎报告更多的测试结果。 49 | 50 | **强烈建议用keras 2.3.1配合tensorflow来跑,强烈不建议使用tensorflow 2.x自带的tf.keras来跑** 51 | 52 | ## 使用效果 53 | 54 | - 在BERT Base版本下,batch_size可以增大为原来的3倍左右; 55 | - 在BERT Large版本下,batch_size可以增大为原来的4倍左右; 56 | - 平均每个样本的训练时间大约增加25%; 57 | - 理论上,层数越多,batch_size可以增大的倍数越大。 58 | 59 | ## 参考内容 60 | - https://kexue.fm 61 | - https://github.com/bojone/bert4keras 62 | - https://github.com/davisyoshida/tf2-gradient-checkpointing 63 | - https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/ops/custom_gradient.py#L454-L499 64 | -------------------------------------------------------------------------------- /recompute.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # recompute for keras/tf 3 | 4 | import os 5 | import tensorflow as tf 6 | from tensorflow.python.util import nest, tf_inspect 7 | from tensorflow.python.eager import tape 8 | from tensorflow.python.ops.custom_gradient import _graph_mode_decorator 9 | 10 | # 判断是tf.keras还是纯keras的标记 11 | is_tf_keras = strtobool(os.environ.get('TF_KERAS', '0')) 12 | 13 | if is_tf_keras: 14 | import tensorflow.keras as keras 15 | import tensorflow.keras.backend as K 16 | sys.modules['keras'] = keras 17 | else: 18 | import keras 19 | import keras.backend as K 20 | 21 | # 判断是否启用重计算(通过时间换空间) 22 | do_recompute = strtobool(os.environ.get('RECOMPUTE', '0')) 23 | 24 | 25 | def graph_mode_decorator(f, *args, **kwargs): 26 | """tf 2.1与之前版本的传参方式不一样,这里做个同步 27 | """ 28 | if tf.__version__ < '2.1': 29 | return _graph_mode_decorator(f, *args, **kwargs) 30 | else: 31 | return _graph_mode_decorator(f, args, kwargs) 32 | 33 | 34 | def recompute_grad(call): 35 | """重计算装饰器(用来装饰Keras层的call函数) 36 | 关于重计算,请参考:https://arxiv.org/abs/1604.06174 37 | """ 38 | if not do_recompute: 39 | return call 40 | 41 | def inner(self, inputs, **kwargs): 42 | """定义需要求梯度的函数以及重新定义求梯度过程 43 | (参考自官方自带的tf.recompute_grad函数) 44 | """ 45 | flat_inputs = nest.flatten(inputs) 46 | call_args = tf_inspect.getfullargspec(call).args 47 | for key in ['mask', 'training']: 48 | if key not in call_args and key in kwargs: 49 | del kwargs[key] 50 | 51 | def kernel_call(): 52 | """定义前向计算 53 | """ 54 | return call(self, inputs, **kwargs) 55 | 56 | def call_and_grad(*inputs): 57 | """定义前向计算和反向计算 58 | """ 59 | if is_tf_keras: 60 | with tape.stop_recording(): 61 | outputs = kernel_call() 62 | outputs = tf.identity(outputs) 63 | else: 64 | outputs = kernel_call() 65 | 66 | def grad_fn(doutputs, variables=None): 67 | watches = list(inputs) 68 | if variables is not None: 69 | watches += list(variables) 70 | with tf.GradientTape() as t: 71 | t.watch(watches) 72 | with tf.control_dependencies([doutputs]): 73 | outputs = kernel_call() 74 | grads = t.gradient( 75 | outputs, watches, output_gradients=[doutputs] 76 | ) 77 | del t 78 | return grads[:len(inputs)], grads[len(inputs):] 79 | 80 | return outputs, grad_fn 81 | 82 | if is_tf_keras: # 仅在tf >= 2.0下可用 83 | outputs, grad_fn = call_and_grad(*flat_inputs) 84 | flat_outputs = nest.flatten(outputs) 85 | 86 | def actual_grad_fn(*doutputs): 87 | grads = grad_fn(*doutputs, variables=self.trainable_weights) 88 | return grads[0] + grads[1] 89 | 90 | watches = flat_inputs + self.trainable_weights 91 | watches = [tf.convert_to_tensor(x) for x in watches] 92 | tape.record_operation( 93 | call.__name__, flat_outputs, watches, actual_grad_fn 94 | ) 95 | return outputs 96 | else: # keras + tf >= 1.14 均可用 97 | return graph_mode_decorator(call_and_grad, *flat_inputs) 98 | 99 | return inner 100 | --------------------------------------------------------------------------------