├── 1_tensor_hooks ├── 1_without_hooks.py ├── 2_with_hooks.py ├── 3_with_hooks_and_remove.py ├── 4_addition_instead_of_multiplication.py └── 5_hook_with_in_place_operation.py ├── 2_module_forward_hooks ├── 1_module_without_hooks.py ├── 2_module_with_hooks.py ├── 3_module_with_hooks_and_prints.py └── 4_remove_module_hooks.py ├── 3_module_backward_hooks ├── 1_working_module_backward_hook.py ├── 2_broken_module_backward_hook.py └── 3_tensor_hooks_workaround.py ├── LICENSE └── README.md /1_tensor_hooks/1_without_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | a = torch.tensor(2.0, requires_grad=True) 5 | b = torch.tensor(3.0, requires_grad=True) 6 | 7 | c = a * b 8 | c.retain_grad() 9 | d = torch.tensor(4.0, requires_grad=True) 10 | 11 | e = c * d 12 | e.retain_grad() 13 | e.backward() 14 | 15 | print(f'a.grad {a.grad}') 16 | print(f'b.grad {b.grad}') 17 | print(f'c.grad {c.grad}') 18 | print(f'd.grad {d.grad}') 19 | print(f'e.grad {e.grad}') -------------------------------------------------------------------------------- /1_tensor_hooks/2_with_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | a = torch.tensor(2.0, requires_grad=True) 5 | b = torch.tensor(3.0, requires_grad=True) 6 | 7 | c = a * b 8 | 9 | 10 | def c_hook(grad): 11 | print(grad) 12 | return grad + 2 13 | 14 | 15 | c.register_hook(c_hook) 16 | c.register_hook(lambda grad: print(grad)) 17 | c.retain_grad() 18 | 19 | d = torch.tensor(4.0, requires_grad=True) 20 | d.register_hook(lambda grad: grad + 100) 21 | 22 | e = c * d 23 | 24 | e.retain_grad() 25 | e.register_hook(lambda grad: grad * 2) 26 | e.retain_grad() # second-time retain_grad() is NOP 27 | 28 | e.backward() 29 | 30 | print(f'a.grad {a.grad}') 31 | print(f'b.grad {b.grad}') 32 | print(f'c.grad {c.grad}') 33 | print(f'd.grad {d.grad}') 34 | print(f'e.grad {e.grad}') -------------------------------------------------------------------------------- /1_tensor_hooks/3_with_hooks_and_remove.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | a = torch.tensor(2.0, requires_grad=True) 5 | b = torch.tensor(3.0, requires_grad=True) 6 | 7 | c = a * b 8 | 9 | 10 | def c_hook(grad): 11 | print(grad) 12 | return grad + 2 13 | 14 | 15 | h = c.register_hook(c_hook) 16 | c.register_hook(lambda grad: print(grad)) 17 | c.retain_grad() 18 | 19 | d = torch.tensor(4.0, requires_grad=True) 20 | d.register_hook(lambda grad: grad + 100) 21 | 22 | e = c * d 23 | 24 | e.retain_grad() 25 | e.register_hook(lambda grad: grad * 2) 26 | e.retain_grad() 27 | 28 | h.remove() 29 | 30 | e.backward() 31 | print(f'a.grad {a.grad}') 32 | print(f'b.grad {b.grad}') 33 | print(f'c.grad {c.grad}') 34 | print(f'd.grad {d.grad}') 35 | print(f'e.grad {e.grad}') -------------------------------------------------------------------------------- /1_tensor_hooks/4_addition_instead_of_multiplication.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | a = torch.tensor(2.0, requires_grad=True) 5 | b = torch.tensor(3.0, requires_grad=True) 6 | 7 | c = a * b 8 | c.retain_grad() 9 | d = torch.tensor(4.0, requires_grad=True) 10 | 11 | e = c + d 12 | e.retain_grad() 13 | e.backward() 14 | 15 | print(f'a.grad {a.grad}') 16 | print(f'b.grad {b.grad}') 17 | print(f'c.grad {c.grad}') 18 | print(f'd.grad {d.grad}') 19 | print(f'e.grad {e.grad}') -------------------------------------------------------------------------------- /1_tensor_hooks/5_hook_with_in_place_operation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | a = torch.tensor(2.0, requires_grad=True) 5 | b = torch.tensor(3.0, requires_grad=True) 6 | 7 | c = a * b 8 | c.retain_grad() 9 | d = torch.tensor(4.0, requires_grad=True) 10 | 11 | 12 | def d_hook(grad): 13 | grad *= 100 14 | 15 | 16 | d.register_hook(d_hook) 17 | 18 | e = c + d 19 | e.retain_grad() 20 | e.backward() 21 | 22 | print(f'a.grad {a.grad}') 23 | print(f'b.grad {b.grad}') 24 | print(f'c.grad {c.grad}') 25 | print(f'd.grad {d.grad}') 26 | print(f'e.grad {e.grad}') -------------------------------------------------------------------------------- /2_module_forward_hooks/1_module_without_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SumNet(nn.Module): 6 | def __init__(self): 7 | super(SumNet, self).__init__() 8 | 9 | @staticmethod 10 | def forward(a, b, c): 11 | return a + b + c 12 | 13 | 14 | def main(): 15 | sum_net = SumNet() 16 | 17 | a = torch.tensor(1.0, requires_grad=True) 18 | b = torch.tensor(2.0, requires_grad=True) 19 | c = torch.tensor(3.0, requires_grad=True) 20 | 21 | d = sum_net(a, b, c) 22 | 23 | print('d:', d) 24 | 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /2_module_forward_hooks/2_module_with_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SumNet(nn.Module): 6 | def __init__(self): 7 | super(SumNet, self).__init__() 8 | 9 | @staticmethod 10 | def forward(a, b, c): 11 | d = a + b + c 12 | return d 13 | 14 | 15 | def forward_pre_hook(module, inputs): 16 | a, b = inputs 17 | return a + 10, b 18 | 19 | 20 | def forward_hook(module, inputs, output): 21 | return output + 100 22 | 23 | 24 | def main(): 25 | sum_net = SumNet() 26 | 27 | sum_net.register_forward_pre_hook(forward_pre_hook) 28 | sum_net.register_forward_hook(forward_hook) 29 | 30 | a = torch.tensor(1.0, requires_grad=True) 31 | b = torch.tensor(2.0, requires_grad=True) 32 | c = torch.tensor(3.0, requires_grad=True) 33 | 34 | d = sum_net(a, b, c=c) 35 | 36 | print('d:', d) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /2_module_forward_hooks/3_module_with_hooks_and_prints.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SumNet(nn.Module): 6 | def __init__(self): 7 | super(SumNet, self).__init__() 8 | 9 | @staticmethod 10 | def forward(a, b, c): 11 | d = a + b + c 12 | 13 | print('forward():') 14 | print(' a:', a) 15 | print(' b:', b) 16 | print(' c:', c) 17 | print() 18 | print(' d:', d) 19 | print() 20 | 21 | return d 22 | 23 | 24 | def forward_pre_hook(module, input_positional_args): 25 | a, b = input_positional_args 26 | new_input_positional_args = a + 10, b 27 | 28 | print('forward_pre_hook():') 29 | print(' module:', module) 30 | print(' input_positional_args:', input_positional_args) 31 | print() 32 | print(' new_input_positional_args:', new_input_positional_args) 33 | print() 34 | 35 | return new_input_positional_args 36 | 37 | 38 | def forward_hook(module, input_positional_args, output): 39 | new_output = output + 100 40 | 41 | print('forward_hook():') 42 | print(' module:', module) 43 | print(' input_positional_args:', input_positional_args) 44 | print(' output:', output) 45 | print() 46 | print(' new_output:', new_output) 47 | print() 48 | 49 | return new_output 50 | 51 | 52 | def main(): 53 | sum_net = SumNet() 54 | sum_net.register_forward_pre_hook(forward_pre_hook) 55 | sum_net.register_forward_hook(forward_hook) 56 | 57 | a = torch.tensor(1.0, requires_grad=True) 58 | b = torch.tensor(2.0, requires_grad=True) 59 | c = torch.tensor(3.0, requires_grad=True) 60 | 61 | print('start') 62 | print() 63 | print('a:', a) 64 | print('b:', b) 65 | print('c:', c) 66 | print() 67 | print('before model') 68 | print() 69 | 70 | d = sum_net(a, b, c=c) 71 | 72 | print('after model') 73 | print() 74 | print('d:', d) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /2_module_forward_hooks/4_remove_module_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SumNet(nn.Module): 6 | def __init__(self): 7 | super(SumNet, self).__init__() 8 | 9 | @staticmethod 10 | def forward(a, b, c): 11 | d = a + b + c 12 | return d 13 | 14 | 15 | def forward_pre_hook(module, input_positional_args): 16 | a, b = input_positional_args 17 | return a + 10, b 18 | 19 | 20 | def forward_hook(module, input_positional_args, output): 21 | return output + 100 22 | 23 | 24 | def main(): 25 | sum_net = SumNet() 26 | 27 | forward_pre_hook_handle = sum_net.register_forward_pre_hook(forward_pre_hook) 28 | forward_hook_handle = sum_net.register_forward_hook(forward_hook) 29 | 30 | a = torch.tensor(1.0, requires_grad=True) 31 | b = torch.tensor(2.0, requires_grad=True) 32 | c = torch.tensor(3.0, requires_grad=True) 33 | 34 | d = sum_net(a, b, c=c) 35 | 36 | print('d:', d) 37 | 38 | forward_pre_hook_handle.remove() 39 | forward_hook_handle.remove() 40 | 41 | d = sum_net(a, b, c=c) 42 | 43 | print('d:', d) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /3_module_backward_hooks/1_working_module_backward_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MyMultiply(nn.Module): 6 | def __init__(self): 7 | super(MyMultiply, self).__init__() 8 | 9 | @staticmethod 10 | def forward(a, b): 11 | return a * b 12 | 13 | 14 | def backward_hook(module, grad_input, grad_output): 15 | print('module:', module) 16 | print('grad_input:', grad_input) 17 | print('grad_output:', grad_output) 18 | 19 | 20 | def main(): 21 | my_multiply = MyMultiply() 22 | my_multiply.register_backward_hook(backward_hook) 23 | 24 | a = torch.tensor(2.0, requires_grad=True) 25 | b = torch.tensor(3.0, requires_grad=True) 26 | 27 | c = my_multiply(a, b) 28 | 29 | c.backward() 30 | 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /3_module_backward_hooks/2_broken_module_backward_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MyMultiply(nn.Module): 6 | def __init__(self): 7 | super(MyMultiply, self).__init__() 8 | 9 | @staticmethod 10 | def forward(a, b, c): 11 | return (a * b) * c 12 | 13 | 14 | def backward_hook(module, grad_input, grad_output): 15 | print('module:', module) 16 | print('grad_input:', grad_input) 17 | print('grad_output:', grad_output) 18 | 19 | 20 | def main(): 21 | my_multiply = MyMultiply() 22 | my_multiply.register_backward_hook(backward_hook) 23 | 24 | a = torch.tensor(2.0, requires_grad=True) 25 | b = torch.tensor(3.0, requires_grad=True) 26 | c = torch.tensor(4.0, requires_grad=True) 27 | 28 | d = my_multiply(a, b, c) 29 | 30 | d.backward() 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /3_module_backward_hooks/3_tensor_hooks_workaround.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MyMultiply(nn.Module): 6 | def __init__(self): 7 | super(MyMultiply, self).__init__() 8 | 9 | @staticmethod 10 | def forward(a, b, c): 11 | return a * b * c 12 | 13 | 14 | def main(): 15 | my_multiply = MyMultiply() 16 | 17 | a = torch.tensor(2.0, requires_grad=True) 18 | b = torch.tensor(3.0, requires_grad=True) 19 | c = torch.tensor(4.0, requires_grad=True) 20 | 21 | d = my_multiply(a, b, c) 22 | 23 | a.register_hook(lambda grad: print('a grad:', grad)) 24 | b.register_hook(lambda grad: print('b grad:', grad)) 25 | c.register_hook(lambda grad: print('c grad:', grad)) 26 | d.register_hook(lambda grad: print('d grad:', grad)) 27 | 28 | d.backward() 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elliot Waite 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch Hooks Explained - In-depth Tutorial 2 | 3 | This repo contains the example code I used in my YouTube tutorial video about hooks in PyTorch: 4 | 5 | https://youtu.be/syLFCVYua6Q 6 | 7 | [](https://youtu.be/syLFCVYua6Q) 8 | --------------------------------------------------------------------------------