├── __init__.py ├── LICENSE ├── .gitignore ├── README.md ├── kfac.py └── ekfac.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Thomas George, César Laurent and Université de Montréal. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EKFAC and K-FAC Preconditioners for Pytorch 2 | This repo contains a Pytorch implementation of the EKFAC and K-FAC preconditioners. If you find this software useful, please check the references below and cite accordingly! 3 | 4 | ### Presentation 5 | 6 | We implemented K-FAC and EKFAC as `preconditioners`. Preconditioners are similar Pytorch's `optimizer` class, with the exception that they do not perform the update of the parameters, but only change the gradient of those parameters. They can thus be used in combination with your favorite optimizer (we used SGD in our experiments). Note that we only implemented them for `Linear` and `Conv2d` modules, so they will silently skip all the other modules of your network. 7 | 8 | ### Usage 9 | 10 | Here is a simple example showing how to add K-FAC or EKFAC to your code: 11 | 12 | ```python 13 | # 1. Instantiate the preconditioner 14 | preconditioner = EKFAC(network, 0.1, update_freq=100) 15 | 16 | # 2. During the training loop, simply call preconditioner.step() before optimizer.step(). 17 | # The optimiser is usually SGD. 18 | for i, (inputs, targets) in enumerate(train_loader): 19 | optimizer.zero_grad() 20 | outputs = network(inputs) 21 | loss = criterion(outputs, targets) 22 | loss.backward() 23 | preconditioner.step() # Add a step of preconditioner before the optimizer step. 24 | optimizer.step() 25 | ``` 26 | 27 | ### References 28 | 29 | #### EKFAC: 30 | - Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent, _[Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884)_, NIPS 2018 31 | 32 | #### K-FAC: 33 | - James Martens, Roger Grosse, _[Optimizing Neural Networks with Kronecker-factored Approximate Curvature](https://arxiv.org/abs/1503.05671)_, ICML 2015 34 | #### K-FAC for Convolutions: 35 | - Roger Grosse, James Martens, _[A Kronecker-factored Approximate Fisher Matrix for Convolution Layers](https://arxiv.org/abs/1602.01407)_, ICML 2016 36 | - César Laurent, Thomas George, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent, _[An Evaluation of Fisher Approximations Beyond Kronecker Factorization](https://openreview.net/pdf?id=ryVC6tkwG)_, ICLR Workshop 2018 37 | #### Norm Constraint: 38 | - Jimmy Ba, Roger Grosse, James Martens, _[Distributed Second-order Optimization using Kronecker-Factored Approximations](https://jimmylba.github.io/papers/nsync.pdf)_, ICLR 2017 39 | - Jean Lafond, Nicolas Vasilache, Léon Bottou, _[Diagonal Rescaling For Neural Networks](https://arxiv.org/abs/1705.09319)_, arXiv 2017 40 | -------------------------------------------------------------------------------- /kfac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch.optim.optimizer import Optimizer 5 | 6 | 7 | class KFAC(Optimizer): 8 | 9 | def __init__(self, net, eps, sua=False, pi=False, update_freq=1, 10 | alpha=1.0, constraint_norm=False): 11 | """ K-FAC Preconditionner for Linear and Conv2d layers. 12 | 13 | Computes the K-FAC of the second moment of the gradients. 14 | It works for Linear and Conv2d layers and silently skip other layers. 15 | 16 | Args: 17 | net (torch.nn.Module): Network to precondition. 18 | eps (float): Tikhonov regularization parameter for the inverses. 19 | sua (bool): Applies SUA approximation. 20 | pi (bool): Computes pi correction for Tikhonov regularization. 21 | update_freq (int): Perform inverses every update_freq updates. 22 | alpha (float): Running average parameter (if == 1, no r. ave.). 23 | constraint_norm (bool): Scale the gradients by the squared 24 | fisher norm. 25 | """ 26 | self.eps = eps 27 | self.sua = sua 28 | self.pi = pi 29 | self.update_freq = update_freq 30 | self.alpha = alpha 31 | self.constraint_norm = constraint_norm 32 | self.params = [] 33 | self._fwd_handles = [] 34 | self._bwd_handles = [] 35 | self._iteration_counter = 0 36 | for mod in net.modules(): 37 | mod_class = mod.__class__.__name__ 38 | if mod_class in ['Linear', 'Conv2d']: 39 | handle = mod.register_forward_pre_hook(self._save_input) 40 | self._fwd_handles.append(handle) 41 | handle = mod.register_full_backward_hook(self._save_grad_output) 42 | self._bwd_handles.append(handle) 43 | params = [mod.weight] 44 | if mod.bias is not None: 45 | params.append(mod.bias) 46 | d = {'params': params, 'mod': mod, 'layer_type': mod_class} 47 | self.params.append(d) 48 | super(KFAC, self).__init__(self.params, {}) 49 | 50 | def step(self, update_stats=True, update_params=True): 51 | """Performs one step of preconditioning.""" 52 | fisher_norm = 0. 53 | for group in self.param_groups: 54 | # Getting parameters 55 | if len(group['params']) == 2: 56 | weight, bias = group['params'] 57 | else: 58 | weight = group['params'][0] 59 | bias = None 60 | state = self.state[weight] 61 | # Update convariances and inverses 62 | if update_stats: 63 | if self._iteration_counter % self.update_freq == 0: 64 | self._compute_covs(group, state) 65 | ixxt, iggt = self._inv_covs(state['xxt'], state['ggt'], 66 | state['num_locations']) 67 | state['ixxt'] = ixxt 68 | state['iggt'] = iggt 69 | else: 70 | if self.alpha != 1: 71 | self._compute_covs(group, state) 72 | if update_params: 73 | # Preconditionning 74 | gw, gb = self._precond(weight, bias, group, state) 75 | # Updating gradients 76 | if self.constraint_norm: 77 | fisher_norm += (weight.grad * gw).sum() 78 | weight.grad.data = gw 79 | if bias is not None: 80 | if self.constraint_norm: 81 | fisher_norm += (bias.grad * gb).sum() 82 | bias.grad.data = gb 83 | # Cleaning 84 | if 'x' in self.state[group['mod']]: 85 | del self.state[group['mod']]['x'] 86 | if 'gy' in self.state[group['mod']]: 87 | del self.state[group['mod']]['gy'] 88 | # Eventually scale the norm of the gradients 89 | if update_params and self.constraint_norm: 90 | scale = (1. / fisher_norm) ** 0.5 91 | for group in self.param_groups: 92 | for param in group['params']: 93 | param.grad.data *= scale 94 | if update_stats: 95 | self._iteration_counter += 1 96 | 97 | def _save_input(self, mod, i): 98 | """Saves input of layer to compute covariance.""" 99 | if mod.training: 100 | self.state[mod]['x'] = i[0] 101 | 102 | def _save_grad_output(self, mod, grad_input, grad_output): 103 | """Saves grad on output of layer to compute covariance.""" 104 | if mod.training: 105 | self.state[mod]['gy'] = grad_output[0] * grad_output[0].size(0) 106 | 107 | def _precond(self, weight, bias, group, state): 108 | """Applies preconditioning.""" 109 | if group['layer_type'] == 'Conv2d' and self.sua: 110 | return self._precond_sua(weight, bias, group, state) 111 | ixxt = state['ixxt'] 112 | iggt = state['iggt'] 113 | g = weight.grad.data 114 | s = g.shape 115 | if group['layer_type'] == 'Conv2d': 116 | g = g.contiguous().view(s[0], s[1]*s[2]*s[3]) 117 | if bias is not None: 118 | gb = bias.grad.data 119 | g = torch.cat([g, gb.view(gb.shape[0], 1)], dim=1) 120 | g = torch.mm(torch.mm(iggt, g), ixxt) 121 | if group['layer_type'] == 'Conv2d': 122 | g /= state['num_locations'] 123 | if bias is not None: 124 | gb = g[:, -1].contiguous().view(*bias.shape) 125 | g = g[:, :-1] 126 | else: 127 | gb = None 128 | g = g.contiguous().view(*s) 129 | return g, gb 130 | 131 | def _precond_sua(self, weight, bias, group, state): 132 | """Preconditioning for KFAC SUA.""" 133 | ixxt = state['ixxt'] 134 | iggt = state['iggt'] 135 | g = weight.grad.data 136 | s = g.shape 137 | mod = group['mod'] 138 | g = g.permute(1, 0, 2, 3).contiguous() 139 | if bias is not None: 140 | gb = bias.grad.view(1, -1, 1, 1).expand(1, -1, s[2], s[3]) 141 | g = torch.cat([g, gb], dim=0) 142 | g = torch.mm(ixxt, g.contiguous().view(-1, s[0]*s[2]*s[3])) 143 | g = g.view(-1, s[0], s[2], s[3]).permute(1, 0, 2, 3).contiguous() 144 | g = torch.mm(iggt, g.view(s[0], -1)).view(s[0], -1, s[2], s[3]) 145 | g /= state['num_locations'] 146 | if bias is not None: 147 | gb = g[:, -1, s[2]//2, s[3]//2] 148 | g = g[:, :-1] 149 | else: 150 | gb = None 151 | return g, gb 152 | 153 | def _compute_covs(self, group, state): 154 | """Computes the covariances.""" 155 | mod = group['mod'] 156 | x = self.state[group['mod']]['x'] 157 | gy = self.state[group['mod']]['gy'] 158 | # Computation of xxt 159 | if group['layer_type'] == 'Conv2d': 160 | if not self.sua: 161 | x = F.unfold(x, mod.kernel_size, padding=mod.padding, 162 | stride=mod.stride) 163 | else: 164 | x = x.view(x.shape[0], x.shape[1], -1) 165 | x = x.data.permute(1, 0, 2).contiguous().view(x.shape[1], -1) 166 | else: 167 | x = x.data.t() 168 | if mod.bias is not None: 169 | ones = torch.ones_like(x[:1]) 170 | x = torch.cat([x, ones], dim=0) 171 | if self._iteration_counter == 0: 172 | state['xxt'] = torch.mm(x, x.t()) / float(x.shape[1]) 173 | else: 174 | state['xxt'].addmm_(mat1=x, mat2=x.t(), 175 | beta=(1. - self.alpha), 176 | alpha=self.alpha / float(x.shape[1])) 177 | # Computation of ggt 178 | if group['layer_type'] == 'Conv2d': 179 | gy = gy.data.permute(1, 0, 2, 3) 180 | state['num_locations'] = gy.shape[2] * gy.shape[3] 181 | gy = gy.contiguous().view(gy.shape[0], -1) 182 | else: 183 | gy = gy.data.t() 184 | state['num_locations'] = 1 185 | if self._iteration_counter == 0: 186 | state['ggt'] = torch.mm(gy, gy.t()) / float(gy.shape[1]) 187 | else: 188 | state['ggt'].addmm_(mat1=gy, mat2=gy.t(), 189 | beta=(1. - self.alpha), 190 | alpha=self.alpha / float(gy.shape[1])) 191 | 192 | def _inv_covs(self, xxt, ggt, num_locations): 193 | """Inverses the covariances.""" 194 | # Computes pi 195 | pi = 1.0 196 | if self.pi: 197 | tx = torch.trace(xxt) * ggt.shape[0] 198 | tg = torch.trace(ggt) * xxt.shape[0] 199 | pi = (tx / tg) 200 | # Regularizes and inverse 201 | eps = self.eps / num_locations 202 | diag_xxt = xxt.new(xxt.shape[0]).fill_((eps * pi) ** 0.5) 203 | diag_ggt = ggt.new(ggt.shape[0]).fill_((eps / pi) ** 0.5) 204 | ixxt = (xxt + torch.diag(diag_xxt)).inverse() 205 | iggt = (ggt + torch.diag(diag_ggt)).inverse() 206 | return ixxt, iggt 207 | 208 | def __del__(self): 209 | for handle in self._fwd_handles + self._bwd_handles: 210 | handle.remove() 211 | -------------------------------------------------------------------------------- /ekfac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch.optim.optimizer import Optimizer 5 | 6 | 7 | class EKFAC(Optimizer): 8 | 9 | def __init__(self, net, eps, sua=False, ra=False, update_freq=1, 10 | alpha=.75): 11 | """ EKFAC Preconditionner for Linear and Conv2d layers. 12 | 13 | Computes the EKFAC of the second moment of the gradients. 14 | It works for Linear and Conv2d layers and silently skip other layers. 15 | 16 | Args: 17 | net (torch.nn.Module): Network to precondition. 18 | eps (float): Tikhonov regularization parameter for the inverses. 19 | sua (bool): Applies SUA approximation. 20 | ra (bool): Computes stats using a running average of averaged gradients 21 | instead of using a intra minibatch estimate 22 | update_freq (int): Perform inverses every update_freq updates. 23 | alpha (float): Running average parameter 24 | 25 | """ 26 | self.eps = eps 27 | self.sua = sua 28 | self.ra = ra 29 | self.update_freq = update_freq 30 | self.alpha = alpha 31 | self.params = [] 32 | self._fwd_handles = [] 33 | self._bwd_handles = [] 34 | self._iteration_counter = 0 35 | if not self.ra and self.alpha != 1.: 36 | raise NotImplementedError 37 | for mod in net.modules(): 38 | mod_class = mod.__class__.__name__ 39 | if mod_class in ['Linear', 'Conv2d']: 40 | handle = mod.register_forward_pre_hook(self._save_input) 41 | self._fwd_handles.append(handle) 42 | handle = mod.register_full_backward_hook(self._save_grad_output) 43 | self._bwd_handles.append(handle) 44 | params = [mod.weight] 45 | if mod.bias is not None: 46 | params.append(mod.bias) 47 | d = {'params': params, 'mod': mod, 'layer_type': mod_class} 48 | if mod_class == 'Conv2d': 49 | if not self.sua: 50 | # Adding gathering filter for convolution 51 | d['gathering_filter'] = self._get_gathering_filter(mod) 52 | self.params.append(d) 53 | super(EKFAC, self).__init__(self.params, {}) 54 | 55 | def step(self, update_stats=True, update_params=True): 56 | """Performs one step of preconditioning.""" 57 | for group in self.param_groups: 58 | # Getting parameters 59 | if len(group['params']) == 2: 60 | weight, bias = group['params'] 61 | else: 62 | weight = group['params'][0] 63 | bias = None 64 | state = self.state[weight] 65 | # Update convariances and inverses 66 | if self._iteration_counter % self.update_freq == 0: 67 | self._compute_kfe(group, state) 68 | # Preconditionning 69 | if group['layer_type'] == 'Conv2d' and self.sua: 70 | if self.ra: 71 | self._precond_sua_ra(weight, bias, group, state) 72 | else: 73 | self._precond_intra_sua(weight, bias, group, state) 74 | else: 75 | if self.ra: 76 | self._precond_ra(weight, bias, group, state) 77 | else: 78 | self._precond_intra(weight, bias, group, state) 79 | self._iteration_counter += 1 80 | 81 | def _save_input(self, mod, i): 82 | """Saves input of layer to compute covariance.""" 83 | self.state[mod]['x'] = i[0] 84 | 85 | def _save_grad_output(self, mod, grad_input, grad_output): 86 | """Saves grad on output of layer to compute covariance.""" 87 | self.state[mod]['gy'] = grad_output[0] * grad_output[0].size(0) 88 | 89 | def _precond_ra(self, weight, bias, group, state): 90 | """Applies preconditioning.""" 91 | kfe_x = state['kfe_x'] 92 | kfe_gy = state['kfe_gy'] 93 | m2 = state['m2'] 94 | g = weight.grad.data 95 | s = g.shape 96 | bs = self.state[group['mod']]['x'].size(0) 97 | if group['layer_type'] == 'Conv2d': 98 | g = g.contiguous().view(s[0], s[1]*s[2]*s[3]) 99 | if bias is not None: 100 | gb = bias.grad.data 101 | g = torch.cat([g, gb.view(gb.shape[0], 1)], dim=1) 102 | g_kfe = torch.mm(torch.mm(kfe_gy.t(), g), kfe_x) 103 | m2.mul_(self.alpha).add_((1. - self.alpha) * bs, g_kfe**2) 104 | g_nat_kfe = g_kfe / (m2 + self.eps) 105 | g_nat = torch.mm(torch.mm(kfe_gy, g_nat_kfe), kfe_x.t()) 106 | if bias is not None: 107 | gb = g_nat[:, -1].contiguous().view(*bias.shape) 108 | bias.grad.data = gb 109 | g_nat = g_nat[:, :-1] 110 | g_nat = g_nat.contiguous().view(*s) 111 | weight.grad.data = g_nat 112 | 113 | def _precond_intra(self, weight, bias, group, state): 114 | """Applies preconditioning.""" 115 | kfe_x = state['kfe_x'] 116 | kfe_gy = state['kfe_gy'] 117 | mod = group['mod'] 118 | x = self.state[mod]['x'] 119 | gy = self.state[mod]['gy'] 120 | g = weight.grad.data 121 | s = g.shape 122 | s_x = x.size() 123 | s_cin = 0 124 | s_gy = gy.size() 125 | bs = x.size(0) 126 | if group['layer_type'] == 'Conv2d': 127 | x = F.conv2d(x, group['gathering_filter'], 128 | stride=mod.stride, padding=mod.padding, 129 | groups=mod.in_channels) 130 | s_x = x.size() 131 | x = x.data.permute(1, 0, 2, 3).contiguous().view(x.shape[1], -1) 132 | if mod.bias is not None: 133 | ones = torch.ones_like(x[:1]) 134 | x = torch.cat([x, ones], dim=0) 135 | s_cin = 1 # adding a channel in dim for the bias 136 | # intra minibatch m2 137 | x_kfe = torch.mm(kfe_x.t(), x).view(s_x[1]+s_cin, -1, s_x[2], s_x[3]).permute(1, 0, 2, 3) 138 | gy = gy.permute(1, 0, 2, 3).contiguous().view(s_gy[1], -1) 139 | gy_kfe = torch.mm(kfe_gy.t(), gy).view(s_gy[1], -1, s_gy[2], s_gy[3]).permute(1, 0, 2, 3) 140 | m2 = torch.zeros((s[0], s[1]*s[2]*s[3]+s_cin), device=g.device) 141 | g_kfe = torch.zeros((s[0], s[1]*s[2]*s[3]+s_cin), device=g.device) 142 | for i in range(x_kfe.size(0)): 143 | g_this = torch.mm(gy_kfe[i].view(s_gy[1], -1), 144 | x_kfe[i].permute(1, 2, 0).view(-1, s_x[1]+s_cin)) 145 | m2 += g_this**2 146 | m2 /= bs 147 | g_kfe = torch.mm(gy_kfe.permute(1, 0, 2, 3).view(s_gy[1], -1), 148 | x_kfe.permute(0, 2, 3, 1).contiguous().view(-1, s_x[1]+s_cin)) / bs 149 | ## sanity check did we obtain the same grad ? 150 | # g = torch.mm(torch.mm(kfe_gy, g_kfe), kfe_x.t()) 151 | # gb = g[:,-1] 152 | # gw = g[:,:-1].view(*s) 153 | # print('bias', torch.dist(gb, bias.grad.data)) 154 | # print('weight', torch.dist(gw, weight.grad.data)) 155 | ## end sanity check 156 | g_nat_kfe = g_kfe / (m2 + self.eps) 157 | g_nat = torch.mm(torch.mm(kfe_gy, g_nat_kfe), kfe_x.t()) 158 | if bias is not None: 159 | gb = g_nat[:, -1].contiguous().view(*bias.shape) 160 | bias.grad.data = gb 161 | g_nat = g_nat[:, :-1] 162 | g_nat = g_nat.contiguous().view(*s) 163 | weight.grad.data = g_nat 164 | else: 165 | if bias is not None: 166 | ones = torch.ones_like(x[:, :1]) 167 | x = torch.cat([x, ones], dim=1) 168 | x_kfe = torch.mm(x, kfe_x) 169 | gy_kfe = torch.mm(gy, kfe_gy) 170 | m2 = torch.mm(gy_kfe.t()**2, x_kfe**2) / bs 171 | g_kfe = torch.mm(gy_kfe.t(), x_kfe) / bs 172 | g_nat_kfe = g_kfe / (m2 + self.eps) 173 | g_nat = torch.mm(torch.mm(kfe_gy, g_nat_kfe), kfe_x.t()) 174 | if bias is not None: 175 | gb = g_nat[:, -1].contiguous().view(*bias.shape) 176 | bias.grad.data = gb 177 | g_nat = g_nat[:, :-1] 178 | g_nat = g_nat.contiguous().view(*s) 179 | weight.grad.data = g_nat 180 | 181 | def _precond_sua_ra(self, weight, bias, group, state): 182 | """Preconditioning for KFAC SUA.""" 183 | kfe_x = state['kfe_x'] 184 | kfe_gy = state['kfe_gy'] 185 | m2 = state['m2'] 186 | g = weight.grad.data 187 | s = g.shape 188 | bs = self.state[group['mod']]['x'].size(0) 189 | mod = group['mod'] 190 | if bias is not None: 191 | gb = bias.grad.view(-1, 1, 1, 1).expand(-1, -1, s[2], s[3]) 192 | g = torch.cat([g, gb], dim=1) 193 | g_kfe = self._to_kfe_sua(g, kfe_x, kfe_gy) 194 | m2.mul_(self.alpha).add_((1. - self.alpha) * bs, g_kfe**2) 195 | g_nat_kfe = g_kfe / (m2 + self.eps) 196 | g_nat = self._to_kfe_sua(g_nat_kfe, kfe_x.t(), kfe_gy.t()) 197 | if bias is not None: 198 | gb = g_nat[:, -1, s[2]//2, s[3]//2] 199 | bias.grad.data = gb 200 | g_nat = g_nat[:, :-1] 201 | weight.grad.data = g_nat 202 | 203 | def _precond_intra_sua(self, weight, bias, group, state): 204 | """Preconditioning for KFAC SUA.""" 205 | kfe_x = state['kfe_x'] 206 | kfe_gy = state['kfe_gy'] 207 | mod = group['mod'] 208 | x = self.state[mod]['x'] 209 | gy = self.state[mod]['gy'] 210 | g = weight.grad.data 211 | s = g.shape 212 | s_x = x.size() 213 | s_gy = gy.size() 214 | s_cin = 0 215 | bs = x.size(0) 216 | if bias is not None: 217 | ones = torch.ones_like(x[:,:1]) 218 | x = torch.cat([x, ones], dim=1) 219 | s_cin += 1 220 | # intra minibatch m2 221 | x = x.permute(1, 0, 2, 3).contiguous().view(s_x[1]+s_cin, -1) 222 | x_kfe = torch.mm(kfe_x.t(), x).view(s_x[1]+s_cin, -1, s_x[2], s_x[3]).permute(1, 0, 2, 3) 223 | gy = gy.permute(1, 0, 2, 3).contiguous().view(s_gy[1], -1) 224 | gy_kfe = torch.mm(kfe_gy.t(), gy).view(s_gy[1], -1, s_gy[2], s_gy[3]).permute(1, 0, 2, 3) 225 | m2 = torch.zeros((s[0], s[1]+s_cin, s[2], s[3]), device=g.device) 226 | g_kfe = torch.zeros((s[0], s[1]+s_cin, s[2], s[3]), device=g.device) 227 | for i in range(x_kfe.size(0)): 228 | g_this = grad_wrt_kernel(x_kfe[i:i+1], gy_kfe[i:i+1], mod.padding, mod.stride) 229 | m2 += g_this**2 230 | m2 /= bs 231 | g_kfe = grad_wrt_kernel(x_kfe, gy_kfe, mod.padding, mod.stride) / bs 232 | ## sanity check did we obtain the same grad ? 233 | # g = self._to_kfe_sua(g_kfe, kfe_x.t(), kfe_gy.t()) 234 | # gb = g[:, -1, s[2]//2, s[3]//2] 235 | # gw = g[:,:-1].view(*s) 236 | # print('bias', torch.dist(gb, bias.grad.data)) 237 | # print('weight', torch.dist(gw, weight.grad.data)) 238 | ## end sanity check 239 | g_nat_kfe = g_kfe / (m2 + self.eps) 240 | g_nat = self._to_kfe_sua(g_nat_kfe, kfe_x.t(), kfe_gy.t()) 241 | if bias is not None: 242 | gb = g_nat[:, -1, s[2]//2, s[3]//2] 243 | bias.grad.data = gb 244 | g_nat = g_nat[:, :-1] 245 | weight.grad.data = g_nat 246 | 247 | def _compute_kfe(self, group, state): 248 | """Computes the covariances.""" 249 | mod = group['mod'] 250 | x = self.state[group['mod']]['x'] 251 | gy = self.state[group['mod']]['gy'] 252 | # Computation of xxt 253 | if group['layer_type'] == 'Conv2d': 254 | if not self.sua: 255 | x = F.conv2d(x, group['gathering_filter'], 256 | stride=mod.stride, padding=mod.padding, 257 | groups=mod.in_channels) 258 | x = x.data.permute(1, 0, 2, 3).contiguous().view(x.shape[1], -1) 259 | else: 260 | x = x.data.t() 261 | if mod.bias is not None: 262 | ones = torch.ones_like(x[:1]) 263 | x = torch.cat([x, ones], dim=0) 264 | xxt = torch.mm(x, x.t()) / float(x.shape[1]) 265 | Ex, state['kfe_x'] = torch.linalg.eigh(xxt,UPLO='U') 266 | # Computation of ggt 267 | if group['layer_type'] == 'Conv2d': 268 | gy = gy.data.permute(1, 0, 2, 3) 269 | state['num_locations'] = gy.shape[2] * gy.shape[3] 270 | gy = gy.contiguous().view(gy.shape[0], -1) 271 | else: 272 | gy = gy.data.t() 273 | state['num_locations'] = 1 274 | ggt = torch.mm(gy, gy.t()) / float(gy.shape[1]) 275 | Eg, state['kfe_gy'] = torch.linalg.eigh(ggt, UPLO='U') 276 | state['m2'] = Eg.unsqueeze(1) * Ex.unsqueeze(0) * state['num_locations'] 277 | if group['layer_type'] == 'Conv2d' and self.sua: 278 | ws = group['params'][0].grad.data.size() 279 | state['m2'] = state['m2'].view(Eg.size(0), Ex.size(0), 1, 1).expand(-1, -1, ws[2], ws[3]) 280 | 281 | def _get_gathering_filter(self, mod): 282 | """Convolution filter that extracts input patches.""" 283 | kw, kh = mod.kernel_size 284 | g_filter = mod.weight.data.new(kw * kh * mod.in_channels, 1, kw, kh) 285 | g_filter.fill_(0) 286 | for i in range(mod.in_channels): 287 | for j in range(kw): 288 | for k in range(kh): 289 | g_filter[k + kh*j + kw*kh*i, 0, j, k] = 1 290 | return g_filter 291 | 292 | def _to_kfe_sua(self, g, vx, vg): 293 | """Project g to the kfe""" 294 | sg = g.size() 295 | g = torch.mm(vg.t(), g.view(sg[0], -1)).view(vg.size(1), sg[1], sg[2], sg[3]) 296 | g = torch.mm(g.permute(0, 2, 3, 1).contiguous().view(-1, sg[1]), vx) 297 | g = g.view(vg.size(1), sg[2], sg[3], vx.size(1)).permute(0, 3, 1, 2) 298 | return g 299 | 300 | def __del__(self): 301 | for handle in self._fwd_handles + self._bwd_handles: 302 | handle.remove() 303 | 304 | 305 | def grad_wrt_kernel(a, g, padding, stride, target_size=None): 306 | gk = F.conv2d(a.transpose(0, 1), g.transpose(0, 1).contiguous(), 307 | padding=padding, dilation=stride).transpose(0, 1) 308 | if target_size is not None and target_size != gk.size(): 309 | return gk[:, :, :target_size[2], :target_size[3]].contiguous() 310 | return gk 311 | --------------------------------------------------------------------------------