├── README.md └── pcgrad-example.py /README.md: -------------------------------------------------------------------------------- 1 | # PCGrad-pytorch-example 2 | write simple PCGrad code based on [code](https://github.com/tianheyu927/PCGrad) implementaiton. 3 | -------------------------------------------------------------------------------- /pcgrad-example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as opt 5 | import random 6 | from itertools import accumulate 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | self.freeze = nn.Linear(10, 10) # freeze module 13 | self.base1 = nn.Linear(10, 10) 14 | self.task1 = nn.Linear(10, 2) 15 | self.task2 = nn.Linear(10, 2) 16 | 17 | for p in self.freeze.parameters(): 18 | p.requires_grad = False 19 | 20 | def forward(self, x): 21 | x = F.relu(self.freeze(x), inplace=True) 22 | x = F.relu(self.base1(x), inplace=True) 23 | t1 = self.task1(x) 24 | t2 = self.task2(x) 25 | return t1, t2 26 | 27 | 28 | def normal_backward(net, optimizer, X, y, loss_layer=nn.CrossEntropyLoss()): 29 | num_tasks = len(y) # T 30 | losses = [] 31 | for i in range(num_tasks): 32 | optimizer.zero_grad() 33 | result = net(X) 34 | loss = loss_layer(result[i], y[i]) 35 | losses += [loss, ] 36 | 37 | tot_loss = sum(losses) 38 | tot_loss.backward() 39 | return losses 40 | 41 | 42 | def PCGrad_backward(net, optimizer, X, y, loss_layer=nn.CrossEntropyLoss()): 43 | grads_task = [] 44 | grad_shapes = [p.shape if p.requires_grad is True else None 45 | for group in optimizer.param_groups for p in group['params']] 46 | grad_numel = [p.numel() if p.requires_grad is True else 0 47 | for group in optimizer.param_groups for p in group['params']] 48 | num_tasks = len(y) # T 49 | losses = [] 50 | optimizer.zero_grad() 51 | 52 | # calculate gradients for each task 53 | for i in range(num_tasks): 54 | result = net(X) 55 | loss = loss_layer(result[i], y[i]) 56 | losses.append(loss) 57 | loss.backward() 58 | 59 | devices = [ 60 | p.device for group in optimizer.param_groups for p in group['params']] 61 | 62 | grad = [p.grad.detach().clone().flatten() if (p.requires_grad is True and p.grad is not None) 63 | else None for group in optimizer.param_groups for p in group['params']] 64 | 65 | # fill zero grad if grad is None but requires_grad is true 66 | grads_task.append(torch.cat([g if g is not None else torch.zeros( 67 | grad_numel[i], device=devices[i]) for i, g in enumerate(grad)])) 68 | optimizer.zero_grad() 69 | 70 | # shuffle gradient order 71 | random.shuffle(grads_task) 72 | 73 | # gradient projection 74 | grads_task = torch.stack(grads_task, dim=0) # (T, # of params) 75 | proj_grad = grads_task.clone() 76 | 77 | def _proj_grad(grad_task): 78 | for k in range(num_tasks): 79 | inner_product = torch.sum(grad_task*grads_task[k]) 80 | proj_direction = inner_product / (torch.sum( 81 | grads_task[k]*grads_task[k])+1e-12) 82 | grad_task = grad_task - torch.min( 83 | proj_direction, torch.zeros_like(proj_direction)) * grads_task[k] 84 | return grad_task 85 | 86 | proj_grad = torch.sum(torch.stack( 87 | list(map(_proj_grad, list(proj_grad)))), dim=0) # (of params, ) 88 | 89 | indices = [0, ] + [v for v in accumulate(grad_numel)] 90 | params = [p for group in optimizer.param_groups for p in group['params']] 91 | assert len(params) == len(grad_shapes) == len(indices[:-1]) 92 | for param, grad_shape, start_idx, end_idx in zip(params, grad_shapes, indices[:-1], indices[1:]): 93 | if grad_shape is not None: 94 | param.grad[...] = proj_grad[start_idx:end_idx].view(grad_shape) # copy proj grad 95 | 96 | return losses 97 | 98 | if __name__ == '__main__': 99 | 100 | net = Net() 101 | net.train() 102 | optimizer = opt.SGD(net.parameters(), lr=0.01) 103 | num_task = 2 104 | num_iterations = 50000 105 | for it in range(num_iterations): 106 | X = torch.rand(20, 10) - \ 107 | torch.cat([torch.zeros(10, 10), torch.ones(10, 10)]) 108 | y = [torch.cat([torch.zeros(10,), torch.ones(10,)]).long(), 109 | torch.cat([torch.ones(10,), torch.zeros(10,)]).long()] 110 | losses = PCGrad_backward(net, optimizer, X, y) 111 | # losses = normal_backward(net, optimizer, X, y) 112 | optimizer.step() 113 | if it % 100 == 0: 114 | print("iter {} total loss: {}".format( 115 | it, sum([l.item() for l in losses]))) 116 | --------------------------------------------------------------------------------