├── README.md ├── convex_nn.py ├── convexnn_pytorch_stepsize_fig.py └── plots ├── cifar_multiclass_stepsize_obj.png ├── cifar_multiclass_stepsize_testacc.png └── cifar_multiclass_stepsize_tracc.png /README.md: -------------------------------------------------------------------------------- 1 | # Convex optimization for two-layer ReLU neural networks 2 | 3 | In this repository, we provide two distinct implementations to optimize two-layer ReLU neural networks. Particularly, we utilize the exact convex formulations introduced in [1]. Then, we optimize these equivalent architectures both via the interior point solvers in CVXPY and optimizers in PyTorch. 4 | 5 | Run the following CVXPY based implementation to perform a binary classification task on a toy dataset: 6 | 7 | ```` 8 | python convex_nn.py 9 | ```` 10 | 11 | Run the following PyTorch implementation to perform a ten class classification task on CIFAR-10 (see the plots folder for the training results): 12 | 13 | ```` 14 | python convexnn_pytorch_stepsize_fig.py --GD 0 --CVX 0 --n_epochs 100 100 --solver_cvx sgd 15 | ```` 16 | 17 | [1] M. Pilanci and T. Ergen. Neural Networks are Convex Regularizers: Exact Polynomial-time Convex Optimization Formulations for Two-layer Networks. ICML 2020 (http://proceedings.mlr.press/v119/pilanci20a.html) 18 | -------------------------------------------------------------------------------- /convex_nn.py: -------------------------------------------------------------------------------- 1 | ## This is a basic CVXPY based implementation on a toy dataset for the paper 2 | ## "Neural Networks are Convex Regularizers: Exact Polynomial-time Convex Optimization Formulations for Two-layer Networks" 3 | import numpy as np 4 | import cvxpy as cp 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def relu(x): 9 | return np.maximum(0,x) 10 | def drelu(x): 11 | return x>=0 12 | n=10 13 | d=3 14 | X=np.random.randn(n,d-1) 15 | X=np.append(X,np.ones((n,1)),axis=1) 16 | 17 | y=((np.linalg.norm(X[:,0:d-1],axis=1)>1)-0.5)*2 18 | beta=1e-4 19 | 20 | 21 | dmat=np.empty((n,0)) 22 | 23 | ## Finite approximation of all possible sign patterns 24 | for i in range(int(1e2)): 25 | u=np.random.randn(d,1) 26 | dmat=np.append(dmat,drelu(np.dot(X,u)),axis=1) 27 | 28 | dmat=(np.unique(dmat,axis=1)) 29 | 30 | 31 | # Optimal CVX 32 | m1=dmat.shape[1] 33 | Uopt1=cp.Variable((d,m1)) 34 | Uopt2=cp.Variable((d,m1)) 35 | 36 | ## Below we use hinge loss as a performance metric for binary classification 37 | yopt1=cp.Parameter((n,1)) 38 | yopt2=cp.Parameter((n,1)) 39 | yopt1=cp.sum(cp.multiply(dmat,(X*Uopt1)),axis=1) 40 | yopt2=cp.sum(cp.multiply(dmat,(X*Uopt2)),axis=1) 41 | cost=cp.sum(cp.pos(1-cp.multiply(y,yopt1-yopt2)))/n+beta*(cp.mixed_norm(Uopt1.T,2,1)+cp.mixed_norm(Uopt2.T,2,1)) 42 | constraints=[] 43 | constraints+=[cp.multiply((2*dmat-np.ones((n,m1))),(X*Uopt1))>=0] 44 | constraints+=[cp.multiply((2*dmat-np.ones((n,m1))),(X*Uopt2))>=0] 45 | prob=cp.Problem(cp.Minimize(cost),constraints) 46 | prob.solve() 47 | cvx_opt=prob.value 48 | print("Convex program objective value (eq (8)): ",cvx_opt) 49 | -------------------------------------------------------------------------------- /convexnn_pytorch_stepsize_fig.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dill 3 | import pickle 4 | from datetime import datetime 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | matplotlib.rcParams['pdf.fonttype'] = 42 8 | matplotlib.rcParams['ps.fonttype'] = 42 9 | import time 10 | import scipy 11 | from scipy.sparse.linalg import LinearOperator 12 | import torch 13 | import sklearn.linear_model 14 | from torch.utils.data import Dataset, DataLoader 15 | from torch.autograd import Variable 16 | import torch.nn as nn 17 | import argparse 18 | import random 19 | 20 | 21 | 22 | 23 | def parse_args(): 24 | # Parse arguments 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--GD', nargs=1, type=int, required=True) 27 | parser.add_argument('--CVX', nargs=1, type=int, required=True) 28 | parser.add_argument('--n_epochs', nargs=2, type=int, required=True) 29 | parser.add_argument('--solver_cvx', type=str, nargs=1, default="adam") 30 | parser.add_argument('--seed', type=int, default=42) 31 | args = parser.parse_args() 32 | random.seed(a=args.seed) 33 | np.random.seed(seed=args.seed) 34 | torch.manual_seed(seed=args.seed) 35 | 36 | return args 37 | 38 | ARGS=parse_args() 39 | 40 | # In[2]: 41 | 42 | 43 | class FCNetwork(nn.Module): 44 | def __init__(self, H, num_classes=10, input_dim=3072): 45 | self.num_classes = num_classes 46 | super(FCNetwork, self).__init__() 47 | self.layer1 = nn.Sequential(nn.Linear(input_dim, H, bias=False), nn.ReLU()) 48 | self.layer2 = nn.Linear(H, num_classes, bias=False) 49 | 50 | def forward(self, x): 51 | x = x.reshape(x.size(0), -1) 52 | out = self.layer2(self.layer1(x)) 53 | return out 54 | 55 | # functions for generating sign patterns 56 | def check_if_already_exists(element_list, element): 57 | # check if element exists in element_list 58 | # where element is a numpy array 59 | for i in range(len(element_list)): 60 | if np.array_equal(element_list[i], element): 61 | return True 62 | return False 63 | 64 | class PrepareData(Dataset): 65 | def __init__(self, X, y): 66 | if not torch.is_tensor(X): 67 | self.X = torch.from_numpy(X) 68 | else: 69 | self.X = X 70 | 71 | if not torch.is_tensor(y): 72 | self.y = torch.from_numpy(y) 73 | else: 74 | self.y = y 75 | 76 | def __len__(self): 77 | return len(self.X) 78 | 79 | def __getitem__(self, idx): 80 | return self.X[idx], self.y[idx] 81 | 82 | 83 | class PrepareData3D(Dataset): 84 | def __init__(self, X, y, z): 85 | if not torch.is_tensor(X): 86 | self.X = torch.from_numpy(X) 87 | else: 88 | self.X = X 89 | 90 | if not torch.is_tensor(y): 91 | self.y = torch.from_numpy(y) 92 | else: 93 | self.y = y 94 | 95 | if not torch.is_tensor(z): 96 | self.z = torch.from_numpy(z) 97 | else: 98 | self.z = z 99 | 100 | 101 | def __len__(self): 102 | return len(self.X) 103 | 104 | def __getitem__(self, idx): 105 | return self.X[idx], self.y[idx], self.z[idx] 106 | 107 | def generate_conv_sign_patterns(A2, P, verbose=False): 108 | # generate convolutional sign patterns 109 | n, c, p1, p2 = A2.shape 110 | A = A2.reshape(n,int(c*p1*p2)) 111 | fsize=9*c 112 | d=c*p1*p2; 113 | fs=int(np.sqrt(9)) 114 | unique_sign_pattern_list = [] 115 | u_vector_list = [] 116 | 117 | for i in range(P): 118 | # obtain a sign pattern 119 | ind1=np.random.randint(0,p1-fs+1) 120 | ind2=np.random.randint(0,p2-fs+1) 121 | u1p= np.zeros((c,p1,p2)) 122 | u1p[:,ind1:ind1+fs,ind2:ind2+fs]=np.random.normal(0, 1, (fsize,1)).reshape(c,fs,fs) 123 | u1=u1p.reshape(d,1) 124 | sampled_sign_pattern = (np.matmul(A, u1) >= 0)[:,0] 125 | unique_sign_pattern_list.append(sampled_sign_pattern) 126 | u_vector_list.append(u1) 127 | 128 | if verbose: 129 | print("Number of unique sign patterns generated: " + str(len(unique_sign_pattern_list))) 130 | return len(unique_sign_pattern_list),unique_sign_pattern_list, u_vector_list 131 | 132 | 133 | def generate_sign_patterns(A, P, verbose=False): 134 | # generate sign patterns 135 | n, d = A.shape 136 | sign_pattern_list = [] # sign patterns 137 | u_vector_list = [] # random vectors used to generate the sign paterns 138 | umat = np.random.normal(0, 1, (d,P)) 139 | sampled_sign_pattern_mat = (np.matmul(A, umat) >= 0) 140 | for i in range(P): 141 | sampled_sign_pattern = sampled_sign_pattern_mat[:,i] 142 | sign_pattern_list.append(sampled_sign_pattern) 143 | u_vector_list.append(umat[:,i]) 144 | if verbose: 145 | print("Number of sign patterns generated: " + str(len(sign_pattern_list))) 146 | return len(sign_pattern_list),sign_pattern_list, u_vector_list 147 | 148 | def one_hot(labels, num_classes=10): 149 | y = torch.eye(num_classes) 150 | return y[labels.long()] 151 | 152 | 153 | 154 | #=====================================STANDARD NON-CONVEX NETWORK===================================== 155 | 156 | 157 | def loss_func_primal(yhat, y, model, beta): 158 | loss = 0.5 * torch.norm(yhat - y)**2 159 | 160 | ## l2 norm on first layer weights, l1 squared norm on second layer 161 | for layer, p in enumerate(model.parameters()): 162 | if layer == 0: 163 | loss += beta/2 * torch.norm(p)**2 164 | else: 165 | loss += beta/2 * sum([torch.norm(p[:, j], 1)**2 for j in range(p.shape[1])]) 166 | 167 | return loss 168 | 169 | def validation_primal(model, testloader, beta, device): 170 | test_loss = 0 171 | test_correct = 0 172 | 173 | for ix, (_x, _y) in enumerate(testloader): 174 | _x = Variable(_x).float().to(device) 175 | _y = Variable(_y).float().to(device) 176 | 177 | output = model.forward(_x) 178 | yhat = model(_x).float() 179 | 180 | loss = loss_func_primal(yhat, one_hot(_y).to(device), model, beta) 181 | 182 | test_loss += loss.item() 183 | test_correct += torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum() 184 | 185 | return test_loss, test_correct 186 | 187 | # solves nonconvex problem 188 | def sgd_solver_pytorch_v2(ds, ds_test, num_epochs, num_neurons, beta, 189 | learning_rate, batch_size, solver_type, schedule, 190 | LBFGS_param, verbose=False, 191 | num_classes=10, D_in=3*1024, test_len=10000, 192 | train_len=50000, device='cuda'): 193 | 194 | device = torch.device(device) 195 | # D_in is input dimension, H is hidden dimension, D_out is output dimension. 196 | H, D_out = num_neurons, num_classes 197 | # create the model 198 | model = FCNetwork(H, D_out, D_in).to(device) 199 | 200 | if solver_type == "sgd": 201 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9) 202 | elif solver_type == "adam": 203 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)#, 204 | elif solver_type == "adagrad": 205 | optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)#, 206 | elif solver_type == "adadelta": 207 | optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)#, 208 | elif solver_type == "LBFGS": 209 | optimizer = torch.optim.LBFGS(model.parameters(), history_size=LBFGS_param[0], max_iter=LBFGS_param[1])#, 210 | 211 | # arrays for saving the loss and accuracy 212 | losses = np.zeros((int(num_epochs*np.ceil(train_len / batch_size)))) 213 | accs = np.zeros(losses.shape) 214 | losses_test = np.zeros((num_epochs+1)) 215 | accs_test = np.zeros((num_epochs+1)) 216 | times = np.zeros((losses.shape[0]+1)) 217 | times[0] = time.time() 218 | 219 | losses_test[0], accs_test[0] = validation_primal(model, ds_test, beta, device) # loss on the entire test set 220 | 221 | if schedule==1: 222 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 223 | verbose=verbose, 224 | factor=0.5, 225 | eps=1e-12) 226 | elif schedule==2: 227 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99) 228 | 229 | iter_no = 0 230 | for i in range(num_epochs): 231 | for ix, (_x, _y) in enumerate(ds): 232 | #=========make input differentiable======================= 233 | _x = Variable(_x).to(device) 234 | _y = Variable(_y).to(device) 235 | 236 | #========forward pass===================================== 237 | yhat = model(_x).float() 238 | 239 | loss = loss_func_primal(yhat, one_hot(_y).to(device), model, beta)/len(_y) 240 | correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()/len(_y) 241 | 242 | 243 | optimizer.zero_grad() # zero the gradients on each pass before the update 244 | loss.backward() # backpropagate the loss through the model 245 | optimizer.step() # update the gradients w.r.t the loss 246 | 247 | losses[iter_no] = loss.item() # loss on the minibatch 248 | accs[iter_no] = correct 249 | 250 | iter_no += 1 251 | times[iter_no] = time.time() 252 | 253 | # get test loss and accuracy 254 | losses_test[i+1], accs_test[i+1] = validation_primal(model, ds_test, beta, device) # loss on the entire test set 255 | 256 | if i % 1 == 0: 257 | print("Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}".format(i, num_epochs, 258 | np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1], 3), 259 | np.round(losses_test[i+1], 3)/test_len, np.round(accs_test[i+1]/test_len, 3))) 260 | if schedule>0: 261 | scheduler.step(losses[iter_no-1]) 262 | 263 | return losses, accs, losses_test/test_len, accs_test/test_len, times, model 264 | 265 | 266 | #=====================================CONVEX NETWORK===================================== 267 | 268 | 269 | class custom_cvx_layer(torch.nn.Module): 270 | def __init__(self, d, num_neurons, num_classes=10): 271 | """ 272 | In the constructor we instantiate two nn.Linear modules and assign them as 273 | member variables. 274 | """ 275 | super(custom_cvx_layer, self).__init__() 276 | 277 | # P x d x C 278 | self.v = torch.nn.Parameter(data=torch.zeros(num_neurons, d, num_classes), requires_grad=True) 279 | self.w = torch.nn.Parameter(data=torch.zeros(num_neurons, d, num_classes), requires_grad=True) 280 | 281 | def forward(self, x, sign_patterns): 282 | sign_patterns = sign_patterns.unsqueeze(2) 283 | x = x.view(x.shape[0], -1) # n x d 284 | 285 | Xv_w = torch.matmul(x, self.v - self.w) # P x N x C 286 | 287 | # for some reason, the permutation is necessary. not sure why 288 | DXv_w = torch.mul(sign_patterns, Xv_w.permute(1, 0, 2)) # N x P x C 289 | y_pred = torch.sum(DXv_w, dim=1, keepdim=False) # N x C 290 | 291 | return y_pred 292 | 293 | def get_nonconvex_cost(y, model, _x, beta, device): 294 | _x = _x.view(_x.shape[0], -1) 295 | Xv = torch.matmul(_x, model.v) 296 | Xw = torch.matmul(_x, model.w) 297 | Xv_relu = torch.max(Xv, torch.Tensor([0]).to(device)) 298 | Xw_relu = torch.max(Xw, torch.Tensor([0]).to(device)) 299 | 300 | prediction_w_relu = torch.sum(Xv_relu - Xw_relu, dim=0, keepdim=False) 301 | prediction_cost = 0.5 * torch.norm(prediction_w_relu - y)**2 302 | 303 | regularization_cost = beta * (torch.sum(torch.norm(model.v, dim=1)**2) + torch.sum(torch.norm(model.w, p=1, dim=1)**2)) 304 | 305 | return prediction_cost + regularization_cost 306 | def loss_func_cvxproblem(yhat, y, model, _x, sign_patterns, beta, rho, device): 307 | _x = _x.view(_x.shape[0], -1) 308 | 309 | # term 1 310 | loss = 0.5 * torch.norm(yhat - y)**2 311 | # term 2 312 | loss = loss + beta * torch.sum(torch.norm(model.v, dim=1)) 313 | loss = loss + beta * torch.sum(torch.norm(model.w, dim=1)) 314 | 315 | # term 3 316 | sign_patterns = sign_patterns.unsqueeze(2) # N x P x 1 317 | 318 | Xv = torch.matmul(_x, torch.sum(model.v, dim=2, keepdim=True)) # N x d times P x d x 1 -> P x N x 1 319 | DXv = torch.mul(sign_patterns, Xv.permute(1, 0, 2)) # P x N x 1 320 | relu_term_v = torch.max(-2*DXv + Xv.permute(1, 0, 2), torch.Tensor([0]).to(device)) 321 | loss = loss + rho * torch.sum(relu_term_v) 322 | 323 | Xw = torch.matmul(_x, torch.sum(model.w, dim=2, keepdim=True)) 324 | DXw = torch.mul(sign_patterns, Xw.permute(1, 0, 2)) 325 | relu_term_w = torch.max(-2*DXw + Xw.permute(1, 0, 2), torch.Tensor([0]).to(device)) 326 | loss = loss + rho * torch.sum(relu_term_w) 327 | 328 | return loss 329 | 330 | def validation_cvxproblem(model, testloader, u_vectors, beta, rho, device): 331 | test_loss = 0 332 | test_correct = 0 333 | test_noncvx_cost = 0 334 | 335 | with torch.no_grad(): 336 | for ix, (_x, _y) in enumerate(testloader): 337 | _x = Variable(_x).to(device) 338 | _y = Variable(_y).to(device) 339 | _x = _x.view(_x.shape[0], -1) 340 | _z = (torch.matmul(_x, torch.from_numpy(u_vectors).float().to(device)) >= 0) 341 | 342 | output = model.forward(_x, _z) 343 | yhat = model(_x, _z).float() 344 | 345 | loss = loss_func_cvxproblem(yhat, one_hot(_y).to(device), model, _x, _z, beta, rho, device) 346 | 347 | test_loss += loss.item() 348 | test_correct += torch.eq(torch.argmax(yhat, dim=1), _y).float().sum() 349 | 350 | test_noncvx_cost += get_nonconvex_cost(one_hot(_y).to(device), model, _x, beta, device) 351 | 352 | return test_loss, test_correct, test_noncvx_cost 353 | def sgd_solver_cvxproblem(ds, ds_test, num_epochs, num_neurons, beta, 354 | learning_rate, batch_size, rho, u_vectors, 355 | solver_type, LBFGS_param, verbose=False, 356 | n=60000, d=3072, num_classes=10, device='cpu'): 357 | device = torch.device(device) 358 | 359 | # create the model 360 | model = custom_cvx_layer(d, num_neurons, num_classes).to(device) 361 | 362 | if solver_type == "sgd": 363 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9) 364 | elif solver_type == "adam": 365 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)#, 366 | elif solver_type == "adagrad": 367 | optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)#, 368 | elif solver_type == "adadelta": 369 | optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)#, 370 | elif solver_type == "LBFGS": 371 | optimizer = torch.optim.LBFGS(model.parameters(), history_size=LBFGS_param[0], max_iter=LBFGS_param[1])#, 372 | 373 | # arrays for saving the loss and accuracy 374 | losses = np.zeros((int(num_epochs*np.ceil(n / batch_size)))) 375 | accs = np.zeros(losses.shape) 376 | noncvx_losses = np.zeros(losses.shape) 377 | 378 | losses_test = np.zeros((num_epochs+1)) 379 | accs_test = np.zeros((num_epochs+1)) 380 | noncvx_losses_test = np.zeros((num_epochs+1)) 381 | 382 | times = np.zeros((losses.shape[0]+1)) 383 | times[0] = time.time() 384 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 385 | verbose=verbose, 386 | factor=0.5, 387 | eps=1e-12) 388 | 389 | model.eval() 390 | losses_test[0], accs_test[0], noncvx_losses_test[0] = validation_cvxproblem(model, ds_test, u_vectors, beta, rho, device) # loss on the entire test set 391 | 392 | iter_no = 0 393 | print('starting training') 394 | for i in range(num_epochs): 395 | model.train() 396 | for ix, (_x, _y, _z) in enumerate(ds): 397 | #=========make input differentiable======================= 398 | _x = Variable(_x).to(device) 399 | _y = Variable(_y).to(device) 400 | _z = Variable(_z).to(device) 401 | 402 | #========forward pass===================================== 403 | yhat = model(_x, _z).float() 404 | 405 | loss = loss_func_cvxproblem(yhat, one_hot(_y).to(device), model, _x,_z, beta, rho, device)/len(_y) 406 | correct = torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()/len(_y) # accuracy 407 | #=======backward pass===================================== 408 | optimizer.zero_grad() # zero the gradients on each pass before the update 409 | loss.backward() # backpropagate the loss through the model 410 | optimizer.step() # update the gradients w.r.t the loss 411 | 412 | losses[iter_no] = loss.item() # loss on the minibatch 413 | accs[iter_no] = correct 414 | noncvx_losses[iter_no] = get_nonconvex_cost(one_hot(_y).to(device), model, _x, beta, device)/len(_y) 415 | 416 | iter_no += 1 417 | times[iter_no] = time.time() 418 | 419 | model.eval() 420 | # get test loss and accuracy 421 | losses_test[i+1], accs_test[i+1], noncvx_losses_test[i+1] = validation_cvxproblem(model, ds_test, u_vectors, beta, rho, device) # loss on the entire test set 422 | 423 | if i % 1 == 0: 424 | print("Epoch [{}/{}], TRAIN: noncvx/cvx loss: {}, {} acc: {}. TEST: noncvx/cvx loss: {}, {} acc: {}".format(i, num_epochs, 425 | np.round(noncvx_losses[iter_no-1], 3), np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1], 3), 426 | np.round(noncvx_losses_test[i+1], 3)/10000, np.round(losses_test[i+1], 3)/10000, np.round(accs_test[i+1]/10000, 3))) 427 | 428 | scheduler.step(losses[iter_no-1]) 429 | 430 | return noncvx_losses, accs, noncvx_losses_test/10000, accs_test/10000, times, losses, losses_test/10000 431 | 432 | 433 | 434 | 435 | 436 | # cifar-10 -- using the version downloaded from "http://www.cs.toronto.edu/~kriz/cifar.html" 437 | import os 438 | directory = os.path.dirname(os.path.realpath(__file__)) 439 | 440 | import torchvision.datasets as datasets 441 | import torchvision.transforms as transforms 442 | 443 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 444 | 445 | train_dataset = datasets.CIFAR10( 446 | directory, train=True, download=True, 447 | transform=transforms.Compose([ 448 | transforms.ToTensor(), 449 | normalize, 450 | ])) 451 | 452 | test_dataset = datasets.CIFAR10( 453 | directory, train=False, download=True, 454 | transform=transforms.Compose([ 455 | transforms.ToTensor(), 456 | normalize, 457 | ])) 458 | 459 | 460 | 461 | # data extraction 462 | print('Extracting the data') 463 | dummy_loader= torch.utils.data.DataLoader( 464 | train_dataset, batch_size=50000, shuffle=False, 465 | pin_memory=True, sampler=None) 466 | for A, y in dummy_loader: 467 | pass 468 | Apatch=A.detach().clone() 469 | 470 | A = A.view(A.shape[0], -1) 471 | n,d=A.size() 472 | 473 | 474 | 475 | # problem parameters 476 | P, verbose = 4096, True # SET verbose to True to see progress 477 | GD_only=ARGS.GD[0] 478 | CVX_only=ARGS.CVX[0] 479 | beta = 1e-3 # regularization parameter 480 | num_epochs1, batch_size = ARGS.n_epochs[0], 1000 # 481 | num_neurons = P # number of neurons is equal to number of hyperplane arrangements 482 | 483 | 484 | # create dataloaders 485 | train_loader = torch.utils.data.DataLoader( 486 | train_dataset, batch_size=batch_size, shuffle=True, 487 | pin_memory=True, sampler=None) 488 | 489 | test_loader = torch.utils.data.DataLoader( 490 | test_dataset, batch_size=1000, shuffle=False, 491 | pin_memory=True) 492 | 493 | 494 | 495 | 496 | 497 | # SGD solver for the nonconvex problem 498 | if CVX_only==0: 499 | 500 | solver_type = "sgd" # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS" 501 | schedule=0 # learning rate schedule (0: Nothing, 1: ReduceLROnPlateau, 2: ExponentialLR) 502 | LBFGS_param = [10, 4] # these parameters are for the LBFGS solver 503 | learning_rate = 1e-2 504 | 505 | ## SGD1 constant 506 | print('SGD1-training-mu={}'.format(learning_rate)) 507 | results_noncvx_sgd1 = sgd_solver_pytorch_v2(train_loader, test_loader, num_epochs1, num_neurons, beta, 508 | learning_rate, batch_size, solver_type, schedule, 509 | LBFGS_param, verbose=True, 510 | num_classes=10, D_in=d, train_len=n ) 511 | 512 | 513 | ## SGD2 constant 514 | learning_rate = 5e-3 515 | print('SGD2-training-mu={}'.format(learning_rate)) 516 | results_noncvx_sgd2 = sgd_solver_pytorch_v2(train_loader, test_loader, num_epochs1, num_neurons, beta, 517 | learning_rate, batch_size, solver_type, schedule, 518 | LBFGS_param, verbose=True, 519 | num_classes=10, D_in=d, train_len=n ) 520 | 521 | ## SGD3 constant 522 | learning_rate = 1e-3 523 | print('SGD3-training-mu={}'.format(learning_rate)) 524 | results_noncvx_sgd3 = sgd_solver_pytorch_v2(train_loader, test_loader, num_epochs1, num_neurons, beta, 525 | learning_rate, batch_size, solver_type, schedule, 526 | LBFGS_param, verbose=True, 527 | num_classes=10, D_in=d, train_len=n ) 528 | 529 | 530 | # Solver for the convex problem 531 | if GD_only ==0: 532 | 533 | rho = 1e-2 # coefficient to penalize the violated constraints 534 | solver_type = ARGS.solver_cvx[0] # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS" 535 | LBFGS_param = [10, 4] 536 | batch_size = 1000 537 | num_epochs2, batch_size = ARGS.n_epochs[1], 1000 538 | 539 | 540 | 541 | # Convex 542 | print('Generating sign patterns') 543 | num_neurons,sign_pattern_list, u_vector_list = generate_sign_patterns(A, P, verbose) 544 | sign_patterns = np.array([sign_pattern_list[i].int().data.numpy() for i in range(num_neurons)]) 545 | u_vectors = np.asarray(u_vector_list).reshape((num_neurons, A.shape[1])).T 546 | 547 | ds_train = PrepareData3D(X=A, y=y, z=sign_patterns.T) 548 | ds_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True) 549 | 550 | test_loader = torch.utils.data.DataLoader( 551 | test_dataset, batch_size=1000, shuffle=False, 552 | pin_memory=True) 553 | 554 | 555 | # Convex1 556 | learning_rate = 1e-6 # 1e-6 for sgd 557 | print('Convex Random1-mu={}'.format(learning_rate)) 558 | results_cvx1 = sgd_solver_cvxproblem(ds_train, test_loader, num_epochs2, num_neurons, beta, 559 | learning_rate, batch_size, rho, u_vectors, solver_type, LBFGS_param, verbose=True, 560 | n=n, device='cuda') 561 | 562 | # Convex2 563 | learning_rate = 5e-7 # 1e-6 for sgd 564 | print('Convex Random2-mu={}'.format(learning_rate)) 565 | results_cvx2 = sgd_solver_cvxproblem(ds_train, test_loader, num_epochs2, num_neurons, beta, 566 | learning_rate, batch_size, rho, u_vectors, solver_type, LBFGS_param, verbose=True, 567 | n=n, device='cuda') 568 | 569 | 570 | 571 | # Convex with convolutional patterns 572 | print('Generating conv sign patterns') 573 | num_neurons,sign_pattern_list, u_vector_list = generate_conv_sign_patterns(Apatch, P, verbose) 574 | sign_patterns = np.array([sign_pattern_list[i].int().data.numpy() for i in range(num_neurons)]) 575 | u_vectors = np.asarray(u_vector_list).reshape((num_neurons, A.shape[1])).T 576 | 577 | ds_train = PrepareData3D(X=A, y=y, z=sign_patterns.T) 578 | ds_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True) 579 | 580 | # Convex Conv1 581 | learning_rate = 1e-6 582 | print('Convex Conv1-mu={}'.format(learning_rate)) 583 | results_cvx_conv1 = sgd_solver_cvxproblem(ds_train, test_loader, num_epochs2, num_neurons, beta, 584 | learning_rate, batch_size, rho, u_vectors, solver_type, LBFGS_param, verbose=True, 585 | n=n, device='cuda') 586 | 587 | # Convex Conv2 588 | learning_rate = 5e-7 589 | print('Convex Conv2-mu={}'.format(learning_rate)) 590 | results_cvx_conv2 = sgd_solver_cvxproblem(ds_train, test_loader, num_epochs2, num_neurons, beta, 591 | learning_rate, batch_size, rho, u_vectors, solver_type, LBFGS_param, verbose=True, 592 | n=n, device='cuda') 593 | 594 | 595 | # plots and saves the results 596 | import pickle 597 | from datetime import datetime 598 | now = datetime.now() 599 | if GD_only==1 and CVX_only==0: 600 | 601 | 602 | 603 | results_noncvx_sgd1v2=results_noncvx_sgd1[:5] 604 | results_noncvx_sgd2v2=results_noncvx_sgd2[:5] 605 | results_noncvx_sgd3v2=results_noncvx_sgd3[:5] 606 | 607 | 608 | 609 | 610 | print('Saving the objects') 611 | torch.save([num_epochs1,results_noncvx_sgd1v2, results_noncvx_sgd2v2, results_noncvx_sgd3v2 612 | ],'results_fig_gdonly_stepsize_cifar10_'+now.strftime("%d-%m-%Y_%H-%M-%S")+'.pt') 613 | 614 | 615 | 616 | elif GD_only==0 and CVX_only==1: 617 | 618 | print('Saving the objects') 619 | torch.save([num_epochs2, results_cvx1,results_cvx2, 620 | results_cvx_conv1,results_cvx_conv2],'results_fig_cvxonly_stepsize_cifar10_'+now.strftime("%d-%m-%Y_%H-%M-%S")+'.pt') 621 | 622 | else: 623 | 624 | results_noncvx_sgd1v2=results_noncvx_sgd1[:5] 625 | results_noncvx_sgd2v2=results_noncvx_sgd2[:5] 626 | results_noncvx_sgd3v2=results_noncvx_sgd3[:5] 627 | print('Saving the objects') 628 | torch.save([num_epochs1,num_epochs2,results_noncvx_sgd1v2, results_noncvx_sgd2v2, results_noncvx_sgd3v2, results_cvx1,results_cvx2, 629 | results_cvx_conv1,results_cvx_conv2],'results_fig_all_stepsize_cifar10_'+now.strftime("%d-%m-%Y_%H-%M-%S")+'.pt') 630 | 631 | import matplotlib.pyplot as plt 632 | 633 | skip=1#int(num_epochs1/num_epochs2) 634 | mark_sgd=10 635 | mark_cvx=30 636 | 637 | marker_size_sgd=10 638 | marker_size_cvx=12 639 | 640 | 641 | plt.gcf().set_facecolor("white") 642 | #fig,ax = plt.subplots() 643 | 644 | # plot 645 | fsize=24 646 | fsize_legend=15 647 | 648 | plt.rcParams.update({'font.size': 24}) 649 | plt.xlabel('Time(s)',fontsize=fsize); plt.grid() 650 | 651 | plot_no = 1 # select --> 0: cost, 1: accuracy 652 | 653 | 654 | 655 | num_all_iters1 = results_noncvx_sgd1v2[4].shape[0] - 1 656 | num_all_iters2 = results_cvx1[4].shape[0] - 1 657 | 658 | iters_per_epoch1 = num_all_iters1 // num_epochs1 659 | iters_per_epoch2 = num_all_iters2 // num_epochs2 660 | 661 | epoch_times_noncvx1 = results_noncvx_sgd1v2[4][0:num_all_iters1+1:iters_per_epoch1]-results_noncvx_sgd1v2[4][0] 662 | epoch_times_noncvx2 = results_noncvx_sgd2v2[4][0:num_all_iters1+1:iters_per_epoch1]-results_noncvx_sgd2v2[4][0] 663 | epoch_times_noncvx3 = results_noncvx_sgd3v2[4][0:num_all_iters1+1:iters_per_epoch1]-results_noncvx_sgd3v2[4][0] 664 | 665 | 666 | epoch_times_cvx1 = results_cvx1[4][0:num_all_iters2+1:iters_per_epoch2]-results_cvx1[4][0] 667 | epoch_times_cvx2 = results_cvx2[4][0:num_all_iters2+1:iters_per_epoch2]-results_cvx2[4][0] 668 | 669 | epoch_times_cvx_conv1= results_cvx_conv1[4][0:num_all_iters2+1:iters_per_epoch2]-results_cvx_conv1[4][0] 670 | epoch_times_cvx_conv2= results_cvx_conv2[4][0:num_all_iters2+1:iters_per_epoch2]-results_cvx_conv2[4][0] 671 | 672 | 673 | plt.grid() 674 | 675 | # To plot results in the validation set 676 | plt.plot( epoch_times_noncvx1[::skip],results_noncvx_sgd1v2[plot_no+2][::skip],'--', color='darkred', markevery=mark_sgd,linewidth=3.0, markersize=marker_size_sgd,label="SGD-$\mu=1e-2$") 677 | plt.plot( epoch_times_noncvx2[::skip],results_noncvx_sgd2v2[plot_no+2][::skip],'--', color='red', markevery=mark_sgd,linewidth=3.0, markersize=marker_size_sgd,label="SGD-$\mu=5e-3$") 678 | plt.plot( epoch_times_noncvx3[::skip],results_noncvx_sgd3v2[plot_no+2][::skip],'--', color='lightcoral', markevery=mark_sgd,linewidth=3.0, markersize=marker_size_sgd,label="SGD-$\mu=1e-3$") 679 | 680 | 681 | plt.plot( epoch_times_cvx1,results_cvx1[plot_no+2], 'o--', color='g', markevery=mark_cvx,linewidth=3.0, markersize=marker_size_cvx,label="Convex-Random-$\mu=1e-6$") 682 | plt.plot( epoch_times_cvx2,results_cvx2[plot_no+2], 'o--', color='lime', markevery=mark_cvx,linewidth=3.0, markersize=marker_size_cvx,label="Convex-Random-$\mu=5e-7$") 683 | 684 | plt.plot( epoch_times_cvx_conv1,results_cvx_conv1[plot_no+2], 'o--', color='b', markevery=mark_cvx,linewidth=3.0, markersize=marker_size_cvx,label="Convex-Conv-$\mu=1e-6$") 685 | plt.plot( epoch_times_cvx_conv2,results_cvx_conv1[plot_no+2], 'o--', color='lightblue', markevery=mark_cvx,linewidth=3.0, markersize=marker_size_cvx,label="Convex-Conv-$\mu=5e-7$") 686 | 687 | 688 | plt.legend(prop={'size': fsize_legend}) 689 | plt.ylabel("Test Accuracy",fontsize=fsize) 690 | plt.ylim(0.3, 0.6) 691 | plt.xlim(0, 4500) 692 | 693 | plt.grid() 694 | plt.savefig('cifar_multiclass_stepsize_testacc.png', format='png', bbox_inches='tight') 695 | 696 | 697 | plt.figure() 698 | # To plot training acc 699 | 700 | plt.xlabel('Time(s)',fontsize=fsize) 701 | plt.grid() 702 | 703 | p11=results_noncvx_sgd1v2[1].reshape(-1,1) 704 | p12=results_noncvx_sgd2v2[1].reshape(-1,1) 705 | p13=results_noncvx_sgd3v2[1].reshape(-1,1) 706 | 707 | p21=results_cvx1[1].reshape(-1,1) 708 | p22=results_cvx2[1].reshape(-1,1) 709 | 710 | p31=results_cvx_conv1[1].reshape(-1,1) 711 | p32=results_cvx_conv2[1].reshape(-1,1) 712 | 713 | 714 | 715 | n=50000 716 | batch_size1=1000 717 | batch_size2=1000 718 | 719 | plt.plot(epoch_times_noncvx1[:-1][::skip],p11[np.arange(num_epochs1)*int(n/batch_size1)][::skip],'-',color='darkred', markevery=mark_sgd,linewidth=3, markersize=marker_size_sgd,label="SGD-$\mu=1e-2$") 720 | plt.plot(epoch_times_noncvx2[:-1][::skip],p12[np.arange(num_epochs1)*int(n/batch_size1)][::skip],'-',color='red', markevery=mark_sgd,linewidth=3, markersize=marker_size_sgd,label="SGD-$\mu=5e-2$") 721 | plt.plot(epoch_times_noncvx3[:-1][::skip],p13[np.arange(num_epochs1)*int(n/batch_size1)][::skip],'-',color='lightcoral', markevery=mark_sgd,linewidth=3, markersize=marker_size_sgd,label="SGD-$\mu=1e-3$") 722 | 723 | plt.plot( epoch_times_cvx1[:-1],p21[np.arange(num_epochs2)*int(n/batch_size2)] ,'o-',color='g', markevery=mark_cvx,linewidth=3, markersize=marker_size_cvx,label="Convex-Random-$\mu=1e-6$") 724 | plt.plot( epoch_times_cvx2[:-1],p22[np.arange(num_epochs2)*int(n/batch_size2)] ,'o-',color='lime', markevery=mark_cvx,linewidth=3, markersize=marker_size_cvx,label="Convex-Random-$\mu=5e-7$") 725 | 726 | plt.plot( epoch_times_cvx_conv1[:-1],p31[np.arange(num_epochs2)*int(n/batch_size2)] ,'o-', color='b',markevery=mark_cvx,linewidth=3, markersize=marker_size_cvx,label="Convex-Conv-$\mu=1e-6$") 727 | plt.plot( epoch_times_cvx_conv2[:-1],p32[np.arange(num_epochs2)*int(n/batch_size2)] ,'o-', color='lightblue',markevery=mark_cvx,linewidth=3, markersize=marker_size_cvx,label="Convex-Conv-$\mu=5e-7$") 728 | 729 | 730 | plt.xlim(0, 4500) 731 | 732 | plt.ylabel("Training Accuracy",fontsize=fsize) 733 | plt.grid() 734 | matplotlib.pyplot.grid(True, which="both") 735 | plt.savefig('cifar_multiclass_stepsize_tracc.png', format='png', bbox_inches='tight') 736 | 737 | 738 | # To plot training loss 739 | 740 | plt.figure() 741 | 742 | plt.xlabel('Time(s)',fontsize=fsize) 743 | plt.grid() 744 | p11=results_noncvx_sgd1v2[0].reshape(-1,1) 745 | p12=results_noncvx_sgd2v2[0].reshape(-1,1) 746 | p13=results_noncvx_sgd3v2[0].reshape(-1,1) 747 | 748 | p21=results_cvx1[5].reshape(-1,1) 749 | p22=results_cvx2[5].reshape(-1,1) 750 | 751 | p31=results_cvx_conv1[5].reshape(-1,1) 752 | p32=results_cvx_conv2[5].reshape(-1,1) 753 | 754 | 755 | 756 | n=50000 757 | batch_size1=1000 758 | batch_size2=1000 759 | 760 | plt.semilogy(epoch_times_noncvx1[:-1][::skip],p11[np.arange(num_epochs1)*int(n/batch_size1)][::skip],'-',color='darkred', markevery=mark_sgd,linewidth=3, markersize=marker_size_sgd,label="SGD-$\mu=1e-2$") 761 | plt.semilogy(epoch_times_noncvx2[:-1][::skip],p12[np.arange(num_epochs1)*int(n/batch_size1)][::skip],'-',color='red', markevery=mark_sgd,linewidth=3, markersize=marker_size_sgd,label="SGD-$\mu=5e-2$") 762 | plt.semilogy(epoch_times_noncvx3[:-1][::skip],p13[np.arange(num_epochs1)*int(n/batch_size1)][::skip],'-',color='lightcoral', markevery=mark_sgd,linewidth=3, markersize=marker_size_sgd,label="SGD-$\mu=1e-3$") 763 | 764 | plt.semilogy( epoch_times_cvx1[:-1],p21[np.arange(num_epochs2)*int(n/batch_size2)] ,'o-',color='g', markevery=mark_cvx,linewidth=3, markersize=marker_size_cvx,label="Convex-Random-$\mu=1e-6$") 765 | plt.semilogy( epoch_times_cvx2[:-1],p22[np.arange(num_epochs2)*int(n/batch_size2)] ,'o-',color='lime', markevery=mark_cvx,linewidth=3, markersize=marker_size_cvx,label="Convex-Random-$\mu=5e-7$") 766 | 767 | plt.semilogy( epoch_times_cvx_conv1[:-1],p31[np.arange(num_epochs2)*int(n/batch_size2)] ,'o-', color='b',markevery=mark_cvx,linewidth=3, markersize=marker_size_cvx,label="Convex-Conv-$\mu=1e-6$") 768 | plt.semilogy( epoch_times_cvx_conv2[:-1],p32[np.arange(num_epochs2)*int(n/batch_size2)] ,'o-', color='lightblue',markevery=mark_cvx,linewidth=3, markersize=marker_size_cvx,label="Convex-Conv-$\mu=5e-7$") 769 | 770 | 771 | plt.xlim(0, 4500) 772 | 773 | plt.ylabel("Objective Value",fontsize=fsize) 774 | plt.grid() 775 | matplotlib.pyplot.grid(True, which="both") 776 | plt.savefig('cifar_multiclass_stepsize_obj.png', format='png', bbox_inches='tight') 777 | 778 | 779 | 780 | 781 | 782 | -------------------------------------------------------------------------------- /plots/cifar_multiclass_stepsize_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pilancilab/convex_nn/e401184311dafbfa5ef9196941d3ddf003823fa4/plots/cifar_multiclass_stepsize_obj.png -------------------------------------------------------------------------------- /plots/cifar_multiclass_stepsize_testacc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pilancilab/convex_nn/e401184311dafbfa5ef9196941d3ddf003823fa4/plots/cifar_multiclass_stepsize_testacc.png -------------------------------------------------------------------------------- /plots/cifar_multiclass_stepsize_tracc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pilancilab/convex_nn/e401184311dafbfa5ef9196941d3ddf003823fa4/plots/cifar_multiclass_stepsize_tracc.png --------------------------------------------------------------------------------