├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt ├── setup.py ├── test └── test_reparam_module.py └── torchreparam ├── __init__.py └── reparam_module.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504 4 | exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # gdb 107 | .gdb_history 108 | 109 | # sftp 110 | sftp-config.json 111 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tongzhou Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-Reparam-Module 2 | Reparameterize your PyTorch modules 3 | 4 | ## Requirements 5 | 6 | + [PyTorch](https://pytorch.org) `>= 1.2.0` 7 | + Python 3 8 | 9 | ## Example 10 | 11 | ```py 12 | import torch 13 | import torch.nn.functional as F 14 | import torchvision 15 | from torchreparam import ReparamModule 16 | 17 | device = torch.device('cuda') 18 | 19 | # A regular network 20 | net = torchvision.models.resnet18().to(device) 21 | 22 | # Reparametrize it! 23 | reparam_net = ReparamModule(net) 24 | 25 | print(f"reparam_net has {reparam_net.param_numel} parameters") 26 | 27 | assert tuple(reparam_net.parameters()) == (reparam_net.flat_param,) 28 | print(f"reparam_net now has **only one** vector parameter of shape {reparam_net.flat_param.shape}") 29 | 30 | # The reparametrized module is equivalent with the original one. 31 | # In fact, the weights share storage. 32 | dummy_input_image = torch.randn(1, 3, 224, 224, device=device) 33 | print(f'original net output for class 746: {net(dummy_input_image)[0, 746]}') 34 | print(f'reparam_net output for class 746: {reparam_net(dummy_input_image)[0, 746]}') 35 | 36 | # We can optionally trace the forward method with PyTorch JIT so it runs faster. 37 | # To do so, we can call `.trace` on the reparamtrized module with dummy inputs 38 | # expected by the module. 39 | # Comment out this following line if you do not want to trace. 40 | reparam_net = reparam_net.trace(dummy_input_image) 41 | 42 | 43 | # Example on a MAML loss that 44 | # 1. Train `theta_0` on `inner_train` for `num_gd_steps` gradient descent steps with `lr`. 45 | # 2. Compute the loss of the updated parameter on `inner_val`. 46 | # 47 | # This assumes classification with cross entropy loss, but can be easily adapted 48 | # to other loss functions. 49 | def maml_loss(reparam_net, theta_0, inner_train, inner_val, num_gd_steps=5, lr=0.01): 50 | # train stage 51 | train_data, train_label = inner_train 52 | theta = theta_0 53 | for _ in range(num_gd_steps): 54 | # perform GD update on (data, label) w.r.t. theta 55 | loss = F.cross_entropy(reparam_net(train_data, flat_param=theta), train_label) 56 | gtheta, = torch.autograd.grad(loss, theta, create_graph=True) # create_graph=True for backprop through this 57 | # update 58 | theta = theta - lr * gtheta 59 | 60 | # val stage 61 | # theta is now the final updated set of parameters 62 | val_data, val_label = inner_val 63 | return F.cross_entropy(reparam_net(val_data, flat_param=theta), val_label) 64 | 65 | # Let's use the above function: 66 | 67 | # Initialize our theta_0 that we want to train 68 | theta_0 = torch.randn_like(reparam_net.flat_param).mul_(0.001).requires_grad_() 69 | # Make dummy data 70 | inner_train = ( 71 | torch.randn(2, 3, 224, 224, device=device), # input 72 | torch.randint(low=0, high=1000, size=(2,), device=device), # label 73 | ) 74 | inner_val = ( 75 | torch.randn(5, 3, 224, 224, device=device), # input 76 | torch.randint(low=0, high=1000, size=(5,), device=device), # label 77 | ) 78 | l = maml_loss(reparam_net, theta_0, inner_train, inner_val) 79 | l.backward() 80 | 81 | # Here, easily backprop-able by autograd. 82 | print(f'MAML loss gradient for theta_0:\n{theta_0.grad}') 83 | ``` 84 | 85 | ## Installation 86 | 87 | ```sh 88 | python setup.py install 89 | ``` 90 | 91 | ## Documentation 92 | 93 | For a `ReparamModule`, the following fields are available: 94 | 95 | + `.flat_param`: a flattened parameter vector representing all parameteres of the wrapped module. 96 | + `.param_numel`: the total number of parameters, i.e., the size of `.flat_param`. 97 | 98 | A `ReparamModule` can be called with the following signatire: 99 | 100 | ```py 101 | reparam_module(self, *inputs, flat_param=None, buffers=None) 102 | ``` 103 | 104 | where 105 | + `inputs` will be passed over as inputs to the inner module. 106 | + `flat_param` will be used as the parameter of this forward pass, if specified. Note that this allows you easily activate a network on an entirely different set of parameters, and backprop to them. 107 | + `buffers` will be used as the buffers for this forward pass, if specified (experimental). 108 | 109 | Note 110 | + `ReparamModule` currently does not work properly with Batch Normalization layers with the default `track_running_stats=True`. 111 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.2.0 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import shutil 4 | import sys 5 | from setuptools import setup, find_packages 6 | 7 | 8 | readme = open('README.md').read() 9 | 10 | VERSION = '0.0.1' 11 | 12 | setup( 13 | # Metadata 14 | name='torchreparam', 15 | version=VERSION, 16 | author='Tongzhou Wang', 17 | author_email='tongzhou.wang.1994@gmail.com', 18 | url='https://github.com/SsnL/PyTorch-Reparam-Module', 19 | description='Reparameterize your PyTorch modules', 20 | long_description=readme, 21 | license='MIT', 22 | 23 | # Package info 24 | packages=find_packages(exclude=('test',)), 25 | 26 | zip_safe=True, 27 | ) 28 | -------------------------------------------------------------------------------- /test/test_reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchreparam 3 | import unittest 4 | import copy 5 | 6 | torch.set_default_dtype(torch.double) 7 | 8 | 9 | class TestMixin(object): 10 | @staticmethod 11 | def assertTensorEqual(a, b): 12 | return bool((a.detach() == b.detach()).all().item()) 13 | 14 | def _test(self, module, input_shapes): 15 | def get_random_input(): 16 | return tuple(torch.randn(s) for s in input_shapes) 17 | 18 | inp1 = get_random_input() 19 | 20 | ref_m = module 21 | ref_out = ref_m(*inp1).detach() 22 | 23 | reparam_m = copy.deepcopy(module) 24 | reparam_m = torchreparam.ReparamModule(reparam_m) 25 | if self.traced: 26 | reparam_m = reparam_m.trace(inp1) 27 | 28 | self.assertTensorEqual(ref_out, reparam_m(*inp1)) 29 | 30 | def sgd(flat_p1, inp1, inp2): 31 | out1 = reparam_m(*inp1, flat_param=flat_p1, 32 | buffers=tuple(b.clone().detach() for _, _, b in reparam_m._buffer_infos)) 33 | l1 = (out1 * ref_out).mean() 34 | flat_p2 = flat_p1 - torch.autograd.grad(l1, flat_p1, create_graph=True)[0] * 0.02 35 | out2 = reparam_m(*inp2, flat_param=flat_p2, 36 | buffers=tuple(b.clone().detach() for _, _, b in reparam_m._buffer_infos)) 37 | return out2 38 | 39 | # assert that fully reparamed forward doesn't change parameter 40 | ref_state_dict = {k: v.detach().clone() for k, v in reparam_m.state_dict().items()} 41 | 42 | sgd_inp = (torch.randn_like(reparam_m.flat_param, requires_grad=True), 43 | get_random_input(), get_random_input()) 44 | sgd(*sgd_inp) 45 | 46 | for k, v in reparam_m.state_dict().items(): 47 | self.assertTensorEqual(ref_state_dict[k], v) 48 | torch.autograd.gradcheck(sgd, sgd_inp) 49 | torch.autograd.gradgradcheck(sgd, sgd_inp) 50 | 51 | def test_conv(self): 52 | self._test(torch.nn.Conv2d(3, 3, 3), ((1, 3, 3, 4),)) 53 | 54 | def test_simple_network(self): 55 | class MyNet(torch.nn.Module): 56 | def __init__(self): 57 | super().__init__() 58 | self.feature = torch.nn.Sequential( 59 | torch.nn.Linear(10, 15), 60 | torch.nn.LeakyReLU(0.2), 61 | # torch.nn.BatchNorm1d(15), 62 | torch.nn.Linear(15, 10), 63 | ) 64 | self.register_buffer('target', torch.tensor(2.)) 65 | 66 | def forward(self, x): 67 | out = self.feature(x) 68 | return out * self.target 69 | 70 | self._test(MyNet(), ((2, 10),)) 71 | 72 | def test_shared_params(self): 73 | 74 | def get_net_and_input(): 75 | torch.manual_seed(0) # deterministic 76 | net = torch.nn.Sequential( 77 | torch.nn.Linear(10, 15, bias=False), 78 | torch.nn.Linear(15, 10, bias=False), 79 | torch.nn.Linear(10, 15, bias=False), 80 | ) 81 | # first and last layer share weights 82 | net[-1].weight = net[0].weight 83 | input = torch.rand(2, 10) 84 | return net, input 85 | 86 | def get_param_norm_after_step(reparam): 87 | net, input = get_net_and_input() 88 | if reparam: 89 | net = torchreparam.ReparamModule(net) 90 | if self.traced: 91 | net = net.trace(input) 92 | optim = torch.optim.SGD(net.parameters(), lr=1.0) 93 | loss = net(input).sum() 94 | loss.backward() 95 | optim.step() 96 | if reparam: 97 | # the first 300 params cover layers 1 and 2 98 | return net.flat_param[:300].norm() 99 | else: 100 | return torch.norm(torch.stack([p.norm() for p in net.parameters()])) 101 | 102 | ref_pnorm = get_param_norm_after_step(reparam=False) 103 | reparam_pnorm = get_param_norm_after_step(reparam=True) 104 | 105 | torch.testing.assert_allclose(ref_pnorm, reparam_pnorm) 106 | 107 | 108 | class TestTraced(unittest.TestCase, TestMixin): 109 | traced = True 110 | 111 | 112 | class TestNotTraced(unittest.TestCase, TestMixin): 113 | traced = False 114 | 115 | if __name__ == '__main__': 116 | unittest.main() 117 | -------------------------------------------------------------------------------- /torchreparam/__init__.py: -------------------------------------------------------------------------------- 1 | from .reparam_module import ReparamModule 2 | -------------------------------------------------------------------------------- /torchreparam/reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | import types 5 | from collections import namedtuple 6 | from contextlib import contextmanager 7 | 8 | 9 | class ReparamModule(nn.Module): 10 | def __init__(self, module): 11 | super(ReparamModule, self).__init__() 12 | self.module = module 13 | 14 | param_infos = [] 15 | shared_param_memo = {} 16 | shared_param_infos = [] 17 | params = [] 18 | param_numels = [] 19 | param_shapes = [] 20 | for m in self.modules(): 21 | for n, p in m.named_parameters(recurse=False): 22 | if p is not None: 23 | if p in shared_param_memo: 24 | shared_m, shared_n = shared_param_memo[p] 25 | shared_param_infos.append((m, n, shared_m, shared_n)) 26 | else: 27 | shared_param_memo[p] = (m, n) 28 | param_infos.append((m, n)) 29 | params.append(p.detach()) 30 | param_numels.append(p.numel()) 31 | param_shapes.append(p.size()) 32 | 33 | assert len(set(p.dtype for p in params)) <= 1, \ 34 | "expects all parameters in module to have same dtype" 35 | 36 | # store the info for unflatten 37 | self._param_infos = tuple(param_infos) 38 | self._shared_param_infos = tuple(shared_param_infos) 39 | self._param_numels = tuple(param_numels) 40 | self._param_shapes = tuple(param_shapes) 41 | 42 | # flatten 43 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) 44 | self.register_parameter('flat_param', flat_param) 45 | self.param_numel = flat_param.numel() 46 | del params 47 | del shared_param_memo 48 | 49 | # deregister the names as parameters 50 | for m, n in self._param_infos: 51 | delattr(m, n) 52 | for m, n, _, _ in self._shared_param_infos: 53 | delattr(m, n) 54 | 55 | # register the views as plain attributes 56 | self._unflatten_param(self.flat_param) 57 | 58 | # now buffers 59 | # they are not reparametrized. just store info as (module, name, buffer) 60 | buffer_infos = [] 61 | for m in self.modules(): 62 | for n, b in m.named_buffers(recurse=False): 63 | if b is not None: 64 | buffer_infos.append((m, n, b)) 65 | 66 | self._buffer_infos = tuple(buffer_infos) 67 | self._traced_self = None 68 | 69 | def trace(self, example_input, **trace_kwargs): 70 | assert self._traced_self is None, 'This ReparamModule is already traced' 71 | 72 | if isinstance(example_input, torch.Tensor): 73 | example_input = (example_input,) 74 | example_input = tuple(example_input) 75 | example_param = (self.flat_param.detach().clone(),) 76 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),) 77 | 78 | self._traced_self = torch.jit.trace_module( 79 | self, 80 | inputs=dict( 81 | _forward_with_param=example_param + example_input, 82 | _forward_with_param_and_buffers=example_param + example_buffers + example_input, 83 | ), 84 | **trace_kwargs, 85 | ) 86 | 87 | # replace forwards with traced versions 88 | self._forward_with_param = self._traced_self._forward_with_param 89 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers 90 | return self 91 | 92 | def clear_views(self): 93 | for m, n in self._param_infos: 94 | setattr(m, n, None) # This will set as plain attr 95 | 96 | def _apply(self, *args, **kwargs): 97 | if self._traced_self is not None: 98 | self._traced_self._apply(*args, **kwargs) 99 | return self 100 | return super(ReparamModule, self)._apply(*args, **kwargs) 101 | 102 | def _unflatten_param(self, flat_param): 103 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) 104 | for (m, n), p in zip(self._param_infos, ps): 105 | setattr(m, n, p) # This will set as plain attr 106 | for (m, n, shared_m, shared_n) in self._shared_param_infos: 107 | setattr(m, n, getattr(shared_m, shared_n)) 108 | 109 | @contextmanager 110 | def unflattened_param(self, flat_param): 111 | saved_views = [getattr(m, n) for m, n in self._param_infos] 112 | self._unflatten_param(flat_param) 113 | yield 114 | # Why not just `self._unflatten_param(self.flat_param)`? 115 | # 1. because of https://github.com/pytorch/pytorch/issues/17583 116 | # 2. slightly faster since it does not require reconstruct the split+view 117 | # graph 118 | for (m, n), p in zip(self._param_infos, saved_views): 119 | setattr(m, n, p) 120 | for (m, n, shared_m, shared_n) in self._shared_param_infos: 121 | setattr(m, n, getattr(shared_m, shared_n)) 122 | 123 | @contextmanager 124 | def replaced_buffers(self, buffers): 125 | for (m, n, _), new_b in zip(self._buffer_infos, buffers): 126 | setattr(m, n, new_b) 127 | yield 128 | for m, n, old_b in self._buffer_infos: 129 | setattr(m, n, old_b) 130 | 131 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs): 132 | with self.unflattened_param(flat_param): 133 | with self.replaced_buffers(buffers): 134 | return self.module(*inputs, **kwinputs) 135 | 136 | def _forward_with_param(self, flat_param, *inputs, **kwinputs): 137 | with self.unflattened_param(flat_param): 138 | return self.module(*inputs, **kwinputs) 139 | 140 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs): 141 | if flat_param is None: 142 | flat_param = self.flat_param 143 | if buffers is None: 144 | return self._forward_with_param(flat_param, *inputs, **kwinputs) 145 | else: 146 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs) 147 | --------------------------------------------------------------------------------