├── README.md ├── adafactor.py └── mnist_test.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # adafactor-pytorch 2 | A pytorch realization of adafactor (https://arxiv.org/pdf/1804.04235.pdf ) 3 | 4 | # Notes 5 | 1)Factorization works on any dimension. When dimension of weight tensor is higher than 2, it will be reshaped to 2D. For turning off this feature just change this lines ( if len(shape) > 2: return False, True ) in _check_shape 6 | 7 | 2)Weights decay was moved to proper position according (https://arxiv.org/abs/1711.05101 ) 8 | 9 | # Parameters description: 10 | lr - learning rate can be scalar or function, in second case relative step size is using. 11 | 12 | beta1, beta2 - is also can be scalar or functions, in first case algorithm works as AMSGrad. Setting beta1 to zero is turning off moments updates. 13 | 14 | non_constant_decay - boolean, has effect if betas are scalars. If True using functions for betas (from section 7.1) 15 | 16 | enable_factorization - boolean. Factorization works on 2D weights. 17 | 18 | clipping_threshold - scalar. Threshold value for update clipping (from section 6) 19 | -------------------------------------------------------------------------------- /adafactor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import operator 3 | import functools 4 | from copy import copy 5 | from math import sqrt 6 | 7 | class AdaFactor(torch.optim.Optimizer): 8 | def __init__(self, params, lr=None, beta1=0.9, beta2=0.999, eps1=1e-30, 9 | eps2=1e-3, cliping_threshold=1,non_constant_decay = True, 10 | enable_factorization=True, ams_grad=True, weight_decay=0): 11 | 12 | enable_momentum = beta1 != 0 13 | self.beta1_glob = copy(beta1) 14 | self.beta2_glob = copy(beta2) 15 | self.lr_glob = copy(lr) 16 | 17 | beta1 = self.beta1_glob if hasattr(beta1,'__call__') else lambda x: self.beta1_glob 18 | beta2 = self.beta2_glob if hasattr(beta2,'__call__') else lambda x: self.beta2_glob 19 | 20 | if non_constant_decay: 21 | ams_grad = False 22 | if isinstance(self.beta1_glob,float): 23 | beta1 = lambda t: self.beta1_glob * (1 - self.beta1_glob ** (t-1)) / (1 - self.beta1_glob ** t) 24 | if isinstance(self.beta2_glob,float): 25 | beta2 = lambda t: self.beta2_glob * (1 - self.beta2_glob ** (t-1)) / (1 - self.beta2_glob ** t) 26 | 27 | relative_step_size = True 28 | 29 | if lr is None: 30 | #default value from article 31 | lr = lambda t: min(1e-2, 1 / sqrt(t)) 32 | 33 | if isinstance(self.lr_glob, float): 34 | lr=lambda x: self.lr_glob 35 | relative_step_size = False 36 | 37 | 38 | defaults = dict(lr=lr, beta1=beta1, beta2=beta2, eps1=eps1, 39 | eps2=eps2, cliping_threshold=cliping_threshold, weight_decay=weight_decay,ams_grad=ams_grad, 40 | enable_factorization=enable_factorization, 41 | enable_momentum=enable_momentum,relative_step_size=relative_step_size) 42 | 43 | super(AdaFactor, self).__init__(params, defaults) 44 | 45 | def __setstate__(self, state): 46 | super(AdaFactor, self).__setstate__(state) 47 | 48 | def _experimental_reshape(self,shape): 49 | temp_shape = shape[2:] 50 | if len(temp_shape) == 1: 51 | new_shape = (shape[0],shape[1]*shape[2]) 52 | else: 53 | tmp_div = len(temp_shape) // 2 + len(temp_shape) % 2 54 | new_shape = (shape[0]*functools.reduce(operator.mul, temp_shape[tmp_div:],1), 55 | shape[1]*functools.reduce(operator.mul, temp_shape[:tmp_div],1)) 56 | return new_shape, copy(shape) 57 | 58 | 59 | def _check_shape(self, shape): 60 | ''' 61 | output1 - True - algorithm for matrix, False - vector; 62 | output2 - need reshape 63 | ''' 64 | if len(shape) > 2: 65 | return True, True 66 | elif len(shape) == 2: 67 | return True, False 68 | elif len(shape) == 2 and (shape[0] == 1 or shape[1] == 1): 69 | return False, False 70 | else: 71 | return False, False 72 | 73 | def _rms(self, x): 74 | return sqrt(torch.mean(x.pow(2))) 75 | 76 | 77 | 78 | def step(self, closure=None): 79 | loss = None 80 | if closure is not None: 81 | loss = closure() 82 | for group in self.param_groups: 83 | for p in group['params']: 84 | if p.grad is None: 85 | continue 86 | grad = p.grad.data 87 | data_backup = p.data.clone().detach() 88 | 89 | 90 | if grad.is_sparse: 91 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 92 | 93 | is_matrix, is_need_reshape = self._check_shape(grad.size()) 94 | new_shape = p.data.size() 95 | if is_need_reshape and group['enable_factorization']: 96 | new_shape, old_shape =\ 97 | self._experimental_reshape(p.data.size()) 98 | grad = grad.view(new_shape) 99 | 100 | state = self.state[p] 101 | if len(state) == 0: 102 | state['step'] = 0 103 | if group['enable_momentum']: 104 | state['exp_avg'] = torch.zeros(new_shape, dtype=torch.float32, device=p.grad.device) 105 | 106 | 107 | if is_matrix and group['enable_factorization']: 108 | state['exp_avg_sq_R'] = torch.zeros((1,new_shape[1]), dtype=torch.float32, device=p.grad.device) 109 | state['exp_avg_sq_C'] = torch.zeros((new_shape[0],1), dtype=torch.float32, device=p.grad.device) 110 | else: 111 | state['exp_avg_sq'] = torch.zeros(new_shape, dtype=torch.float32, device=p.grad.device) 112 | if group['ams_grad']: 113 | state['exp_avg_sq_hat'] = torch.zeros(new_shape, dtype=torch.float32, device=p.grad.device) 114 | 115 | 116 | if group['enable_momentum']: 117 | exp_avg = state['exp_avg'] 118 | 119 | if is_matrix and group['enable_factorization']: 120 | exp_avg_sq_R = state['exp_avg_sq_R'] 121 | exp_avg_sq_C = state['exp_avg_sq_C'] 122 | else: 123 | exp_avg_sq = state['exp_avg_sq'] 124 | 125 | if group['ams_grad']: 126 | exp_avg_sq_hat = state['exp_avg_sq_hat'] 127 | 128 | 129 | state['step'] += 1 130 | lr_t = group['lr'](state['step']) 131 | if group['relative_step_size']: 132 | lr_t *= max(group['eps2'], self._rms(p.data)) 133 | 134 | if group['enable_momentum']: 135 | beta1_t = group['beta1'](state['step']) 136 | exp_avg.mul_(beta1_t).add_(1 - beta1_t, grad) 137 | 138 | beta2_t = group['beta2'](state['step']) 139 | 140 | if is_matrix and group['enable_factorization']: 141 | exp_avg_sq_R.mul_(beta2_t).add_(1 - beta2_t, 142 | torch.sum(torch.mul(grad,grad).add_(group['eps1']), dim=0, keepdim=True)) 143 | exp_avg_sq_C.mul_(beta2_t).add_(1 - beta2_t, 144 | torch.sum(torch.mul(grad,grad).add_(group['eps1']), dim=1, keepdim=True)) 145 | v = torch.mul(exp_avg_sq_C,exp_avg_sq_R).div_(torch.sum(exp_avg_sq_R)) 146 | else: 147 | exp_avg_sq.mul_(beta2_t).addcmul_(1 - beta2_t, grad, grad).add_((1 - beta2_t)*group['eps1']) 148 | v = exp_avg_sq 149 | 150 | 151 | g = grad 152 | if group['enable_momentum']: 153 | g = torch.div(exp_avg,1 - beta1_t ** state['step']) 154 | 155 | if group['ams_grad']: 156 | torch.max(exp_avg_sq_hat, v, out=exp_avg_sq_hat) 157 | v = exp_avg_sq_hat 158 | u = torch.div(g,(torch.div(v,1 - beta2_t ** state['step'])).sqrt().add_(group['eps1'])) 159 | else: 160 | u = torch.div(g,v.sqrt()) 161 | 162 | u.div_(max(1,self._rms(u) / group['cliping_threshold'])) 163 | p.data.add_(-lr_t * (u.view(old_shape) if is_need_reshape and group['enable_factorization'] else u)) 164 | 165 | if group['weight_decay'] != 0: 166 | p.data.add_(-group['weight_decay'] * lr_t, data_backup) 167 | 168 | return loss 169 | -------------------------------------------------------------------------------- /mnist_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "157it [00:01, 98.68it/s]" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "from __future__ import print_function\n", 18 | "import argparse\n", 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "import torch.nn.functional as F\n", 22 | "import torch.optim as optim\n", 23 | "from torchvision import datasets, transforms\n", 24 | "from torch.autograd import Variable\n", 25 | "from adafactor import AdaFactor\n", 26 | "import math\n", 27 | "from tqdm import tqdm as tqdmn\n", 28 | "# Training settings\n", 29 | "\n", 30 | "class ProtoArgs(object):\n", 31 | " def __init__(self):\n", 32 | " self.batch_size = 64\n", 33 | " self.test_batch_size = 1000\n", 34 | " self.epochs = 10\n", 35 | " self.lr = 0.01\n", 36 | " self.momentum = 0.5\n", 37 | " self.seed = 1\n", 38 | " self.log_interval = 10000000\n", 39 | " self.cuda = 3\n", 40 | "\n", 41 | "\n", 42 | "\n", 43 | "args = ProtoArgs()\n", 44 | "\n", 45 | "torch.manual_seed(args.seed)\n", 46 | "main_device = torch.device(\"cpu\")\n", 47 | "\n", 48 | "kwargs = {}\n", 49 | "if args.cuda > -1:\n", 50 | " torch.cuda.manual_seed(args.seed)\n", 51 | " main_device = torch.device(\"cuda:{}\".format(args.cuda))\n", 52 | " kwargs = {'num_workers': 1, 'pin_memory': True}\n", 53 | "\n", 54 | "\n", 55 | "train_loader = torch.utils.data.DataLoader(\n", 56 | " datasets.MNIST('../data', train=True, download=True,\n", 57 | " transform=transforms.Compose([\n", 58 | " transforms.ToTensor(),\n", 59 | " transforms.Normalize((0.1307,), (0.3081,))\n", 60 | " ])),\n", 61 | " batch_size=args.batch_size, shuffle=True, **kwargs)\n", 62 | "test_loader = torch.utils.data.DataLoader(\n", 63 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 64 | " transforms.ToTensor(),\n", 65 | " transforms.Normalize((0.1307,), (0.3081,))\n", 66 | " ])),\n", 67 | " batch_size=args.test_batch_size, shuffle=True, **kwargs)\n", 68 | "\n", 69 | "\n", 70 | "class Net(nn.Module):\n", 71 | " def __init__(self):\n", 72 | " super(Net, self).__init__()\n", 73 | " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", 74 | " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", 75 | " self.conv2_drop = nn.Dropout2d()\n", 76 | " self.fc1 = nn.Linear(320, 50)\n", 77 | " self.fc2 = nn.Linear(50, 10)\n", 78 | "\n", 79 | " def forward(self, x):\n", 80 | " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", 81 | " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", 82 | " x = x.view(-1, 320)\n", 83 | " x = F.relu(self.fc1(x))\n", 84 | " x = F.dropout(x, training=self.training)\n", 85 | " x = self.fc2(x)\n", 86 | " return F.log_softmax(x, dim=1)\n", 87 | "\n", 88 | "model = Net()\n", 89 | "model.to(main_device)\n", 90 | "\n", 91 | "optimizer = AdaFactor(model.parameters(), non_constant_decay = True, enable_factorization=True)\n", 92 | "\n", 93 | "\n", 94 | "def train(epoch):\n", 95 | " model.train()\n", 96 | " for batch_idx, (data, target) in tqdmn(enumerate(train_loader)):\n", 97 | " data, target = data.to(main_device), target.to(main_device)\n", 98 | " optimizer.zero_grad()\n", 99 | " output = model(data)\n", 100 | " loss = F.nll_loss(output, target)\n", 101 | " loss.backward()\n", 102 | " optimizer.step()\n", 103 | " if batch_idx + 1 % args.log_interval == 0:\n", 104 | " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 105 | " epoch, batch_idx * len(data), len(train_loader.dataset),\n", 106 | " 100. * batch_idx / len(train_loader), loss.item()))\n", 107 | "\n", 108 | "def test():\n", 109 | " model.eval()\n", 110 | " test_loss = 0\n", 111 | " correct = 0\n", 112 | " for data, target in test_loader:\n", 113 | " data, target = data.to(main_device), target.to(main_device)\n", 114 | " with torch.no_grad():\n", 115 | " output = model(data)\n", 116 | " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", 117 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 118 | " correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()\n", 119 | "\n", 120 | " test_loss /= len(test_loader.dataset)\n", 121 | " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 122 | " test_loss, correct, len(test_loader.dataset),\n", 123 | " 100. * correct / len(test_loader.dataset)))\n", 124 | "\n", 125 | "\n", 126 | "for epoch in range(1, args.epochs + 1):\n", 127 | " train(epoch)\n", 128 | " test()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [] 137 | } 138 | ], 139 | "metadata": { 140 | "kernelspec": { 141 | "display_name": "Python 3", 142 | "language": "python", 143 | "name": "python3" 144 | }, 145 | "language_info": { 146 | "codemirror_mode": { 147 | "name": "ipython", 148 | "version": 3 149 | }, 150 | "file_extension": ".py", 151 | "mimetype": "text/x-python", 152 | "name": "python", 153 | "nbconvert_exporter": "python", 154 | "pygments_lexer": "ipython3", 155 | "version": "3.6.6" 156 | } 157 | }, 158 | "nbformat": 4, 159 | "nbformat_minor": 2 160 | } 161 | --------------------------------------------------------------------------------