├── .gitignore ├── README.md ├── dataset.py ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-grad-norm 2 | Pytorch implementation of the GradNorm. GradNorm addresses the problem of balancing multiple losses for multi-task learning by learning adjustable weight coefficients. 3 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | 5 | from torch.utils import data 6 | 7 | 8 | class RegressionDataset(data.Dataset): 9 | ''' 10 | Data set for the experiment in Section 4 of the paper 11 | ''' 12 | 13 | def __init__(self, sigmas, epsilons): 14 | ''' 15 | Initialize the dataset 16 | Inputs: 17 | sigmas: ($\sigma_i$) fixed scalars that set the scales of the outputs of each function $f_i$ 18 | epsilons: ($\epsilon_i$) task-specific information 19 | ''' 20 | 21 | # B is a constant matrix with its elemenets generated IID 22 | # from a normal distribution N(0,10) 23 | self.B = np.random.normal(scale=10, size=(100, 250)).astype(np.float32) 24 | 25 | # check if the given epsilons have the appropriate size 26 | assert epsilons.shape == (len(sigmas), 100, 250) 27 | 28 | # assign the epsilons and the sigmas 29 | self.sigmas = np.array(sigmas).astype(np.float32) 30 | self.epsilons = np.array(epsilons).astype(np.float32) 31 | 32 | 33 | def __len__(self): 34 | return 100 35 | 36 | 37 | def __getitem__(self, index): 38 | 39 | # retrieve a single input sample with d=250, normalized 40 | x = np.random.uniform(-1, 1, size=(250,)).astype(np.float32) 41 | x = x / np.linalg.norm(x) 42 | 43 | # retrieve one target value for each of the tasks 44 | ys = [] 45 | for i in range(len(self.sigmas)): 46 | # eq (3) on the paper: 47 | # each target is $\sigma_i \tanh((B + \epsilon_i)) \mathbf{x}) $ 48 | ys.append( 49 | self.sigmas[i] * np.tanh((self.B + self.epsilons[i]).dot(x)) 50 | ) 51 | ys = np.stack(ys) 52 | 53 | # move everything to torch variables 54 | x = torch.from_numpy(x).float() 55 | ys = torch.from_numpy(ys).float() 56 | 57 | return x, ys -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.nn.modules.loss import MSELoss 4 | 5 | import torch.nn.functional as F 6 | 7 | 8 | 9 | 10 | 11 | class RegressionTrain(torch.nn.Module): 12 | ''' 13 | ''' 14 | 15 | def __init__(self, model): 16 | ''' 17 | ''' 18 | 19 | # initialize the module using super() constructor 20 | super(RegressionTrain, self).__init__() 21 | # assign the architectures 22 | self.model = model 23 | # assign the weights for each task 24 | self.weights = torch.nn.Parameter(torch.ones(model.n_tasks).float()) 25 | # loss function 26 | self.mse_loss = MSELoss() 27 | 28 | 29 | def forward(self, x, ts): 30 | B, n_tasks = ts.shape[:2] 31 | ys = self.model(x) 32 | 33 | # check if the number of tasks is equal to this size 34 | assert(ys.size()[1] == n_tasks) 35 | task_loss = [] 36 | for i in range(n_tasks): 37 | task_loss.append( self.mse_loss(ys[:,i,:], ts[:,i,:]) ) 38 | task_loss = torch.stack(task_loss) 39 | 40 | return task_loss 41 | 42 | 43 | def get_last_shared_layer(self): 44 | return self.model.get_last_shared_layer() 45 | 46 | 47 | 48 | class RegressionModel(torch.nn.Module): 49 | ''' 50 | ''' 51 | 52 | def __init__(self, n_tasks): 53 | ''' 54 | Constructor of the architecture. 55 | Input: 56 | n_tasks: number of tasks to solve ($T$ in the paper) 57 | ''' 58 | 59 | # initialize the module using super() constructor 60 | super(RegressionModel, self).__init__() 61 | 62 | # number of tasks to solve 63 | self.n_tasks = n_tasks 64 | # fully connected layers 65 | self.l1 = torch.nn.Linear(250, 100) 66 | self.l2 = torch.nn.Linear(100, 100) 67 | self.l3 = torch.nn.Linear(100, 100) 68 | self.l4 = torch.nn.Linear(100, 100) 69 | # branches for each task 70 | for i in range(self.n_tasks): 71 | setattr(self, 'task_{}'.format(i), torch.nn.Linear(100, 100)) 72 | 73 | 74 | def forward(self, x): 75 | # forward pass through the common fully connected layers 76 | h = F.relu(self.l1(x)) 77 | h = F.relu(self.l2(h)) 78 | h = F.relu(self.l3(h)) 79 | h = F.relu(self.l4(h)) 80 | 81 | # forward pass through each output layer 82 | outs = [] 83 | for i in range(self.n_tasks): 84 | layer = getattr(self, 'task_{}'.format(i)) 85 | outs.append(layer(h)) 86 | 87 | return torch.stack(outs, dim=1) 88 | 89 | 90 | def get_last_shared_layer(self): 91 | return self.l4 92 | 93 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import torch 4 | 5 | import numpy as np 6 | 7 | from dataset import RegressionDataset 8 | from model import RegressionModel, RegressionTrain 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | from torch.utils import data 13 | from torch.autograd import Variable 14 | 15 | import torch.nn.functional as F 16 | 17 | 18 | def train_toy_example(args): 19 | 20 | # set the random seeds for reproducibility 21 | np.random.seed(123) 22 | torch.cuda.manual_seed_all(123) 23 | torch.manual_seed(123) 24 | 25 | # define the sigmas, the number of tasks and the epsilons 26 | # for the toy example 27 | sigmas = [1.0, float(args.sigma)] 28 | print('Training toy example with sigmas={}'.format(sigmas)) 29 | n_tasks = len(sigmas) 30 | epsilons = np.random.normal(scale=3.5, size=(n_tasks, 100, 250)).astype(np.float32) 31 | 32 | # initialize the data loader 33 | dataset = RegressionDataset(sigmas, epsilons) 34 | data_loader = data.DataLoader(dataset, batch_size=200, num_workers=4, shuffle=False) 35 | 36 | # initialize the model and use CUDA if available 37 | model = RegressionTrain(RegressionModel(n_tasks)) 38 | if torch.cuda.is_available(): 39 | model.cuda() 40 | 41 | # initialize the optimizer 42 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 43 | 44 | n_iterations = int(args.n_iter) 45 | weights = [] 46 | task_losses = [] 47 | loss_ratios = [] 48 | grad_norm_losses = [] 49 | 50 | # run n_iter iterations of training 51 | for t in range(n_iterations): 52 | 53 | # get a single batch 54 | for (it, batch) in enumerate(data_loader): 55 | # get the X and the targets values 56 | X = batch[0] 57 | ts = batch[1] 58 | if torch.cuda.is_available(): 59 | X = X.cuda() 60 | ts = ts.cuda() 61 | 62 | # evaluate each task loss L_i(t) 63 | task_loss = model(X, ts) # this will do a forward pass in the model and will also evaluate the loss 64 | # compute the weighted loss w_i(t) * L_i(t) 65 | weighted_task_loss = torch.mul(model.weights, task_loss) 66 | # initialize the initial loss L(0) if t=0 67 | if t == 0: 68 | # set L(0) 69 | if torch.cuda.is_available(): 70 | initial_task_loss = task_loss.data.cpu() 71 | else: 72 | initial_task_loss = task_loss.data 73 | initial_task_loss = initial_task_loss.numpy() 74 | 75 | # get the total loss 76 | loss = torch.sum(weighted_task_loss) 77 | # clear the gradients 78 | optimizer.zero_grad() 79 | # do the backward pass to compute the gradients for the whole set of weights 80 | # This is equivalent to compute each \nabla_W L_i(t) 81 | loss.backward(retain_graph=True) 82 | 83 | # set the gradients of w_i(t) to zero because these gradients have to be updated using the GradNorm loss 84 | #print('Before turning to 0: {}'.format(model.weights.grad)) 85 | model.weights.grad.data = model.weights.grad.data * 0.0 86 | #print('Turning to 0: {}'.format(model.weights.grad)) 87 | 88 | 89 | # switch for each weighting algorithm: 90 | # --> grad norm 91 | if args.mode == 'grad_norm': 92 | 93 | # get layer of shared weights 94 | W = model.get_last_shared_layer() 95 | 96 | # get the gradient norms for each of the tasks 97 | # G^{(i)}_w(t) 98 | norms = [] 99 | for i in range(len(task_loss)): 100 | # get the gradient of this task loss with respect to the shared parameters 101 | gygw = torch.autograd.grad(task_loss[i], W.parameters(), retain_graph=True) 102 | # compute the norm 103 | norms.append(torch.norm(torch.mul(model.weights[i], gygw[0]))) 104 | norms = torch.stack(norms) 105 | #print('G_w(t): {}'.format(norms)) 106 | 107 | 108 | # compute the inverse training rate r_i(t) 109 | # \curl{L}_i 110 | if torch.cuda.is_available(): 111 | loss_ratio = task_loss.data.cpu().numpy() / initial_task_loss 112 | else: 113 | loss_ratio = task_loss.data.numpy() / initial_task_loss 114 | # r_i(t) 115 | inverse_train_rate = loss_ratio / np.mean(loss_ratio) 116 | #print('r_i(t): {}'.format(inverse_train_rate)) 117 | 118 | 119 | # compute the mean norm \tilde{G}_w(t) 120 | if torch.cuda.is_available(): 121 | mean_norm = np.mean(norms.data.cpu().numpy()) 122 | else: 123 | mean_norm = np.mean(norms.data.numpy()) 124 | #print('tilde G_w(t): {}'.format(mean_norm)) 125 | 126 | 127 | # compute the GradNorm loss 128 | # this term has to remain constant 129 | constant_term = torch.tensor(mean_norm * (inverse_train_rate ** args.alpha), requires_grad=False) 130 | if torch.cuda.is_available(): 131 | constant_term = constant_term.cuda() 132 | #print('Constant term: {}'.format(constant_term)) 133 | # this is the GradNorm loss itself 134 | grad_norm_loss = torch.sum(torch.abs(norms - constant_term)) 135 | #print('GradNorm loss {}'.format(grad_norm_loss)) 136 | 137 | # compute the gradient for the weights 138 | model.weights.grad = torch.autograd.grad(grad_norm_loss, model.weights)[0] 139 | 140 | # do a step with the optimizer 141 | optimizer.step() 142 | ''' 143 | print('') 144 | wait = input("PRESS ENTER TO CONTINUE.") 145 | print('') 146 | ''' 147 | 148 | # renormalize 149 | normalize_coeff = n_tasks / torch.sum(model.weights.data, dim=0) 150 | model.weights.data = model.weights.data * normalize_coeff 151 | 152 | # record 153 | if torch.cuda.is_available(): 154 | task_losses.append(task_loss.data.cpu().numpy()) 155 | loss_ratios.append(np.sum(task_losses[-1] / task_losses[0])) 156 | weights.append(model.weights.data.cpu().numpy()) 157 | grad_norm_losses.append(grad_norm_loss.data.cpu().numpy()) 158 | else: 159 | task_losses.append(task_loss.data.numpy()) 160 | loss_ratios.append(np.sum(task_losses[-1] / task_losses[0])) 161 | weights.append(model.weights.data.numpy()) 162 | grad_norm_losses.append(grad_norm_loss.data.numpy()) 163 | 164 | if t % 100 == 0: 165 | if torch.cuda.is_available(): 166 | print('{}/{}: loss_ratio={}, weights={}, task_loss={}, grad_norm_loss={}'.format( 167 | t, args.n_iter, loss_ratios[-1], model.weights.data.cpu().numpy(), task_loss.data.cpu().numpy(), grad_norm_loss.data.cpu().numpy())) 168 | else: 169 | print('{}/{}: loss_ratio={}, weights={}, task_loss={}, grad_norm_loss={}'.format( 170 | t, args.n_iter, loss_ratios[-1], model.weights.data.numpy(), task_loss.data.numpy(), grad_norm_loss.data.numpy())) 171 | 172 | task_losses = np.array(task_losses) 173 | weights = np.array(weights) 174 | 175 | plt.rc('text', usetex=True) 176 | plt.rc('font', family='serif') 177 | fig = plt.figure() 178 | ax1 = fig.add_subplot(2, 3, 1) 179 | ax1.set_title(r'Loss (scale $\sigma_0=1.0$)') 180 | ax2 = fig.add_subplot(2, 3, 2) 181 | ax2.set_title(r'Loss (scale $\sigma_1={})$'.format(sigmas[1])) 182 | ax3 = fig.add_subplot(2, 3, 3) 183 | ax3.set_title(r"$\sum_i L_i(t) / L_i(0)$") 184 | ax4 = fig.add_subplot(2, 3, 4) 185 | ax4.set_title(r'$L_{\text{grad}}$') 186 | 187 | ax5 = fig.add_subplot(2, 3, 5) 188 | ax5.set_title(r'Change of weights $w_i$ over time') 189 | 190 | ax1.plot(task_losses[:, 0]) 191 | ax2.plot(task_losses[:, 1]) 192 | ax3.plot(loss_ratios) 193 | ax4.plot(grad_norm_losses) 194 | ax5.plot(weights[:, 0]) 195 | ax5.plot(weights[:, 1]) 196 | plt.show() 197 | 198 | 199 | 200 | if __name__ == '__main__': 201 | 202 | parser = argparse.ArgumentParser(description='GradNorm') 203 | parser.add_argument('--n-iter', '-it', type=int, default=25000) 204 | parser.add_argument('--mode', '-m', choices=('grad_norm', 'equal_weight'), default='grad_norm') 205 | parser.add_argument('--alpha', '-a', type=float, default=0.12) 206 | parser.add_argument('--sigma', '-s', type=float, default=100.0) 207 | args = parser.parse_args() 208 | 209 | train_toy_example(args) 210 | --------------------------------------------------------------------------------