├── KRR alternate regularizer synthetic high noise.ipynb ├── KRR alternate regularizer synthetic low noise.ipynb ├── README.md ├── krr_alternate_regularizer.py └── list_of_installed_packages.txt /README.md: -------------------------------------------------------------------------------- 1 | # Distributionally Robust Optimization and Generalization in Kernel Methods 2 | This repository contains the supporting code for the paper: 3 | 4 | [Matthew Staib, Stefanie Jegelka. Distributionally Robust Optimization and Generalization in Kernel Methods. In _Advances in Neural Information Processing Systems 32_, 2019.](https://arxiv.org/abs/1905.10943) 5 | 6 | ``` 7 | @inproceedings{staib2019distributionally, 8 | author = {Staib, Matthew and Jegelka, Stefanie}, 9 | title = {Distributionally Robust Optimization and Generalization in Kernel Methods}, 10 | booktitle = {Advances in Neural Information Processing Systems 32}, 11 | year = {2019} 12 | } 13 | ``` 14 | 15 | 16 | ## Dependencies 17 | The only language used was Python 3 with a virtualenv -- all the installed packages in the virtualenv are listed in `list_of_installed_packages.txt` 18 | (Most of these are probably unnecessary except for scikit-learn and numpy/scipy.) 19 | 20 | 21 | ## Getting started 22 | All you need to do is load up a virtualenv with all the packages listed above, and then run the two included notebooks. 23 | -------------------------------------------------------------------------------- /krr_alternate_regularizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from scipy.optimize import minimize 4 | 5 | 6 | def fn_and_grad_square_reg(d, K): 7 | # auto-generated from matrixcalculus.org 8 | 9 | assert isinstance(K, np.ndarray) 10 | dim = K.shape 11 | assert len(dim) == 2 12 | K_rows = dim[0] 13 | K_cols = dim[1] 14 | assert isinstance(d, np.ndarray) 15 | dim = d.shape 16 | assert len(dim) == 1 17 | d_rows = dim[0] 18 | assert K_cols == d_rows == K_rows 19 | 20 | t_0 = (1 / 2) 21 | T_1 = (d[:, np.newaxis] * K) 22 | functionValue = (np.trace((T_1 ** 4)) ** t_0) 23 | gradient = (((4 * (np.trace(((K * d[np.newaxis, :]) ** 4)) ** (t_0 - 1))) / 2) * np.diag(np.dot((np.eye(d_rows, K_cols) * (T_1 ** 3)), K))) 24 | 25 | return functionValue, gradient 26 | 27 | 28 | 29 | def fn_and_grad_normal_reg(a, K): 30 | n = len(a) 31 | 32 | reg = a.T.dot(K.dot(a)) 33 | grad = 2 * K.dot(a) 34 | 35 | return reg, grad 36 | 37 | 38 | 39 | def fit_krr_both_regularizers(K, y, lmb_new, lmb_old, a_init=None): 40 | def fun(a): 41 | errs = K.dot(a) - y 42 | 43 | val_fit = np.sum(errs ** 2) 44 | grad_fit = 2 * K.dot(errs) 45 | 46 | val_total = val_fit 47 | grad_total = grad_fit 48 | 49 | if lmb_new > 0: 50 | if np.all(a == 0): 51 | val_reg_new = 0 52 | grad_reg_new = np.zeros(grad_total.shape) 53 | else: 54 | val_reg_new, grad_reg_new = fn_and_grad_square_reg(a, K ** 0.5) 55 | 56 | val_total += lmb_new * val_reg_new 57 | grad_total += lmb_new * grad_reg_new 58 | 59 | if lmb_old > 0: 60 | # val_reg_old = a.T.dot(K.dot(a)) 61 | # grad_reg_old = 2*K.dot(a) 62 | 63 | val_reg_old, grad_reg_old = fn_and_grad_normal_reg(a, K) 64 | 65 | val_total += lmb_old * val_reg_old 66 | grad_total += lmb_old * grad_reg_old 67 | 68 | return val_total, grad_total 69 | 70 | if a_init is None: 71 | a_init = np.zeros(len(y)) 72 | res = minimize(fun, 73 | a_init, 74 | method='L-BFGS-B', 75 | jac=True, 76 | options={'gtol': 1e-8, 'ftol': 1e-10}) 77 | 78 | 79 | 80 | val = res.fun 81 | a = res.x 82 | return a, val 83 | 84 | 85 | 86 | def create_population_distribution_2(seed=1, n=10000): 87 | np.random.seed(seed) 88 | 89 | k = 2 90 | true_x = np.reshape(np.array([-1,1]), (k,1)) 91 | true_a = np.reshape(np.array([-1,1]), (1,k)) 92 | 93 | from sklearn.gaussian_process.kernels import RBF 94 | m = RBF() 95 | 96 | def true_eval(x_new): 97 | return np.sum(m(x_new, true_x) * true_a, axis=1) 98 | 99 | x = np.random.randn(n, 1) 100 | x = np.array(sorted(x)) 101 | 102 | # noise = 1*np.random.randn(len(x),1) 103 | # y = np.array([true_eval(x[inx]) + noise[inx] for inx in range(len(x))]).flatten() 104 | y = true_eval(x).flatten() 105 | # y = np.array([true_eval(x[inx]) for inx in range(len(x))]).flatten() 106 | 107 | return x, y, m, true_eval 108 | 109 | 110 | def sample_from_population(x, y, n=20, stddev=1., seed=1): 111 | np.random.seed(seed) 112 | 113 | sample_inx = np.random.choice(range(len(x)), n) 114 | 115 | return x[sample_inx], y[sample_inx] + stddev*np.random.randn(n) 116 | 117 | 118 | 119 | def create_pred(x, m): 120 | def pred(x_new, a_learned): 121 | return np.sum( m(x_new, x) * a_learned, axis=1) 122 | 123 | return pred 124 | 125 | def create_fit_quality(pred, x, y): 126 | def fit_quality(a_learned): 127 | y_fit = pred(x, a_learned) 128 | 129 | return np.mean((y_fit - y) ** 2) 130 | 131 | return fit_quality 132 | 133 | 134 | def main(seed1, seed2, lmbs_old, lmbs_new, n_sample=20, stddev=1., verbose=False): 135 | x, y, m, true_eval = create_population_distribution_2(seed1) 136 | 137 | x_s, y_s = sample_from_population(x, y, n_sample, seed=seed2, stddev=stddev) 138 | K = m(x_s) 139 | pred = create_pred(x_s, m) 140 | 141 | fit_quality = create_fit_quality(pred, x, y) 142 | 143 | out_dicts = [] 144 | 145 | a_init, val = fit_krr_both_regularizers(K, y_s, 0, 0) 146 | 147 | for lmb_old in lmbs_old: 148 | for lmb_new in lmbs_new: 149 | a, val = fit_krr_both_regularizers(K, y_s, lmb_new, lmb_old, a_init) 150 | 151 | this_dict = { 152 | 'lmb_old': lmb_old, 153 | 'lmb_new': lmb_new, 154 | 'l2': fit_quality(a) 155 | } 156 | 157 | out_dicts.append(this_dict) 158 | if verbose: 159 | print(this_dict) 160 | 161 | 162 | return { 163 | 'lmbs_new': lmbs_new, 164 | 'lmbs_old': lmbs_old, 165 | 'out_dicts': out_dicts 166 | } -------------------------------------------------------------------------------- /list_of_installed_packages.txt: -------------------------------------------------------------------------------- 1 | Most of these are probably unnecessary except for scikit-learn and numpy/scipy. 2 | The version of virtualenv used was 16.4.3 3 | 4 | attrs==19.1.0 5 | backcall==0.1.0 6 | bleach==3.1.0 7 | bottle==0.12.16 8 | CProfileV==1.0.7 9 | cvxpy==1.0.21 10 | cycler==0.10.0 11 | Cython==0.29.7 12 | decorator==4.4.0 13 | defusedxml==0.5.0 14 | dill==0.2.9 15 | ecos==2.0.7.post1 16 | entrypoints==0.3 17 | fastcache==1.0.2 18 | future==0.17.1 19 | ipykernel==5.1.0 20 | ipython==7.4.0 21 | ipython-genutils==0.2.0 22 | ipywidgets==7.4.2 23 | jedi==0.13.3 24 | Jinja2==2.10.1 25 | jsonschema==3.0.1 26 | jupyter==1.0.0 27 | jupyter-client==5.2.4 28 | jupyter-console==6.0.0 29 | jupyter-core==4.4.0 30 | kiwisolver==1.1.0 31 | MarkupSafe==1.1.1 32 | matplotlib==3.0.3 33 | mistune==0.8.4 34 | Mosek==9.0.88 35 | multiprocess==0.70.7 36 | nbconvert==5.4.1 37 | nbformat==4.4.0 38 | notebook==5.7.8 39 | numpy==1.16.2 40 | osqp==0.5.0 41 | pandas==0.24.2 42 | pandocfilters==1.4.2 43 | parso==0.4.0 44 | pathos==0.2.3 45 | pexpect==4.6.0 46 | pickleshare==0.7.5 47 | Pillow==6.0.0 48 | pox==0.2.5 49 | ppft==1.6.4.9 50 | prometheus-client==0.6.0 51 | prompt-toolkit==2.0.9 52 | ptyprocess==0.6.0 53 | Pygments==2.3.1 54 | pyparsing==2.4.0 55 | pyrsistent==0.14.11 56 | python-dateutil==2.8.0 57 | pytz==2018.9 58 | pyzmq==18.0.1 59 | qtconsole==4.4.3 60 | scikit-learn==0.20.3 61 | scipy==1.2.1 62 | scs==2.1.0 63 | Send2Trash==1.5.0 64 | six==1.12.0 65 | sklearn==0.0 66 | terminado==0.8.2 67 | testpath==0.4.2 68 | torch==1.1.0 69 | torchvision==0.2.2.post3 70 | tornado==6.0.2 71 | traitlets==4.3.2 72 | UNKNOWN==0.0.0 73 | wcwidth==0.1.7 74 | webencodings==0.5.1 75 | widgetsnbextension==3.4.2 76 | --------------------------------------------------------------------------------