├── README.md └── zdt ├── moosvgd.py ├── run_zdt_moosvgd.py └── zdt_functions.py /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Objective Optimization with Sampling Methods 2 | This repo contains the code for NeurIPS 2021 spotlight paper ([paper link](https://proceedings.neurips.cc/paper/2021/hash/7bb16972da003e87724f048d76b7e0e1-Abstract.html)): 3 | **Profiling Pareto Front With Multi-Objective Stein Variational Gradient Descent** 4 | by *Xingchao Liu, Xin Tong, and Qiang Liu* from UT Austin and the National University of Singapore. 5 | 6 | ## Requirements 7 | ```PyTorch```, ```Numpy``` and ```pymoo==0.4.3.dev0```. Should be easy to install with ```pip```. 8 | 9 | ## Usage 10 | 11 | ### ZDT Problems 12 | ```python run_zdt_moosvgd.py``` to get results of MOO-SVGD on ZDT problems. 13 | 14 | ## Citation 15 | If you use our code, please consider citing us with: 16 | ```BibTex 17 | @article{liu2021profiling, 18 | title={Profiling Pareto Front With Multi-Objective Stein Variational Gradient Descent}, 19 | author={Liu, Xingchao and Tong, Xin and Liu, Qiang}, 20 | journal={Advances in Neural Information Processing Systems}, 21 | year={2021} 22 | } 23 | ``` 24 | -------------------------------------------------------------------------------- /zdt/moosvgd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def solve_min_norm_2_loss(grad_1, grad_2): 5 | v1v1 = torch.sum(grad_1*grad_1, dim=1) 6 | v2v2 = torch.sum(grad_2*grad_2, dim=1) 7 | v1v2 = torch.sum(grad_1*grad_2, dim=1) 8 | gamma = torch.zeros_like(v1v1) 9 | gamma = -1.0 * ( (v1v2 - v2v2) / (v1v1+v2v2 - 2*v1v2) ) 10 | gamma[v1v2>=v1v1] = 0.999 11 | gamma[v1v2>=v2v2] = 0.001 12 | gamma = gamma.view(-1, 1) 13 | g_w = gamma.repeat(1, grad_1.shape[1])*grad_1 + (1.-gamma.repeat(1, grad_2.shape[1]))*grad_2 14 | 15 | return g_w 16 | 17 | def median(tensor): 18 | """ 19 | torch.median() acts differently from np.median(). We want to simulate numpy implementation. 20 | """ 21 | tensor = tensor.detach().flatten() 22 | tensor_max = tensor.max()[None] 23 | return (torch.cat((tensor, tensor_max)).median() + tensor.median()) / 2. 24 | 25 | def kernel_functional_rbf(losses): 26 | n = losses.shape[0] 27 | pairwise_distance = torch.norm(losses[:, None] - losses, dim=2).pow(2) 28 | h = median(pairwise_distance) / math.log(n) 29 | kernel_matrix = torch.exp(-pairwise_distance / 5e-6*h) #5e-6 for zdt1,2,3 (no bracket) 30 | return kernel_matrix 31 | 32 | def get_gradient(grad_1, grad_2, inputs, losses): 33 | n = inputs.size(0) 34 | #inputs = inputs.detach().requires_grad_(True) 35 | 36 | g_w = solve_min_norm_2_loss(grad_1, grad_2) 37 | ### g_w (100, x_dim) 38 | # See https://github.com/activatedgeek/svgd/issues/1#issuecomment-649235844 for why there is a factor -0.5 39 | kernel = kernel_functional_rbf(losses) 40 | kernel_grad = -0.5 * torch.autograd.grad(kernel.sum(), inputs, allow_unused=True)[0] 41 | 42 | gradient = (kernel.mm(g_w) - kernel_grad) / n 43 | 44 | return gradient 45 | -------------------------------------------------------------------------------- /zdt/run_zdt_moosvgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from torch.optim import Adam 5 | from moosvgd import get_gradient 6 | from pymoo.factory import get_problem 7 | from pymoo.util.plotting import plot 8 | from zdt_functions import * 9 | import time 10 | from pymoo.factory import get_performance_indicator 11 | 12 | cur_problem = 'zdt3' 13 | run_num = 0 14 | 15 | if __name__ == '__main__': 16 | x = torch.rand((50, 30)) 17 | x.requires_grad = True 18 | optimizer = Adam([x], lr=5e-4) 19 | 20 | ref_point = get_ref_point(cur_problem) 21 | hv = get_performance_indicator('hv', ref_point=ref_point) 22 | iters = 10000 23 | start_time = time.time() 24 | hv_results = [] 25 | for i in range(iters): 26 | loss_1, loss_2 = loss_function(x, problem=cur_problem) 27 | pfront = torch.cat([loss_1.unsqueeze(1), loss_2.unsqueeze(1)], dim=1) 28 | pfront = pfront.detach().cpu().numpy() 29 | hvi = hv.calc(pfront) 30 | hv_results.append(hvi) 31 | 32 | if i%1000 == 0: 33 | problem = get_problem(cur_problem) 34 | x_p = problem.pareto_front()[:, 0] 35 | y_p = problem.pareto_front()[:, 1] 36 | plt.scatter(x_p, y_p, c='r') 37 | 38 | plt.scatter(loss_1.detach().cpu().numpy(),loss_2.detach().cpu().numpy()) 39 | plt.savefig('figs/%s_%d.png'%(cur_problem, i)) 40 | plt.close() 41 | 42 | loss_1.sum().backward(retain_graph=True) 43 | grad_1 = x.grad.detach().clone() 44 | x.grad.zero_() 45 | 46 | loss_2.sum().backward(retain_graph=True) 47 | grad_2 = x.grad.detach().clone() 48 | x.grad.zero_() 49 | 50 | # Perforam gradient normalization trick 51 | grad_1 = torch.nn.functional.normalize(grad_1, dim=0) 52 | grad_2 = torch.nn.functional.normalize(grad_2, dim=0) 53 | 54 | optimizer.zero_grad() 55 | losses = torch.cat([loss_1.unsqueeze(1), loss_2.unsqueeze(1)], dim=1) 56 | x.grad = get_gradient(grad_1, grad_2, x, losses) 57 | optimizer.step() 58 | 59 | x.data = torch.clamp(x.data.clone(), min=1e-6, max=1.-1e-6) 60 | 61 | print(i, 'time:', time.time()-start_time, 'hv:', hvi, loss_1.sum().detach().cpu().numpy(), loss_2.sum().detach().cpu().numpy()) 62 | -------------------------------------------------------------------------------- /zdt/zdt_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def loss_function(x, problem='zdt1'): 5 | 6 | ### x = (100, 30) 7 | f = x[:, 0] 8 | g = x[:, 1:] 9 | 10 | if problem == 'zdt1': 11 | g = g.sum(dim=1, keepdim=False) * (9./29.) + 1. 12 | h = 1. - torch.sqrt(f/g) 13 | 14 | if problem == 'zdt2': 15 | g = g.sum(dim=1, keepdim=False) * (9./29.) + 1. 16 | h = 1. - (f/g)**2 17 | 18 | if problem == 'zdt3': 19 | g = g.sum(dim=1, keepdim=False) * (9./29.) + 1. 20 | h = 1. - torch.sqrt(f/g) - (f/g)*torch.sin(10.*np.pi*f) 21 | 22 | 23 | return f, g*h 24 | 25 | def get_ref_point(problem='zdt1'): 26 | if problem == 'zdt1': 27 | return np.array([0.99022638, 6.39358545]) 28 | 29 | if problem == 'zdt2': 30 | return np.array([0.99022638, 7.71577261]) 31 | 32 | if problem == 'zdt3': 33 | return np.array([0.99022638, 6.54635266]) 34 | 35 | --------------------------------------------------------------------------------