├── l2ws ├── __init__.py ├── utils │ ├── __init__.py │ ├── data_utils.py │ ├── nn_utils.py │ ├── generic_utils.py │ └── portfolio_utils.py ├── examples │ ├── __init__.py │ ├── lasso.py │ ├── jamming.py │ ├── unconstrained_qp.py │ ├── sparse_pca.py │ ├── phase_retrieval.py │ ├── markowitz.py │ ├── robust_ls.py │ └── solve_script.py ├── gd_model.py ├── ista_model.py ├── eg_model.py ├── l2ws_helper_fns.py ├── scs_problem.py ├── osqp_model.py └── scs_model.py ├── benchmarks ├── configs │ ├── robust_ls │ │ ├── robust_ls_agg.yaml │ │ ├── robust_ls_setup.yaml │ │ ├── robust_ls_plot.yaml │ │ └── robust_ls_run.yaml │ ├── robust_pca │ │ ├── robust_pca_agg.yaml │ │ ├── robust_pca_setup.yaml │ │ ├── robust_pca_plot.yaml │ │ ├── robust_pca_gif.yaml │ │ └── robust_pca_run.yaml │ ├── vehicle │ │ ├── vehicle_agg.yaml │ │ ├── vehicle_plot.yaml │ │ ├── vehicle_setup.yaml │ │ └── vehicle_run.yaml │ ├── robust_kalman │ │ ├── robust_kalman_agg.yaml │ │ ├── robust_kalman_setup.yaml │ │ ├── robust_kalman_plot.yaml │ │ └── robust_kalman_run.yaml │ ├── osc_mass │ │ ├── osc_mass_plot.yaml │ │ ├── osc_mass_setup.yaml │ │ ├── osc_mass_agg.yaml │ │ └── osc_mass_run.yaml │ ├── sparse_pca │ │ ├── sparse_pca_setup.yaml │ │ ├── sparse_pca_plot.yaml │ │ └── sparse_pca_run.yaml │ ├── lasso │ │ ├── lasso_setup.yaml │ │ ├── lasso_plot.yaml │ │ └── lasso_run.yaml │ ├── mpc │ │ ├── mpc_setup.yaml │ │ ├── mpc_plot.yaml │ │ └── mpc_run.yaml │ ├── unconstrained_qp │ │ ├── unconstrained_qp_setup.yaml │ │ ├── unconstrained_qp_plot.yaml │ │ └── unconstrained_qp_run.yaml │ ├── jamming │ │ ├── jamming_setup.yaml │ │ ├── jamming_plot.yaml │ │ └── jamming_run.yaml │ ├── mnist │ │ ├── mnist_setup.yaml │ │ ├── mnist_plot.yaml │ │ └── mnist_run.yaml │ ├── markowitz │ │ ├── markowitz_setup.yaml │ │ ├── markowitz_agg.yaml │ │ ├── markowitz_plot.yaml │ │ └── markowitz_run.yaml │ ├── phase_retrieval │ │ ├── phase_retrieval_setup.yaml │ │ ├── phase_retrieval_run.yaml │ │ └── phase_retrieval_plot.yaml │ ├── quadcopter │ │ ├── quadcopter_setup.yaml │ │ ├── quadcopter_plot.yaml │ │ └── quadcopter_run.yaml │ └── all │ │ └── plot.yaml ├── slurm_script_gpu.sh ├── slurm_script_cpu.sh ├── l2ws_setup.py └── l2ws_train.py ├── .github └── workflows │ ├── build.yml │ └── deploy.yml ├── pyproject.toml ├── tests ├── compare_scs.py ├── test_canonicalizations.py ├── test_l2ws_model.py └── test_algo_steps.py ├── .gitignore └── README.md /l2ws/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /l2ws/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /l2ws/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/configs/robust_ls/robust_ls_agg.yaml: -------------------------------------------------------------------------------- 1 | datetimes: [] -------------------------------------------------------------------------------- /benchmarks/configs/robust_pca/robust_pca_agg.yaml: -------------------------------------------------------------------------------- 1 | datetimes: [] -------------------------------------------------------------------------------- /benchmarks/configs/vehicle/vehicle_agg.yaml: -------------------------------------------------------------------------------- 1 | datetimes: [] 2 | -------------------------------------------------------------------------------- /benchmarks/configs/robust_kalman/robust_kalman_agg.yaml: -------------------------------------------------------------------------------- 1 | datetimes: [] -------------------------------------------------------------------------------- /benchmarks/configs/osc_mass/osc_mass_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: [] 2 | naive_ws_datetime: '' 3 | pretrain_datetime: '' 4 | no_learning_datetime: '' 5 | eval_iters: 1500 6 | accuracies: [1e-2, 1e-3, 1e-4] -------------------------------------------------------------------------------- /benchmarks/configs/sparse_pca/sparse_pca_setup.yaml: -------------------------------------------------------------------------------- 1 | n_orig: 10 #40 2 | k: 5 #10 3 | r: 5 #10 4 | solve_acc_abs: 1e-4 5 | solve_acc_rel: 0 6 | N_train: 100 7 | N_test: 10 8 | seed: 42 9 | solve: True 10 | -------------------------------------------------------------------------------- /benchmarks/configs/robust_pca/robust_pca_setup.yaml: -------------------------------------------------------------------------------- 1 | p: 30 2 | q: 4 3 | sparse_frac: 0.1 4 | low_rank: 2 5 | N_train: 1000 6 | N_test: 200 7 | A_star_seed: 0 8 | B_star_seed: None 9 | solve_acc_abs: 1e-3 10 | solve_acc_rel: 0 -------------------------------------------------------------------------------- /benchmarks/configs/osc_mass/osc_mass_setup.yaml: -------------------------------------------------------------------------------- 1 | T: 50 2 | nx: 36 # fixed 3 | nu: 9 # fixed 4 | state_box: 4 5 | control_box: .5 6 | x_init_box: 2 7 | N_train: 100 8 | N_test: 20 9 | Q_val: 1 10 | QT_val: 1 11 | R_val: 1 12 | solve_acc: 1e-4 -------------------------------------------------------------------------------- /benchmarks/configs/robust_ls/robust_ls_setup.yaml: -------------------------------------------------------------------------------- 1 | m_orig: 500 2 | n_orig: 800 3 | rho: 10 4 | solve_acc_abs: 1e-4 5 | solve_acc_rel: 0 6 | N_train: 100 7 | N_test: 20 8 | seed: 42 9 | b_nominal: 1.5 10 | b_range: .5 11 | solve: True 12 | -------------------------------------------------------------------------------- /benchmarks/configs/lasso/lasso_setup.yaml: -------------------------------------------------------------------------------- 1 | m_orig: 500 2 | n_orig: 500 3 | lambd: 10 4 | solve_acc_abs: 1e-10 5 | solve_acc_rel: 0 6 | N_train: 10000 7 | N_test: 1000 8 | seed: 42 9 | solve: True 10 | A_scale: 1 11 | b_min: 0 12 | b_max: 30 13 | -------------------------------------------------------------------------------- /benchmarks/configs/mpc/mpc_setup.yaml: -------------------------------------------------------------------------------- 1 | nx: 10 2 | nu: 5 3 | T: 10 4 | traj_length: 10 5 | x_init_factor: .5 6 | solve_acc_abs: 1e-4 7 | solve_acc_rel: 0 8 | N_train: 1000 9 | N_test: 100 10 | seed: 42 11 | solve: True 12 | noise_std_dev: 0.1 13 | -------------------------------------------------------------------------------- /benchmarks/configs/unconstrained_qp/unconstrained_qp_setup.yaml: -------------------------------------------------------------------------------- 1 | n_orig: 20 2 | # solve_acc_abs: 1e-10 3 | # solve_acc_rel: 0 4 | split_factor: 100 5 | N_train: 100 6 | N_test: 10 7 | seed: 42 8 | solve: True 9 | A_scale: 1 10 | c_min: -10 11 | c_max: 10 12 | P_split: 10 -------------------------------------------------------------------------------- /benchmarks/configs/jamming/jamming_setup.yaml: -------------------------------------------------------------------------------- 1 | n: 100 2 | beta_min: 1 3 | beta_max: 2 4 | sigma_min: 1 5 | sigma_max: 1.0001 6 | step_size: .05 7 | solve_acc_abs: 1e-10 8 | solve_acc_rel: 0 9 | N_train: 10 10 | N_test: 10 11 | seed: 42 12 | solve: True 13 | solve_iters: 100 14 | delta_frac: 1 15 | -------------------------------------------------------------------------------- /benchmarks/configs/jamming/jamming_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: 2 | - 2023-07-14/12-09-40 3 | loss_overlay_titles: [] 4 | nearest_neighbor_datetime: 2023-07-14/12-09-40 5 | pretrain_datetime: '' 6 | cold_start_datetime: 2023-07-14/12-09-40 7 | eval_iters: 500 8 | accuracies: 9 | - 0.1 10 | - 0.01 11 | - 0.001 12 | - 0.0001 -------------------------------------------------------------------------------- /benchmarks/configs/robust_kalman/robust_kalman_setup.yaml: -------------------------------------------------------------------------------- 1 | rollout_length: 100 2 | num_rollouts: 2 3 | T: 50 4 | mu: 2 5 | rho: 2 6 | gamma: .05 7 | solve_acc_abs: 1e-4 8 | solve_acc_rel: 0 9 | dt: .5 #.5 10 | sigma: 20 11 | p: 1e-8 12 | # N_train: 100 13 | # N_test: 20 14 | w_noise_var: .1 15 | y_noise_var: .1 16 | B_const: 1 17 | -------------------------------------------------------------------------------- /benchmarks/configs/mnist/mnist_setup.yaml: -------------------------------------------------------------------------------- 1 | solve_acc_abs: 1e-3 2 | solve_acc_rel: 0 3 | lambd: .001 4 | lambd2: 0.4 5 | N_train: 10000 6 | N_test: 1000 7 | seed: 42 8 | solve: True 9 | noise_std_dev: .001 10 | blur_size: 8 11 | dataset: emnist 12 | obj_const: 10 13 | A_scale: 1 14 | deblur_or_denoise: 'deblur' 15 | mri_size: 70 16 | # salt_and_pepper_noise_prob: 0.2 17 | -------------------------------------------------------------------------------- /benchmarks/configs/robust_pca/robust_pca_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: ['2023-01-17/09-09-54', '2023-01-17/09-42-53', '2023-01-17/10-22-45'] #['2023-01-16/23-31-58', '2023-01-17/08-51-56'] 2 | naive_ws_datetime: '' 3 | pretrain_datetime: '' #'2023-01-17/08-51-56' 4 | no_learning_datetime: '2023-01-17/10-22-45' #'2023-01-17/08-51-56' 5 | eval_iters: 500 6 | accuracies: [1e-2, 1e-3, 1e-4] -------------------------------------------------------------------------------- /benchmarks/configs/markowitz/markowitz_setup.yaml: -------------------------------------------------------------------------------- 1 | a: 3000 # 3000 for nasdaq, 2342 for yahoo 2 | N_train: 1000 3 | N_test: 200 4 | # mult_factor: 1000 5 | # std_mult: 1e-3 #.3 6 | pen_rets_min: -1.5 #-1.5 7 | pen_rets_max: -1.5 #-1.5 8 | # max_clip: 3e-3 #1 9 | # min_clip: -3e-3 #-1 10 | solve_acc: 1e-3 11 | data: 'eod' #'yahoo' #'nasdaq' 12 | alpha: .024 13 | scale_factor: 1e2 14 | idio_risk: 1e-3 -------------------------------------------------------------------------------- /benchmarks/configs/phase_retrieval/phase_retrieval_setup.yaml: -------------------------------------------------------------------------------- 1 | n_orig: 40 2 | d_mul: 3 # number of constraints will equal d_mul * n_orig 3 | x_var: 1 # variance in the Gaussians for X 4 | x_mean: 5 #5 5 | ###### change the above for phase_retrieval 6 | 7 | solve_acc_abs: 1e-3 8 | solve_acc_rel: 0 9 | solve_max_iters: 1000 10 | N_train: 10000 11 | N_test: 1000 12 | seed: 42 13 | solve: True 14 | -------------------------------------------------------------------------------- /benchmarks/configs/vehicle/vehicle_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: [] 2 | pretrain_datetime: '' 3 | naive_ws_datetime: '' 4 | no_learning_datetime: '' 5 | eval_iters: 2500 6 | accuracies: [1e-2, 1e-3, 1e-4] 7 | # output_datetimes: ['2022-11-28/21-34-54', '2022-11-28/21-33-36', '2022-11-28/19-58-28'] 8 | # naive_ws_datetime: '2022-11-30/14-56-17' 9 | # pretrain_datetime: '' 10 | # no_learning_datetime: '2022-11-28/21-33-36' -------------------------------------------------------------------------------- /benchmarks/configs/markowitz/markowitz_agg.yaml: -------------------------------------------------------------------------------- 1 | datetimes: [] 2 | # ['2022-12-07/00-51-37', '2022-12-07/00-52-08', '2022-12-07/00-53-09', '2022-12-07/00-53-11', '2022-12-07/00-53-42', '2022-12-07/00-55-38', '2022-12-07/00-56-09', '2022-12-07/00-56-10'] 3 | # ['2022-12-06/22-29-16', '2022-12-06/22-29-24', '2022-12-06/22-29-27', '2022-12-06/22-30-57'] 4 | # ['2022-12-06/20-41-30', '2022-12-06/20-41-47', '2022-12-06/20-45-14', '2022-12-06/20-45-52', '2022-12-06/20-46-57', '2022-12-06/20-47-13'] -------------------------------------------------------------------------------- /benchmarks/configs/markowitz/markowitz_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: ['2022-12-07/09-02-57', '2022-12-07/01-13-54'] 2 | #['2022-12-07/09-03-08', '2022-12-07/09-02-57', '2022-12-07/01-13-54'] 3 | naive_ws_datetime: '2022-12-07/09-03-08' 4 | pretrain_datetime: '' 5 | no_learning_datetime: '2022-12-07/09-03-08' 6 | eval_iters: 500 7 | accuracies: [1e-2, 1e-3, 1e-4] 8 | # output_datetimes: ['2022-11-28/21-34-54', '2022-11-28/21-33-36', '2022-11-28/19-58-28'] 9 | # naive_ws_datetime: '2022-11-30/14-56-17' 10 | # pretrain_datetime: '' 11 | # no_learning_datetime: '2022-11-28/21-33-36' -------------------------------------------------------------------------------- /benchmarks/configs/markowitz/markowitz_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-4 3 | method: adam 4 | intermediate_layer_sizes: [500, 500, 500] 5 | batch_size: 50 6 | epochs: 200 7 | decay_lr: .6 8 | min_lr: 1e-7 9 | decay_every: 5000 10 | 11 | pretrain: 12 | pretrain_method: adam 13 | pretrain_stepsize: .001 14 | pretrain_iters: 2 15 | 16 | data: 17 | datetime: '' #'2022-12-07/00-11-31' 18 | 19 | a: 3000 20 | train_unrolls: 50 21 | eval_unrolls: 500 22 | eval_every_x_epochs: 10 23 | N_train: 1000 24 | N_test: 200 25 | num_samples: 200 26 | prediction_variable: w 27 | -------------------------------------------------------------------------------- /benchmarks/configs/lasso/lasso_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: 2 | - 2023-08-27/11-52-07 3 | - 2023-08-27/11-55-34 4 | - 2023-08-27/16-11-05 5 | # - 2023-08-27/16-11-35 6 | - 2023-08-27/13-21-10 7 | - 2023-09-05/20-39-56 # 2023-08-27/22-04-36 8 | - 2023-08-27/15-27-17 9 | - 2023-08-27/15-27-56 10 | # - 2023-08-27/13-48-25 11 | - 2023-08-27/13-55-58 12 | 13 | loss_overlay_titles: [] 14 | nearest_neighbor_datetime: 2023-08-27/11-52-07 15 | pretrain_datetime: '' 16 | cold_start_datetime: 2023-08-27/11-52-07 17 | eval_iters: 500 18 | accuracies: 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 -------------------------------------------------------------------------------- /benchmarks/configs/unconstrained_qp/unconstrained_qp_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: 2 | - 2023-08-28/23-46-00 3 | - 2023-08-28/23-46-31 4 | - 2023-08-28/23-53-29 5 | # - 2023-08-28/23-54-38 6 | - 2023-08-29/09-47-00 7 | - 2023-08-27/22-43-26 8 | - 2023-08-27/22-54-44 9 | - 2023-08-27/22-47-43 10 | # - 2023-08-27/23-03-36 11 | - 2023-08-27/23-04-48 12 | 13 | 14 | loss_overlay_titles: [] 15 | nearest_neighbor_datetime: 2023-08-29/09-47-00 16 | pretrain_datetime: '' 17 | cold_start_datetime: 2023-08-29/09-47-00 18 | eval_iters: 500 19 | accuracies: 20 | - 0.1 21 | - 0.01 22 | - 0.001 23 | - 0.0001 -------------------------------------------------------------------------------- /benchmarks/configs/mpc/mpc_plot.yaml: -------------------------------------------------------------------------------- 1 | # output_datetimes: ['2023-05-18/20-08-46', '2023-05-18/20-08-46', '2023-05-18/10-34-45', '2023-05-20/15-54-06', '2023-05-20/15-47-51'] 2 | output_datetimes: ['2023-06-01/17-58-38', '2023-05-18/20-08-46', '2023-05-18/10-34-45', '2023-05-18/20-10-15', '2023-05-20/15-54-06', '2023-05-20/15-47-51'] 3 | loss_overlay_titles: [] 4 | nearest_neighbor_datetime: '2023-05-18/20-08-46' 5 | prev_sol_datetime: '2023-05-18/20-08-46' 6 | pretrain_datetime: '' 7 | cold_start_datetime: '2023-05-18/20-08-46' 8 | eval_iters: 500 9 | accuracies: [1e-1, 1e-2, 1e-3, 1e-4] 10 | rel_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 11 | abs_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 12 | -------------------------------------------------------------------------------- /benchmarks/configs/osc_mass/osc_mass_agg.yaml: -------------------------------------------------------------------------------- 1 | datetimes: [] 2 | #['2022-11-30/16-06-14'] 3 | # ['2022-11-30/13-43-05'] 4 | #['2022-11-24/09-48-56', '2022-11-24/09-49-26', '2022-11-24/09-49-28'] 5 | #['2022-11-06/00-30-31', '2022-11-06/00-30-50', '2022-11-06/00-30-55', '2022-11-06/00-31-22', '2022-11-06/00-31-53', '2022-11-06/00-32-08'] 6 | #['2022-11-06/01-02-30'] 7 | #['2022-11-06/00-30-31', '2022-11-06/00-30-50', '2022-11-06/00-30-55', 2022-11-06/00-31-22', 2022-11-06/00-31-53', '2022-11-06/00-32-08'] 8 | #['2022-11-05/16-21-56'] 9 | #['2022-11-05/16-25-56'] 10 | #['2022-11-04/09-20-28'] 11 | #['2022-11-04/09-19-36'] 12 | #['2022-11-04/08-56-25'] local 13 | #['2022-11-03/16-32-20'] synthetic 14 | -------------------------------------------------------------------------------- /benchmarks/configs/vehicle/vehicle_setup.yaml: -------------------------------------------------------------------------------- 1 | T: 30 2 | dt: .01 3 | slip_box_deg: 25 4 | yaw_box_deg: 40 5 | roll_box_deg: 15 6 | roll_rate_box_deg: 30 7 | delta0_box_deg: 45 8 | delta_rate_box_deg: 30 9 | Fy_box: 1000 10 | Mx_box: 20000 11 | Mz_box: 30000 12 | v_min: 2 13 | v_max: 35 14 | slip_penalty: 30 #100 15 | yaw_penalty: 200 #20000 16 | roll_penalty: 0 17 | roll_rate_penalty: 50 #500 18 | Fy_penalty: 1e-5 #1e-6 #1 19 | Mx_penalty: 1e-6 #1e-6 20 | Mz_penalty: 6e-6 #6e-6 21 | control_rate_lim: 35000 22 | # states: 'with' #'with' or 'without' 23 | # steer_features: 'low' #'low' or 'high' 24 | # matrices_fixed: False 25 | linear_system_solve: 'inverse' 26 | Fy_factor: 1e-4 27 | Mx_factor: 1e-4 28 | Mz_factor: 1e-4 29 | N_train: 100 30 | N_test: 20 31 | solve_acc: 1e-6 -------------------------------------------------------------------------------- /benchmarks/configs/sparse_pca/sparse_pca_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: 2 | - 2023-07-25/23-28-42 3 | - 2023-07-25/22-13-13 4 | - 2023-07-26/22-08-40 5 | # - 2023-07-26/22-38-49 6 | # - 2023-07-26/22-47-30 7 | - 2023-07-26/21-55-00 8 | - 2023-07-26/08-49-43 9 | - 2023-07-26/21-24-25 10 | # - 2023-07-26/20-47-36 11 | # - 2023-07-26/21-18-04 12 | # - 2023-08-05/11-47-49 -- trained over 40 hours 13 | # - 2023-08-05/11-48-20 -- trained over 40 hours 14 | loss_overlay_titles: [] 15 | nearest_neighbor_datetime: 2023-07-25/23-28-42 16 | pretrain_datetime: '' 17 | cold_start_datetime: 2023-07-25/23-28-42 18 | eval_iters: 500 19 | accuracies: 20 | - 0.1 21 | - 0.01 22 | - 0.001 23 | - 0.0001 24 | rel_tols: 25 | - 0.1 26 | - 0.01 27 | - 0.001 28 | - 0.0001 29 | - 1.0e-05 30 | abs_tols: 31 | - 0.1 32 | - 0.01 33 | - 0.001 34 | - 0.0001 35 | - 1.0e-05 36 | -------------------------------------------------------------------------------- /benchmarks/configs/quadcopter/quadcopter_setup.yaml: -------------------------------------------------------------------------------- 1 | QT_factor: 1 2 | Q_diag: 3 | - 1 4 | - 1 5 | - 1 6 | - 0 7 | - 0 8 | - 0 9 | - 0 10 | - 0 11 | - 0 12 | - 0 13 | R_diag: 14 | - 0.001 15 | - 0.001 16 | - 0.001 17 | - 0.001 18 | T: 10 19 | delta_u: 20 | - 20 21 | - 6 22 | - 6 23 | - 6 24 | dt: 0.05 25 | goal_bound: 0.5 26 | noise_std_dev: 0.0 27 | num_goals: 5 28 | num_rollouts: 110 29 | obstacle_tol: 0.1 30 | rollout_length: 100 31 | rollout_osqp_iters: 500 32 | save_gif: true 33 | seed: 42 34 | solve: true 35 | solve_acc_abs: 0.0001 36 | solve_acc_rel: 0 37 | u_max: 38 | - 20 39 | - 6 40 | - 6 41 | - 6 42 | u_min: 43 | - 2 44 | - -6 45 | - -6 46 | - -6 47 | waypoints: 10 48 | x_max: 49 | - 1 50 | - 1 51 | - 1 52 | - 50 53 | - 50 54 | - 50 55 | - 1 56 | - 1 57 | - 1 58 | - 1 59 | x_min: 60 | - -1 61 | - -1 62 | - -1 63 | - -50 64 | - -50 65 | - -50 66 | - 0 67 | - -1 68 | - -1 69 | - -1 70 | -------------------------------------------------------------------------------- /benchmarks/configs/all/plot.yaml: -------------------------------------------------------------------------------- 1 | markowitz: 2 | output_datetimes: ['2022-11-28/21-34-54', '2022-11-28/21-33-36', '2022-11-28/19-58-28'] 3 | naive_ws_datetime: '2022-11-30/14-56-17' 4 | pretrain_datetime: '' 5 | no_learning_datetime: '2022-11-28/21-33-36' 6 | eval_iters: 500 7 | accuracies: [1e-2, 1e-3, 1e-4] 8 | vehicle: 9 | output_datetimes: ['2022-11-27/10-06-50', '2022-11-27/12-09-16', '2022-11-26/18-21-01'] 10 | pretrain_datetime: '' 11 | naive_ws_datetime: '2022-11-30/15-06-40' 12 | no_learning_datetime: '2022-11-26/20-27-13' 13 | eval_iters: 500 14 | accuracies: [1e-2, 1e-3, 1e-4] 15 | mpc: 16 | output_datetimes: ['2022-11-30/19-39-39', '2022-11-30/19-40-09', '2022-11-30/16-03-42'] 17 | naive_ws_datetime: '2022-11-30/19-40-09' 18 | pretrain_datetime: '' 19 | no_learning_datetime: '2022-11-30/19-40-09' 20 | eval_iters: 500 21 | accuracies: [1e-2, 1e-3, 1e-4] 22 | -------------------------------------------------------------------------------- /benchmarks/configs/vehicle/vehicle_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-3 3 | method: adam 4 | intermediate_layer_sizes: [1000, 1000] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | pretrain: 12 | pretrain_method: adam 13 | pretrain_stepsize: 1e-3 14 | pretrain_iters: 0 15 | pretrain_batches: 10 16 | 17 | data: 18 | datetime: '' 19 | 20 | 21 | 22 | train_unrolls: 20 23 | eval_unrolls: 500 24 | eval_every_x_epochs: 20 25 | save_every_x_epochs: 1 26 | test_every_x_epochs: 1 27 | write_csv_every_x_batches: 1 28 | N_train: 100 29 | N_test: 20 30 | num_samples: 20 31 | prediction_variable: w 32 | angle_anchors: [0] 33 | supervised: False 34 | plot_iterates: [0, 10, 20] 35 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 36 | share_all: False 37 | num_clusters: 1000 38 | pretrain_alpha: False 39 | normalize_inputs: False 40 | normalize_alpha: 'other' -------------------------------------------------------------------------------- /benchmarks/configs/osc_mass/osc_mass_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-3 3 | method: adam 4 | intermediate_layer_sizes: [1000, 1000] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | pretrain: 12 | pretrain_method: adam 13 | pretrain_stepsize: 1e-3 14 | pretrain_iters: 0 15 | pretrain_batches: 10 16 | 17 | data: 18 | datetime: '' 19 | 20 | 21 | 22 | train_unrolls: 20 23 | eval_unrolls: 500 24 | eval_every_x_epochs: 20 25 | save_every_x_epochs: 1 26 | test_every_x_epochs: 1 27 | write_csv_every_x_batches: 1 28 | N_train: 100 29 | N_test: 20 30 | num_samples: 20 31 | prediction_variable: w 32 | angle_anchors: [0] 33 | supervised: False 34 | plot_iterates: [0, 10, 20] 35 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 36 | share_all: False 37 | num_clusters: 1000 38 | pretrain_alpha: False 39 | normalize_inputs: False 40 | normalize_alpha: 'other' -------------------------------------------------------------------------------- /l2ws/gd_model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from l2ws.algo_steps import k_steps_eval_gd, k_steps_train_gd 4 | from l2ws.l2ws_model import L2WSmodel 5 | 6 | 7 | class GDmodel(L2WSmodel): 8 | def __init__(self, **kwargs): 9 | super(GDmodel, self).__init__(**kwargs) 10 | 11 | def initialize_algo(self, input_dict): 12 | self.factor_static = None 13 | self.algo = 'gd' 14 | self.factors_required = False 15 | self.q_mat_train, self.q_mat_test = input_dict['c_mat_train'], input_dict['c_mat_test'] 16 | P = input_dict['P'] 17 | gd_step = input_dict['gd_step'] 18 | n = P.shape[0] 19 | self.output_size = n 20 | 21 | self.k_steps_train_fn = partial(k_steps_train_gd, P=P, gd_step=gd_step, jit=self.jit) 22 | self.k_steps_eval_fn = partial(k_steps_eval_gd, P=P, gd_step=gd_step, jit=self.jit) 23 | self.out_axes_length = 5 24 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | 7 | linter: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - uses: actions/setup-python@v4 12 | with: 13 | python-version: "3.10" 14 | - name: ruff 15 | uses: chartboost/ruff-action@v1 16 | 17 | test: 18 | runs-on: ubuntu-latest 19 | strategy: 20 | matrix: 21 | python-version: ["3.10"] 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install ".[dev]" 33 | - name: Test with pytest 34 | run: | 35 | pytest tests 36 | -------------------------------------------------------------------------------- /benchmarks/configs/robust_pca/robust_pca_gif.yaml: -------------------------------------------------------------------------------- 1 | eval_iters: 500 2 | # datetimes: '' 3 | # datetimes: ['2023-01-19/23-07-45', '2023-01-19/21-26-18'] 4 | 5 | # below is the best run for meeting jan 25 6 | # datetimes: ['2023-01-20/23-08-01', '2023-01-22/22-42-59', '2023-01-20/19-13-45'] 7 | # labels: ['incr. loss sum to k', 'const. loss sum to k', 'loss only at k'] 8 | 9 | # below includes angles 10 | # datetimes: ['2023-01-24/20-21-46', '2023-01-23/12-16-06', '2023-01-23/12-11-45'] 11 | # labels: ['t=10 fixed k', 't=20 fixed k', 't=20 sum k'] 12 | 13 | # change d 14 | datetimes: ['2023-01-24/23-51-49', '2023-01-25/07-58-25', '2023-01-25/07-58-41', '2023-01-24/23-45-44'] 15 | labels: ['d=3', 'd=10', 'd=20', 'd=34 (full rank)'] 16 | 17 | # datetimes: '' #['2023-01-19/22-45-22', '2023-01-20/09-14-51'] 18 | 19 | add_k: True 20 | y_low: 3e-5 21 | y_high: 5e1 22 | gif_length: 30 23 | gradient: False 24 | # y_grad_low: 8e-4 25 | # y_grad_high: 3e1 26 | angle_prob_nums: [0, 1, 2] -------------------------------------------------------------------------------- /l2ws/ista_model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from l2ws.algo_steps import ( 4 | k_steps_eval_ista, 5 | k_steps_train_ista, 6 | ) 7 | from l2ws.l2ws_model import L2WSmodel 8 | 9 | 10 | class ISTAmodel(L2WSmodel): 11 | def __init__(self, **kwargs): 12 | super(ISTAmodel, self).__init__(**kwargs) 13 | 14 | def initialize_algo(self, input_dict): 15 | self.factor_static = None 16 | self.algo = 'ista' 17 | self.factors_required = False 18 | self.q_mat_train, self.q_mat_test = input_dict['b_mat_train'], input_dict['b_mat_test'] 19 | A = input_dict['A'] 20 | lambd = input_dict['lambd'] 21 | ista_step = input_dict['ista_step'] 22 | m, n = A.shape 23 | self.output_size = n 24 | 25 | self.k_steps_train_fn = partial(k_steps_train_ista, A=A, lambd=lambd, 26 | ista_step=ista_step, jit=self.jit) 27 | self.k_steps_eval_fn = partial(k_steps_eval_ista, A=A, lambd=lambd, 28 | ista_step=ista_step, jit=self.jit) 29 | self.out_axes_length = 5 30 | -------------------------------------------------------------------------------- /benchmarks/configs/robust_pca/robust_pca_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-3 3 | method: adam 4 | intermediate_layer_sizes: [500, 500] 5 | batch_size: 100 6 | epochs: 10000 7 | decay_lr: .6 8 | min_lr: 1e-7 9 | decay_every: 1000 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 10 # in epochs 15 | tolerance: 1e-4 16 | patience: 2 17 | 18 | pretrain: 19 | pretrain_method: adam 20 | pretrain_stepsize: 1e-4 21 | pretrain_iters: 0 22 | pretrain_batches: 10 23 | 24 | data: 25 | datetime: '' 26 | 27 | train_unrolls: 20 28 | eval_unrolls: 500 29 | eval_every_x_epochs: 10 30 | save_every_x_epochs: 1 31 | write_csv_every_x_batches: 1 32 | N_train: 1000 33 | N_test: 20 34 | num_samples: 100 35 | prediction_variable: w 36 | supervised: False 37 | angle_anchors: [0] 38 | dx: 0 39 | dy: 0 40 | tx: 50 41 | ty: 50 42 | learn_XY: False 43 | loss_method: constant_sum 44 | #increasing_sum or constant_sum or fixed_k 45 | plot_iterates: [0, 10, 20] 46 | share_all: False 47 | num_clusters: 100 48 | pretrain_alpha: False 49 | normalize_inputs: True 50 | normalize_alpha: 'other' 51 | epochs_jit: 10 52 | -------------------------------------------------------------------------------- /benchmarks/configs/jamming/jamming_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-4 3 | method: adam 4 | intermediate_layer_sizes: [500, 500] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 5 # in epochs 15 | tolerance: 1e-4 #1e-3 16 | patience: 1 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' 27 | 28 | train_unrolls: 1 29 | eval_unrolls: 100 30 | eval_every_x_epochs: 20 31 | save_every_x_epochs: 1 32 | test_every_x_epochs: 1 33 | write_csv_every_x_batches: 1 34 | epochs_jit: 2 35 | N_train: 10 36 | N_test: 10 37 | num_samples: 10 38 | prediction_variable: w 39 | angle_anchors: [0] 40 | supervised: True 41 | plot_iterates: [5] 42 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 43 | share_all: False 44 | num_clusters: 10 45 | pretrain_alpha: False 46 | normalize_inputs: False 47 | normalize_alpha: 'other' 48 | 49 | accuracies: [.1, .01, .001, .0001] 50 | rho_x: 1 51 | scale: 1 52 | alpha_relax: 1 53 | skip_startup: False -------------------------------------------------------------------------------- /benchmarks/configs/lasso/lasso_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-4 3 | method: adam 4 | intermediate_layer_sizes: [] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 5000 # in epochs 15 | tolerance: -100 #1e-3 16 | patience: 1 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' 27 | 28 | train_unrolls: 1 29 | eval_unrolls: 5000 30 | eval_every_x_epochs: 20 31 | save_every_x_epochs: 1 32 | test_every_x_epochs: 1 33 | write_csv_every_x_batches: 1 34 | epochs_jit: 2 35 | N_train: 10000 36 | N_test: 1000 37 | num_samples: 1000 38 | prediction_variable: w 39 | angle_anchors: [0] 40 | supervised: True 41 | plot_iterates: [0, 10, 20] 42 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 43 | share_all: False 44 | num_clusters: 10 45 | pretrain_alpha: False 46 | normalize_inputs: True 47 | normalize_alpha: 'other' 48 | 49 | accuracies: [.1, .01, .001, .0001] 50 | rho_x: 1 51 | scale: 1 52 | alpha_relax: 1 53 | skip_startup: False -------------------------------------------------------------------------------- /benchmarks/configs/unconstrained_qp/unconstrained_qp_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-1 3 | method: adam 4 | intermediate_layer_sizes: [10] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 5 # in epochs 15 | tolerance: 1e-3 16 | patience: 1 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' 27 | 28 | train_unrolls: 60 29 | eval_unrolls: 1000 30 | eval_every_x_epochs: 100 31 | save_every_x_epochs: 1 32 | test_every_x_epochs: 1 33 | write_csv_every_x_batches: 1 34 | epochs_jit: 2 35 | N_train: 100 36 | N_test: 10 37 | num_samples: 10 38 | prediction_variable: w 39 | angle_anchors: [0] 40 | supervised: False 41 | plot_iterates: [0, 10, 20] 42 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 43 | share_all: True 44 | num_clusters: 10 45 | pretrain_alpha: False 46 | normalize_inputs: True 47 | normalize_alpha: 'other' 48 | 49 | accuracies: [.1, .01, .001, .0001] 50 | rho_x: 1 51 | scale: 1 52 | alpha_relax: 1 53 | skip_startup: False -------------------------------------------------------------------------------- /benchmarks/slurm_script_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=array-job # create a short name for your job 3 | #SBATCH --output=slurm-%A.%a.out # STDOUT file 4 | #SBATCH --error=slurm-%A.%a.err # STDERR file 5 | #SBATCH --nodes=1 # node count 6 | #SBATCH --ntasks=1 # total number of tasks across all nodes 7 | #SBATCH --cpus-per-task=1 # cpu-cores per task (>1 if multi-threaded tasks) 8 | #SBATCH --mem-per-cpu=100G # memory per cpu-core (4G is default) 9 | #SBATCH --array=0 # job array with index values 0, 1, 2, 3, 4 10 | #SBATCH --time=00:59:00 # total run time limit (HH:MM:SS) 11 | #SBATCH --mail-type=all # send email on job start, end and fault 12 | #SBATCH --mail-user=rajivs@princeton.edu # 13 | #SBATCH --gres=gpu:1 14 | 15 | echo "My SLURM_ARRAY_JOB_ID is $SLURM_ARRAY_JOB_ID." 16 | echo "My SLURM_ARRAY_TASK_ID is $SLURM_ARRAY_TASK_ID" 17 | echo "Executing on the machine:" $(hostname) 18 | 19 | # os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='true' 20 | # export XLA_PYTHON_CLIENT_MEM_FRACTION='0.30' 21 | # XLA_PYTHON_CLIENT_ALLOCATOR=platform 22 | # export xla_force_host_platform_device_count=1 23 | 24 | 25 | python l2ws_train_script.py quadcopter cluster 26 | # python aggregate_slurm_runs_script.py robust_ls cluster 27 | 28 | 29 | # gpu command: #SBATCH --gres=gpu:1 -------------------------------------------------------------------------------- /benchmarks/configs/quadcopter/quadcopter_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: ['2023-07-31/11-31-17', 2 | '2023-08-22/16-48-35', 3 | '2023-08-22/18-09-18', 4 | # '2023-07-31/14-17-10', 5 | '2023-08-22/14-26-39', # fp 6 | '2023-07-31/11-19-01', 7 | '2023-07-30/23-40-17', 8 | '2023-07-31/10-04-44', 9 | # '2023-07-31/10-05-50', 10 | '2023-07-31/10-50-55'] # reg 11 | 12 | # FOR rel gap output_datetimes: ['2023-07-09/14-14-53', '2023-07-09/14-27-54', '2023-07-09/14-32-59', '2023-07-09/14-33-59', '2023-08-22/03-07-27', 13 | # '2023-07-09/15-06-38', '2023-07-09/15-23-01', '2023-07-09/15-35-15', '2023-08-22/10-25-00', '2023-07-09/15-40-50'] 14 | 15 | # output_datetimes: ['2023-07-09/14-14-53', '2023-07-09/14-27-54', '2023-07-09/14-32-59', '2023-07-09/14-33-59', '2023-07-09/14-35-29', 16 | # '2023-07-09/15-06-38', '2023-07-09/15-23-01', '2023-07-09/15-35-15', '2023-07-09/15-38-54', '2023-07-09/15-40-50'] 17 | # above: reg first, then fp 18 | # output_datetimes: ['2023-07-30/23-40-17'] 19 | loss_overlay_titles: [] 20 | # nearest_neighbor_datetime: '2023-07-30/23-21-16' 21 | # prev_sol_datetime: '2023-07-30/23-21-16' 22 | nearest_neighbor_datetime: '2023-08-22/15-28-27' 23 | prev_sol_datetime: '2023-08-22/15-28-27' 24 | pretrain_datetime: '' 25 | cold_start_datetime: '2023-08-22/15-28-27' 26 | eval_iters: 500 27 | accuracies: [1e-1, 1e-2, 1e-3, 1e-4] 28 | rel_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 29 | abs_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 30 | -------------------------------------------------------------------------------- /l2ws/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import yaml 5 | 6 | 7 | def recover_last_datetime(orig_cwd, example, stage): 8 | ''' 9 | stage should be either 10 | 1. data_setup 11 | 2. aggregate 12 | 3. train 13 | ''' 14 | folder = f"{orig_cwd}/outputs/{example}/{stage}_outputs/" 15 | 16 | date_entries = os.listdir(folder) 17 | date_entries.sort() 18 | last_date = date_entries[len(date_entries)-1] 19 | date_folder = f"{folder}{last_date}" 20 | 21 | datetime_entries = os.listdir(date_folder) 22 | datetime_entries.sort() 23 | last_time = datetime_entries[len(datetime_entries)-1] 24 | 25 | last_datetime = f"{last_date}/{last_time}" 26 | return last_datetime 27 | 28 | def copy_data_file(example, datetime): 29 | orig_cwd = hydra.utils.get_original_cwd() 30 | # data_yaml_filename = f"{orig_cwd}/outputs/{example}/aggregate_outputs/{datetime}/data_setup_copied.yaml" # noqa 31 | data_yaml_filename = f"{orig_cwd}/outputs/{example}/data_setup_outputs/{datetime}/.hydra/config.yaml" # noqa 32 | 33 | # read the yaml file 34 | with open(data_yaml_filename, "r") as stream: 35 | try: 36 | setup_cfg = yaml.safe_load(stream) 37 | except yaml.YAMLError as exc: 38 | print(exc) 39 | 40 | # write the yaml file to the train_outputs folder 41 | with open('data_setup_copied.yaml', 'w') as file: 42 | yaml.dump(setup_cfg, file) 43 | -------------------------------------------------------------------------------- /benchmarks/configs/mpc/mpc_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-4 3 | method: adam 4 | intermediate_layer_sizes: [500] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 5 # in epochs 15 | tolerance: 1e-8 #1e-3 16 | patience: 1 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' 27 | 28 | train_unrolls: 15 29 | eval_unrolls: 1000 30 | eval_every_x_epochs: 50 31 | save_every_x_epochs: 1 32 | test_every_x_epochs: 1 33 | write_csv_every_x_batches: 1 34 | epochs_jit: 10 35 | <<<<<<< HEAD 36 | N_train: 10000 37 | ======= 38 | N_train: 1000 39 | >>>>>>> 57d7d58d89100054075b34448174650e77c68f4e 40 | N_test: 100 41 | num_samples: 100 42 | prediction_variable: w 43 | angle_anchors: [0] 44 | supervised: True 45 | plot_iterates: [0, 10, 20] 46 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 47 | share_all: False 48 | num_clusters: 10 49 | pretrain_alpha: False 50 | normalize_inputs: True 51 | normalize_alpha: 'other' 52 | 53 | accuracies: [.1, .01, .001, .0001] 54 | rho_x: 1 55 | scale: 1 56 | alpha_relax: 1 57 | skip_startup: False 58 | save_weights_flag: True 59 | # load_weights_datetime: #'2023-05-12/15-22-37' 60 | 61 | # solving in C 62 | # rel_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 63 | # abs_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 64 | # solve_c_num: 100 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "l2ws" 8 | description="A package for warm-starting optimization algorithms, using data." 9 | readme = "README.md" 10 | license = {text = "Apache 2.0"} 11 | dynamic = ["version"] 12 | authors = [ 13 | {name = "Rajiv Samvbharya", email="rajivs@princeton.edu"}, 14 | {name = "Brandon Amos", email="bda@meta.com"}, 15 | {name = "Georgina Hall", email="georgina.hall@insead.edu"}, 16 | {name = "Bartolomeo Stellato", email="bstellato@princeton.edu"} 17 | ] 18 | dependencies = [ 19 | "numpy", 20 | "scipy", 21 | "cvxpy>=1.3.0", 22 | "matplotlib", 23 | "jax", 24 | "jaxopt", 25 | "optax==0.1.5", 26 | "matplotlib", 27 | "hydra-core", 28 | "trajax @ git+https://github.com/google/trajax", 29 | "emnist", 30 | "imageio" 31 | ] 32 | 33 | [tool.setuptools.packages.find] 34 | include = ["l2ws*"] 35 | exclude = ["tutorials*", "benchmarks*", "tests*"] 36 | 37 | [tool.setuptools_scm] 38 | # To infer version automatically from git 39 | write_to = "l2ws/_version.py" 40 | 41 | [project.optional-dependencies] 42 | dev = ["pytest", "ruff", "ruff-lsp", "black", "pandas", "jupyterlab"] 43 | 44 | 45 | [tool.black] 46 | line-length = 100 47 | target-version = ['py310'] 48 | 49 | [tool.ruff] 50 | select = ["E", "F", "I"] 51 | ignore = ["E722"] 52 | line-length = 100 53 | exclude = ["build", "examples", "instances", "docs", "*__init__.py"] 54 | target-version = "py310" 55 | -------------------------------------------------------------------------------- /benchmarks/configs/robust_ls/robust_ls_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: ['2023-07-20/15-20-15', '2023-07-20/15-22-02', '2023-07-20/15-23-38', '2023-07-20/15-25-31', '2023-07-20/15-26-28', '2023-07-20/15-27-15', '2023-07-20/15-28-16', '2023-08-02/10-32-25'] 2 | # output_datetimes: ['2023-07-20/15-20-15', '2023-07-20/15-22-02', '2023-07-20/15-23-38', '2023-07-20/15-24-26', '2023-07-20/15-25-31', '2023-07-20/15-26-28', '2023-07-20/15-27-15', '2023-07-20/15-28-16', '2023-08-02/12-24-30', '2023-08-02/10-32-25'] 3 | # output_datetimes: ['2023-07-20/15-20-15', '2023-07-20/15-22-02', '2023-07-20/15-23-38', '2023-07-20/15-24-26', '2023-07-20/15-25-31', '2023-07-20/15-26-28', '2023-07-20/15-27-15', '2023-07-20/15-28-16', '2023-07-20/15-29-20', '2023-07-20/15-29-59'] 4 | # very old ['2023-05-24/16-12-32', '2023-05-24/16-13-32', '2023-05-24/16-16-56', '2023-05-24/16-17-35', '2023-05-24/16-18-05', '2023-06-15/20-58-39', '2023-06-15/20-56-15', '2023-06-15/21-24-04', '2023-07-11/21-07-41', '2023-07-11/21-08-43'] 5 | # ['2023-05-24/16-12-32', '2023-05-24/16-13-32', '2023-05-24/16-16-56', '2023-05-24/16-17-35', '2023-05-24/16-18-05', '2023-05-24/16-19-20', '2023-06-15/20-58-39', '2023-06-15/20-56-15', '2023-06-15/21-24-04', '2023-07-11/21-07-41', '2023-07-11/21-08-43'] 6 | loss_overlay_titles: [] 7 | nearest_neighbor_datetime: '2023-07-20/15-33-13' 8 | # prev_sol_datetime: '2023-05-18/20-08-46' 9 | pretrain_datetime: '' 10 | cold_start_datetime: '2023-07-20/15-33-13' 11 | eval_iters: 299 12 | accuracies: [1e-1, 1e-2, 1e-3, 1e-4] 13 | rel_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 14 | abs_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] -------------------------------------------------------------------------------- /benchmarks/configs/robust_ls/robust_ls_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-3 3 | method: adam 4 | intermediate_layer_sizes: [500, 500] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 5 # in epochs 15 | tolerance: 1e-3 16 | patience: 1 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' 27 | 28 | 29 | eval_unrolls: 300 30 | eval_every_x_epochs: 20 31 | save_every_x_epochs: 1 32 | test_every_x_epochs: 1 33 | write_csv_every_x_batches: 1 34 | N_train: 10000 35 | N_test: 1000 36 | num_samples: 100 37 | prediction_variable: w 38 | angle_anchors: [0] 39 | plot_iterates: [0, 10, 20] 40 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 41 | share_all: False 42 | num_clusters: 1000 43 | pretrain_alpha: False 44 | normalize_inputs: True 45 | normalize_alpha: 'other' 46 | epochs_jit: 2 47 | accuracies: [.1, .01, .001, .0001] 48 | 49 | rho_x: 1 50 | scale: 1 51 | alpha_relax: 1 52 | skip_startup: False 53 | # solve_c_num: 1000 54 | save_weights_flag: True 55 | # load_weights_datetime: '2023-08-02/10-19-58' 56 | train_unrolls: 30 57 | supervised: True 58 | lightweight: True 59 | 60 | # obj: output_datetimes: ['2023-05-24/16-12-32', '2023-05-24/16-13-32', '2023-05-24/16-16-56', '2023-05-24/16-17-35', '2023-05-24/16-18-05', 61 | # reg: '2023-06-15/20-58-39', '2023-06-15/20-56-15', '2023-06-15/21-24-04', '2023-07-11/21-07-41', '2023-07-11/21-08-43'] 62 | 63 | -------------------------------------------------------------------------------- /benchmarks/slurm_script_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=array-job # create a short name for your job 3 | #SBATCH --output=slurm-%A.%a.out # STDOUT file 4 | #SBATCH --error=slurm-%A.%a.err # STDERR file 5 | #SBATCH --nodes=1 # node count 6 | #SBATCH --ntasks=1 # total number of tasks across all nodes 7 | #SBATCH --cpus-per-task=1 # cpu-cores per task (>1 if multi-threaded tasks) 8 | #SBATCH --mem-per-cpu=10G # memory per cpu-core (4G is default) 9 | #SBATCH --array=0 # job array with index values 0, 1, 2, 3, 4 10 | #SBATCH --time=00:15:00 # total run time limit (HH:MM:SS) 11 | #SBATCH --mail-type=all # send email on job start, end and fault 12 | #SBATCH --mail-user=rajivs@princeton.edu # 13 | 14 | 15 | echo "My SLURM_ARRAY_JOB_ID is $SLURM_ARRAY_JOB_ID." 16 | echo "My SLURM_ARRAY_TASK_ID is $SLURM_ARRAY_TASK_ID" 17 | echo "Executing on the machine:" $(hostname) 18 | 19 | # python l2ws_train_script.py sparse_pca cluster 20 | # python gif_script.py robust_pca cluster 21 | # python utils/portfolio_utils.py 22 | # python plot_script.py unconstrained_qp cluster 23 | # python plot_script.py lasso cluster 24 | # python plot_script.py mnist cluster 25 | # python plot_script.py quadcopter cluster 26 | # python plot_script.py robust_kalman cluster 27 | # python plot_script.py robust_ls cluster 28 | python plot_script.py sparse_pca cluster 29 | # python plot_script.py phase_retrieval cluster 30 | # python l2ws_train_script.py quadcopter cluster 31 | # python l2ws_setup_script.py unconstrained_qp cluster 32 | #python scs_c_speed.py markowitz 33 | # python aggregate_slurm_runs_script.py robust_pca cluster 34 | 35 | # gpu command: #SBATCH --gres=gpu:1 -------------------------------------------------------------------------------- /benchmarks/configs/quadcopter/quadcopter_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 0.0001 3 | method: adam 4 | intermediate_layer_sizes: 5 | - 100 6 | - 500 7 | batch_size: 100 8 | epochs: 1000000.0 9 | decay_lr: 0.1 10 | min_lr: 1.0e-07 11 | decay_every: 10000000.0 12 | plateau_decay: 13 | min_lr: 1.0e-07 14 | decay_factor: 5 15 | avg_window_size: 5 16 | tolerance: 1.0e-08 17 | patience: 1 18 | pretrain: 19 | pretrain_method: adam 20 | pretrain_stepsize: 0.001 21 | pretrain_iters: 0 22 | pretrain_batches: 10 23 | data: 24 | datetime: '' 25 | eval_every_x_epochs: 200 26 | save_every_x_epochs: 1 27 | test_every_x_epochs: 1 28 | write_csv_every_x_batches: 1 29 | epochs_jit: 10 30 | prediction_variable: w 31 | angle_anchors: 32 | - 0 33 | 34 | plot_iterates: 35 | - 0 36 | - 10 37 | - 20 38 | loss_method: fixed_k 39 | num_clusters: 10 40 | pretrain_alpha: false 41 | normalize_inputs: true 42 | normalize_alpha: other 43 | accuracies: 44 | - 0.1 45 | - 0.01 46 | - 0.001 47 | - 0.0001 48 | rho_x: 1 49 | scale: 1 50 | alpha_relax: 1 51 | skip_startup: false 52 | save_weights_flag: true 53 | 54 | supervised: true 55 | # eval_batch_size_test: 100 56 | solve_c_num: 0 57 | N_train: 10000 58 | N_test: 1000 59 | # num_samples: 1000 60 | num_samples_train: 100 61 | num_samples_test: 100 62 | train_unrolls: 5 63 | eval_unrolls: 500 64 | # load_weights_datetime: '2023-07-30/23-40-17' #'2023-07-09/14-14-53' #'2023-07-09/15-38-54' #'2023-07-09/14-35-29' #2023-07-09/14-14-53 65 | num_rollouts: 5 66 | closed_loop_budget: 15 67 | 68 | # reg: ['2023-07-09/14-14-53', '2023-07-09/14-27-54', '2023-07-09/14-32-59', '2023-07-09/14-33-59', '2023-08-22/03-07-27'] 69 | # fp: ['2023-07-09/15-06-38', '2023-07-09/15-23-01', '2023-07-09/15-35-15', '2023-08-22/10-25-00', '2023-07-09/15-40-50'] -------------------------------------------------------------------------------- /benchmarks/configs/robust_kalman/robust_kalman_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: ['2023-07-27/21-38-45', '2023-07-27/21-59-54', '2023-08-02/22-45-44', '2023-08-03/09-03-44', 2 | '2023-07-27/22-14-14', '2023-08-02/22-57-13', '2023-07-27/22-24-52', '2023-08-02/22-50-48'] 3 | # output_datetimes: ['2023-07-27/21-38-45', '2023-07-27/21-59-54', '2023-08-02/22-45-44', '2023-08-02/22-46-53', '2023-08-03/09-03-44', 4 | # '2023-07-27/22-14-14', '2023-08-02/22-57-13', '2023-07-27/22-24-52', '2023-08-02/22-48-46', '2023-08-02/22-50-48'] 5 | # for MOPTA: output_datetimes: ['2023-07-27/21-38-45', '2023-07-27/21-59-54', '2023-08-02/22-45-44', '2023-08-02/22-46-53', '2023-08-03/09-03-44', 6 | # '2023-07-27/22-14-14', '2023-08-02/22-57-13', '2023-07-27/22-24-52'] 7 | #, '2023-08-02/22-48-46', '2023-08-02/22-50-48'] 8 | # output_datetimes: ['2023-07-27/21-38-45', '2023-07-27/21-59-54', '2023-08-02/20-46-22', '2023-08-02/20-43-12', '2023-08-02/20-44-12', 9 | # '2023-07-27/22-14-14', '2023-07-27/22-23-07', '2023-07-27/22-24-52', '2023-08-02/20-47-33', '2023-08-02/20-49-06'] 10 | # output_datetimes: ['2023-07-27/21-38-45', '2023-07-27/21-59-54', '2023-07-27/21-58-52', '2023-07-27/22-10-49', '2023-07-27/22-11-54', 11 | # '2023-07-27/22-14-14', '2023-07-27/22-23-07', '2023-07-27/22-24-52', '2023-07-27/22-29-31', '2023-07-27/22-42-31'] 12 | # very old output_datetimes: ['2023-07-27/14-40-49', '2023-07-23/16-05-40', '2023-07-23/14-44-55', '2023-07-23/16-18-11', '2023-07-23/16-30-17', 13 | # '2023-07-23/16-33-31', '2023-07-23/14-44-10', '2023-07-23/13-48-13', '2023-07-27/15-01-47', '2023-07-27/15-02-50'] 14 | loss_overlay_titles: [] 15 | nearest_neighbor_datetime: '2023-07-27/21-38-45' 16 | prev_sol_datetime: '2023-07-27/21-38-45' #'2023-05-18/20-08-46' 17 | pretrain_datetime: '' 18 | cold_start_datetime: '2023-07-27/21-38-45' 19 | eval_iters: 300 20 | accuracies: [1e-1, 1e-2, 1e-3, 1e-4] 21 | rel_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 22 | abs_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Build and upload to PyPI 2 | 3 | # Build on every branch push, tag push, and pull request change: 4 | on: [push, pull_request] 5 | # Alternatively, to publish when a (published) GitHub Release is created, use the following: 6 | # on: 7 | # push: 8 | # pull_request: 9 | # release: 10 | # types: 11 | # - published 12 | 13 | jobs: 14 | build_wheels: 15 | name: Build wheels 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v3 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | python -m pip install ".[dev]" 24 | python -m pip install build 25 | 26 | - name: Build wheels and source distribution 27 | run: | 28 | python -m build -s -w -o ./dist/ . 29 | 30 | # Disable upload artifacts to avoid taking too much memory on github 31 | # actions 32 | # - uses: actions/upload-artifact@v3 33 | # with: 34 | # path: ./dist/* 35 | 36 | upload_pypi: 37 | name: Upload to PyPI 38 | needs: [build_wheels] 39 | runs-on: ubuntu-latest 40 | # upload to PyPI on every tag starting with 'v' 41 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 42 | # alternatively, to publish when a GitHub Release is created, use the following rule: 43 | # if: github.event_name == 'release' && github.event.action == 'published' 44 | steps: 45 | - uses: actions/download-artifact@v3 46 | with: 47 | # unpacks default artifact into dist/ 48 | # if `name: artifact` is omitted, the action will create extra parent dir 49 | name: artifact 50 | path: dist 51 | 52 | - uses: pypa/gh-action-pypi-publish@release/v1 53 | with: 54 | user: __token__ 55 | password: ${{ secrets.PYPI_API_TOKEN }} 56 | # To test: repository_url: https://test.pypi.org/legacy/ 57 | -------------------------------------------------------------------------------- /l2ws/eg_model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from l2ws.algo_steps import k_steps_eval_extragrad, k_steps_train_extragrad 4 | from l2ws.l2ws_model import L2WSmodel 5 | 6 | 7 | class EGmodel(L2WSmodel): 8 | def __init__(self, input_dict): 9 | super(EGmodel, self).__init__(input_dict) 10 | 11 | def initialize_algo(self, input_dict): 12 | self.m, self.n = input_dict['m'], input_dict['n'] 13 | self.q_mat_train, self.q_mat_test = input_dict['q_mat_train'], input_dict['q_mat_test'] 14 | self.algo = 'extragradient' 15 | self.factors_required = False 16 | self.factor_static = None 17 | 18 | eg_step = input_dict['eg_step'] 19 | m, n = self.m, self.n 20 | 21 | # function 22 | proj_X, proj_Y = input_dict['proj_X'], input_dict['proj_Y'] 23 | f = input_dict['f'] 24 | 25 | self.output_size = m + n 26 | self.out_axes_length = 5 27 | # self.k_steps_train_fn = partial(k_steps_train_extragrad, Q=Q, R=R, 28 | # eg_step=eg_step, jit=self.jit) 29 | self.k_steps_train_fn = partial( 30 | k_steps_train_extragrad, f=f, proj_X=proj_X, proj_Y=proj_Y, n=n, 31 | eg_step=eg_step, jit=self.jit) 32 | self.k_steps_eval_fn = partial(k_steps_eval_extragrad, 33 | f=f, proj_X=proj_X, proj_Y=proj_Y, n=n, 34 | eg_step=eg_step, jit=self.jit) 35 | 36 | # old 37 | # self.q_mat_train, self.q_mat_test = input_dict['q_mat_train'], input_dict['q_mat_test'] 38 | # m = R.shape[0] 39 | # n = Q.shape[0] 40 | # Q, R = input_dict['Q'], input_dict['R'] 41 | # self.k_steps_train_fn = partial(k_steps_train_extragrad, Q=Q, R=R, 42 | # eg_step=eg_step, jit=self.jit) 43 | # self.k_steps_eval_fn = partial(k_steps_eval_extragrad, Q=Q, R=R, 44 | # eg_step=eg_step, jit=self.jit) 45 | -------------------------------------------------------------------------------- /benchmarks/configs/mnist/mnist_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: ['2023-09-05/11-42-40', 2 | '2023-09-05/12-02-15', 3 | '2023-09-05/12-10-43', 4 | # '2023-09-05/12-30-10', 5 | '2023-09-05/12-31-19', 6 | '2023-08-28/16-00-25', 7 | '2023-08-28/16-14-46', 8 | '2023-08-28/16-15-38', 9 | # '2023-09-05/11-37-07', 10 | '2023-09-05/11-14-26'] 11 | 12 | # below is the failure 13 | # output_datetimes: ['2023-08-28/15-59-21', 14 | # '2023-08-28/15-08-17', 15 | # '2023-08-28/16-46-53', 16 | # '2023-08-28/16-47-34', 17 | # '2023-08-28/17-39-31', 18 | # '2023-08-28/16-00-25', 19 | # '2023-08-28/16-14-46', 20 | # '2023-08-28/16-15-38', 21 | # '2023-08-28/16-29-14', 22 | # '2023-08-28/16-30-25'] 23 | 24 | # belos is the way to get the plot 25 | # output_datetimes: ['2023-08-22/22-52-54', 26 | # '2023-08-22/22-58-52', 27 | # '2023-08-22/23-25-27', 28 | # # '2023-08-23/00-13-30', 29 | # '2023-08-23/00-14-41', 30 | # '2023-08-23/14-31-51', 31 | # '2023-08-20/19-34-54', 32 | # '2023-08-20/19-38-37', 33 | # # '2023-08-22/22-32-12', 34 | # '2023-08-22/22-38-07'] 35 | 36 | # output_datetimes: ['2023-08-19/21-55-53', '2023-08-19/21-56-31', '2023-08-19/21-57-04', '2023-08-19/22-27-02', '2023-08-19/22-29-37'] 37 | # output_datetimes: ['2023-07-19/23-07-27', '2023-07-21/16-29-12', '2023-07-21/16-37-23', '2023-07-21/16-43-27', '2023-07-21/16-47-44', '2023-07-21/17-15-45', '2023-07-21/17-16-47', '2023-07-21/17-19-27', '2023-07-21/17-21-21', '2023-07-21/17-22-37'] 38 | # ['2023-07-08/21-13-25', '2023-07-08/21-14-58', '2023-07-08/21-16-42', '2023-07-08/21-18-39', '2023-07-08/21-26-31', '2023-07-08/21-27-13', '2023-07-08/21-27-43', '2023-07-08/21-30-09', '2023-07-08/22-11-27', '2023-07-08/21-41-38', '2023-07-08/21-42-33'] 39 | loss_overlay_titles: [] 40 | nearest_neighbor_datetime: '2023-08-28/14-56-38' #'2023-08-20/19-34-54' # '2023-08-19/21-55-53' 41 | pretrain_datetime: '' 42 | cold_start_datetime: '2023-08-28/14-56-38' #'2023-08-20/19-34-54' # '2023-08-19/21-55-53' 43 | eval_iters: 499 44 | accuracies: [1e-1, 1e-2, 1e-3, 1e-4] 45 | rel_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 46 | abs_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 47 | -------------------------------------------------------------------------------- /benchmarks/configs/robust_kalman/robust_kalman_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-3 3 | method: adam 4 | intermediate_layer_sizes: [500, 500] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 10 # in epochs 15 | tolerance: 1e-4 16 | patience: 2 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' 27 | 28 | 29 | eval_unrolls: 500 30 | eval_every_x_epochs: 10 31 | save_every_x_epochs: 1 32 | test_every_x_epochs: 10 33 | write_csv_every_x_batches: 1 34 | N_train: 10000 35 | N_test: 1000 36 | num_samples_test: 100 37 | num_samples_train: 100 38 | angle_anchors: [0] 39 | 40 | plot_iterates: [0, 10, 20] 41 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 42 | share_all: False 43 | num_clusters: 2000 44 | pretrain_alpha: False 45 | normalize_inputs: True 46 | normalize_alpha: 'other' 47 | epochs_jit: 10 48 | accuracies: [1, .1, .01, .001, .0001] 49 | iterates_visualize: [5] 50 | 51 | rho_x: 1 52 | scale: 1 53 | alpha_relax: 1 54 | 55 | 56 | # solve_c_num: 1000 57 | save_weights_flag: True 58 | load_weights_datetime: '2023-07-27/21-59-54' 59 | vis_num: 10 60 | supervised: False 61 | train_unrolls: 5 62 | skip_startup: False 63 | 64 | # output_datetimes: 65 | # 8/2 66 | # output_datetimes: ['2023-07-27/21-38-45', '2023-07-27/21-59-54', '2023-08-02/14-41-25', '2023-08-02/14-52-11', '2023-08-02/15-10-41', 67 | # '2023-07-27/22-14-14', '2023-07-27/22-23-07', '2023-07-27/22-24-52', '2023-08-02/15-11-56', '2023-08-02/15-13-09'] 68 | 69 | # obj: ['2023-07-27/14-40-49', '2023-07-23/16-05-40', '2023-07-23/14-44-55', '2023-07-23/16-18-11', '2023-07-23/16-30-17', 70 | # reg: '2023-07-23/16-33-31', '2023-07-23/14-44-10', '2023-07-23/13-48-13', '2023-07-27/15-01-47', '2023-07-27/15-02-50'] 71 | 72 | #5: '2023-05-22/14-22-20' 73 | #15: '2023-05-22/14-23-18' 74 | #30: '2023-05-22/15-25-55' 75 | #60: '2023-05-22/15-29-14' 76 | #120: '2023-05-22/16-37-55' 77 | #1: '2023-05-22/16-43-06' -------------------------------------------------------------------------------- /l2ws/l2ws_helper_fns.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import lax 3 | 4 | 5 | def train_jitted_epochs(model, permutation, num_batches, epochs_jit, epoch=0): 6 | """ 7 | train self.epochs_jit at a time 8 | special case: the first time we call train_batch (i.e. epoch = 0) 9 | """ 10 | def train_over_epochs_body_simple_fn(batch, val): 11 | """ 12 | to be used as the body_fn in lax.fori_loop 13 | need to call partial for the specific permutation 14 | """ 15 | train_losses, params, state, permutation = val 16 | start_index = batch * model.batch_size 17 | batch_indices = lax.dynamic_slice( 18 | permutation, (start_index,), (model.batch_size,)) 19 | train_loss, params, state = model.train_batch( 20 | batch_indices, params, state) 21 | train_losses = train_losses.at[batch].set(train_loss) 22 | val = train_losses, params, state, permutation 23 | return val 24 | 25 | # epoch_batch_start_time = time.time() 26 | loop_size = int(num_batches * epochs_jit) 27 | epoch_train_losses = jnp.zeros(loop_size) 28 | if epoch == 0: 29 | # unroll the first iterate so that This allows `init_val` and `body_fun` 30 | # below to have the same output type, which is a requirement of 31 | # lax.while_loop and lax.scan. 32 | batch_indices = lax.dynamic_slice( 33 | permutation, (0,), (model.batch_size,)) 34 | 35 | train_loss_first, params, state = model.train_batch( 36 | batch_indices, model.params, model.state) 37 | 38 | epoch_train_losses = epoch_train_losses.at[0].set(train_loss_first) 39 | start_index = 1 40 | train_over_epochs_body_simple_fn_jitted = train_over_epochs_body_simple_fn 41 | else: 42 | start_index = 0 43 | params, state = model.params, model.state 44 | 45 | init_val = epoch_train_losses, params, state, permutation 46 | 47 | val = lax.fori_loop(start_index, loop_size, train_over_epochs_body_simple_fn_jitted, init_val) 48 | epoch_train_losses, params, state, permutation = val 49 | return params, state, epoch_train_losses -------------------------------------------------------------------------------- /l2ws/utils/nn_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from jax import jit, random, vmap 6 | from scipy.spatial import distance_matrix 7 | 8 | 9 | def get_nearest_neighbors(train_inputs, test_inputs, z_stars_train): 10 | distances = distance_matrix(np.array(test_inputs), np.array(train_inputs)) 11 | indices = np.argmin(distances, axis=1) 12 | np.min(distances, axis=1) 13 | 14 | # print('distances', distances) 15 | # print('indices', indices) 16 | # print('best val', best_val) 17 | 18 | return z_stars_train[indices, :] 19 | 20 | 21 | def random_layer_params(m, n, key, scale=1e-2): 22 | # def random_layer_params(m, n, key, scale=1e-2): 23 | # def random_layer_params(m, n, key, scale=1e-1): 24 | w_key, b_key = random.split(key) 25 | # fan_in, fan_out = shape[0], shape[1] 26 | # scale = jnp.sqrt(2.0 / (m + n)) 27 | return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) 28 | 29 | 30 | # Initialize all layers for a fully-connected neural network with sizes "sizes" 31 | def init_network_params(sizes, key): 32 | keys = random.split(key, len(sizes)) 33 | return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] 34 | 35 | 36 | def init_matrix_params(t, n, key): 37 | X_list = [] 38 | for i in range(t): 39 | U = random.normal(key + i, (n, n)) 40 | X = U @ U.T 41 | norm_X = X / X.max() 42 | X_list.append(norm_X) 43 | return X_list 44 | 45 | 46 | def relu(x): 47 | return jnp.maximum(0, x) 48 | 49 | 50 | @jit 51 | def predict_y(params, inputs): 52 | for W, b in params[:-1]: 53 | outputs = jnp.dot(W, inputs) + b 54 | inputs = relu(outputs) 55 | final_w, final_b = params[-1] 56 | outputs = jnp.dot(final_w, inputs) + final_b 57 | return outputs 58 | 59 | 60 | batched_predict_y = vmap(predict_y, in_axes=(None, 0)) 61 | 62 | 63 | @functools.partial(jit, static_argnums=(1,)) 64 | def full_vec_2_components(input, T): 65 | L = input[0] 66 | L_vec = input[1:T] 67 | x = input[T:3*T] 68 | delta = input[3*T:3*T+2*(T-1)] 69 | s = input[3*T+2*(T-1):] 70 | return L, L_vec, x, delta, s 71 | -------------------------------------------------------------------------------- /tests/compare_scs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scs 3 | from scipy.sparse import csc_matrix 4 | 5 | 6 | def main(): 7 | # setup some problem 8 | P, A, b, c, cones_dict = robust_ls_setup(20, 20) 9 | 10 | scs_data = dict(P=P, A=A, b=b, c=c) 11 | 12 | # solve with scs 13 | solver = scs.SCS(scs_data, 14 | cones_dict, 15 | normalize=False, 16 | scale=1, 17 | adaptive_scale=False, 18 | rho_x=1, 19 | alpha=1, 20 | acceleration_lookback=0, 21 | eps_abs=.001, 22 | eps_rel=0) 23 | sol = solver.solve() 24 | x, y, s = sol['x'], sol['y'], sol['s'] 25 | print('x', x) 26 | print('y', y) 27 | print('s', s) 28 | 29 | # solve with our method 30 | 31 | 32 | def robust_ls_setup(m_orig, n_orig): 33 | rho = 1 34 | 35 | # random A matrix 36 | A = (np.random.rand(m_orig, n_orig) * 2) - 1 37 | 38 | m_orig, n_orig = A.shape 39 | m, n = 2 * m_orig + n_orig + 2, n_orig + 2 40 | A_dense = np.zeros((m, n)) 41 | b = np.zeros(m) 42 | 43 | # constraint 1 44 | A_dense[:n_orig, :n_orig] = -np.eye(n_orig) 45 | 46 | # constraint 2 47 | A_dense[n_orig, n_orig] = -1 48 | A_dense[n_orig + 1:n_orig + m_orig + 1, :n_orig] = -A 49 | 50 | b[n_orig:m_orig + n_orig] = np.random.normal(m_orig) # fill in for b when theta enters -- 51 | # here we can put anything since it will change 52 | 53 | # constraint 3 54 | A_dense[n_orig + m_orig + 1, n_orig + 1] = -1 55 | A_dense[n_orig + m_orig + 2:, :n_orig] = -np.eye(n_orig) 56 | 57 | # create sparse matrix 58 | A_sparse = csc_matrix(A_dense) 59 | 60 | # Quadratic objective 61 | P = np.zeros((n, n)) 62 | P_sparse = csc_matrix(P) 63 | 64 | # Linear objective 65 | c = np.zeros(n) 66 | c[n_orig], c[n_orig + 1] = 1, rho 67 | 68 | # cones 69 | q_array = [m_orig + 1, n_orig + 1] 70 | cones_dict = dict(z=0, l=n_orig, q=q_array) 71 | # cones_array = jnp.array([cones["z"], cones["l"]]) 72 | # cones_array = jnp.concatenate([cones_array, jnp.array(cones["q"])]) 73 | return P_sparse, A_sparse, b, c, cones_dict 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /benchmarks/configs/sparse_pca/sparse_pca_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-3 3 | method: adam 4 | intermediate_layer_sizes: [500, 500] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 5 # in epochs 15 | tolerance: 1e-3 16 | patience: 50 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' #'2023-04-10/11-17-04' 27 | 28 | 29 | eval_unrolls: 300 30 | eval_every_x_epochs: 10 31 | save_every_x_epochs: 1 32 | test_every_x_epochs: 1 33 | write_csv_every_x_batches: 1 34 | epochs_jit: 2 35 | N_train: 10000 36 | N_test: 1000 37 | # num_samples: 1000 38 | prediction_variable: w 39 | angle_anchors: [0] 40 | plot_iterates: [0, 10, 20] 41 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 42 | share_all: False 43 | num_clusters: 100 44 | pretrain_alpha: False 45 | normalize_inputs: True 46 | normalize_alpha: 'other' 47 | 48 | accuracies: [1, .1, .01, .001, .0001] 49 | rho_x: 1 50 | scale: 1 51 | alpha_relax: 1 52 | skip_startup: True 53 | # solve_c_num: 1000 54 | save_weights_flag: True 55 | 56 | load_weights_datetime: '2023-08-04/21-58-01' 57 | train_unrolls: 60 58 | supervised: True 59 | num_samples: 100 60 | # num_samples_test: 1000 61 | # num_samples_train: 100 62 | # eval_batch_size_test: 200 63 | # eval_batch_size_train: 200 64 | # num_samples: 10 65 | # lightweight: True 66 | 67 | # output_datetimes: ['2023-07-25/23-28-42', '2023-07-25/22-13-13', '2023-07-26/22-08-40', '2023-07-26/22-38-49', '2023-07-26/22-47-30', 68 | # '2023-07-26/21-55-00', '2023-07-26/08-49-43', '2023-07-26/21-24-25', '2023-08-04/21-54-55', '2023-08-04/21-58-01'] 69 | 70 | # output_datetimes: ['2023-07-25/23-28-42', '2023-07-25/22-13-13', '2023-07-26/22-08-40', '2023-07-26/22-38-49', '2023-07-26/22-47-30', 71 | # '2023-07-26/21-55-00', '2023-07-26/08-49-43', '2023-07-26/21-24-25', '2023-07-26/20-47-36', '2023-07-26/21-18-04'] 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | # below is old 93 | # obj: ['2023-06-05/10-59-27', '2023-06-04/22-30-33', '2023-06-04/21-53-52', '2023-06-04/21-50-00', '2023-06-04/22-23-46', 94 | # reg: '2023-06-14/23-19-14', '2023-06-14/23-50-40', '2023-06-14/23-25-13', '2023-07-11/16-25-13', '2023-07-11/18-59-59'] -------------------------------------------------------------------------------- /benchmarks/configs/mnist/mnist_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-3 3 | method: adam 4 | intermediate_layer_sizes: [] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 5 # in epochs 15 | tolerance: 1e-3 #1e-3 16 | patience: 1 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' 27 | 28 | eval_every_x_epochs: 100 29 | save_every_x_epochs: 1 30 | test_every_x_epochs: 1 31 | write_csv_every_x_batches: 1 32 | epochs_jit: 2 33 | 34 | N_train: 10000 35 | N_test: 1000 36 | 37 | 38 | prediction_variable: w 39 | angle_anchors: [0] 40 | plot_iterates: [0, 10, 20] 41 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 42 | # share_all: False 43 | num_clusters: 10 44 | pretrain_alpha: False 45 | normalize_inputs: False 46 | normalize_alpha: 'other' 47 | 48 | accuracies: [.1, .01, .001, .0001] 49 | rho_x: 1 50 | scale: 1 51 | alpha_relax: 1 52 | 53 | 54 | 55 | save_weights_flag: True 56 | 57 | 58 | # eval_batch_size: 100 59 | 60 | 61 | 62 | 63 | # solving in C 64 | # rel_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 65 | # abs_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 66 | 67 | 68 | # rollouts 69 | # num_rollouts: 0 70 | # closed_loop_budget: 10 71 | 72 | # visualize 73 | iterates_visualize: [10, 20, 50, 90] #[10, 20, 50, 100, 200, 500] 74 | vis_num: 0 75 | skip_startup: True 76 | 77 | eval_unrolls: 100 78 | solve_c_num: 100 79 | num_samples: 100 80 | supervised: False 81 | train_unrolls: 60 82 | # load_weights_datetime: '2023-09-01/14-16-46' 83 | 84 | # output_datetimes: 85 | # fp: ['2023-09-01/14-20-37', 86 | # '2023-09-01/14-07-02', 87 | # '2023-09-01/14-12-38', 88 | # '2023-09-01/14-15-13', 89 | # '2023-08-23/00-14-41', 90 | # reg: '2023-08-23/14-31-51', 91 | # '2023-08-20/19-34-54', 92 | # '2023-08-20/19-38-37', 93 | # '2023-09-01/14-20-37', wrong 94 | # '2023-09-01/14-29-05'] 95 | 96 | # load_weights_datetime: '2023-08-22/22-58-52' 97 | # output_datetimes: 98 | # fp: ['2023-08-22/22-52-54', 99 | # '2023-08-22/22-58-52', 100 | # '2023-08-22/23-25-27', 101 | # '2023-08-23/00-13-30', 102 | # '2023-08-23/00-14-41', 103 | # reg: '2023-08-23/14-31-51', 104 | # '2023-08-20/19-34-54', 105 | # '2023-08-20/19-38-37', 106 | # '2023-08-22/22-32-12', 107 | # '2023-08-22/22-38-07'] -------------------------------------------------------------------------------- /benchmarks/configs/phase_retrieval/phase_retrieval_run.yaml: -------------------------------------------------------------------------------- 1 | nn_cfg: 2 | lr: 1e-3 #1e-2 3 | method: adam 4 | intermediate_layer_sizes: [500, 500] 5 | batch_size: 100 6 | epochs: 1e6 7 | decay_lr: .1 8 | min_lr: 1e-7 9 | decay_every: 1e7 10 | 11 | plateau_decay: 12 | min_lr: 1e-7 13 | decay_factor: 5 14 | avg_window_size: 5 # in epochs 15 | tolerance: 1e-3 16 | patience: 1 17 | 18 | 19 | pretrain: 20 | pretrain_method: adam 21 | pretrain_stepsize: 1e-3 22 | pretrain_iters: 0 23 | pretrain_batches: 10 24 | 25 | data: 26 | datetime: '' 27 | 28 | 29 | 30 | eval_every_x_epochs: 10 31 | save_every_x_epochs: 1 32 | test_every_x_epochs: 1 33 | write_csv_every_x_batches: 1 34 | N_train: 10000 # adjust these for larger setups 35 | N_test: 1000 36 | 37 | prediction_variable: w 38 | angle_anchors: [0] 39 | 40 | plot_iterates: [0, 10, 100] 41 | loss_method: 'fixed_k' #'fixed_k' #'constant_sum' 42 | share_all: False # this is the shared solution approach 43 | num_clusters: 1000 44 | pretrain_alpha: False 45 | normalize_inputs: True 46 | normalize_alpha: 'other' 47 | epochs_jit: 2 48 | accuracies: [1, .1, .01, .001, .0001] 49 | rho_x: 1 50 | scale: 1 51 | alpha_relax: 1 52 | 53 | num_samples_train: 2 54 | num_samples_test: 10 55 | skip_startup: False # can be used to skip the eval for no learn and NN 56 | solve_c_num: 100 57 | save_weights_flag: True 58 | eval_unrolls: 13000 59 | train_unrolls: 60 # this is the k that tells you how many FOM iterations 60 | supervised: True 61 | load_weights_datetime: '2023-08-05/11-55-58' 62 | # FOR gain plots output_datetimes: 63 | # reg: ['2023-07-10/21-01-41', '2023-07-10/21-58-27', '2023-07-10/21-56-15', '2023-08-04/22-18-37', '2023-08-05/11-55-58'] 64 | # fp: ['2023-07-10/22-31-47', '2023-07-10/23-13-58', '2023-07-10/23-16-59', '2023-08-04/22-13-05', '2023-08-05/11-56-29'] 65 | 66 | # BELOW IS OLD (written Aug 22) 67 | # output_datetimes: ['2023-07-10/21-01-41', '2023-07-10/21-58-27', '2023-07-10/21-56-15', '2023-08-04/22-18-37', '2023-08-04/22-19-38', 68 | # '2023-07-10/22-31-47', '2023-07-10/23-13-58', '2023-07-10/23-16-59', '2023-08-04/22-13-05', '2023-08-04/22-14-10'] 69 | # obj: '2023-07-10/22-31-47', '2023-07-10/23-13-58', '2023-07-10/23-16-59', '2023-07-10/23-29-20', '2023-07-10/23-33-13' 70 | # reg: '2023-07-10/21-01-41', '2023-07-10/21-58-27', '2023-07-10/21-56-15', '2023-07-10/22-18-28', '2023-07-10/22-26-36' 71 | # output_datetimes: ['2023-07-10/21-01-41', '2023-07-10/21-58-27', '2023-07-10/21-56-15', '2023-07-10/22-18-28', '2023-07-10/22-26-36', '2023-07-10/22-31-47', '2023-07-10/23-13-58', '2023-07-10/23-16-59', '2023-07-10/23-29-20', '2023-07-10/23-33-13'] -------------------------------------------------------------------------------- /benchmarks/configs/phase_retrieval/phase_retrieval_plot.yaml: -------------------------------------------------------------------------------- 1 | output_datetimes: 2 | - 2023-08-23/00-09-06 3 | - 2023-08-22/21-21-41 4 | - 2023-08-22/23-05-00 5 | # - 2023-08-22/22-34-04 6 | - 2023-08-22/23-01-44 7 | - 2023-08-23/00-10-10 8 | - 2023-08-23/00-36-09 9 | - 2023-08-23/00-40-50 10 | # - 2023-08-23/01-04-25 11 | - 2023-08-23/01-05-35 12 | loss_overlay_titles: [] 13 | nearest_neighbor_datetime: 2023-08-23/02-09-18 14 | pretrain_datetime: '' 15 | cold_start_datetime: 2023-08-23/02-09-18 16 | eval_iters: 500 17 | accuracies: 18 | - 1 19 | - 0.1 20 | - 0.01 21 | - 0.001 22 | - 0.0001 23 | rel_tols: 24 | - 0.1 25 | - 0.01 26 | - 0.001 27 | - 0.0001 28 | - 1.0e-05 29 | abs_tols: 30 | - 0.1 31 | - 0.01 32 | - 0.001 33 | - 0.0001 34 | - 1.0e-05 35 | 36 | 37 | # FOR plots output_datetimes: 38 | # - 2023-07-10/21-01-41 39 | # - 2023-07-10/21-58-27 40 | # - 2023-07-10/21-56-15 41 | # - 2023-08-04/22-18-37 42 | # - 2023-08-04/22-19-38 43 | # - 2023-07-10/22-31-47 44 | # - 2023-07-10/23-13-58 45 | # - 2023-07-10/23-16-59 46 | # - 2023-08-04/22-13-05 47 | # - 2023-08-04/22-14-10 48 | # loss_overlay_titles: [] 49 | # nearest_neighbor_datetime: 2023-07-10/23-33-13 50 | # pretrain_datetime: '' 51 | # cold_start_datetime: 2023-07-10/23-33-13 52 | # eval_iters: 500 53 | # accuracies: 54 | # - 1 55 | # - 0.1 56 | # - 0.01 57 | # - 0.001 58 | # - 0.0001 59 | # rel_tols: 60 | # - 0.1 61 | # - 0.01 62 | # - 0.001 63 | # - 0.0001 64 | # - 1.0e-05 65 | # abs_tols: 66 | # - 0.1 67 | # - 0.01 68 | # - 0.001 69 | # - 0.0001 70 | # - 1.0e-05 71 | 72 | 73 | 74 | # FOR gain plots output_datetimes: ['2023-07-10/21-01-41', '2023-07-10/21-58-27', '2023-07-10/21-56-15', '2023-08-04/22-18-37', '2023-08-05/11-55-58', 75 | # '2023-07-10/22-31-47', '2023-07-10/23-13-58', '2023-07-10/23-16-59', '2023-08-04/22-13-05', '2023-08-05/11-56-29'] 76 | # loss_overlay_titles: [] 77 | # nearest_neighbor_datetime: '2023-07-10/23-33-13' 78 | # pretrain_datetime: '' 79 | # cold_start_datetime: '2023-07-10/23-33-13' 80 | # eval_iters: 500 81 | # accuracies: [1, 1e-1, 1e-2, 1e-3, 1e-4] 82 | # rel_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 83 | # abs_tols: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 84 | 85 | 86 | # reg60: '2023-08-05/11-55-58' 87 | # fp60: '2023-08-05/11-56-29' 88 | # output_datetimes: ['2023-07-10/21-01-41', '2023-07-10/21-58-27', '2023-07-10/21-56-15', '2023-08-04/22-18-37', '2023-08-04/22-19-38', 89 | # '2023-07-10/22-31-47', '2023-07-10/23-13-58', '2023-07-10/23-16-59', '2023-08-04/22-13-05', '2023-08-04/22-14-10'] 90 | # output_datetimes: ['2023-07-10/21-01-41', '2023-07-10/21-58-27', '2023-08-04/22-17-03', '2023-08-04/22-18-37', '2023-08-04/22-19-38', 91 | # '2023-07-10/22-31-47', '2023-07-10/23-13-58', '2023-08-04/22-11-09', '2023-08-04/22-13-05', '2023-08-04/22-14-10'] 92 | # output_datetimes: ['2023-07-10/21-01-41', '2023-07-10/21-58-27', '2023-07-10/21-56-15', '2023-07-10/22-18-28', '2023-07-10/22-26-36', '2023-07-10/22-31-47', '2023-07-10/23-13-58', '2023-07-10/23-16-59', '2023-07-10/23-29-20', '2023-07-10/23-33-13'] 93 | 94 | -------------------------------------------------------------------------------- /l2ws/utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from jax import random 7 | 8 | 9 | def count_files_in_directory(directory): 10 | file_count = 0 11 | for _, _, files in os.walk(directory): 12 | file_count += len(files) 13 | return file_count 14 | 15 | 16 | def setup_permutation(key_count, N_train, epochs_jit): 17 | permutations = [] 18 | for i in range(epochs_jit): 19 | key = random.PRNGKey(key_count) 20 | key_count += 1 21 | epoch_permutation = random.permutation(key, N_train) 22 | permutations.append(epoch_permutation) 23 | stacked_permutation = jnp.stack(permutations) 24 | permutation = jnp.ravel(stacked_permutation) 25 | return permutation 26 | 27 | 28 | def sample_plot(input, title, num_plot): 29 | num_plot = np.min([num_plot, 4]) 30 | for i in range(num_plot): 31 | plt.plot(input[i, :]) 32 | plt.ylabel(f"{title} values") 33 | plt.xlabel(f"{title} indices") 34 | plt.savefig(f"sample_{title}.pdf") 35 | plt.clf() 36 | 37 | 38 | def vec_symm(X, triu_indices=None, factor=jnp.sqrt(2)): 39 | """Returns a vectorized representation of a symmetric matrix `X`. 40 | Vectorization (including scaling) as per SCS. 41 | vec(X) = (X11, sqrt(2)*X21, ..., sqrt(2)*Xk1, X22, sqrt(2)*X32, ..., Xkk) 42 | """ 43 | 44 | # X = X.copy() 45 | X *= factor 46 | X = X.at[jnp.diag_indices(X.shape[0])].set(jnp.diagonal(X) / factor) 47 | if triu_indices is None: 48 | col_idx, row_idx = jnp.triu_indices(X.shape[0]) 49 | else: 50 | col_idx, row_idx = triu_indices 51 | return X[(row_idx, col_idx)] 52 | 53 | 54 | def unvec_symm(x, dim, triu_indices=None): 55 | """Returns a dim-by-dim symmetric matrix corresponding to `x`. 56 | `x` is a vector of length dim*(dim + 1)/2, corresponding to a symmetric 57 | matrix; the correspondence is as in SCS. 58 | X = [ X11 X12 ... X1k 59 | X21 X22 ... X2k 60 | ... 61 | Xk1 Xk2 ... Xkk ], 62 | where 63 | vec(X) = (X11, sqrt(2)*X21, ..., sqrt(2)*Xk1, X22, sqrt(2)*X32, ..., Xkk) 64 | """ 65 | 66 | X = jnp.zeros((dim, dim)) 67 | 68 | # triu_indices gets indices of upper triangular matrix in row-major order 69 | if triu_indices is None: 70 | col_idx, row_idx = jnp.triu_indices(dim) 71 | else: 72 | col_idx, row_idx = triu_indices 73 | z = jnp.zeros(x.size) 74 | 75 | if x.ndim > 1: 76 | for i in range(x.size): 77 | z = z.at[i].set(x[i][0, 0]) 78 | else: 79 | z = x 80 | 81 | X = X.at[(row_idx, col_idx)].set(z) 82 | 83 | X = X + X.T 84 | X /= jnp.sqrt(2) 85 | X = X.at[jnp.diag_indices(dim)].set(jnp.diagonal(X) * jnp.sqrt(2) / 2) 86 | return X 87 | 88 | 89 | # non jit loop 90 | def python_fori_loop(lower, upper, body_fun, init_val): 91 | val = init_val 92 | for i in range(lower, upper): 93 | val = body_fun(i, val) 94 | return val 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # added 132 | .DS_Store 133 | .vscode/ 134 | l2ws_env 135 | outputs/ 136 | data/ 137 | scs-python/ 138 | outputs 139 | notebooks/mnist_data/ 140 | slurm-* 141 | # examples/mnist.py 142 | examples/mnist_data/ 143 | mnist_data/ 144 | cartpole* 145 | high_mpc/ 146 | examples/cifar-10-batches-py/ 147 | examples/mri_data/ 148 | motivating_example/ 149 | examples/mri_data2/ 150 | rollout_2_flight_learned_old.gif 151 | rollout_2_flight_learned.gif 152 | rollout_2_flight_nn_old.gif 153 | rollout_2_flight_nn.gif 154 | rollout_2_flight_ps_old.gif 155 | rollout_2_flight_ps.gif 156 | combined_gif.gif 157 | perturb_plots/* 158 | 159 | # Direnv 160 | .envrc 161 | .direnv/ 162 | 163 | # SCM version 164 | l2ws/_version.py 165 | -------------------------------------------------------------------------------- /tests/test_canonicalizations.py: -------------------------------------------------------------------------------- 1 | import cvxpy as cp 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | from l2ws.examples.robust_ls import multiple_random_robust_ls 6 | from l2ws.examples.sparse_pca import multiple_random_sparse_pca 7 | from l2ws.scs_problem import scs_jax 8 | 9 | 10 | def test_phase_retrieval(): 11 | pass 12 | 13 | 14 | def test_sparse_pca(): 15 | n_orig, k, r, N = 60, 5, 10, 5 16 | 17 | # create n parametric problems 18 | P, A, cones, q_mat, theta_mat_jax, A_tensor = multiple_random_sparse_pca( 19 | n_orig, k, r, N, factor=False) 20 | 21 | # scs hyperparams 22 | rho_x = 1 23 | scale = 10 24 | alpha = 1 25 | 26 | # solve with our DR splitting 27 | m, n = A.shape 28 | x_ws = jnp.ones(n) 29 | y_ws = jnp.ones(m) 30 | s_ws = jnp.zeros(m) 31 | max_iters = 800 32 | c, b = q_mat[0, :n], q_mat[0, n:] 33 | data = dict(P=P, A=A, c=c, b=b, cones=cones, x=x_ws, y=y_ws, s=s_ws) 34 | sol_hsde = scs_jax(data, hsde=True, rho_x=rho_x, scale=scale, alpha=alpha, 35 | iters=max_iters, plot=False) 36 | x_jax = sol_hsde['x'] 37 | fp_res_hsde = sol_hsde['fixed_point_residuals'] 38 | 39 | # form matrix from vector solution 40 | jax_obj = c @ x_jax 41 | 42 | # solve with cvxpy 43 | X = cp.Variable((n_orig, n_orig), symmetric=True) 44 | constraints = [X >> 0, cp.sum(cp.abs(X)) <= k, cp.trace(X) == 1] 45 | prob = cp.Problem(cp.Minimize(-cp.trace(A_tensor[0, :, :] @ X)), constraints) 46 | # prob.solve(solver=cp.SCS, verbose=True, rho_x=1, normalize=False, adaptive_scale=False) 47 | # prob.solve(solver=cp.SCS, verbose=True, rho_x=1, normalize=False) 48 | prob.solve(solver=cp.SCS, verbose=True, rho_x=rho_x, scale=scale, 49 | alpha=alpha, eps_abs=1e-4, eps_rel=0, adaptive_scale=False) 50 | cvxpy_obj = prob.value 51 | 52 | assert jnp.abs((jax_obj - cvxpy_obj) / cvxpy_obj) <= 1e-2 53 | 54 | assert fp_res_hsde[0] > 10 55 | assert fp_res_hsde[-1] < 5e-2 and fp_res_hsde[-1] > 1e-16 56 | 57 | 58 | def test_robust_ls(): 59 | m_orig, n_orig = 10, 20 60 | N = 1 61 | rho, b_center, b_range = 2, 1, 1 62 | 63 | P, A, cones, q_mat, theta_mat, A_orig, b_orig_mat = multiple_random_robust_ls( 64 | m_orig, n_orig, rho, b_center, b_range, N) 65 | m, n = A.shape 66 | 67 | # solve with our DR splitting 68 | x_ws = np.ones(n) 69 | y_ws = np.ones(m) 70 | s_ws = np.zeros(m) 71 | max_iters = 500 72 | c, b = q_mat[0, :n], q_mat[0, n:] 73 | data = dict(P=P, A=A, c=c, b=b, cones=cones, x=x_ws, y=y_ws, s=s_ws) 74 | sol_hsde = scs_jax(data, hsde=True, iters=max_iters) 75 | x_jax = sol_hsde['x'] 76 | fp_res_hsde = sol_hsde['fixed_point_residuals'] 77 | x_jax_final = x_jax[:n_orig] 78 | 79 | # solve with cvxpy 80 | x, u, v = cp.Variable(n_orig), cp.Variable(), cp.Variable() 81 | constraints = [x >= 0, cp.norm(A_orig @ x - b_orig_mat[0, :]) <= u, cp.norm(x) <= v] 82 | prob = cp.Problem(cp.Minimize(u + rho * v), constraints) 83 | prob.solve() 84 | x_cvxpy = x.value 85 | 86 | assert jnp.linalg.norm(x_cvxpy - x_jax_final) <= 5e-2 87 | # assert jnp.all(jnp.diff(fp_res_hsde[1:]) < 1e-10) 88 | 89 | assert fp_res_hsde[0] > 10 90 | assert fp_res_hsde[-1] < 1e-3 and fp_res_hsde[-1] > 1e-16 91 | -------------------------------------------------------------------------------- /l2ws/utils/portfolio_utils.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import yfinance as yf 6 | 7 | # quandl.ApiConfig.api_key = os.environ['QUANDL_API_KEY'] 8 | 9 | 10 | def main(): 11 | nasdaq() 12 | # yahoo() 13 | 14 | def nasdaq(): 15 | # stacked_filename = 'data/portfolio_data/WIKI_prices_all.csv' 16 | stacked_filename = 'data/portfolio_data/EOD.csv' 17 | stacked_df = pd.read_csv(stacked_filename) 18 | 19 | # create a new dataframe with 'date' column as index 20 | new = stacked_df.set_index('date') 21 | 22 | # use pandas pivot function to sort adj_close by tickers 23 | clean_data = new.pivot(columns='ticker') 24 | 25 | # check the head of the output 26 | clean_data.head() 27 | 28 | num_nan = clean_data.isna().sum(axis=0) 29 | # close_data = clean_data.dropna(axis=1, how='any') 30 | # pdb.set_trace() 31 | small_nan_data = num_nan < 10 32 | indices = small_nan_data.to_numpy() 33 | close_data = clean_data.iloc[:, indices] 34 | 35 | # now get largest 3000 assets 36 | sums = close_data.min(axis=0).to_numpy() 37 | sums_argsorted = np.argsort(sums) 38 | highest_indices = sums_argsorted[-3000:] 39 | 40 | top3000 = close_data.iloc[:, highest_indices] 41 | 42 | # close_data_shorter = close_data.iloc[:3000, :] 43 | close_data_shorter = top3000 44 | 45 | # clean_data_shorter = clean_data.iloc[2300:, :] 46 | 47 | # fill in missing entries 48 | close_data_shorter = close_data_shorter.fillna(close_data_shorter.mean()) 49 | diff = -close_data_shorter.diff() 50 | diff2 = diff.iloc[1:,:] 51 | ret = diff2 / close_data_shorter.iloc[:-1,:] 52 | ret = ret.fillna(ret.mean()) 53 | 54 | 55 | short_ret = ret.iloc[:, :3000] 56 | # short_ret = short_ret.clip(lower=-.5, upper=.5) 57 | # pdb.set_trace() 58 | short_ret.to_csv('data/portfolio_data/returns.csv') 59 | 60 | covariance = short_ret.cov() 61 | covariance.to_csv('data/portfolio_data/covariance.csv') 62 | 63 | short_ret_np = short_ret.to_numpy() 64 | cov_np = covariance.to_numpy() 65 | 66 | # get factor model 67 | U, S, VT = np.linalg.svd(cov_np) 68 | pdb.set_trace() 69 | factor = 15 70 | S_factor = np.diag(S[:factor]) 71 | factor_cov = U[:, :factor] @ S_factor @ VT[:factor, :] 72 | 73 | # cov_np[cov_np > .001] = .001 74 | # cov_np[cov_np < -.001] = -.001 75 | filename = 'data/portfolio_data/eod_ret_cov.npz' 76 | pdb.set_trace() 77 | np.savez(filename, ret=short_ret_np, cov=factor_cov) 78 | 79 | def yahoo(): 80 | ''' 81 | get all tickers 82 | ''' 83 | tickers_df = pd.read_csv('data/portfolio_data/yahoo_tickers.csv') 84 | tickers_list_ = tickers_df.values.tolist() 85 | tickers_list = [tt[0] for tt in tickers_list_] 86 | data1 = yf.download(tickers_list[:2000], start="2020-01-01", end="2021-01-01") 87 | data2 = yf.download(tickers_list[2000:], start="2020-01-01", end="2021-01-01") 88 | data = data1.append(data2, ignore_index=True) 89 | close_data_all = data['Adj Close'] 90 | close_data = close_data_all.dropna(axis=1, how='all') 91 | close_data = close_data.fillna(close_data.mean()) 92 | 93 | ret = close_data.diff() / close_data 94 | covariance = ret.cov() 95 | 96 | ret_np = ret.to_numpy() 97 | cov_np = covariance.to_numpy() 98 | filename = 'data/portfolio_data/yahoo_ret_cov.npz' 99 | 100 | np.savez(filename, ret=ret_np, cov=cov_np) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /l2ws/examples/lasso.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import cvxpy as cp 4 | import yaml 5 | from l2ws.launcher import Workspace 6 | from l2ws.examples.solve_script import ista_setup_script 7 | import os 8 | from scipy.sparse import random 9 | 10 | def run(run_cfg): 11 | example = "lasso" 12 | data_yaml_filename = 'data_setup_copied.yaml' 13 | 14 | # read the yaml file 15 | with open(data_yaml_filename, "r") as stream: 16 | try: 17 | setup_cfg = yaml.safe_load(stream) 18 | except yaml.YAMLError as exc: 19 | print(exc) 20 | setup_cfg = {} 21 | 22 | # set the seed 23 | np.random.seed(setup_cfg['seed']) 24 | m_orig, n_orig = setup_cfg['m_orig'], setup_cfg['n_orig'] 25 | A_scale = setup_cfg['A_scale'] 26 | A = A_scale * jnp.array(np.random.normal(size=(m_orig, n_orig))) 27 | # n2 = int(n_orig / 2) 28 | # A = A.at[:, :n2].set(A[:, :n2] / 10) 29 | # A = A.at[:, :(int(n_orig / 2))] * 100 30 | # split = int(n_orig / 2) 31 | # A_vec = jnp.concatenate([100 * jnp.ones(split), 1 * jnp.ones(split)]) 32 | # A = jnp.diag(A_vec) 33 | # density = 0.1 34 | # A = A_scale * jnp.array(random(m_orig, n_orig, density=density, format='csr').todense()) 35 | evals, evecs = jnp.linalg.eigh(A.T @ A) 36 | ista_step = 1 / evals.max() 37 | lambd = setup_cfg['lambd'] 38 | 39 | static_dict = dict(A=A, lambd=lambd, ista_step=ista_step) 40 | 41 | # we directly save q now 42 | static_flag = True 43 | algo = 'ista' 44 | workspace = Workspace(algo, run_cfg, static_flag, static_dict, example) 45 | 46 | # run the workspace 47 | workspace.run() 48 | 49 | 50 | def setup_probs(setup_cfg): 51 | cfg = setup_cfg 52 | N_train, N_test = cfg.N_train, cfg.N_test 53 | N = N_train + N_test 54 | np.random.seed(setup_cfg['seed']) 55 | m_orig, n_orig = setup_cfg['m_orig'], setup_cfg['n_orig'] 56 | n2 = int(n_orig / 2) 57 | A_scale = setup_cfg['A_scale'] 58 | # b_scale = setup_cfg['b_scale'] 59 | A = A_scale * jnp.array(np.random.normal(size=(m_orig, n_orig))) 60 | # A = A.at[:, :n2].set(A[:, :n2] / 10) 61 | # split = int(n_orig / 2) 62 | # A_vec = jnp.concatenate([100 * jnp.ones(split), 1 * jnp.ones(split)]) 63 | # A = jnp.diag(A_vec) 64 | # density = 0.1 65 | # A = A_scale * jnp.array(random(m_orig, n_orig, density=density, format='csr').todense()) 66 | evals, evecs = jnp.linalg.eigh(A.T @ A) 67 | ista_step = 1 / evals.max() 68 | lambd = setup_cfg['lambd'] 69 | 70 | np.random.seed(cfg.seed) 71 | 72 | # save output to output_filename 73 | output_filename = f"{os.getcwd()}/data_setup" 74 | 75 | b_min, b_max = setup_cfg['b_min'], setup_cfg['b_max'] 76 | # b_mat = b_scale * generate_b_mat(A, N, b_min, b_max) 77 | m, n = A.shape 78 | b_mat = (b_max - b_min) * np.random.rand(N, m) + b_min 79 | 80 | # 81 | # b_mat[:, :n2] = b_mat[:, :n2] / 10 82 | 83 | ista_setup_script(b_mat, A, lambd, output_filename) 84 | 85 | 86 | def generate_b_mat(A, N, p=.1): 87 | m, n = A.shape 88 | b0 = jnp.array(np.random.normal(size=(m))) 89 | b_mat = 0 * b0 + 1 * jnp.array(np.random.normal(size=(N, m))) + 1 90 | # b_mat = jnp.zeros((N, m)) 91 | # x_star_mask = np.random.binomial(1, p, size=(N, n)) 92 | # x_stars_dense = jnp.array(np.random.normal(size=(N, n))) 93 | # x_stars = jnp.multiply(x_star_mask, x_stars_dense) 94 | # for i in range(N): 95 | # b = A @ x_stars[i, :] 96 | # b_mat = b_mat.at[i, :].set(b) 97 | return b_mat 98 | 99 | def eval_ista_obj(z, A, b, lambd): 100 | return .5 * jnp.linalg.norm(A @ z - b) ** 2 + lambd * jnp.linalg.norm(z, ord=1) 101 | 102 | 103 | def obj_diff(obj, true_obj): 104 | return (obj - true_obj) 105 | 106 | 107 | def sol_2_obj_diff(z, b, true_obj, A, lambd): 108 | obj = eval_ista_obj(z, A, b, lambd) 109 | return obj_diff(obj, true_obj) 110 | 111 | def solve_many_probs_cvxpy(A, b_mat, lambd): 112 | """ 113 | solves many lasso problems where each problem has a different b vector 114 | """ 115 | m, n = A.shape 116 | N = b_mat.shape[0] 117 | z, b_param = cp.Variable(n), cp.Parameter(m) 118 | prob = cp.Problem(cp.Minimize(.5 * cp.sum_squares(np.array(A) @ z - b_param) + lambd * cp.norm(z, p=1))) 119 | # prob = cp.Problem(cp.Minimize(.5 * cp.sum_squares(np.array(A) @ z - b_param) + lambd * cp.tv(z))) 120 | z_stars = jnp.zeros((N, n)) 121 | objvals = jnp.zeros((N)) 122 | for i in range(N): 123 | b_param.value = np.array(b_mat[i, :]) 124 | prob.solve(verbose=False) 125 | objvals = objvals.at[i].set(prob.value) 126 | z_stars = z_stars.at[i, :].set(jnp.array(z.value)) 127 | print('finished solving cvxpy problems') 128 | return z_stars, objvals 129 | -------------------------------------------------------------------------------- /l2ws/examples/jamming.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import yaml 4 | from jax import vmap 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import time 8 | import jax.numpy as jnp 9 | import os 10 | import scs 11 | import cvxpy as cp 12 | import jax.scipy as jsp 13 | import jax.random as jra 14 | from l2ws.algo_steps import create_M 15 | from scipy.sparse import csc_matrix 16 | from l2ws.examples.solve_script import setup_script 17 | from l2ws.launcher import Workspace 18 | from l2ws.algo_steps import get_scaled_vec_and_factor 19 | from jaxopt.projection import projection_simplex 20 | from l2ws.algo_steps import k_steps_train_extragrad, k_steps_eval_extragrad 21 | 22 | 23 | plt.rcParams.update( 24 | { 25 | "text.usetex": True, 26 | "font.family": "serif", 27 | "font.size": 16, 28 | } 29 | ) 30 | log = logging.getLogger(__name__) 31 | 32 | 33 | def run(run_cfg): 34 | example = "jamming" 35 | data_yaml_filename = 'data_setup_copied.yaml' 36 | 37 | # read the yaml file 38 | with open(data_yaml_filename, "r") as stream: 39 | try: 40 | setup_cfg = yaml.safe_load(stream) 41 | except yaml.YAMLError as exc: 42 | print(exc) 43 | setup_cfg = {} 44 | 45 | # TODO 46 | # set the seed 47 | np.random.seed(setup_cfg['seed']) 48 | # n_orig = setup_cfg['n_orig'] 49 | # d_mul = setup_cfg['d_mul'] 50 | # k = setup_cfg['k'] 51 | # static_dict = static_canon(n_orig, d_mul, rho_x=rho_x, scale=scale) 52 | 53 | # we directly save q now 54 | get_q = None 55 | static_flag = True 56 | algo = 'extragradient' 57 | m, n = setup_cfg['n'], setup_cfg['n'] 58 | eg_step = setup_cfg['step_size'] 59 | static_dict = dict(f=jamming_obj, 60 | proj_X=projection_simplex, 61 | proj_Y=projection_simplex, m=m, n=n, eg_step=eg_step) 62 | workspace = Workspace(algo, run_cfg, static_flag, static_dict, example) 63 | 64 | # run the workspace 65 | workspace.run() 66 | 67 | 68 | def setup_probs(setup_cfg): 69 | cfg = setup_cfg 70 | N_train, N_test = cfg.N_train, cfg.N_test 71 | N = N_train + N_test 72 | beta_min, beta_max = cfg.beta_min, cfg.beta_max 73 | sigma_min, sigma_max = cfg.sigma_min, cfg.sigma_max 74 | n = cfg.n 75 | k = cfg.solve_iters 76 | eg_step = cfg.step_size 77 | delta_frac = cfg.delta_frac 78 | 79 | np.random.seed(cfg.seed) 80 | key = jra.PRNGKey(cfg.seed) 81 | 82 | beta0 = beta_min + np.random.rand(n) * (beta_max - beta_min) 83 | 84 | # sample uniformly to get beta, sigma 85 | beta_delta = beta_min + np.random.rand(N, n) * (beta_max - beta_min) 86 | beta = beta0 + delta_frac * beta_delta 87 | 88 | # sample uniformly to get beta, sigma 89 | # beta = beta_min + np.random.rand(N, n) * (beta_max - beta_min) 90 | sigma = sigma_min + np.random.rand(N, n) * (sigma_max - sigma_min) 91 | theta_mat = jnp.hstack([beta, sigma]) 92 | 93 | # solve each problem using the extragradient method 94 | proj_X = projection_simplex 95 | proj_Y = projection_simplex 96 | 97 | z0 = -jnp.ones(2 * n) 98 | f = jamming_obj 99 | # z_final, iter_losses = k_steps_train_extragrad(k, z0, theta_mat[0, :], f, proj_X, proj_Y, n, eg_step, 100 | # supervised=False, z_star=None, jit=True) 101 | 102 | # save output to output_filename 103 | output_filename = f"{os.getcwd()}/data_setup" 104 | 105 | z_stars = jnp.zeros((N, 2 * n)) 106 | for i in range(N): 107 | if i % 100 == 0: 108 | print('solving ith prob', i) 109 | z_final, iter_losses, z_all, obj_diffs = k_steps_eval_extragrad(k, z0, 110 | theta_mat[i, :], 111 | f, proj_X, proj_Y, n, eg_step, 112 | supervised=False, z_star=None, jit=True) 113 | z_stars = z_stars.at[i, :].set(z_final) 114 | print('fixed point residual', iter_losses[-1]) 115 | if i % 1000 == 0: 116 | log.info("final intermediate final data...", i) 117 | t0 = time.time() 118 | jnp.savez( 119 | output_filename + str(i), 120 | thetas=theta_mat, 121 | z_stars=z_stars, 122 | ) 123 | 124 | # plt.plot(iter_losses) 125 | # plt.yscale('lo') 126 | 127 | # import pdb 128 | # pdb.set_trace() 129 | # aa = z_all - z_final 130 | # diffs = jnp.linalg.norm(aa, axis=1) 131 | # plt.plot(diffs) 132 | plt.plot(iter_losses) 133 | plt.yscale('log') 134 | plt.savefig("fp_resids.pdf") 135 | plt.clf() 136 | 137 | 138 | 139 | # setup_script(q_mat, theta_mat_jax, solver, data, cones, output_filename, solve=cfg.solve) 140 | # save the data 141 | log.info("final saving final data...") 142 | t0 = time.time() 143 | jnp.savez( 144 | output_filename, 145 | q_mat=theta_mat, 146 | thetas=theta_mat[:, :n], 147 | z_stars=z_stars, 148 | ) 149 | 150 | save_time = time.time() 151 | log.info(f"finished saving final data... took {save_time-t0}'") 152 | 153 | # save plot of first 5 solutions 154 | for i in range(5): 155 | plt.plot(z_stars[i, :]) 156 | plt.savefig("z_stars.pdf") 157 | plt.clf() 158 | 159 | # save plot of first 5 parameters 160 | for i in range(5): 161 | plt.plot(theta_mat[i, :]) 162 | plt.savefig("thetas.pdf") 163 | plt.clf() 164 | 165 | import pdb 166 | pdb.set_trace() 167 | 168 | 169 | def jamming_obj(x, y, theta): 170 | """ 171 | creates the objective in the saddle problem 172 | """ 173 | n = x.size 174 | beta, sigma = theta[:n], theta[n:] 175 | 176 | # costs = jnp.log(1 + beta[i] * x[i] / (sigma[i] + y[i])) 177 | objs = batch_jamming_costs(beta, sigma, x, y) 178 | return jnp.sum(objs) 179 | 180 | 181 | def single_jamming_cost(beta, sigma, x, y): 182 | return jnp.log(1 + beta * y / (sigma + x)) 183 | 184 | 185 | batch_jamming_costs = vmap(single_jamming_cost, in_axes=(0, 0, 0, 0), out_axes=(0)) 186 | -------------------------------------------------------------------------------- /l2ws/scs_problem.py: -------------------------------------------------------------------------------- 1 | import cvxpy as cp 2 | import jax 3 | import jax.numpy as jnp 4 | from jax import random 5 | from matplotlib import pyplot as plt 6 | 7 | from l2ws.algo_steps import ( 8 | create_M, 9 | create_projection_fn, 10 | extract_sol, 11 | get_scaled_vec_and_factor, 12 | k_steps_eval_scs, 13 | lin_sys_solve, 14 | ) 15 | 16 | 17 | class SCSinstance(object): 18 | def __init__(self, prob, solver, manual_canon=False): 19 | self.manual_canon = manual_canon 20 | if manual_canon: 21 | # manual canonicalization 22 | data = prob 23 | self.P = data['P'] 24 | self.A = data['A'] 25 | self.b = data['b'] 26 | self.c = data['c'] 27 | self.scs_data = dict(P=self.P, A=self.A, b=self.b, c=self.c) 28 | # self.cones = data['cones'] 29 | solver.update(b=self.b) 30 | solver.update(c=self.c) 31 | self.solver = solver 32 | 33 | else: 34 | # automatic canonicalization 35 | data = prob.get_problem_data(cp.SCS)[0] 36 | self.P = data['P'] 37 | self.A = data['A'] 38 | self.b = data['b'] 39 | self.c = data['c'] 40 | self.cones = dict(z=data['dims'].zero, l=data['dims'].nonneg) 41 | # self.cones = dict(data['dims'].zero, data['dims'].nonneg) 42 | self.prob = prob 43 | 44 | # we will use self.q for our DR-splitting 45 | 46 | self.q = jnp.concatenate([self.c, self.b]) 47 | self.solve() 48 | 49 | def solve(self): 50 | if self.manual_canon: 51 | # solver = scs.SCS(self.scs_data, self.cones, 52 | # eps_abs=1e-4, eps_rel=1e-4) 53 | # solver = scs.SCS(self.scs_data, self.cones, 54 | # eps_abs=1e-5, eps_rel=1e-5) 55 | # Solve! 56 | sol = self.solver.solve() 57 | self.x_star = jnp.array(sol['x']) 58 | self.y_star = jnp.array(sol['y']) 59 | self.s_star = jnp.array(sol['s']) 60 | self.solve_time = sol['info']['solve_time'] / 1000 61 | else: 62 | self.prob.solve(solver=cp.SCS, verbose=True) 63 | self.x_star = jnp.array( 64 | self.prob.solution.attr['solver_specific_stats']['x']) 65 | self.y_star = jnp.array( 66 | self.prob.solution.attr['solver_specific_stats']['y']) 67 | self.s_star = jnp.array( 68 | self.prob.solution.attr['solver_specific_stats']['s']) 69 | 70 | 71 | def scs_jax(data, hsde=True, rho_x=1e-6, scale=.1, alpha=1.5, iters=5000, jit=True, plot=False): 72 | P, A = data['P'], data['A'] 73 | c, b = data['c'], data['b'] 74 | cones = data['cones'] 75 | zero_cone_size = cones['z'] 76 | 77 | m, n = A.shape 78 | 79 | M = create_M(P, A) 80 | 81 | algo_factor, scale_vec = get_scaled_vec_and_factor(M, rho_x, scale, m, n, zero_cone_size, 82 | hsde=hsde) 83 | q = jnp.concatenate([c, b]) 84 | 85 | proj = create_projection_fn(cones, n) 86 | 87 | key = random.PRNGKey(0) 88 | if 'x' in data.keys() and 'y' in data.keys() and 's' in data.keys(): 89 | # warm start with z = (x, y + s) or 90 | # z = (x, y + s, 1) with the hsde 91 | z = jnp.concatenate([data['x'], data['y'] + data['s'] / scale_vec[n:]]) 92 | if hsde: 93 | # we pick eta = 1 for feasibility of warm-start 94 | z = jnp.concatenate([z, jnp.array([1])]) 95 | else: 96 | if hsde: 97 | mu = 1 * random.normal(key, (m + n,)) 98 | z = jnp.concatenate([mu, jnp.array([1])]) 99 | else: 100 | z = 1 * random.normal(key, (m + n,)) 101 | 102 | if hsde: 103 | q_r = lin_sys_solve(algo_factor, q) 104 | else: 105 | q_r = q 106 | 107 | # eval_out = k_steps_eval_scs(iters, z, q_r, algo_factor, proj, P, A, 108 | # c, b, jit, hsde, zero_cone_size, rho_x=rho_x, 109 | # scale=scale, alpha=alpha) 110 | supervised, z_star = False, None 111 | eval_out = k_steps_eval_scs(iters, z, q_r, algo_factor, proj, P, A, supervised, z_star, 112 | jit, hsde, zero_cone_size, rho_x=rho_x, scale=scale, alpha=alpha) 113 | # z_final, iter_losses, primal_residuals, dual_residuals, z_all_plus_1, u_all, v_all = eval_out 114 | z_final, iter_losses, z_all_plus_1, primal_residuals, dual_residuals, u_all, v_all = eval_out 115 | 116 | u_final, v_final = u_all[-1, :], v_all[-1, :] 117 | 118 | # extract the primal and dual variables 119 | x, y, s = extract_sol(u_final, v_final, n, hsde) 120 | 121 | if plot: 122 | plt.plot(iter_losses, label='fixed point residuals') 123 | plt.yscale('log') 124 | plt.legend() 125 | plt.show() 126 | 127 | # populate the sol dictionary 128 | sol = {} 129 | sol['fixed_point_residuals'] = iter_losses 130 | sol['primal_residuals'] = primal_residuals 131 | sol['dual_residuals'] = dual_residuals 132 | sol['x'], sol['y'], sol['s'] = x, y, s 133 | return sol 134 | 135 | 136 | def ruiz_equilibrate(M, num_passes=20): 137 | """ 138 | NOT USED ANYWHERE -- ONLY BRIEFLY TESTED 139 | """ 140 | p, p_ = M.shape 141 | D, E = jnp.eye(p), jnp.eye(p) 142 | val = M, E, D 143 | 144 | def body(i, val): 145 | M, E, D = val 146 | drinv = 1 / jnp.sqrt(jnp.linalg.norm(M, jnp.inf, axis=1)) 147 | dcinv = 1 / jnp.sqrt(jnp.linalg.norm(M, jnp.inf, axis=0)) 148 | D = jnp.multiply(D, drinv) 149 | E = jnp.multiply(E, dcinv) 150 | M = jnp.multiply(M, dcinv) 151 | M = jnp.multiply(drinv[:, None], M) 152 | val = M, E, D 153 | return val 154 | val = jax.lax.fori_loop(0, num_passes, body, val) 155 | M, E, D = val 156 | 157 | # for i in range(num_passes): 158 | # drinv = 1 / jnp.sqrt(jnp.linalg.norm(M, jnp.inf, axis=1)) 159 | # dcinv = 1 / jnp.sqrt(jnp.linalg.norm(M, jnp.inf, axis=0)) 160 | # D = jnp.multiply(D, drinv) 161 | # E = jnp.multiply(E, dcinv) 162 | # M = jnp.multiply(M, dcinv) 163 | # M = jnp.multiply(drinv[:, None], M) 164 | return M, E, D 165 | -------------------------------------------------------------------------------- /benchmarks/l2ws_setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import hydra 4 | 5 | import l2ws.examples.jamming as jamming 6 | import l2ws.examples.lasso as lasso 7 | import l2ws.examples.markowitz as markowitz 8 | import l2ws.examples.mnist as mnist 9 | import l2ws.examples.mpc as mpc 10 | import l2ws.examples.osc_mass as osc_mass 11 | import l2ws.examples.phase_retrieval as phase_retrieval 12 | import l2ws.examples.quadcopter as quadcopter 13 | import l2ws.examples.robust_kalman as robust_kalman 14 | import l2ws.examples.robust_ls as robust_ls 15 | import l2ws.examples.robust_pca as robust_pca 16 | import l2ws.examples.sparse_pca as sparse_pca 17 | import l2ws.examples.unconstrained_qp as unconstrained_qp 18 | import l2ws.examples.vehicle as vehicle 19 | 20 | 21 | @hydra.main(config_path='configs/markowitz', config_name='markowitz_setup.yaml') 22 | def main_setup_markowitz(cfg): 23 | markowitz.setup_probs(cfg) 24 | 25 | 26 | @hydra.main(config_path='configs/osc_mass', config_name='osc_mass_setup.yaml') 27 | def main_setup_osc_mass(cfg): 28 | osc_mass.setup_probs(cfg) 29 | 30 | 31 | @hydra.main(config_path='configs/vehicle', config_name='vehicle_setup.yaml') 32 | def main_setup_vehicle(cfg): 33 | vehicle.setup_probs(cfg) 34 | 35 | 36 | @hydra.main(config_path='configs/quadcopter', config_name='quadcopter_setup.yaml') 37 | def main_setup_quadcopter(cfg): 38 | quadcopter.setup_probs(cfg) 39 | 40 | 41 | @hydra.main(config_path='configs/mnist', config_name='mnist_setup.yaml') 42 | def main_setup_mnist(cfg): 43 | mnist.setup_probs(cfg) 44 | 45 | 46 | @hydra.main(config_path='configs/jamming', config_name='jamming_setup.yaml') 47 | def main_setup_jamming(cfg): 48 | jamming.setup_probs(cfg) 49 | 50 | 51 | @hydra.main(config_path='configs/robust_kalman', config_name='robust_kalman_setup.yaml') 52 | def main_setup_robust_kalman(cfg): 53 | robust_kalman.setup_probs(cfg) 54 | 55 | 56 | @hydra.main(config_path='configs/robust_pca', config_name='robust_pca_setup.yaml') 57 | def main_setup_robust_pca(cfg): 58 | robust_pca.setup_probs(cfg) 59 | 60 | 61 | @hydra.main(config_path='configs/robust_ls', config_name='robust_ls_setup.yaml') 62 | def main_setup_robust_ls(cfg): 63 | robust_ls.setup_probs(cfg) 64 | 65 | 66 | @hydra.main(config_path='configs/sparse_pca', config_name='sparse_pca_setup.yaml') 67 | def main_setup_sparse_pca(cfg): 68 | sparse_pca.setup_probs(cfg) 69 | 70 | 71 | @hydra.main(config_path='configs/phase_retrieval', config_name='phase_retrieval_setup.yaml') 72 | def main_setup_phase_retrieval(cfg): 73 | phase_retrieval.setup_probs(cfg) 74 | 75 | 76 | @hydra.main(config_path='configs/lasso', config_name='lasso_setup.yaml') 77 | def main_setup_lasso(cfg): 78 | lasso.setup_probs(cfg) 79 | 80 | 81 | @hydra.main(config_path='configs/mpc', config_name='mpc_setup.yaml') 82 | def main_setup_mpc(cfg): 83 | mpc.setup_probs(cfg) 84 | 85 | 86 | @hydra.main(config_path='configs/unconstrained_qp', config_name='unconstrained_qp_setup.yaml') 87 | def main_setup_unconstrained_qp(cfg): 88 | unconstrained_qp.setup_probs(cfg) 89 | 90 | 91 | if __name__ == '__main__': 92 | if sys.argv[2] == 'cluster': 93 | base = 'hydra.run.dir=/scratch/gpfs/rajivs/learn2warmstart/outputs/' 94 | elif sys.argv[2] == 'local': 95 | base = 'hydra.run.dir=outputs/' 96 | if sys.argv[1] == 'markowitz': 97 | # step 1. remove the markowitz argument -- otherwise hydra uses it as an override 98 | # step 2. add the train_outputs/... argument for data_setup_outputs not outputs 99 | # sys.argv[1] = 'hydra.run.dir=outputs/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 100 | sys.argv[1] = base + 'markowitz/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 101 | sys.argv = [sys.argv[0], sys.argv[1]] 102 | main_setup_markowitz() 103 | elif sys.argv[1] == 'osc_mass': 104 | sys.argv[1] = base + 'osc_mass/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 105 | sys.argv = [sys.argv[0], sys.argv[1]] 106 | main_setup_osc_mass() 107 | elif sys.argv[1] == 'vehicle': 108 | sys.argv[1] = base + 'vehicle/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 109 | sys.argv = [sys.argv[0], sys.argv[1]] 110 | main_setup_vehicle() 111 | elif sys.argv[1] == 'robust_kalman': 112 | sys.argv[1] = base + 'robust_kalman/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 113 | sys.argv = [sys.argv[0], sys.argv[1]] 114 | main_setup_robust_kalman() 115 | elif sys.argv[1] == 'robust_pca': 116 | sys.argv[1] = base + 'robust_pca/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 117 | sys.argv = [sys.argv[0], sys.argv[1]] 118 | main_setup_robust_pca() 119 | elif sys.argv[1] == 'robust_ls': 120 | sys.argv[1] = base + 'robust_ls/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 121 | sys.argv = [sys.argv[0], sys.argv[1]] 122 | main_setup_robust_ls() 123 | elif sys.argv[1] == 'sparse_pca': 124 | sys.argv[1] = base + 'sparse_pca/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 125 | sys.argv = [sys.argv[0], sys.argv[1]] 126 | main_setup_sparse_pca() 127 | elif sys.argv[1] == 'phase_retrieval': 128 | sys.argv[1] = base + 'phase_retrieval/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 129 | sys.argv = [sys.argv[0], sys.argv[1]] 130 | main_setup_phase_retrieval() 131 | elif sys.argv[1] == 'lasso': 132 | sys.argv[1] = base + 'lasso/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 133 | sys.argv = [sys.argv[0], sys.argv[1]] 134 | main_setup_lasso() 135 | elif sys.argv[1] == 'mpc': 136 | sys.argv[1] = base + 'mpc/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 137 | sys.argv = [sys.argv[0], sys.argv[1]] 138 | main_setup_mpc() 139 | elif sys.argv[1] == 'unconstrained_qp': 140 | sys.argv[1] = base + 'unconstrained_qp/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 141 | sys.argv = [sys.argv[0], sys.argv[1]] 142 | main_setup_unconstrained_qp() 143 | elif sys.argv[1] == 'quadcopter': 144 | sys.argv[1] = base + 'quadcopter/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 145 | sys.argv = [sys.argv[0], sys.argv[1]] 146 | main_setup_quadcopter() 147 | elif sys.argv[1] == 'mnist': 148 | sys.argv[1] = base + 'mnist/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 149 | sys.argv = [sys.argv[0], sys.argv[1]] 150 | main_setup_mnist() 151 | elif sys.argv[1] == 'jamming': 152 | sys.argv[1] = base + 'jamming/data_setup_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 153 | sys.argv = [sys.argv[0], sys.argv[1]] 154 | main_setup_jamming() 155 | -------------------------------------------------------------------------------- /l2ws/examples/unconstrained_qp.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import cvxpy as cp 4 | import yaml 5 | from l2ws.launcher import Workspace 6 | from l2ws.examples.solve_script import gd_setup_script 7 | import os 8 | 9 | 10 | def run(run_cfg): 11 | example = "unconstrained_qp" 12 | data_yaml_filename = 'data_setup_copied.yaml' 13 | 14 | # read the yaml file 15 | with open(data_yaml_filename, "r") as stream: 16 | try: 17 | setup_cfg = yaml.safe_load(stream) 18 | except yaml.YAMLError as exc: 19 | print(exc) 20 | setup_cfg = {} 21 | 22 | # set the seed 23 | np.random.seed(setup_cfg['seed']) 24 | n_orig = setup_cfg['n_orig'] 25 | 26 | split = setup_cfg['P_split'] 27 | split_factor = setup_cfg['split_factor'] 28 | P_vec = jnp.concatenate([split_factor * jnp.ones(split), 1 * jnp.ones(n_orig - split)]) 29 | P = jnp.diag(P_vec) 30 | 31 | # A_scale = setup_cfg['A_scale'] 32 | # A = A_scale * jnp.array(np.random.normal(size=(m_orig, n_orig))) 33 | # split = int(n_orig / 2) 34 | # A_vec = jnp.concatenate([100 * jnp.ones(split), 1 * jnp.ones(split)]) 35 | # A = jnp.diag(A_vec) 36 | # density = 0.1 37 | # A = A_scale * jnp.array(random(m_orig, n_orig, density=density, format='csr').todense()) 38 | # evals, evecs = jnp.linalg.eigh(A.T @ A) 39 | # gd_step = 1 / evals.max() 40 | gd_step = 1 / P.max() # 2 / (P.max() + 1) #1 / P.max() 41 | 42 | static_dict = dict(P=P, gd_step=gd_step) 43 | 44 | # we directly save q now 45 | static_flag = True 46 | algo = 'gd' 47 | workspace = Workspace(algo, run_cfg, static_flag, static_dict, example) 48 | 49 | # run the workspace 50 | workspace.run() 51 | 52 | 53 | def setup_probs(setup_cfg): 54 | cfg = setup_cfg 55 | N_train, N_test = cfg.N_train, cfg.N_test 56 | N = N_train + N_test 57 | np.random.seed(setup_cfg['seed']) 58 | n_orig = setup_cfg['n_orig'] 59 | 60 | split = setup_cfg['P_split'] 61 | split_factor = setup_cfg['split_factor'] 62 | P_vec = jnp.concatenate([split_factor * jnp.ones(split), 1 * jnp.ones(n_orig - split)]) 63 | P = jnp.diag(P_vec) 64 | c_min = setup_cfg['c_min'] 65 | c_max = setup_cfg['c_max'] 66 | # range_factor = setup_cfg['range_factor'] 67 | 68 | 69 | """ 70 | P is a diagonal matrix diag(p_1, ..., p_n) 71 | p_1, ..., p_split = split_factor 72 | p_{split+1}, ..., n = 1 73 | sample c in the following way 74 | c_i = p_i b_i cos(theta pi i / n) 75 | 76 | b is range_factor 77 | """ 78 | 79 | # generate theta 80 | 81 | # generate c_mat 82 | 83 | c1_mat = split_factor ** 2 * (c_min + (c_max - c_min) * jnp.array(np.random.rand(N, split))) 84 | c2_mat = c_min + (c_max - c_min) * jnp.array(np.random.rand(N, n_orig - split)) 85 | c_mat = jnp.hstack([c1_mat, c2_mat]) 86 | 87 | np.random.seed(cfg.seed) 88 | 89 | # save output to output_filename 90 | output_filename = f"{os.getcwd()}/data_setup" 91 | 92 | # b_min, b_max = setup_cfg['b_min'], setup_cfg['b_max'] 93 | # b_mat = b_scale * generate_b_mat(A, N, b_min, b_max) 94 | # m, n = A.shape 95 | # b_mat = (b_max - b_min) * np.random.rand(N, m) + b_min 96 | 97 | gd_setup_script(c_mat, P, output_filename) 98 | 99 | 100 | # def generate_b_mat(A, N, p=.1): 101 | # m, n = A.shape 102 | # b0 = jnp.array(np.random.normal(size=(m))) 103 | # b_mat = 0 * b0 + 1 * jnp.array(np.random.normal(size=(N, m))) + 1 104 | # # b_mat = jnp.zeros((N, m)) 105 | # # x_star_mask = np.random.binomial(1, p, size=(N, n)) 106 | # # x_stars_dense = jnp.array(np.random.normal(size=(N, n))) 107 | # # x_stars = jnp.multiply(x_star_mask, x_stars_dense) 108 | # # for i in range(N): 109 | # # b = A @ x_stars[i, :] 110 | # # b_mat = b_mat.at[i, :].set(b) 111 | # return b_mat 112 | 113 | # def eval_gd_obj(z, A, b, lambd): 114 | # return .5 * jnp.linalg.norm(A @ z - b) ** 2 + lambd * jnp.linalg.norm(z, ord=1) 115 | 116 | 117 | # def obj_diff(obj, true_obj): 118 | # return (obj - true_obj) 119 | 120 | 121 | # def sol_2_obj_diff(z, b, true_obj, A, lambd): 122 | # obj = eval_gd_obj(z, A, b, lambd) 123 | # return obj_diff(obj, true_obj) 124 | 125 | # def solve_many_probs_cvxpy(A, b_mat, lambd): 126 | # """ 127 | # solves many gd problems where each problem has a different b vector 128 | # """ 129 | # m, n = A.shape 130 | # N = b_mat.shape[0] 131 | # z, b_param = cp.Variable(n), cp.Parameter(m) 132 | # prob = cp.Problem(cp.Minimize(.5 * cp.sum_squares(np.array(A) @ z - b_param) + lambd * cp.norm(z, p=1))) 133 | # # prob = cp.Problem(cp.Minimize(.5 * cp.sum_squares(np.array(A) @ z - b_param) + lambd * cp.tv(z))) 134 | # z_stars = jnp.zeros((N, n)) 135 | # objvals = jnp.zeros((N)) 136 | # for i in range(N): 137 | # b_param.value = np.array(b_mat[i, :]) 138 | # prob.solve(verbose=False) 139 | # objvals = objvals.at[i].set(prob.value) 140 | # z_stars = z_stars.at[i, :].set(jnp.array(z.value)) 141 | # print('finished solving cvxpy problems') 142 | # return z_stars, objvals 143 | 144 | 145 | # def run(run_cfg): 146 | # example = "unconstrained_qp" 147 | # data_yaml_filename = 'data_setup_copied.yaml' 148 | 149 | # # read the yaml file 150 | # with open(data_yaml_filename, "r") as stream: 151 | # try: 152 | # setup_cfg = yaml.safe_load(stream) 153 | # except yaml.YAMLError as exc: 154 | # print(exc) 155 | # setup_cfg = {} 156 | 157 | # # set the seed 158 | # np.random.seed(setup_cfg['seed']) 159 | # n_orig = setup_cfg['n_orig'] 160 | 161 | # Q = jnp.array(np.random.normal(size=(n_orig, n_orig))) 162 | # P = Q @ Q.T 163 | # evals, evecs = jnp.linalg.eigh(P) 164 | # gd_step = 1 / evals.max() 165 | 166 | # # static_dict = static_canon(n_orig, k, rho_x=rho_x, scale=scale) 167 | # static_dict = dict(Q=Q, gd_step=gd_step) 168 | 169 | # # we directly save q now 170 | # get_q = None 171 | # static_flag = True 172 | # algo = 'gd' 173 | # workspace = Workspace(run_cfg, algo, static_flag, static_dict, example, get_q) 174 | 175 | # # run the workspace 176 | # workspace.run() 177 | 178 | 179 | def obj(Q, c, z): 180 | return .5 * z.T @ Q @ z + c @ z 181 | 182 | def obj_diff(obj, true_obj): 183 | return (obj - true_obj) 184 | 185 | 186 | def solve_many_probs_cvxpy(Q, c_mat): 187 | """ 188 | solves many unconstrained qp problems where each problem has a different b vector 189 | """ 190 | # m, n = A.shape 191 | Q_inv = jnp.linalg.inv(Q) 192 | z_stars = -Q_inv @ c_mat 193 | objvals = obj(Q, c_mat, z_stars) 194 | return z_stars, objvals 195 | -------------------------------------------------------------------------------- /tests/test_l2ws_model.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import scs 4 | from scipy.sparse import csc_matrix 5 | 6 | from l2ws.algo_steps import create_M, create_projection_fn, get_scaled_vec_and_factor 7 | from l2ws.examples.robust_ls import multiple_random_robust_ls 8 | from l2ws.scs_model import SCSmodel 9 | 10 | 11 | def multiple_random_robust_ls_setup(m_orig, n_orig, rho, b_center, b_range, N_train, N_test, rho_x, 12 | scale): 13 | N = N_train + N_test 14 | P, A, cones, q_mat, theta_mat, A_orig, b_orig_mat = multiple_random_robust_ls( 15 | m_orig, n_orig, rho, b_center, b_range, N) 16 | m, n = A.shape 17 | 18 | proj = create_projection_fn(cones, n) 19 | q_mat_train, q_mat_test = q_mat[:N_train, :], q_mat[N_train:N, :] 20 | train_inputs, test_inputs = theta_mat[:N_train, :], theta_mat[N_train:N, :] 21 | static_M = create_M(P, A) 22 | # static_algo_factor = jsp.linalg.lu_factor(static_M + jnp.eye(n + m)) 23 | zero_cone_size = cones['z'] 24 | static_algo_factor, scale_vec = get_scaled_vec_and_factor( 25 | static_M, rho_x, scale, m, n, zero_cone_size) 26 | 27 | static_prob_data = dict(P=P, A=A, cones=cones, proj=proj, 28 | static_M=static_M, static_algo_factor=static_algo_factor, 29 | m=m, n=n) 30 | varying_prob_data = dict(q_mat_train=q_mat_train, q_mat_test=q_mat_test, 31 | train_inputs=train_inputs, test_inputs=test_inputs) 32 | return static_prob_data, varying_prob_data 33 | 34 | 35 | def test_minimal_l2ws_model(): 36 | """ 37 | tests that we can initialize an L2WSmodel with the minimal amount of information needed 38 | 39 | we test for 40 | - no errors in creation of L2WSmodel without specifying most the hyperparameters 41 | such as any part of the neural network 42 | - no errors during training 43 | - decrease in training loss after training for 10 epochs of over 50% 44 | - test loss does not increase by more than 5% 45 | - warm-starting with our architecture matches the SCS C implementation to machine precision 46 | """ 47 | # get the problem 48 | N_train, N_test = 10, 5 49 | m_orig, n_orig = 30, 40 50 | rho, b_center, b_range = 1, 1, 1 51 | 52 | # scs hyperparams 53 | alpha_relax = 1.1 54 | rho_x = 1.1 55 | scale = .2 56 | 57 | static_prob_data, varying_prob_data = multiple_random_robust_ls_setup( 58 | m_orig, n_orig, rho, b_center, b_range, N_train, N_test, rho_x, scale) 59 | q_mat_train, q_mat_test = varying_prob_data['q_mat_train'], varying_prob_data['q_mat_test'] 60 | train_inputs, test_inputs = varying_prob_data['train_inputs'], varying_prob_data['test_inputs'] 61 | 62 | m, n, cones = static_prob_data['m'], static_prob_data['n'], static_prob_data['cones'] 63 | P, A = static_prob_data['P'], static_prob_data['A'] 64 | 65 | # enter into the L2WSmodel 66 | algo_dict = dict(algorithm='scs', 67 | m=m, 68 | n=n, 69 | proj=static_prob_data['proj'], 70 | cones=cones, 71 | q_mat_train=q_mat_train, 72 | q_mat_test=q_mat_test, 73 | static_M=static_prob_data['static_M'], 74 | static_algo_factor=static_prob_data['static_algo_factor'], 75 | rho_x=rho_x, 76 | scale=scale, 77 | alpha_relax=alpha_relax) 78 | 79 | l2ws_model = SCSmodel(train_unrolls=20, 80 | train_inputs=train_inputs, 81 | test_inputs=test_inputs, 82 | algo_dict=algo_dict) 83 | 84 | # evaluate test before training 85 | init_test_loss, init_time_per_iter = l2ws_model.short_test_eval() 86 | 87 | # call train_batch without jitting 88 | params, state = l2ws_model.params, l2ws_model.state 89 | num_epochs = 10 90 | losses = jnp.zeros(num_epochs) 91 | for i in range(num_epochs): 92 | train_result = l2ws_model.train_full_batch(params, state) 93 | loss, params, state = train_result 94 | losses = losses.at[i].set(loss) 95 | 96 | # some reduction should be made from first to last epoch 97 | assert losses[0] - losses[-1] > 0 98 | 99 | # final loss should be at least 60% better than the first loss 100 | assert losses[-1] / losses[0] < 0.6 101 | 102 | l2ws_model.params, l2ws_model.state = params, state 103 | 104 | # evaluate test after training 105 | final_test_loss, final_time_per_iter = l2ws_model.short_test_eval() 106 | 107 | # test loss does not get 5% worse 108 | assert final_test_loss < init_test_loss * 1.05 109 | 110 | # after jitting, evaluating the test set should be much faster 111 | assert final_time_per_iter < .1 * init_time_per_iter 112 | 113 | # evaluate the training set for a different number of iterations 114 | # dynamic_factor, M_dynamic = None, None 115 | # loss, eval_out, time_per_prob = l2ws_model.evaluate(300, train_inputs, dynamic_factor, 116 | # M_dynamic, q_mat_train, 117 | # z_stars=None, fixed_ws=False, tag='train') 118 | loss, eval_out, time_per_prob = l2ws_model.evaluate(300, train_inputs, q_mat_train, 119 | z_stars=None, fixed_ws=False, tag='train') 120 | 121 | # out, losses, iter_losses, angles, primal_residuals, dual_residuals = eval_out 122 | losses, iter_losses, z_all_plus_1, angles, primal_residuals, dual_residuals, u_all, v_all = eval_out # noqa 123 | # z_all_plus_1, z_final, alpha, u_all, v_all = out 124 | 125 | # warm-start SCS with z0 from all_z_plus_1 126 | # SCS setup 127 | max_iters = 6 128 | P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 129 | scs_data = dict(P=P_sparse, A=A_sparse, b=np.zeros(m), c=np.zeros(n)) 130 | 131 | solver = scs.SCS(scs_data, 132 | cones, 133 | normalize=False, 134 | scale=scale, 135 | adaptive_scale=False, 136 | rho_x=rho_x, 137 | alpha=alpha_relax, 138 | acceleration_lookback=0, 139 | max_iters=max_iters, 140 | eps_abs=1e-12, 141 | eps_rel=0) 142 | 143 | x_ws = z_all_plus_1[0, 0, :n] 144 | y_ws = z_all_plus_1[0, 0, n:n + m] 145 | s_ws = np.zeros(m) 146 | 147 | c_np = np.array(q_mat_train[0, :n]) 148 | b_np = np.array(q_mat_train[0, n:]) 149 | solver.update(b=b_np, c=c_np) 150 | sol = solver.solve(warm_start=True, x=np.array(x_ws), y=np.array(y_ws), s=np.array(s_ws)) 151 | x_c = jnp.array(sol['x']) 152 | y_c = jnp.array(sol['y']) 153 | s_c = jnp.array(sol['s']) 154 | u_final = u_all[0, max_iters - 1, :] 155 | x_jax = u_final[:n] / u_all[0, max_iters - 1, -1] 156 | y_jax = u_final[n:n + m] / u_all[0, max_iters - 1, -1] 157 | s_jax = v_all[0, max_iters - 1, n:n+m] / u_all[0, max_iters - 1, -1] 158 | 159 | 160 | assert jnp.linalg.norm(x_jax - x_c) < 1e-10 161 | assert jnp.linalg.norm(y_jax - y_c) < 1e-10 162 | assert jnp.linalg.norm(s_jax - s_c) < 1e-10 163 | -------------------------------------------------------------------------------- /l2ws/examples/sparse_pca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import yaml 4 | from jax import vmap 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import time 8 | import jax.numpy as jnp 9 | import os 10 | import scs 11 | import cvxpy as cp 12 | import jax.scipy as jsp 13 | from l2ws.algo_steps import create_M 14 | from scipy.sparse import csc_matrix 15 | from l2ws.examples.solve_script import setup_script 16 | from l2ws.launcher import Workspace 17 | from l2ws.algo_steps import get_scaled_vec_and_factor 18 | 19 | 20 | plt.rcParams.update( 21 | { 22 | "text.usetex": True, 23 | "font.family": "serif", 24 | "font.size": 16, 25 | } 26 | ) 27 | log = logging.getLogger(__name__) 28 | 29 | 30 | def run(run_cfg): 31 | example = "sparse_pca" 32 | data_yaml_filename = 'data_setup_copied.yaml' 33 | 34 | # read the yaml file 35 | with open(data_yaml_filename, "r") as stream: 36 | try: 37 | setup_cfg = yaml.safe_load(stream) 38 | except yaml.YAMLError as exc: 39 | print(exc) 40 | setup_cfg = {} 41 | 42 | # set the seed 43 | np.random.seed(setup_cfg['seed']) 44 | n_orig = setup_cfg['n_orig'] 45 | k = setup_cfg['k'] 46 | 47 | # non-identity DR scaling 48 | rho_x = run_cfg.get('rho_x', 1) 49 | scale = run_cfg.get('scale', 1) 50 | 51 | static_dict = static_canon(n_orig, k, rho_x=rho_x, scale=scale) 52 | 53 | # we directly save q now 54 | # get_q = None 55 | static_flag = True 56 | # workspace = Workspace(run_cfg, static_flag, static_dict, example, get_q) 57 | 58 | algo = 'scs' 59 | workspace = Workspace(algo, run_cfg, static_flag, static_dict, example) 60 | 61 | # run the workspace 62 | workspace.run() 63 | 64 | 65 | def multiple_random_sparse_pca(n_orig, k, r, N, factor=True, seed=42): 66 | out_dict = static_canon(n_orig, k, factor=factor) 67 | # c, b = out_dict['c'], out_dict['b'] 68 | P_sparse, A_sparse = out_dict['P_sparse'], out_dict['A_sparse'] 69 | cones = out_dict['cones_dict'] 70 | prob, A_param = out_dict['prob'], out_dict['A_param'] 71 | P, A = jnp.array(P_sparse.todense()), jnp.array(A_sparse.todense()) 72 | 73 | # get theta_mat 74 | A_tensor, theta_mat = generate_A_tensor(N, n_orig, r) 75 | theta_mat_jax = jnp.array(theta_mat) 76 | 77 | # get theta_mat 78 | m, n = A.shape 79 | q_mat = get_q_mat(A_tensor, prob, A_param, m, n) 80 | # import pdb 81 | # pdb.set_trace() 82 | 83 | return P, A, cones, q_mat, theta_mat_jax, A_tensor 84 | 85 | 86 | def generate_A_tensor(N, n_orig, r): 87 | """ 88 | generates covariance matrices A_1, ..., A_N 89 | where each A_i has shape (n_orig, n_orig) 90 | A_i = F Sigma_i F^T 91 | where F has shape (n_orig, r) 92 | i.e. each Sigma_i is psd (Sigma_i = B_i B_i^T) and is different 93 | B_i has shape (r, r) 94 | F stays the same for each problem 95 | We let theta = upper_tri(Sigma_i) 96 | """ 97 | # first generate a random A matrix 98 | # A0 = np.random.rand(n_orig, n_orig) 99 | A0 = np.random.normal(size=(n_orig, n_orig)) 100 | 101 | # take the SVD 102 | U, S, VT = np.linalg.svd(A0) 103 | 104 | # take F to be the first r columns of U 105 | F = U[:, :r] 106 | A_tensor = np.zeros((N, n_orig, n_orig)) 107 | r_choose_2 = int(r * (r + 1) / 2) 108 | theta_mat = np.zeros((N, r_choose_2)) 109 | # theta_mat = np.zeros((N, n_orig * r)) 110 | B0 = np.diag(np.sqrt(S[:r])) 111 | 112 | for i in range(N): 113 | # B = .1*np.random.rand(r, r) #np.diag(np.random.rand(r)) #2 * np.random.rand(r, r) - 1 114 | # B = np.random.normal(size=(r, r)) 115 | delta = 2 * np.random.rand(r, r) - 1 116 | B = .1 * delta + B0 117 | # B = 2 * np.random.rand(r, r) 118 | # B = np.random.normal(size=(r, r)) 119 | Sigma = .1 * B @ B.T 120 | col_idx, row_idx = np.triu_indices(r) 121 | theta_mat[i, :] = Sigma[(row_idx, col_idx)] 122 | A_tensor[i, :, :] = F @ Sigma @ F.T 123 | 124 | # curr_perturb = np.random.normal(size=(n_orig, r)) 125 | # C = F + .1 * curr_perturb 126 | # A_tensor[i, :, :] = C @ C.T 127 | # theta_mat[i, :] = np.ravel(curr_perturb) 128 | 129 | return A_tensor, theta_mat 130 | 131 | 132 | def cvxpy_prob(n_orig, k): 133 | A_param = cp.Parameter((n_orig, n_orig), symmetric=True) 134 | X = cp.Variable((n_orig, n_orig), symmetric=True) 135 | constraints = [X >> 0, cp.sum(cp.abs(X)) <= k, cp.trace(X) == 1] 136 | prob = cp.Problem(cp.Minimize(-cp.trace(A_param @ X)), constraints) 137 | return prob, A_param 138 | 139 | 140 | def get_q_mat(A_tensor, prob, A_param, m, n): 141 | N, n_orig, _ = A_tensor.shape 142 | q_mat = jnp.zeros((N, m + n)) 143 | for i in range(N): 144 | # set the parameter 145 | A_param.value = A_tensor[i, :, :] 146 | 147 | # get the problem data 148 | data, _, __ = prob.get_problem_data(cp.SCS) 149 | 150 | c, b = data['c'], data['b'] 151 | n = c.size 152 | q_mat = q_mat.at[i, :n].set(c) 153 | q_mat = q_mat.at[i, n:].set(b) 154 | return q_mat 155 | 156 | 157 | def static_canon(n_orig, k, rho_x=1, scale=1, factor=True): 158 | # create the cvxpy problem 159 | prob, A_param = cvxpy_prob(n_orig, k) 160 | 161 | # get the problem data 162 | data, _, __ = prob.get_problem_data(cp.SCS) 163 | 164 | A_sparse, c, b = data['A'], data['c'], data['b'] 165 | m, n = A_sparse.shape 166 | P_sparse = csc_matrix(np.zeros((n, n))) 167 | cones_cp = data['dims'] 168 | 169 | # factor for DR splitting 170 | m, n = A_sparse.shape 171 | P_jax, A_jax = jnp.array(P_sparse.todense()), jnp.array(A_sparse.todense()) 172 | M = create_M(P_jax, A_jax) 173 | zero_cone_size = cones_cp.zero 174 | 175 | if factor: 176 | algo_factor, scale_vec = get_scaled_vec_and_factor(M, rho_x, scale, m, n, 177 | zero_cone_size) 178 | # algo_factor = jsp.linalg.lu_factor(M + jnp.eye(n + m)) 179 | else: 180 | algo_factor = None 181 | 182 | # set the dict 183 | cones = {'z': cones_cp.zero, 'l': cones_cp.nonneg, 'q': cones_cp.soc, 's': cones_cp.psd} 184 | out_dict = dict( 185 | M=M, 186 | algo_factor=algo_factor, 187 | cones_dict=cones, 188 | A_sparse=A_sparse, 189 | P_sparse=P_sparse, 190 | b=b, 191 | c=c, 192 | prob=prob, 193 | A_param=A_param 194 | ) 195 | return out_dict 196 | 197 | 198 | def setup_probs(setup_cfg): 199 | cfg = setup_cfg 200 | N_train, N_test = cfg.N_train, cfg.N_test 201 | N = N_train + N_test 202 | n_orig = cfg.n_orig 203 | 204 | np.random.seed(cfg.seed) 205 | 206 | # save output to output_filename 207 | output_filename = f"{os.getcwd()}/data_setup" 208 | 209 | P, A, cones, q_mat, theta_mat_jax, A_tensor = multiple_random_sparse_pca( 210 | n_orig, cfg.k, cfg.r, N, factor=False) 211 | P_sparse, A_sparse = csc_matrix(P), csc_matrix(A) 212 | m, n = A.shape 213 | 214 | # create scs solver object 215 | # we can cache the factorization if we do it like this 216 | b_np, c_np = np.array(q_mat[0, n:]), np.array(q_mat[0, :n]) 217 | data = dict(P=P_sparse, A=A_sparse, b=b_np, c=c_np) 218 | tol_abs = cfg.solve_acc_abs 219 | tol_rel = cfg.solve_acc_rel 220 | solver = scs.SCS(data, cones, normalize=False, alpha=1, scale=1, 221 | rho_x=1, adaptive_scale=False, eps_abs=tol_abs, eps_rel=tol_rel) 222 | 223 | setup_script(q_mat, theta_mat_jax, solver, data, cones, output_filename, solve=cfg.solve) 224 | 225 | import pdb 226 | pdb.set_trace() 227 | -------------------------------------------------------------------------------- /l2ws/osqp_model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import osqp 6 | from scipy.sparse import csc_matrix 7 | 8 | from l2ws.algo_steps import k_steps_eval_osqp, k_steps_train_osqp, unvec_symm 9 | from l2ws.l2ws_model import L2WSmodel 10 | 11 | 12 | class OSQPmodel(L2WSmodel): 13 | def __init__(self, **kwargs): 14 | super(OSQPmodel, self).__init__(**kwargs) 15 | 16 | def initialize_algo(self, input_dict): 17 | # self.m, self.n = self.A.shape 18 | self.algo = 'osqp' 19 | self.m, self.n = input_dict['m'], input_dict['n'] 20 | self.q_mat_train, self.q_mat_test = input_dict['q_mat_train'], input_dict['q_mat_test'] 21 | 22 | self.rho = input_dict['rho'] 23 | self.sigma = input_dict.get('sigma', 1) 24 | self.alpha = input_dict.get('alpha', 1) 25 | self.output_size = self.n + self.m 26 | 27 | """ 28 | break into the 2 cases 29 | 1. factors are the same for each problem (i.e. matrices A and P don't change) 30 | 2. factors change for each problem 31 | """ 32 | self.factors_required = True 33 | self.factor_static_bool = input_dict.get('factor_static_bool', True) 34 | if self.factor_static_bool: 35 | self.A = input_dict['A'] 36 | # self.P = input_dict.get('P', None) 37 | self.P = input_dict['P'] 38 | self.factor_static = input_dict['factor'] 39 | self.k_steps_train_fn = partial( 40 | k_steps_train_osqp, A=self.A, rho=self.rho, sigma=self.sigma, jit=self.jit) 41 | self.k_steps_eval_fn = partial(k_steps_eval_osqp, P=self.P, 42 | A=self.A, rho=self.rho, sigma=self.sigma, jit=self.jit) 43 | else: 44 | self.k_steps_train_fn = self.create_k_steps_train_fn_dynamic() 45 | self.k_steps_eval_fn = self.create_k_steps_eval_fn_dynamic() 46 | # self.k_steps_eval_fn = partial(k_steps_eval_osqp, rho=rho, sigma=sigma, jit=self.jit) 47 | 48 | self.factors_train = input_dict['factors_train'] 49 | self.factors_test = input_dict['factors_test'] 50 | 51 | # self.k_steps_train_fn = partial(k_steps_train_osqp, factor=factor, A=self.A, rho=rho, 52 | # sigma=sigma, jit=self.jit) 53 | # self.k_steps_eval_fn = partial(k_steps_eval_osqp, factor=factor, P=self.P, A=self.A, 54 | # rho=rho, sigma=sigma, jit=self.jit) 55 | self.out_axes_length = 6 56 | 57 | def create_k_steps_train_fn_dynamic(self): 58 | """ 59 | creates the self.k_steps_train_fn function for the dynamic case 60 | acts as a wrapper around the k_steps_train_osqp functino from algo_steps.py 61 | 62 | we want to maintain the argument inputs as (k, z0, q_bar, factor, supervised, z_star) 63 | """ 64 | m, n = self.m, self.n 65 | 66 | def k_steps_train_osqp_dynamic(k, z0, q, factor, supervised, z_star): 67 | nc2 = int(n * (n + 1) / 2) 68 | q_bar = q[:2 * m + n] 69 | unvec_symm(q[2 * m + n: 2 * m + n + nc2], n) 70 | A = jnp.reshape(q[2 * m + n + nc2:], (m, n)) 71 | return k_steps_train_osqp(k=k, z0=z0, q=q_bar, 72 | factor=factor, A=A, rho=self.rho, sigma=self.sigma, 73 | supervised=supervised, z_star=z_star, jit=self.jit) 74 | return k_steps_train_osqp_dynamic 75 | 76 | def create_k_steps_eval_fn_dynamic(self): 77 | """ 78 | creates the self.k_steps_train_fn function for the dynamic case 79 | acts as a wrapper around the k_steps_train_osqp functino from algo_steps.py 80 | 81 | we want to maintain the argument inputs as (k, z0, q_bar, factor, supervised, z_star) 82 | """ 83 | m, n = self.m, self.n 84 | 85 | def k_steps_eval_osqp_dynamic(k, z0, q, factor, supervised, z_star): 86 | nc2 = int(n * (n + 1) / 2) 87 | q_bar = q[:2 * m + n] 88 | P = unvec_symm(q[2 * m + n: 2 * m + n + nc2], n) 89 | A = jnp.reshape(q[2 * m + n + nc2:], (m, n)) 90 | return k_steps_eval_osqp(k=k, z0=z0, q=q_bar, 91 | factor=factor, P=P, A=A, rho=self.rho, sigma=self.sigma, 92 | supervised=supervised, z_star=z_star, jit=self.jit) 93 | return k_steps_eval_osqp_dynamic 94 | 95 | def solve_c(self, z0_mat, q_mat, rel_tol, abs_tol, max_iter=40000): 96 | # assume M doesn't change across problems 97 | # static problem data 98 | m, n = self.m, self.n 99 | nc2 = int(n * (n + 1) / 2) 100 | 101 | if self.factor_static_bool: 102 | P, A = self.P, self.A 103 | else: 104 | P, A = np.ones((n, n)), np.zeros((m, n)) 105 | P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 106 | 107 | 108 | osqp_solver = osqp.OSQP() 109 | 110 | 111 | # q = q_mat[0, :] 112 | c, l, u = np.zeros(n), np.zeros(m), np.zeros(m) # noqa 113 | 114 | rho = 1 115 | osqp_solver.setup(P=P_sparse, q=c, A=A_sparse, l=l, u=u, alpha=self.alpha, rho=rho, 116 | sigma=self.sigma, polish=False, 117 | adaptive_rho=False, scaling=0, max_iter=max_iter, verbose=True, 118 | eps_abs=abs_tol, eps_rel=rel_tol) 119 | 120 | num = z0_mat.shape[0] 121 | solve_times = np.zeros(num) 122 | solve_iters = np.zeros(num) 123 | x_sols = jnp.zeros((num, n)) 124 | y_sols = jnp.zeros((num, m)) 125 | for i in range(num): 126 | if not self.factor_static_bool: 127 | P = unvec_symm(q_mat[i, 2 * m + n: 2 * m + n + nc2], n) 128 | A = jnp.reshape(q_mat[i, 2 * m + n + nc2:], (m, n)) 129 | c, l, u = np.array(q_mat[i, :n]), np.array(q_mat[i, n:n + m]), np.array(q_mat[i, n + m:n + 2 * m]) # noqa 130 | 131 | P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 132 | # Px = sparse.triu(P_sparse).data 133 | # import pdb 134 | # pdb.set_trace() 135 | osqp_solver = osqp.OSQP() 136 | osqp_solver.setup(P=P_sparse, q=c, A=A_sparse, l=l, u=u, alpha=self.alpha, rho=rho, 137 | sigma=self.sigma, polish=False, 138 | adaptive_rho=False, scaling=0, max_iter=max_iter, verbose=True, 139 | eps_abs=abs_tol, eps_rel=rel_tol) 140 | # osqp_solver.update(Px=P_sparse, Ax=csc_matrix(np.array(A))) 141 | else: 142 | # set c, l, u 143 | c, l, u = q_mat[i, :n], q_mat[i, n:n + m], q_mat[i, n + m:n + 2 * m] # noqa 144 | osqp_solver.update(q=np.array(c)) 145 | osqp_solver.update(l=np.array(l), u=np.array(u)) 146 | 147 | 148 | 149 | # set the warm start 150 | # x, y, s = self.get_xys_from_z(z0_mat[i, :]) 151 | x_ws, y_ws = np.array(z0_mat[i, :n]), np.array(z0_mat[i, n:n + m]) 152 | 153 | # fix warm start 154 | osqp_solver.warm_start(x=x_ws, y=y_ws) 155 | 156 | # solve 157 | results = osqp_solver.solve() 158 | # sol = solver.solve(warm_start=True, x=np.array(x), y=np.array(y), s=np.array(s)) 159 | 160 | # set the solve time in seconds 161 | solve_times[i] = results.info.solve_time * 1000 162 | solve_iters[i] = results.info.iter 163 | 164 | # set the results 165 | x_sols = x_sols.at[i, :].set(results.x) 166 | y_sols = y_sols.at[i, :].set(results.y) 167 | 168 | return solve_times, solve_iters, x_sols, y_sols 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # L2WS 2 | This repository is by 3 | [Rajiv Sambharya](https://rajivsambharya.github.io/), 4 | [Georgina Hall](https://sites.google.com/view/georgina-hall), 5 | [Brandon Amos](http://bamos.github.io/), 6 | and [Bartolomeo Stellato](https://stellato.io/), 7 | and contains the Python source code to 8 | reproduce the experiments in our paper 9 | "[Learning to Warm-Start Fixed-Point Optimization Algorithms](https://arxiv.org/pdf/2309.07835.pdf)." 10 | For an earlier conference version targeting QPs only, check out [this repo](https://github.com/stellatogrp/l2ws_qp). 11 | 12 | If you find this repository helpful in your publications, please consider citing our papers. 13 | 14 | # Abstract 15 | We introduce a machine-learning framework to warm-start fixed-point optimization algorithms. Our architecture consists of a neural network mapping problem parameters to warm starts, followed by a predefined number of fixed-point iterations. We propose two loss functions designed to either minimize the fixed-point residual or the distance to a ground truth solution. In this way, the neural network predicts warm starts with the end-to-end goal of minimizing the downstream loss. An important feature of our architecture is its flexibility, in that it can predict a warm start for fixed-point algorithms run for any number of steps, without being limited to the number of steps it has been trained on. We provide PAC- Bayes generalization bounds on unseen data for common classes of fixed-point operators: contractive, linearly convergent, and averaged. Applying this framework to well-known applications in control, statistics, and signal processing, we observe a significant reduction in the number of iterations and solution time required to solve these problems, through learned warm starts. 16 | 17 | ## Installation 18 | To install the package, run 19 | ``` 20 | $ pip install git+https://github.com/stellatogrp/l2ws 21 | ``` 22 | 23 | ## Getting started 24 | 25 | ### Intro tutorials 26 | You can find introductory tutorials on how to use `l2ws` in the folder `tutorials/`. 27 | 28 | 29 | ### Running experiments 30 | To download the experiments, you should clone this repository with 31 | ``` 32 | git clone https://github.com/stellatogrp/l2ws_fixed_point.git 33 | ``` 34 | Experiments can from the `benchmarks/` folder using the commands below: 35 | ``` 36 | python l2ws_setup.py local 37 | python l2ws_train.py local 38 | python plot_script.py local 39 | ``` 40 | 41 | Replace the ``` ``` with one of the following to run an experiment. 42 | ``` 43 | unconstrained_qp 44 | lasso 45 | quadcopter 46 | mnist 47 | robust_kalman 48 | robust_ls 49 | phase_retrieval 50 | sparse_pca 51 | ``` 52 | 53 | *** 54 | #### ```l2ws_setup.py``` 55 | 56 | The first script ```l2ws_setup.py``` creates all of the problem instances and solves them. 57 | The number of problems that are being solved is set in the setup config file. 58 | That config file also includes other parameters that define the problem instances. 59 | This only needs to be run once for each example. 60 | Depending on the example, this can take some time because 10000 problems are being solved. 61 | After running this script, the results are saved a file in 62 | ``` 63 | outputs/quadcopter/data_setup_outputs/2022-06-03/14-54-32/ 64 | ``` 65 | 66 | *** 67 | #### ```l2ws_train.py``` 68 | 69 | The second script ```l2ws_train.py``` does the actual training using the output from the prevous setup command. 70 | In particular, in the config file, it takes a datetime that points to the setup output. 71 | By default, it takes the most recent setup if this pointer is empty. 72 | The train config file holds information about the actual training process. 73 | Run this file for each $k$ value to train for that number of fixed-point steps. 74 | Each run for a given $k$ and the loss function creates an output folder like 75 | To replicate our results in the paper, the only inputs that need to be changed are the ones that determine the number of training steps and which of the two loss functions you are using. 76 | - ```train_unrolls``` (an integer that is the value $k$) 77 | - ```supervised``` (either True or False) 78 | 79 | ``` 80 | outputs/quadcopter/train_outputs/2022-06-04/15-14-05/ 81 | ``` 82 | In this folder there are many metrics that are stored. 83 | We highlight the mains ones here (both the raw data in csv files and the corresponding plots in pdf files). 84 | 85 | 86 | - Fixed-point residuals over the test problems 87 | 88 | ```outputs/quadcopter/train_outputs/2022-06-04/15-14-05/plots/iters_compared_test.csv``` 89 | ```outputs/quadcopter/train_outputs/2022-06-04/15-14-05/plots/eval_iters_test.pdf``` 90 | 91 | - Fixed-point residuals over the training problems 92 | 93 | ```outputs/quadcopter/train_outputs/2022-06-04/15-14-05/plots/iters_compared_train.csv``` 94 | ```outputs/quadcopter/train_outputs/2022-06-04/15-14-05/eval_iters_train.pdf``` 95 | 96 | - Losses over epochs: for training this holds the average loss (for either loss function), for testing we plot the fixed-point residual at $k$ steps 97 | 98 | ```outputs/quadcopter/train_outputs/2022-06-04/15-14-05/train_test_results.csv``` 99 | ```outputs/quadcopter/train_outputs/2022-06-04/15-14-05/losses_over_training.pdf``` 100 | 101 | - The ```accuracies``` folder holds the results that are used for the tables. First, it holds the average number of iterations to reach the desired accuracies ($0.1$, $0.01$, $0.001$, and $0.0001$ by default). 102 | Second, it holds the reduction in iterations in comparison to the cold start. 103 | 104 | ```outputs/quadcopter/2022-12-03/14-54-32/plots/accuracies``` 105 | 106 | - The ```solve_c``` folder holds the results that we show for the timings (for OSQP and SCS) in the paper. 107 | In the train config file, we can set the accuracies that we set OSQP and SCS to (both the relative and absolute accuracies are set to the same value). 108 | 109 | ```outputs/quadcopter/2022-12-03/14-54-32/plots/solve_c``` 110 | 111 | 112 | 113 | *** 114 | #### ```plot.py``` 115 | 116 | The third script ```plot.py``` plots the results across many different training runs. 117 | Each train run creates a new folder 118 | ``` 119 | outputs/quadcopter/plots/2022-06-04/15-14-05/ 120 | ``` 121 | 122 | 123 | 124 | For the image deblurring task, we use the EMNIST dataset found at https://www.nist.gov/itl/products-and-services/emnist-dataset and use pip to install emnist (https://pypi.org/project/emnist/). 125 | 126 | 127 | Adjust the config files to try different settings; for example, the number of train/test data, number of evaluation iterations, and the number of training steps. 128 | Additionally, the neural network and problem setup configurations can be updated. 129 | We automatically use the most recent output after each stage, but the specific datetime can be inputted. Additionally, the final evaluation plot can take in multiple training datetimes in a list. See the commented out lines in the config files. 130 | 131 | *** 132 | 133 | 134 | # Important files in the backend 135 | To reproduce our results, this part is not needed. 136 | 137 | - The ```l2ws/examples``` folder holds the code for each of the numerical experiments we run. The main purpose is to be used in conjunction with the ```l2ws_setup.py```. 138 | 139 | - An important note is that the code is set to periodically evaluate the train and test sets; this is set in the ```eval_every_x_epochs``` entry in the run config file. 140 | When we evaluate, the fixed-point curves are updated (see the above files for the run config). 141 | 142 | We can also set the number of problems we run with C (for OSQP and SCS) with ```solve_c_num```. This will create the results that are used for our timing tables. 143 | *** 144 | 145 | The ```l2ws``` folder holds the code that implements our architecture and allows for the training. In particular, 146 | 147 | - ```l2ws/launcher.py``` is the workspace which holds the L2WSmodel below. 148 | All of the evaluation and training is run through 149 | 150 | - ```l2ws/algo_steps.py``` holds all of the code that runs the algorithms 151 | 152 | - the fixed-point algorithms follow the same form in case you want to try your own algorithm 153 | 154 | - ```l2ws/l2ws_model.py``` holds the L2WSmodel object, i.e., the architecture. This code allows us to 155 | - evaluate the problems (both test and train) for any initialization technique 156 | - train the neural network weights with the given parameters: the number of fixed-point steps in the architecture $k$ (```train_unrolls```) and the training loss $`\ell^{\rm fp}_{\theta}`$ (```supervised=False```) or $`\ell^{\rm reg}_{\theta}`$ (```supervised=True```) 157 | -------------------------------------------------------------------------------- /l2ws/examples/phase_retrieval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import yaml 4 | from jax import vmap 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import time 8 | import jax.numpy as jnp 9 | import os 10 | import scs 11 | import cvxpy as cp 12 | import jax.scipy as jsp 13 | import jax.random as jra 14 | from l2ws.algo_steps import create_M 15 | from scipy.sparse import csc_matrix 16 | from l2ws.examples.solve_script import setup_script 17 | from l2ws.launcher import Workspace 18 | from l2ws.algo_steps import get_scaled_vec_and_factor 19 | 20 | 21 | plt.rcParams.update( 22 | { 23 | "text.usetex": True, 24 | "font.family": "serif", 25 | "font.size": 16, 26 | } 27 | ) 28 | log = logging.getLogger(__name__) 29 | 30 | 31 | def run(run_cfg): 32 | example = "phase_retrieval" 33 | data_yaml_filename = 'data_setup_copied.yaml' 34 | 35 | # read the yaml file 36 | with open(data_yaml_filename, "r") as stream: 37 | try: 38 | setup_cfg = yaml.safe_load(stream) 39 | except yaml.YAMLError as exc: 40 | print(exc) 41 | setup_cfg = {} 42 | 43 | ######################### TODO 44 | # set the seed 45 | np.random.seed(setup_cfg['seed']) 46 | n_orig = setup_cfg['n_orig'] 47 | d_mul = setup_cfg['d_mul'] 48 | # k = setup_cfg['k'] 49 | 50 | 51 | # non-identity DR scaling 52 | rho_x = run_cfg.get('rho_x', 1) 53 | scale = run_cfg.get('scale', 1) 54 | 55 | static_dict = static_canon(n_orig, d_mul, rho_x=rho_x, scale=scale) 56 | 57 | # we directly save q now 58 | get_q = None 59 | static_flag = True 60 | algo = 'scs' 61 | workspace = Workspace(algo, run_cfg, static_flag, static_dict, example) 62 | 63 | # run the workspace 64 | workspace.run() 65 | 66 | 67 | def multiple_random_phase_retrieval(n_orig, d_mul, x_mean, x_var, N, seed=42): 68 | ######################### TODO 69 | out_dict = static_canon(n_orig, d_mul) 70 | # # c, b = out_dict['c'], out_dict['b'] 71 | P_sparse, A_sparse = out_dict['P_sparse'], out_dict['A_sparse'] 72 | cones = out_dict['cones_dict'] 73 | prob, b_param = out_dict['prob'], out_dict['b_param'] 74 | P, A = jnp.array(P_sparse.todense()), jnp.array(A_sparse.todense()) 75 | A_tensor = out_dict['A_tensor'] 76 | 77 | 78 | # get theta_mat and b_vals together 79 | b_matrix = generate_theta_mat_b_vals(N, A_tensor, x_mean, x_var, n_orig, d_mul) 80 | theta_mat_jax = jnp.array(b_matrix) 81 | 82 | # convert to q_mat 83 | m, n = A.shape 84 | q_mat = get_q_mat(b_matrix, prob, b_param, m, n) 85 | 86 | return P, A, cones, q_mat, theta_mat_jax # possibly return more 87 | 88 | 89 | def generate_psi_vecs(n, d): 90 | ''' 91 | Generate vector of psi's, TODO: type up details 92 | ''' 93 | # first random var 94 | A_vals = np.array([1, 1j, -1, -1j]) 95 | psi_A = np.random.choice(A_vals, size=(n, d)) 96 | # print('psi_A', psi_A) 97 | 98 | # second random var 99 | B_vals = np.array([np.sqrt(2)/2, np.sqrt(3)]) 100 | B_probs = np.array([0.2, 0.8]) 101 | psi_B = np.random.choice(B_vals, size=(n, d), p=B_probs) 102 | # print('psi_B', psi_B) 103 | 104 | # element-wise multiply 105 | out = np.multiply(psi_A, psi_B, dtype='complex_') 106 | return out 107 | 108 | 109 | def generate_A_tensor(n_orig, d_mul): 110 | dftmtx = jnp.fft.fft(jnp.eye(n_orig)) 111 | d = n_orig * d_mul 112 | A_out = np.zeros((d, n_orig, n_orig), dtype='complex_') 113 | psi = generate_psi_vecs(n_orig, d_mul) 114 | for l in range(n_orig): 115 | Wl = dftmtx[l, :].reshape(1, -1) 116 | for j in range(d_mul): 117 | curr_psi = psi[:, j] 118 | ai = np.multiply(Wl, curr_psi.conjugate(), dtype='complex_') 119 | Ai = ai.T @ ai.conjugate() 120 | A_out[(j-1) * n_orig + l, :, :] = Ai 121 | return A_out / 100 122 | 123 | 124 | def cvxpy_prob(n_orig, d_mul, seed=42): 125 | """ 126 | TODO adapt for phase retrieval 127 | will need to pass in specific A_i matrices 128 | """ 129 | d = n_orig * d_mul 130 | A_tensor = generate_A_tensor(n_orig, d_mul) 131 | 132 | ####### this was for sparse pca 133 | # A_param = cp.Parameter((n_orig, n_orig), symmetric=True) 134 | # X = cp.Variable((n_orig, n_orig), symmetric=True) 135 | # constraints = [X >> 0, cp.sum(cp.abs(X)) <= k, cp.trace(X) == 1] 136 | # prob = cp.Problem(cp.Minimize(-cp.trace(A_param @ X)), constraints) 137 | # return prob, A_param 138 | 139 | b_param = cp.Parameter(d) 140 | X = cp.Variable((n_orig, n_orig), hermitian=True) 141 | constraints = [X >> 0] 142 | for i in range(d): 143 | Ai = A_tensor[i] 144 | constraints += [cp.trace(Ai @ X) == b_param[i]] 145 | prob = cp.Problem(cp.Minimize(cp.trace(X)), constraints) 146 | return prob, A_tensor, b_param 147 | 148 | 149 | def generate_theta_mat_b_vals(N, A_tensor, x_mean, x_var, n_orig, d_mul): 150 | d = n_orig * d_mul 151 | b_matrix = np.zeros((N, d)) 152 | # n_orig_choose_2 = int(n_orig * (n_orig + 1) / 2) 153 | # theta_mat = np.zeros((N, n_orig_choose_2), dtype='complex_') 154 | negate1 = np.random.binomial(n=1, p=0.5, size=(n_orig)) 155 | negate2 = np.random.binomial(n=1, p=0.5, size=(n_orig)) 156 | negate1[negate1 == 0] = -1 157 | negate2[negate2 == 0] = -1 158 | for i in range(N): 159 | # this is where the parameterization comes in 160 | # could modify where the xi comes from 161 | # negate1 = np.random.binomial(n=1, p=0.5, size=(n_orig)) 162 | # negate2 = np.random.binomial(n=1, p=0.5, size=(n_orig)) 163 | # negate1[negate1 == 0] = -1 164 | # negate2[negate2 == 0] = -1 165 | # negate1 = np.random.binomial(n=1, p=0.5, size=(n_orig)) 166 | # negate2 = np.random.binomial(n=1, p=0.5, size=(n_orig)) 167 | 168 | xi = np.multiply(np.random.normal(size=(n_orig), loc=x_mean, scale=np.sqrt(x_var)), negate1) \ 169 | + 1j * np.multiply(np.random.normal(size=(n_orig), loc=x_mean, scale=np.sqrt(x_var)), negate2) 170 | Xi = np.outer(xi, xi.conjugate()) / 10 171 | # col_idx, row_idx = np.triu_indices(n_orig) 172 | # theta_mat[i, :] = Xi[(col_idx, row_idx)] 173 | # import pdb 174 | # pdb.set_trace() 175 | for j in range(d): 176 | # the trace will be real for hermitian matrices, but we use np.real to remove small complex floats 177 | b_matrix[i, j] = np.real(np.trace(A_tensor[j] @ Xi)) 178 | return b_matrix 179 | 180 | 181 | def get_q_mat(b_matrix, prob, b_param, m, n): 182 | """ 183 | change this so that b_matrix, b_param is passed in 184 | instead of A_tensor, A_param 185 | 186 | I think this should work now 187 | """ 188 | N = b_matrix.shape[0] 189 | q_mat = jnp.zeros((N, m + n)) 190 | for i in range(N): 191 | # set the parameter 192 | b_param.value = b_matrix[i, :] 193 | 194 | # get the problem data 195 | data, _, __ = prob.get_problem_data(cp.SCS) 196 | 197 | c, b = data['c'], data['b'] 198 | n = c.size 199 | q_mat = q_mat.at[i, :n].set(c) 200 | q_mat = q_mat.at[i, n:].set(b) 201 | return q_mat 202 | 203 | 204 | def static_canon(n_orig, d_mul, rho_x=1, scale=1, factor=True, seed=42): 205 | # create the cvxpy problem 206 | prob, A_tensor, b_param = cvxpy_prob(n_orig, d_mul, seed=42) 207 | 208 | # get the problem data 209 | data, _, __ = prob.get_problem_data(cp.SCS) 210 | 211 | A_sparse, c, b = data['A'], data['c'], data['b'] 212 | m, n = A_sparse.shape 213 | P_sparse = csc_matrix(np.zeros((n, n))) 214 | cones_cp = data['dims'] 215 | 216 | # factor for DR splitting 217 | m, n = A_sparse.shape 218 | P_jax, A_jax = jnp.array(P_sparse.todense()), jnp.array(A_sparse.todense()) 219 | M = create_M(P_jax, A_jax) 220 | zero_cone_size = cones_cp.zero 221 | 222 | if factor: 223 | algo_factor, scale_vec = get_scaled_vec_and_factor(M, rho_x, scale, m, n, 224 | zero_cone_size) 225 | # algo_factor = jsp.linalg.lu_factor(M + jnp.eye(n + m)) 226 | else: 227 | algo_factor = None 228 | 229 | # import pdb 230 | # pdb.set_trace() 231 | 232 | # set the dict 233 | cones = {'z': cones_cp.zero, 'l': cones_cp.nonneg, 'q': cones_cp.soc, 's': cones_cp.psd} 234 | out_dict = dict( 235 | M=M, 236 | algo_factor=algo_factor, 237 | cones_dict=cones, 238 | A_sparse=A_sparse, 239 | P_sparse=P_sparse, 240 | b=b, 241 | c=c, 242 | prob=prob, 243 | # A_param=A_param, 244 | A_tensor=A_tensor, 245 | b_param=b_param, 246 | ) 247 | return out_dict 248 | 249 | 250 | def setup_probs(setup_cfg): 251 | cfg = setup_cfg 252 | N_train, N_test = cfg.N_train, cfg.N_test 253 | N = N_train + N_test 254 | n_orig = cfg.n_orig 255 | d_mul = cfg.d_mul 256 | x_var = cfg.x_var 257 | x_mean = cfg.x_mean 258 | # d_orig = n_orig * d_mul 259 | 260 | np.random.seed(cfg.seed) 261 | key = jra.PRNGKey(cfg.seed) 262 | 263 | # save output to output_filename 264 | output_filename = f"{os.getcwd()}/data_setup" 265 | 266 | ################## TODO add extra params to generation 267 | P, A, cones, q_mat, theta_mat_jax = multiple_random_phase_retrieval( 268 | n_orig, d_mul, x_mean, x_var, N) 269 | 270 | P_sparse, A_sparse = csc_matrix(P), csc_matrix(A) 271 | m, n = A.shape 272 | 273 | # create scs solver object 274 | # we can cache the factorization if we do it like this 275 | b_np, c_np = np.array(q_mat[0, n:]), np.array(q_mat[0, :n]) 276 | data = dict(P=P_sparse, A=A_sparse, b=b_np, c=c_np) 277 | tol_abs = cfg.solve_acc_abs 278 | tol_rel = cfg.solve_acc_rel 279 | max_iters = cfg.get('solve_max_iters', 10000) 280 | solver = scs.SCS(data, cones, eps_abs=tol_abs, eps_rel=tol_rel, max_iters=max_iters) 281 | 282 | setup_script(q_mat, theta_mat_jax, solver, data, cones, output_filename, solve=cfg.solve) 283 | -------------------------------------------------------------------------------- /l2ws/scs_model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import cvxpy as cp 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import scs 7 | from scipy.sparse import csc_matrix 8 | 9 | from l2ws.algo_steps import ( 10 | create_M, 11 | get_scaled_vec_and_factor, 12 | k_steps_eval_scs, 13 | k_steps_train_scs, 14 | ) 15 | from l2ws.l2ws_model import L2WSmodel 16 | 17 | 18 | class SCSmodel(L2WSmodel): 19 | def __init__(self, **kwargs): 20 | super(SCSmodel, self).__init__(**kwargs) 21 | 22 | def initialize_algo(self, input_dict): 23 | """ 24 | the input_dict is required to contain these keys 25 | otherwise there is an error 26 | """ 27 | self.factors_required = True 28 | self.factor_static_bool = input_dict.get('factor_static_bool', True) 29 | self.algo = 'scs' 30 | self.factors_required = True 31 | self.hsde = input_dict.get('hsde', True) 32 | self.m, self.n = input_dict['m'], input_dict['n'] 33 | self.cones = input_dict['cones'] 34 | self.proj, self.static_flag = input_dict['proj'], input_dict.get('static_flag', True) 35 | self.q_mat_train, self.q_mat_test = input_dict['q_mat_train'], input_dict['q_mat_test'] 36 | 37 | M = input_dict['static_M'] 38 | self.P = M[:self.n, :self.n] 39 | self.A = -M[self.n:, :self.n] 40 | 41 | factor = input_dict['static_algo_factor'] 42 | self.factor = factor 43 | self.factor_static = factor 44 | 45 | # hyperparameters of scs 46 | self.rho_x = input_dict.get('rho_x', 1) 47 | self.scale = input_dict.get('scale', 1) 48 | self.alpha_relax = input_dict.get('alpha_relax', 1) 49 | 50 | # not a hyperparameter, but used for scale knob 51 | self.zero_cone_size = self.cones['z'] #input_dict['zero_cone_size'] 52 | lightweight = input_dict.get('lightweight', False) 53 | 54 | self.output_size = self.n + self.m 55 | self.out_axes_length = 8 56 | 57 | self.k_steps_train_fn = partial(k_steps_train_scs, factor=factor, proj=self.proj, 58 | rho_x=self.rho_x, scale=self.scale, 59 | alpha=self.alpha_relax, jit=self.jit, 60 | m=self.m, 61 | n=self.n, 62 | zero_cone_size=self.zero_cone_size, 63 | hsde=True) 64 | self.k_steps_eval_fn = partial(k_steps_eval_scs, factor=factor, proj=self.proj, 65 | P=self.P, A=self.A, 66 | zero_cone_size=self.zero_cone_size, 67 | rho_x=self.rho_x, scale=self.scale, 68 | alpha=self.alpha_relax, 69 | jit=self.jit, 70 | hsde=True, 71 | lightweight=lightweight) 72 | 73 | # def setup_optimal_solutions(self, dict): 74 | def setup_optimal_solutions(self, 75 | z_stars_train, 76 | z_stars_test, 77 | x_stars_train=None, 78 | x_stars_test=None, 79 | y_stars_train=None, 80 | y_stars_test=None): 81 | # if dict.get('x_stars_train', None) is not None: 82 | if x_stars_train is not None: 83 | # self.y_stars_train, self.y_stars_test = dict['y_stars_train'], dict['y_stars_test'] 84 | # self.x_stars_train, self.x_stars_test = dict['x_stars_train'], dict['x_stars_test'] 85 | # self.z_stars_train = jnp.array(dict['z_stars_train']) 86 | # self.z_stars_test = jnp.array(dict['z_stars_test']) 87 | self.z_stars_train = jnp.array(z_stars_train) 88 | self.z_stars_test = jnp.array(z_stars_test) 89 | self.x_stars_train = jnp.array(x_stars_train) 90 | self.x_stars_test = jnp.array(x_stars_test) 91 | self.y_stars_train = jnp.array(y_stars_train) 92 | self.y_stars_test = jnp.array(y_stars_test) 93 | self.u_stars_train = jnp.hstack([self.x_stars_train, self.y_stars_train]) 94 | self.u_stars_test = jnp.hstack([self.x_stars_test, self.y_stars_test]) 95 | if z_stars_train is not None: 96 | self.z_stars_train = z_stars_train 97 | self.z_stars_test = z_stars_test 98 | else: 99 | self.z_stars_train, self.z_stars_test = None, None 100 | 101 | 102 | def solve_c(self, z0_mat, q_mat, rel_tol, abs_tol, max_iter=10000): 103 | # assume M doesn't change across problems 104 | # static problem data 105 | m, n = self.m, self.n 106 | P, A = self.P, self.A 107 | 108 | # set the solver 109 | b_zeros, c_zeros = np.zeros(m), np.zeros(n) 110 | 111 | # osqp_solver = osqp.OSQP() 112 | P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 113 | c_data = dict(P=P_sparse, A=A_sparse, c=c_zeros, b=b_zeros) 114 | 115 | solver = scs.SCS(c_data, 116 | self.cones, 117 | normalize=False, 118 | scale=self.scale, 119 | adaptive_scale=False, 120 | rho_x=self.rho_x, 121 | alpha=self.alpha_relax, 122 | acceleration_lookback=0, 123 | max_iters=max_iter, 124 | eps_abs=abs_tol, 125 | eps_rel=rel_tol, 126 | verbose=False) 127 | 128 | 129 | 130 | # q = q_mat[0, :] 131 | # c, l, u = np.zeros(n), np.zeros(m), np.zeros(m) 132 | # osqp_solver.setup(P=P_sparse, q=c, A=A_sparse, l=l, u=u, alpha=1, 133 | # rho=1, sigma=1, polish=False, 134 | # adaptive_rho=False, scaling=0, max_iter=max_iter, 135 | # verbose=True, eps_abs=abs_tol, eps_rel=rel_tol) 136 | 137 | num = z0_mat.shape[0] 138 | solve_times = np.zeros(num) 139 | solve_iters = np.zeros(num) 140 | x_sols = jnp.zeros((num, n)) 141 | y_sols = jnp.zeros((num, m)) 142 | for i in range(num): 143 | # set c, l, u 144 | # c, l, u = q_mat[i, :n], q_mat[i, n:n + m], q_mat[i, n + m:] 145 | # osqp_solver.update(q=np.array(c)) 146 | # osqp_solver.update(l=np.array(l), u=np.array(u)) 147 | b, c = q_mat[i, n:], q_mat[i, :n] 148 | solver.update(b=np.array(b)) 149 | solver.update(c=np.array(c)) 150 | 151 | # set the warm start 152 | x, y, s = self.get_xys_from_z(z0_mat[i, :], m, n) 153 | x_ws, y_ws = np.array(x), np.array(y) 154 | s_ws = np.array(s) 155 | 156 | # fix warm start 157 | # osqp_solver.warm_start(x=x_ws, y=y_ws) 158 | sol = solver.solve(warm_start=True, x=x_ws, y=y_ws, s=s_ws) 159 | 160 | # solve 161 | # results = osqp_solver.solve() 162 | # sol = solver.solve(warm_start=True, x=np.array(x), y=np.array(y), s=np.array(s)) 163 | 164 | # set the solve time in seconds 165 | # solve_times[i] = results.info.solve_time 166 | # solve_iters[i] = results.info.iter 167 | solve_times[i] = sol['info']['solve_time'] #/ 1000 168 | solve_iters[i] = sol['info']['iter'] 169 | 170 | # set the results 171 | x_sols = x_sols.at[i, :].set(sol['x']) 172 | y_sols = y_sols.at[i, :].set(sol['y']) 173 | 174 | return solve_times, solve_iters, x_sols, y_sols 175 | 176 | def get_xys_from_z(self, z_init, m, n): 177 | """ 178 | z = (x, y + s, 1) 179 | we always set the last entry of z to be 1 180 | we allow s to be zero (we just need z[n:m + n] = y + s) 181 | """ 182 | # m, n = self.l2ws_model.m, self.l2ws_model.n 183 | x = z_init[:n] 184 | y = z_init[n:n + m] 185 | s = jnp.zeros(m) 186 | return x, y, s 187 | 188 | 189 | def get_scs_factor(P, A, cones, rho_x=1, scale=1): 190 | m, n = A.shape 191 | zero_cone_size = cones['z'] 192 | P_jax = jnp.array(P.todense()) 193 | A_jax = jnp.array(A.todense()) 194 | M_jax = create_M(P_jax, A_jax) 195 | algo_factor, scale_vec = get_scaled_vec_and_factor(M_jax, rho_x, scale, m, n, 196 | zero_cone_size) 197 | return M_jax, algo_factor, scale_vec 198 | 199 | 200 | def solve_cvxpy_get_params(prob, cp_param, theta_values): 201 | cp_param.value = theta_values[0] 202 | prob.solve() 203 | data, _, __ = prob.get_problem_data(cp.SCS) 204 | c, b = data['c'], data['b'] 205 | A = data['A'] 206 | 207 | m = b.size 208 | n = c.size 209 | P = csc_matrix(np.zeros((n, n))) 210 | 211 | cones_cp = data['dims'] 212 | cones = {'z': cones_cp.zero, 'l': cones_cp.nonneg, 'q': cones_cp.soc, 's': cones_cp.psd} 213 | 214 | N = len(theta_values) 215 | 216 | q_mat = np.zeros((N, m + n)) 217 | z_stars = np.zeros((N, m + n)) 218 | x_stars = np.zeros((N, n)) 219 | y_stars = np.zeros((N, m)) 220 | for i in range(N): 221 | cp_param.value = theta_values[i] 222 | prob.solve() 223 | data, _, __ = prob.get_problem_data(cp.SCS) 224 | c, b = data['c'], data['b'] 225 | 226 | # q = (c, b) 227 | q_mat[i, :n] = c 228 | q_mat[i, n:] = b 229 | 230 | # get the optimal solution 231 | x_star = prob.solution.attr['solver_specific_stats']['x'] 232 | y_star = prob.solution.attr['solver_specific_stats']['y'] 233 | s_star = prob.solution.attr['solver_specific_stats']['s'] 234 | 235 | # transform the solution to the z variable 236 | z_star = np.concatenate([x_star, y_star + s_star]) 237 | z_stars[i, :] = z_star 238 | x_stars[i, :] = x_star 239 | y_stars[i, :] = y_star 240 | return z_stars, q_mat, cones, P, A 241 | -------------------------------------------------------------------------------- /l2ws/examples/markowitz.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from jax import vmap 3 | from l2ws.scs_problem import SCSinstance, scs_jax 4 | import numpy as np 5 | import pdb 6 | from l2ws.launcher import Workspace 7 | import jax.numpy as jnp 8 | from scipy.sparse import csc_matrix 9 | import jax.scipy as jsp 10 | import time 11 | import matplotlib.pyplot as plt 12 | import os 13 | import scs 14 | import logging 15 | import yaml 16 | log = logging.getLogger(__name__) 17 | 18 | 19 | def run(run_cfg): 20 | ''' 21 | retrieve data for this config 22 | theta is all of the following 23 | theta = (ret, pen_risk, pen_hold, pen_trade, w0) 24 | 25 | Sigma is constant 26 | 27 | just need (theta, factor, u_star), Pi 28 | ''' 29 | # todo: retrieve data and put into a nice form - OR - just save to nice form 30 | 31 | ''' 32 | create workspace 33 | needs to know the following somehow -- from the run_cfg 34 | 1. nn cfg 35 | 2. (theta, factor, u_star)_i=1^N 36 | 3. Pi 37 | 38 | 2. and 3. are stored in data files and the run_cfg holds the location 39 | 40 | it will create the l2a_model 41 | ''' 42 | 43 | datetime = run_cfg.data.datetime 44 | orig_cwd = hydra.utils.get_original_cwd() 45 | example = 'markowitz' 46 | folder = f"{orig_cwd}/outputs/{example}/aggregate_outputs/{datetime}" 47 | data_yaml_filename = f"{folder}/data_setup_copied.yaml" 48 | 49 | # read the yaml file 50 | with open(data_yaml_filename, "r") as stream: 51 | try: 52 | setup_cfg = yaml.safe_load(stream) 53 | except yaml.YAMLError as exc: 54 | print(exc) 55 | setup_cfg = {} 56 | 57 | pen_ret = 10**setup_cfg['pen_rets_min'] 58 | a = setup_cfg['a'] 59 | static_dict = static_canon( 60 | setup_cfg['data'], a, setup_cfg['idio_risk'], setup_cfg['scale_factor']) 61 | 62 | def get_q(theta): 63 | q = jnp.zeros(2*a + 1) 64 | q = q.at[:a].set(-theta * pen_ret) 65 | q = q.at[a].set(1) 66 | return q 67 | get_q_batch = vmap(get_q, in_axes=(0), out_axes=(0)) 68 | static_flag = True 69 | workspace = Workspace(run_cfg, static_flag, static_dict, 'markowitz', get_q_batch) 70 | 71 | ''' 72 | run the workspace 73 | ''' 74 | workspace.run() 75 | 76 | 77 | def setup_probs(setup_cfg): 78 | print('entered convex markowitz', flush=True) 79 | cfg = setup_cfg 80 | 81 | a = cfg.a 82 | N_train, N_test = cfg.N_train, cfg.N_test 83 | N = N_train + N_test 84 | # std_mult = cfg.std_mult 85 | pen_rets_min = cfg.pen_rets_min 86 | pen_rets_max = cfg.pen_rets_max 87 | alpha = cfg.alpha 88 | # max_clip, min_clip = cfg.max_clip, cfg.min_clip 89 | 90 | # p is the size of each feature vector (mu and pen_rets_factor) 91 | if pen_rets_max > pen_rets_min: 92 | p = a + 1 93 | else: 94 | p = a 95 | thetas = jnp.zeros((N, p)) 96 | 97 | # read in the returns dataframe 98 | orig_cwd = hydra.utils.get_original_cwd() 99 | if cfg.data == 'yahoo': 100 | ret_cov_np = f"{orig_cwd}/data/portfolio_data/yahoo_ret_cov.npz" 101 | elif cfg.data == 'nasdaq': 102 | ret_cov_np = f"{orig_cwd}/data/portfolio_data/ret_cov.npz" 103 | elif cfg.data == 'eod': 104 | ret_cov_np = f"{orig_cwd}/data/portfolio_data/eod_ret_cov_factor.npz" 105 | 106 | ret_cov_loaded = np.load(ret_cov_np) 107 | # Sigma = ret_cov_loaded['cov'] + np.eye(a) * cfg.idio_risk 108 | ret = ret_cov_loaded['ret'][1:, :a] 109 | 110 | # ret_mean = ret.mean(axis=0) 111 | # clipped_ret_mean_orig = np.clip(ret_mean, a_min=min_clip, a_max=max_clip) 112 | # clipped_ret_mean = clipped_ret_mean_orig * SCALE_FACTOR 113 | 114 | log.info('creating static canonicalization...') 115 | t0 = time.time() 116 | out_dict = static_canon(cfg.data, a, cfg.idio_risk, cfg.scale_factor) 117 | t1 = time.time() 118 | log.info(f"finished static canonicalization - took {t1-t0} seconds") 119 | 120 | A_sparse = out_dict['A_sparse'] 121 | A_sparse, P_sparse = out_dict['A_sparse'], out_dict['P_sparse'] 122 | b = out_dict['b'] 123 | n = a 124 | m = a + 1 125 | cones_dict = dict(z=1, l=a) 126 | 127 | ''' 128 | save output to output_filename 129 | ''' 130 | # save to outputs/mm-dd-ss/... file 131 | if "SLURM_ARRAY_TASK_ID" in os.environ.keys(): 132 | slurm_idx = os.environ["SLURM_ARRAY_TASK_ID"] 133 | output_filename = f"{os.getcwd()}/data_setup_slurm_{slurm_idx}" 134 | else: 135 | output_filename = f"{os.getcwd()}/data_setup_slurm" 136 | ''' 137 | create scs solver object 138 | we can cache the factorization if we do it like this 139 | ''' 140 | blank_b = np.zeros(m) 141 | blank_c = np.zeros(n) 142 | data = dict(P=P_sparse, A=A_sparse, b=blank_b, c=blank_c) 143 | tol = cfg.solve_acc 144 | solver = scs.SCS(data, cones_dict, eps_abs=tol, eps_rel=tol) 145 | solve_times = np.zeros(N) 146 | x_stars = jnp.zeros((N, n)) 147 | y_stars = jnp.zeros((N, m)) 148 | s_stars = jnp.zeros((N, m)) 149 | q_mat = jnp.zeros((N, n + m)) 150 | mu_mat = np.zeros((N, a)) 151 | pen_rets = np.zeros(N) 152 | scs_instances = [] 153 | 154 | # mu_mat = SCALE_FACTOR * np.random.multivariate_normal(clipped_ret_mean_orig, Sigma, size=(N)) 155 | # pdb.set_trace() 156 | mu_mat = np.zeros((N, a)) 157 | T = ret.shape[0] 158 | noise = np.sqrt(.02) * np.random.normal(size=(N, a)) 159 | 160 | for i in range(N): 161 | log.info(f"solving problem number {i}") 162 | time_index = i % T 163 | mu_mat[i, :] = cfg.scale_factor * alpha * (ret[time_index, :] + noise[i, :]) 164 | # mu_mat[i, :] = clipped_ret_mean * \ 165 | # (1 + std_mult*np.random.normal(size=(a)) 166 | # ) + std_mult *np.random.normal(size=(a)) 167 | 168 | sample = np.random.rand(1) * (pen_rets_max - 169 | pen_rets_min) + pen_rets_min 170 | pen_rets[i] = 10 ** sample 171 | thetas = thetas.at[i, :a].set(mu_mat[i, :]) 172 | if pen_rets_max > pen_rets_min: 173 | thetas = thetas.at[i, a].set(pen_rets[i]) 174 | 175 | # manual canon 176 | c = -mu_mat[i, :] * pen_rets[i] 177 | manual_canon_dict = {'P': P_sparse, 'A': A_sparse, 178 | 'b': b, 'c': c, 179 | 'cones': cones_dict} 180 | scs_instance = SCSinstance( 181 | manual_canon_dict, solver, manual_canon=True) 182 | 183 | scs_instances.append(scs_instance) 184 | x_stars = x_stars.at[i, :].set(scs_instance.x_star) 185 | y_stars = y_stars.at[i, :].set(scs_instance.y_star) 186 | s_stars = s_stars.at[i, :].set(scs_instance.s_star) 187 | q_mat = q_mat.at[i, :].set(scs_instance.q) 188 | solve_times[i] = scs_instance.solve_time 189 | 190 | # input into our_scs 191 | P_jax, A_jax = jnp.array(P_sparse.todense()), jnp.array(A_sparse.todense()) 192 | b_jax, c_jax = jnp.array(b), jnp.array(c) 193 | data = dict(P=P_jax, A=A_jax, b=b_jax, c=c_jax, cones=cones_dict) 194 | scs_jax(data, iters=1000) 195 | # pdb.set_trace() 196 | 197 | # M, E, D = ruiz_equilibrate(M) 198 | pdb.set_trace() 199 | 200 | # resave the data?? 201 | # print('saving final data...', flush=True) 202 | log.info('saving final data...') 203 | t0 = time.time() 204 | jnp.savez(output_filename, 205 | thetas=thetas, 206 | x_stars=x_stars, 207 | y_stars=y_stars, 208 | ) 209 | save_time = time.time() 210 | log.info(f"finished saving final data... took {save_time-t0}'") 211 | 212 | # save plot of first 5 solutions 213 | for i in range(20): 214 | plt.plot(x_stars[i, :], label=i) 215 | if T < N: 216 | plt.plot(x_stars[T, :], label=T) 217 | plt.legend() 218 | plt.savefig('opt_solutions.pdf') 219 | plt.clf() 220 | 221 | # save plot of first 5 parameters 222 | for i in range(5): 223 | plt.plot(thetas[i, :], label=i) 224 | if T < N: 225 | plt.plot(thetas[T, :], label=T) 226 | plt.legend() 227 | plt.savefig('thetas.pdf') 228 | 229 | 230 | def static_canon(data, a, idio_risk, scale_factor): 231 | ''' 232 | This method produces the parts of each problem that does not change 233 | i.e. P, A, b, cones 234 | 235 | It creates the matrix 236 | M = [P A.T 237 | -A 0] 238 | 239 | It also returns the necessary factorizations 240 | 1. factor(I + M) 241 | 2. factor(A.T A) 242 | ''' 243 | orig_cwd = hydra.utils.get_original_cwd() 244 | if data == 'yahoo': 245 | ret_cov_np = f"{orig_cwd}/data/portfolio_data/yahoo_ret_cov.npz" 246 | elif data == 'nasdaq': 247 | ret_cov_np = f"{orig_cwd}/data/portfolio_data/ret_cov.npz" 248 | elif data == 'eod': 249 | ret_cov_np = f"{orig_cwd}/data/portfolio_data/eod_ret_cov_factor.npz" 250 | 251 | ret_cov_loaded = np.load(ret_cov_np) 252 | Sigma = ret_cov_loaded['cov'][:a, :a] + idio_risk * np.eye(a) 253 | n = a 254 | m = a + 1 255 | 256 | # scale Sigma 257 | Sigma = scale_factor * Sigma 258 | 259 | # do the manual canonicalization 260 | b = np.zeros(a + 1) 261 | b[0] = 1 262 | A = np.zeros((a + 1, a)) 263 | A[0, :] = 1 264 | A[1:, :] = -np.eye(a) 265 | A_sparse = csc_matrix(A) 266 | P_sparse = csc_matrix(Sigma) 267 | 268 | # cones 269 | cones_dict = dict(z=1, l=a) 270 | cones_array = jnp.array([cones_dict['z'], cones_dict['l']]) 271 | 272 | # factor for dual prediction from primal 273 | ATA_factor = jsp.linalg.cho_factor(A.T @ A) 274 | 275 | # create the matrix M 276 | M = jnp.zeros((n + m, n + m)) 277 | P_jax = jnp.array(Sigma) 278 | A_jax = jnp.array(A) 279 | M = M.at[:n, :n].set(P_jax) 280 | M = M.at[:n, n:].set(A_jax.T) 281 | M = M.at[n:, :n].set(-A_jax) 282 | 283 | # factor for DR splitting 284 | algo_factor = jsp.linalg.lu_factor(M + jnp.eye(n+m)) 285 | 286 | out_dict = dict(Sigma=Sigma, M=M, 287 | ATA_factor=ATA_factor, 288 | algo_factor=algo_factor, 289 | cones_array=cones_array, 290 | cones_dict=cones_dict, 291 | A_sparse=A_sparse, 292 | P_sparse=P_sparse, 293 | b=b) 294 | return out_dict 295 | -------------------------------------------------------------------------------- /benchmarks/l2ws_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import hydra 4 | 5 | import l2ws.examples.jamming as jamming 6 | import l2ws.examples.lasso as lasso 7 | import l2ws.examples.markowitz as markowitz 8 | import l2ws.examples.mnist as mnist 9 | import l2ws.examples.mpc as mpc 10 | import l2ws.examples.osc_mass as osc_mass 11 | import l2ws.examples.phase_retrieval as phase_retrieval 12 | import l2ws.examples.quadcopter as quadcopter 13 | import l2ws.examples.robust_kalman as robust_kalman 14 | import l2ws.examples.robust_ls as robust_ls 15 | import l2ws.examples.robust_pca as robust_pca 16 | import l2ws.examples.sparse_pca as sparse_pca 17 | import l2ws.examples.unconstrained_qp as unconstrained_qp 18 | import l2ws.examples.vehicle as vehicle 19 | from l2ws.utils.data_utils import copy_data_file, recover_last_datetime 20 | 21 | 22 | @hydra.main(config_path='configs/markowitz', config_name='markowitz_run.yaml') 23 | def main_run_markowitz(cfg): 24 | orig_cwd = hydra.utils.get_original_cwd() 25 | example = 'markowitz' 26 | agg_datetime = cfg.data.datetime 27 | if agg_datetime == '': 28 | # get the most recent datetime and update datetimes 29 | agg_datetime = recover_last_datetime(orig_cwd, example, 'aggregate') 30 | cfg.data.datetime = agg_datetime 31 | copy_data_file(example, agg_datetime) 32 | markowitz.run(cfg) 33 | 34 | 35 | @hydra.main(config_path='configs/osc_mass', config_name='osc_mass_run.yaml') 36 | def main_run_osc_mass(cfg): 37 | orig_cwd = hydra.utils.get_original_cwd() 38 | example = 'osc_mass' 39 | agg_datetime = cfg.data.datetime 40 | if agg_datetime == '': 41 | # get the most recent datetime and update datetimes 42 | agg_datetime = recover_last_datetime(orig_cwd, example, 'aggregate') 43 | cfg.data.datetime = agg_datetime 44 | copy_data_file(example, agg_datetime) 45 | osc_mass.run(cfg) 46 | 47 | 48 | @hydra.main(config_path='configs/lasso', config_name='lasso_run.yaml') 49 | def main_run_lasso(cfg): 50 | orig_cwd = hydra.utils.get_original_cwd() 51 | example = 'lasso' 52 | setup_datetime = cfg.data.datetime 53 | if setup_datetime == '': 54 | # get the most recent datetime and update datetimes 55 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 56 | cfg.data.datetime = setup_datetime 57 | copy_data_file(example, setup_datetime) 58 | lasso.run(cfg) 59 | 60 | 61 | @hydra.main(config_path='configs/quadcopter', config_name='quadcopter_run.yaml') 62 | def main_run_quadcopter(cfg): 63 | orig_cwd = hydra.utils.get_original_cwd() 64 | example = 'quadcopter' 65 | setup_datetime = cfg.data.datetime 66 | if setup_datetime == '': 67 | # get the most recent datetime and update datetimes 68 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 69 | cfg.data.datetime = setup_datetime 70 | copy_data_file(example, setup_datetime) 71 | quadcopter.run(cfg) 72 | 73 | 74 | @hydra.main(config_path='configs/jamming', config_name='jamming_run.yaml') 75 | def main_run_jamming(cfg): 76 | orig_cwd = hydra.utils.get_original_cwd() 77 | example = 'jamming' 78 | setup_datetime = cfg.data.datetime 79 | if setup_datetime == '': 80 | # get the most recent datetime and update datetimes 81 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 82 | cfg.data.datetime = setup_datetime 83 | copy_data_file(example, setup_datetime) 84 | jamming.run(cfg) 85 | 86 | 87 | @hydra.main(config_path='configs/mnist', config_name='mnist_run.yaml') 88 | def main_run_mnist(cfg): 89 | orig_cwd = hydra.utils.get_original_cwd() 90 | example = 'mnist' 91 | setup_datetime = cfg.data.datetime 92 | if setup_datetime == '': 93 | # get the most recent datetime and update datetimes 94 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 95 | cfg.data.datetime = setup_datetime 96 | copy_data_file(example, setup_datetime) 97 | mnist.run(cfg) 98 | 99 | 100 | @hydra.main(config_path='configs/unconstrained_qp', config_name='unconstrained_qp_run.yaml') 101 | def main_run_unconstrained_qp(cfg): 102 | orig_cwd = hydra.utils.get_original_cwd() 103 | example = 'unconstrained_qp' 104 | setup_datetime = cfg.data.datetime 105 | if setup_datetime == '': 106 | # get the most recent datetime and update datetimes 107 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 108 | cfg.data.datetime = setup_datetime 109 | copy_data_file(example, setup_datetime) 110 | unconstrained_qp.run(cfg) 111 | 112 | 113 | @hydra.main(config_path='configs/mpc', config_name='mpc_run.yaml') 114 | def main_run_mpc(cfg): 115 | orig_cwd = hydra.utils.get_original_cwd() 116 | example = 'mpc' 117 | setup_datetime = cfg.data.datetime 118 | if setup_datetime == '': 119 | # get the most recent datetime and update datetimes 120 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 121 | cfg.data.datetime = setup_datetime 122 | copy_data_file(example, setup_datetime) 123 | mpc.run(cfg) 124 | 125 | 126 | @hydra.main(config_path='configs/robust_kalman', config_name='robust_kalman_run.yaml') 127 | def main_run_robust_kalman(cfg): 128 | orig_cwd = hydra.utils.get_original_cwd() 129 | example = 'robust_kalman' 130 | setup_datetime = cfg.data.datetime 131 | if setup_datetime == '': 132 | # get the most recent datetime and update datetimes 133 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 134 | cfg.data.datetime = setup_datetime 135 | copy_data_file(example, setup_datetime) 136 | robust_kalman.run(cfg) 137 | 138 | 139 | @hydra.main(config_path='configs/robust_pca', config_name='robust_pca_run.yaml') 140 | def main_run_robust_pca(cfg): 141 | orig_cwd = hydra.utils.get_original_cwd() 142 | example = 'robust_pca' 143 | agg_datetime = cfg.data.datetime 144 | if agg_datetime == '': 145 | # get the most recent datetime and update datetimes 146 | agg_datetime = recover_last_datetime(orig_cwd, example, 'aggregate') 147 | cfg.data.datetime = agg_datetime 148 | copy_data_file(example, agg_datetime) 149 | robust_pca.run(cfg) 150 | 151 | 152 | @hydra.main(config_path='configs/robust_ls', config_name='robust_ls_run.yaml') 153 | def main_run_robust_ls(cfg): 154 | orig_cwd = hydra.utils.get_original_cwd() 155 | example = 'robust_ls' 156 | setup_datetime = cfg.data.datetime 157 | if setup_datetime == '': 158 | # get the most recent datetime and update datetimes 159 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 160 | cfg.data.datetime = setup_datetime 161 | copy_data_file(example, setup_datetime) 162 | robust_ls.run(cfg) 163 | 164 | 165 | @hydra.main(config_path='configs/sparse_pca', config_name='sparse_pca_run.yaml') 166 | def main_run_sparse_pca(cfg): 167 | orig_cwd = hydra.utils.get_original_cwd() 168 | example = 'sparse_pca' 169 | setup_datetime = cfg.data.datetime 170 | if setup_datetime == '': 171 | # get the most recent datetime and update datetimes 172 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 173 | cfg.data.datetime = setup_datetime 174 | copy_data_file(example, setup_datetime) 175 | sparse_pca.run(cfg) 176 | 177 | 178 | @hydra.main(config_path='configs/phase_retrieval', config_name='phase_retrieval_run.yaml') 179 | def main_run_phase_retrieval(cfg): 180 | orig_cwd = hydra.utils.get_original_cwd() 181 | example = 'phase_retrieval' 182 | setup_datetime = cfg.data.datetime 183 | if setup_datetime == '': 184 | # get the most recent datetime and update datetimes 185 | setup_datetime = recover_last_datetime(orig_cwd, example, 'data_setup') 186 | cfg.data.datetime = setup_datetime 187 | copy_data_file(example, setup_datetime) 188 | phase_retrieval.run(cfg) 189 | 190 | 191 | @hydra.main(config_path='configs/vehicle', config_name='vehicle_run.yaml') 192 | def main_run_vehicle(cfg): 193 | orig_cwd = hydra.utils.get_original_cwd() 194 | example = 'vehicle' 195 | agg_datetime = cfg.data.datetime 196 | if agg_datetime == '': 197 | # get the most recent datetime and update datetimes 198 | agg_datetime = recover_last_datetime(orig_cwd, example, 'aggregate') 199 | cfg.data.datetime = agg_datetime 200 | copy_data_file(example, agg_datetime) 201 | vehicle.run(cfg) 202 | 203 | 204 | if __name__ == '__main__': 205 | if sys.argv[2] == 'cluster': 206 | base = 'hydra.run.dir=/scratch/gpfs/rajivs/learn2warmstart/outputs/' 207 | elif sys.argv[2] == 'local': 208 | base = 'hydra.run.dir=outputs/' 209 | if sys.argv[1] == 'markowitz': 210 | # step 1. remove the markowitz argument -- otherwise hydra uses it as an override 211 | # step 2. add the train_outputs/... argument for train_outputs not outputs 212 | # sys.argv[1] = 'hydra.run.dir=outputs/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 213 | sys.argv[1] = base + 'markowitz/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 214 | sys.argv = [sys.argv[0], sys.argv[1]] 215 | main_run_markowitz() 216 | elif sys.argv[1] == 'osc_mass': 217 | sys.argv[1] = base + 'osc_mass/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 218 | sys.argv = [sys.argv[0], sys.argv[1]] 219 | main_run_osc_mass() 220 | elif sys.argv[1] == 'vehicle': 221 | sys.argv[1] = base + 'vehicle/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 222 | sys.argv = [sys.argv[0], sys.argv[1]] 223 | main_run_vehicle() 224 | elif sys.argv[1] == 'robust_kalman': 225 | sys.argv[1] = base + 'robust_kalman/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 226 | sys.argv = [sys.argv[0], sys.argv[1]] 227 | main_run_robust_kalman() 228 | elif sys.argv[1] == 'robust_pca': 229 | sys.argv[1] = base + 'robust_pca/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 230 | sys.argv = [sys.argv[0], sys.argv[1]] 231 | main_run_robust_pca() 232 | elif sys.argv[1] == 'robust_ls': 233 | sys.argv[1] = base + 'robust_ls/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 234 | sys.argv = [sys.argv[0], sys.argv[1]] 235 | main_run_robust_ls() 236 | elif sys.argv[1] == 'sparse_pca': 237 | sys.argv[1] = base + 'sparse_pca/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 238 | sys.argv = [sys.argv[0], sys.argv[1]] 239 | main_run_sparse_pca() 240 | elif sys.argv[1] == 'phase_retrieval': 241 | sys.argv[1] = base + 'phase_retrieval/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 242 | sys.argv = [sys.argv[0], sys.argv[1]] 243 | main_run_phase_retrieval() 244 | elif sys.argv[1] == 'lasso': 245 | sys.argv[1] = base + 'lasso/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 246 | sys.argv = [sys.argv[0], sys.argv[1]] 247 | main_run_lasso() 248 | elif sys.argv[1] == 'mpc': 249 | sys.argv[1] = base + 'mpc/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 250 | sys.argv = [sys.argv[0], sys.argv[1]] 251 | main_run_mpc() 252 | elif sys.argv[1] == 'unconstrained_qp': 253 | sys.argv[1] = base + 'unconstrained_qp/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 254 | sys.argv = [sys.argv[0], sys.argv[1]] 255 | main_run_unconstrained_qp() 256 | elif sys.argv[1] == 'quadcopter': 257 | sys.argv[1] = base + 'quadcopter/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 258 | sys.argv = [sys.argv[0], sys.argv[1]] 259 | main_run_quadcopter() 260 | elif sys.argv[1] == 'mnist': 261 | sys.argv[1] = base + 'mnist/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 262 | sys.argv = [sys.argv[0], sys.argv[1]] 263 | main_run_mnist() 264 | elif sys.argv[1] == 'jamming': 265 | sys.argv[1] = base + 'jamming/train_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 266 | sys.argv = [sys.argv[0], sys.argv[1]] 267 | main_run_jamming() 268 | -------------------------------------------------------------------------------- /l2ws/examples/robust_ls.py: -------------------------------------------------------------------------------- 1 | from l2ws.scs_problem import SCSinstance 2 | import numpy as np 3 | from l2ws.launcher import Workspace 4 | import jax.numpy as jnp 5 | from scipy.sparse import csc_matrix 6 | import time 7 | import matplotlib.pyplot as plt 8 | import os 9 | import scs 10 | import logging 11 | import yaml 12 | from jax import vmap 13 | import pandas as pd 14 | from l2ws.algo_steps import get_scaled_vec_and_factor 15 | 16 | 17 | plt.rcParams.update( 18 | { 19 | "text.usetex": True, 20 | "font.family": "serif", 21 | "font.size": 16, 22 | } 23 | ) 24 | log = logging.getLogger(__name__) 25 | 26 | 27 | def single_q(theta, rho, m_orig, n_orig): 28 | # note: m, n are the sizes of the constraint matrix in the SOCP 29 | # theta is the vector b 30 | m = 2 * n_orig + m_orig + 2 31 | n = n_orig + 2 32 | 33 | # c 34 | c = jnp.zeros(n) 35 | c = c.at[-1].set(rho) 36 | c = c.at[-2].set(1) 37 | 38 | # b 39 | b = jnp.zeros(m) 40 | b = b.at[n_orig + 1: n_orig + m_orig + 1].set(-theta) 41 | 42 | # q 43 | m = b.size 44 | q = jnp.zeros(m + n) 45 | q = q.at[:n].set(c) 46 | q = q.at[n:].set(b) 47 | 48 | return q 49 | 50 | 51 | def run(run_cfg): 52 | example = "robust_ls" 53 | data_yaml_filename = 'data_setup_copied.yaml' 54 | 55 | # read the yaml file 56 | with open(data_yaml_filename, "r") as stream: 57 | try: 58 | setup_cfg = yaml.safe_load(stream) 59 | except yaml.YAMLError as exc: 60 | print(exc) 61 | setup_cfg = {} 62 | 63 | # set the seed 64 | np.random.seed(setup_cfg['seed']) 65 | m_orig, n_orig = setup_cfg['m_orig'], setup_cfg['n_orig'] 66 | rho = setup_cfg['rho'] 67 | 68 | # create the nominal matrix A 69 | A = (np.random.rand(m_orig, n_orig) * 2) - 1 70 | 71 | # non-identity DR scaling 72 | rho_x = run_cfg.get('rho_x', 1) 73 | scale = run_cfg.get('scale', 1) 74 | 75 | static_dict = static_canon(A, rho, rho_x=rho_x, scale=scale) 76 | 77 | get_q = None 78 | 79 | # static_flag = True 80 | # means that the matrices don't change across problems 81 | 82 | static_flag = True 83 | algo = 'scs' 84 | workspace = Workspace(algo, run_cfg, static_flag, static_dict, example) 85 | 86 | # run the workspace 87 | workspace.run() 88 | 89 | 90 | def setup_probs(setup_cfg): 91 | print("entered robust least squares setup", flush=True) 92 | cfg = setup_cfg 93 | N_train, N_test = cfg.N_train, cfg.N_test 94 | N = N_train + N_test 95 | m_orig, n_orig = cfg.m_orig, cfg.n_orig 96 | 97 | np.random.seed(cfg.seed) 98 | A = (np.random.rand(m_orig, n_orig) * 2) - 1 99 | 100 | log.info("creating static canonicalization...") 101 | t0 = time.time() 102 | out_dict = static_canon(A, cfg.rho) 103 | 104 | t1 = time.time() 105 | log.info(f"finished static canonicalization - took {t1-t0} seconds") 106 | 107 | cones_dict = out_dict["cones_dict"] 108 | A_sparse, P_sparse = out_dict["A_sparse"], out_dict["P_sparse"] 109 | b, c = out_dict["b"], out_dict["c"] 110 | m, n = A_sparse.shape 111 | 112 | # save output to output_filename 113 | output_filename = f"{os.getcwd()}/data_setup" 114 | 115 | # create scs solver object 116 | # we can cache the factorization if we do it like this 117 | data = dict(P=P_sparse, A=A_sparse, b=b, c=c) 118 | tol_abs = cfg.solve_acc_abs 119 | tol_rel = cfg.solve_acc_rel 120 | solver = scs.SCS(data, cones_dict, eps_abs=tol_abs, eps_rel=tol_rel) 121 | solve_times = np.zeros(N) 122 | x_stars = jnp.zeros((N, n)) 123 | y_stars = jnp.zeros((N, m)) 124 | s_stars = jnp.zeros((N, m)) 125 | q_mat = jnp.zeros((N, m + n)) 126 | 127 | # sample theta for each problem 128 | thetas_np = (2 * np.random.rand(N, m_orig) - 1) * cfg.b_range + cfg.b_nominal 129 | thetas = jnp.array(thetas_np) 130 | 131 | batch_q = vmap(single_q, in_axes=(0, None, None, None), out_axes=(0)) 132 | 133 | q_mat = batch_q(thetas, cfg.rho, cfg.m_orig, cfg.n_orig) 134 | 135 | scs_instances = [] 136 | 137 | if setup_cfg['solve']: 138 | for i in range(N): 139 | log.info(f"solving problem number {i}") 140 | 141 | # update 142 | b = np.array(q_mat[i, n:]) 143 | c = np.array(q_mat[i, :n]) 144 | 145 | # manual canon 146 | manual_canon_dict = { 147 | "P": P_sparse, 148 | "A": A_sparse, 149 | "b": b, 150 | "c": c, 151 | "cones": cones_dict, 152 | } 153 | scs_instance = SCSinstance(manual_canon_dict, solver, manual_canon=True) 154 | 155 | scs_instances.append(scs_instance) 156 | x_stars = x_stars.at[i, :].set(scs_instance.x_star) 157 | y_stars = y_stars.at[i, :].set(scs_instance.y_star) 158 | s_stars = s_stars.at[i, :].set(scs_instance.s_star) 159 | q_mat = q_mat.at[i, :].set(scs_instance.q) 160 | solve_times[i] = scs_instance.solve_time 161 | 162 | if i % 1000 == 0: 163 | log.info(f"saving final data... after solving problem number {i}") 164 | jnp.savez( 165 | output_filename, 166 | thetas=thetas, 167 | x_stars=x_stars, 168 | y_stars=y_stars, 169 | s_stars=s_stars, 170 | q_mat=q_mat 171 | ) 172 | # save the data 173 | log.info("final saving final data...") 174 | t0 = time.time() 175 | jnp.savez( 176 | output_filename, 177 | thetas=thetas, 178 | x_stars=x_stars, 179 | y_stars=y_stars, 180 | s_stars=s_stars, 181 | q_mat=q_mat 182 | ) 183 | else: 184 | log.info("final saving final data...") 185 | t0 = time.time() 186 | jnp.savez( 187 | output_filename, 188 | thetas=thetas, 189 | q_mat=q_mat, 190 | m=m, 191 | n=n 192 | ) 193 | 194 | # save solve times 195 | df_solve_times = pd.DataFrame(solve_times, columns=['solve_times']) 196 | df_solve_times.to_csv('solve_times.csv') 197 | 198 | save_time = time.time() 199 | log.info(f"finished saving final data... took {save_time-t0}'") 200 | 201 | # save plot of first 5 solutions 202 | for i in range(5): 203 | plt.plot(x_stars[i, :]) 204 | plt.savefig("x_stars.pdf") 205 | plt.clf() 206 | 207 | # save plot of first 5 solutions - just x 208 | for i in range(5): 209 | plt.plot(x_stars[i, :-2]) 210 | plt.savefig("x_stars_just_x.pdf") 211 | plt.clf() 212 | 213 | # save plot of first 5 solutions - non-zeros 214 | for i in range(5): 215 | plt.plot(x_stars[i, :-2] >= 0.0001) 216 | plt.savefig("x_stars_zero_one.pdf") 217 | plt.clf() 218 | 219 | # correlation matrix 220 | corrcoef = np.corrcoef(x_stars[:, :-2] >= 0.0001) 221 | print('corrcoef', corrcoef) 222 | plt.imshow(corrcoef) 223 | plt.savefig("corrcoef_zero_one.pdf") 224 | plt.clf() 225 | 226 | for i in range(5): 227 | plt.plot(y_stars[i, :]) 228 | plt.savefig("y_stars.pdf") 229 | plt.clf() 230 | 231 | # save plot of first 5 parameters 232 | for i in range(5): 233 | plt.plot(thetas[i, :]) 234 | plt.savefig("thetas.pdf") 235 | plt.clf() 236 | 237 | 238 | def static_canon(A, rho, rho_x=1, scale=1, factor=True): 239 | """ 240 | Let A have shape (m_orig, n_orig) 241 | min_{x,u,v} u + rho v 242 | s.t. x >= 0 243 | ||Ax-b||_2 <= u 244 | ||x||_2 <= v 245 | 246 | min_{x,u,v} u + rho v 247 | s.t. -x + s_1 == 0 (n) 248 | -(u, Ax) + s_2 == (-b, 0) (m, 1) 249 | -(v, x) + s_3 == 0 (n, 1) 250 | s_1 in R^n+ 251 | s_2 in SOC(m, 1) 252 | s_3 in SOC(n, 1) 253 | 254 | in total: 255 | m = 2 * n_orig + m_orig + 2 constraints 256 | n = n_orig + 2 vars 257 | 258 | Assume that A is fixed from problem to problem 259 | 260 | vars = (x, u, v) 261 | c = (0, 1, rho) 262 | """ 263 | m_orig, n_orig = A.shape 264 | m, n = 2 * n_orig + m_orig + 2, n_orig + 2 265 | A_dense = np.zeros((m, n)) 266 | b = np.zeros(m) 267 | 268 | # constraint 1 269 | A_dense[:n_orig, :n_orig] = -np.eye(n_orig) 270 | 271 | # constraint 2 272 | A_dense[n_orig, n_orig] = -1 273 | A_dense[n_orig + 1:n_orig + m_orig + 1, :n_orig] = -A 274 | 275 | b[n_orig + 1:m_orig + n_orig + 1] = 0 # fill in for b when theta enters -- 276 | # here we can put anything since it will change 277 | 278 | # constraint 3 279 | A_dense[n_orig + m_orig + 1, n_orig + 1] = -1 280 | A_dense[n_orig + m_orig + 2:, :n_orig] = -np.eye(n_orig) 281 | 282 | # create sparse matrix 283 | A_sparse = csc_matrix(A_dense) 284 | 285 | # cones 286 | q_array = [m_orig + 1, n_orig + 1] 287 | cones = dict(z=0, l=n_orig, q=q_array) 288 | cones_array = jnp.array([cones["z"], cones["l"]]) 289 | cones_array = jnp.concatenate([cones_array, jnp.array(cones["q"])]) 290 | 291 | # Quadratic objective 292 | P = np.zeros((n, n)) 293 | P_sparse = csc_matrix(P) 294 | 295 | # Linear objective 296 | c = np.zeros(n) 297 | c[n_orig], c[n_orig + 1] = 1, rho 298 | 299 | # create the matrix M 300 | M = jnp.zeros((n + m, n + m)) 301 | P = P_sparse.todense() 302 | A = A_sparse.todense() 303 | P_jax = jnp.array(P) 304 | A_jax = jnp.array(A) 305 | M = M.at[:n, :n].set(P_jax) 306 | M = M.at[:n, n:].set(A_jax.T) 307 | M = M.at[n:, :n].set(-A_jax) 308 | 309 | # factor for DR splitting 310 | # algo_factor = jsp.linalg.lu_factor(M + jnp.eye(n + m)) 311 | if factor: 312 | algo_factor, scale_vec = get_scaled_vec_and_factor(M, rho_x, scale, m, n, 313 | cones['z']) 314 | # algo_factor = jsp.linalg.lu_factor(M + jnp.eye(n + m)) 315 | else: 316 | algo_factor = None 317 | 318 | A_sparse = csc_matrix(A) 319 | P_sparse = csc_matrix(P) 320 | 321 | out_dict = dict( 322 | M=M, 323 | algo_factor=algo_factor, 324 | cones_dict=cones, 325 | cones_array=cones_array, 326 | A_sparse=A_sparse, 327 | P_sparse=P_sparse, 328 | b=b, 329 | c=c 330 | ) 331 | return out_dict 332 | 333 | 334 | def random_robust_ls(m_orig, n_orig, rho, b_center, b_range, seed=42): 335 | """ 336 | given dimensions, returns a random robust least squares problem 337 | """ 338 | A_orig = (np.random.rand(m_orig, n_orig) * 2) - 1 339 | out = static_canon(A_orig, rho) 340 | c, b = out['c'], out['b'] 341 | P_sparse, A_sparse = out['P_sparse'], out['A_sparse'] 342 | P, A = jnp.array(P_sparse.todense()), jnp.array(A_sparse.todense()) 343 | cones = out['cones_dict'] 344 | 345 | b_rand_np = (2 * np.random.rand(m_orig) - 1) * b_range + b_center 346 | b[n_orig + 1:m_orig + n_orig + 1] = -np.array(b_rand_np) 347 | return P, A, jnp.array(c), jnp.array(b), cones 348 | 349 | 350 | def multiple_random_robust_ls(m_orig, n_orig, rho, b_center, b_range, N, seed=42): 351 | A_orig = (np.random.rand(m_orig, n_orig) * 2) - 1 352 | out = static_canon(A_orig, rho) 353 | c, b = out['c'], out['b'] 354 | c_jax = jnp.array(c) 355 | P_sparse, A_sparse = out['P_sparse'], out['A_sparse'] 356 | P, A = jnp.array(P_sparse.todense()), jnp.array(A_sparse.todense()) 357 | m, n = A.shape 358 | cones = out['cones_dict'] 359 | 360 | q_mat = jnp.zeros((N, m + n)) 361 | theta_mat = jnp.zeros((N, m_orig)) 362 | b_orig = jnp.zeros((N, m_orig)) 363 | for i in range(N): 364 | b_rand_np = (2 * np.random.rand(m_orig) - 1) * b_range + b_center 365 | b_orig = b_orig.at[i, :].set(b_rand_np) 366 | 367 | b[n_orig + 1:m_orig + n_orig + 1] = -np.array(b_rand_np) 368 | b_jax = jnp.array(b) 369 | 370 | theta_mat = theta_mat.at[i, :].set(b_rand_np) 371 | q_mat = q_mat.at[i, :n].set(c_jax) 372 | q_mat = q_mat.at[i, n:].set(b_jax) 373 | 374 | return P, A, cones, q_mat, theta_mat, A_orig, b_orig 375 | -------------------------------------------------------------------------------- /tests/test_algo_steps.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax.numpy as jnp 4 | import jax.scipy as jsp 5 | import numpy as np 6 | import scs 7 | from scipy.sparse import csc_matrix 8 | 9 | from l2ws.algo_steps import ( 10 | create_M, 11 | create_projection_fn, 12 | get_scale_vec, 13 | k_steps_eval_scs, 14 | k_steps_train_scs, 15 | lin_sys_solve, 16 | ) 17 | from l2ws.examples.robust_ls import random_robust_ls 18 | from l2ws.examples.sparse_pca import multiple_random_sparse_pca 19 | from l2ws.scs_problem import scs_jax 20 | 21 | 22 | def test_train_vs_eval(): 23 | # get a random robust least squares problem 24 | m_orig, n_orig = 20, 25 25 | rho = 1 26 | b_center, b_range = 1, 1 27 | P, A, c, b, cones = random_robust_ls(m_orig, n_orig, rho, b_center, b_range) 28 | m, n = A.shape 29 | zero_cone_size = cones['z'] 30 | proj = create_projection_fn(cones, n) 31 | k = 20 32 | z0 = jnp.ones(m + n + 1) 33 | M = create_M(P, A) 34 | 35 | rho_x, scale = 1e-5, .1 36 | scale_vec = get_scale_vec(rho_x, scale, m, n, zero_cone_size) 37 | scale_vec_diag = jnp.diag(scale_vec) 38 | factor = jsp.linalg.lu_factor(M + scale_vec_diag) 39 | 40 | q = jnp.concatenate([c, b]) 41 | q_r = lin_sys_solve(factor, q) 42 | 43 | train_out = k_steps_train_scs(k, z0, q_r, factor, supervised=False, 44 | z_star=None, proj=proj, jit=False, hsde=True, 45 | m=m, n=n, zero_cone_size=zero_cone_size, rho_x=rho_x, scale=scale) 46 | z_final_train, iter_losses_train = train_out 47 | 48 | eval_out = k_steps_eval_scs(k, z0, q_r, factor, proj, P, A, c, b, jit=True, 49 | hsde=True, zero_cone_size=zero_cone_size, rho_x=rho_x, scale=scale) 50 | z_final_eval, iter_losses_eval = eval_out[:2] 51 | assert jnp.linalg.norm(iter_losses_train - iter_losses_eval) <= 1e-10 52 | assert jnp.linalg.norm(z_final_eval - z_final_train) <= 1e-10 53 | 54 | 55 | def test_jit_speed(): 56 | # problem setup 57 | m_orig, n_orig = 30, 40 58 | rho = 1 59 | b_center, b_range = 1, 1 60 | P, A, c, b, cones = random_robust_ls(m_orig, n_orig, rho, b_center, b_range) 61 | m, n = A.shape 62 | max_iters = 1000 63 | 64 | # fix warm start 65 | x_ws = np.ones(n) 66 | y_ws = np.ones(m) 67 | s_ws = np.zeros(m) 68 | 69 | # solve with jit 70 | t0_jit = time.time() 71 | data = dict(P=P, A=A, c=c, b=b, cones=cones, x=x_ws, y=y_ws, s=s_ws) 72 | sol_hsde = scs_jax(data, hsde=True, iters=max_iters) 73 | x_jit, y_jit, s_jit = sol_hsde['x'], sol_hsde['y'], sol_hsde['s'] 74 | fp_res_jit = sol_hsde['fixed_point_residuals'] 75 | t1_jit = time.time() 76 | jit_time = t1_jit - t0_jit 77 | 78 | # solve without jit 79 | t0_non_jit = time.time() 80 | data = dict(P=P, A=A, c=c, b=b, cones=cones, x=x_ws, y=y_ws, s=s_ws) 81 | sol_hsde = scs_jax(data, hsde=True, iters=max_iters) 82 | x_non_jit, y_non_jit, s_non_jit = sol_hsde['x'], sol_hsde['y'], sol_hsde['s'] 83 | fp_res_non_jit = sol_hsde['fixed_point_residuals'] 84 | t1_non_jit = time.time() 85 | non_jit_time = t1_non_jit - t0_non_jit 86 | 87 | assert jit_time - non_jit_time > 0 88 | # assert jnp.all(jnp.diff(fp_res_jit) < 1e-10) 89 | # assert jnp.all(jnp.diff(fp_res_non_jit) < 1e-10) 90 | 91 | # these should match to machine precision 92 | assert jnp.linalg.norm(x_jit - x_non_jit) < 1e-10 93 | assert jnp.linalg.norm(y_jit - y_non_jit) < 1e-10 94 | assert jnp.linalg.norm(s_jit - s_non_jit) < 1e-10 95 | 96 | # make sure the residuals start high and end very low 97 | assert fp_res_jit[0] > .1 and fp_res_non_jit[0] > .1 98 | assert fp_res_jit[-1] < 1e-6 and fp_res_non_jit[-1] > 1e-16 99 | assert fp_res_jit[-1] < 1e-7 and fp_res_non_jit[-1] > 1e-16 100 | 101 | 102 | def test_hsde_socp_robust_ls(): 103 | """ 104 | tests to make sure hsde returns the same solution as the non-hsde 105 | tests socp of different cone sizes also (there are 2 SOCs) 106 | """ 107 | # get a random robust least squares problem 108 | m_orig, n_orig = 50, 55 109 | rho = 1 110 | b_center, b_range = 1, 1 111 | P, A, c, b, cones = random_robust_ls(m_orig, n_orig, rho, b_center, b_range) 112 | 113 | data = dict(P=P, A=A, c=c, b=b, cones=cones) 114 | iters = 400 115 | 116 | sol_std = scs_jax(data, hsde=False, iters=iters, rho_x=1, scale=1, alpha=1) 117 | x_std, y_std, s_std = sol_std['x'], sol_std['y'], sol_std['s'] 118 | fp_res_std = sol_std['fixed_point_residuals'] 119 | 120 | sol_hsde = scs_jax(data, hsde=True, iters=iters) 121 | x_hsde, y_hsde, s_hsde = sol_hsde['x'], sol_hsde['y'], sol_hsde['s'] 122 | fp_res_hsde = sol_hsde['fixed_point_residuals'] 123 | 124 | # import pdb 125 | # pdb.set_trace() 126 | 127 | assert jnp.linalg.norm(x_hsde - x_std) < 1e-3 128 | assert jnp.linalg.norm(y_hsde - y_std) < 1e-3 129 | assert jnp.linalg.norm(s_hsde - s_std) < 1e-3 130 | 131 | # make sure the residuals start high and end very low 132 | assert fp_res_std[0] > .1 and fp_res_std[0] > .1 133 | # assert fp_res_std[-1] < 1e-4 and fp_res_std[-1] > 1e-16 134 | assert fp_res_hsde[-1] < 1e-4 and fp_res_hsde[-1] > 1e-16 135 | 136 | 137 | # def test_c_socp_robust_kalman_filter_relaxation(): 138 | # """ 139 | # tests to make sure hsde returns the same solution as the non-hsde 140 | # tests socp of different cone sizes also (there are 2 SOCs) 141 | # """ 142 | # # get a random robust least squares problem 143 | # P, A, cones, q_mat, theta_mat = multiple_random_robust_kalman( 144 | # N=5, T=50, gamma=.05, dt=.5, mu=2, rho=2, sigma=20, p=0, w_noise_var=.1, y_noise_var=.1) 145 | # m, n = A.shape 146 | 147 | # c, b = q_mat[0, :n], q_mat[0, n:] 148 | # data = dict(P=P, A=A, c=c, b=b, cones=cones) 149 | 150 | # # sol_std = scs_jax(data, hsde=False, iters=200) 151 | # # x_std, y_std, s_std = sol_std['x'], sol_std['y'], sol_std['s'] 152 | # # fp_res_std = sol_std['fixed_point_residuals'] 153 | 154 | # # sol_hsde = scs_jax(data, hsde=True, iters=200) 155 | # # x_hsde, y_hsde, s_hsde = sol_hsde['x'], sol_hsde['y'], sol_hsde['s'] 156 | # # fp_res_hsde = sol_hsde['fixed_point_residuals'] 157 | 158 | # # fix warm start 159 | # x_ws = np.ones(n) 160 | # y_ws = np.ones(m) 161 | # s_ws = np.zeros(m) 162 | # max_iters = 30 163 | 164 | # # pick algorithm hyperparameters 165 | # rho_x = 1 166 | # scale = 1 167 | # alpha = 1 168 | 169 | # # solve in C 170 | # P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 171 | # c_np, b_np = np.array(c), np.array(b) 172 | # c_data = dict(P=P_sparse, A=A_sparse, c=c_np, b=b_np) 173 | # solver = scs.SCS(c_data, 174 | # cones, 175 | # normalize=False, 176 | # scale=scale, 177 | # adaptive_scale=False, 178 | # rho_x=rho_x, 179 | # alpha=alpha, 180 | # acceleration_lookback=0, 181 | # max_iters=max_iters) 182 | 183 | # sol = solver.solve(warm_start=True, x=x_ws, y=y_ws, s=s_ws) 184 | # x_c = jnp.array(sol['x']) 185 | # y_c = jnp.array(sol['y']) 186 | # s_c = jnp.array(sol['s']) 187 | 188 | # # solve with our jax implementation 189 | # data = dict(P=P, A=A, c=c, b=b, cones=cones, x=x_ws, y=y_ws, s=s_ws) 190 | # sol_hsde = scs_jax(data, hsde=True, iters=max_iters, jit=False, 191 | # rho_x=rho_x, scale=scale, alpha=alpha, plot=False) 192 | # x_jax, y_jax, s_jax = sol_hsde['x'], sol_hsde['y'], sol_hsde['s'] 193 | # fp_res_hsde = sol_hsde['fixed_point_residuals'] 194 | 195 | # # these should match to machine precision 196 | # assert jnp.linalg.norm(x_jax - x_c) < 1e-10 197 | # assert jnp.linalg.norm(y_jax - y_c) < 1e-10 198 | # assert jnp.linalg.norm(s_jax - s_c) < 1e-10 199 | 200 | # # make sure the residuals start high and end very low 201 | # assert fp_res_hsde[0] > 10 202 | # assert fp_res_hsde[-1] < .5 and fp_res_hsde[-1] > 1e-16 203 | 204 | # import pdb 205 | # pdb.set_trace() 206 | 207 | 208 | 209 | def test_c_vs_jax_sdp(): 210 | """ 211 | check iterate returned by x vs one returned by jax with warm-start is the same 212 | """ 213 | # get a random sparse pca problem 214 | P, A, cones, q_mat, theta_mat_jax, A_tensor = multiple_random_sparse_pca( 215 | n_orig=30, k=10, r=10, N=5) 216 | m, n = A.shape 217 | 218 | max_iters = 10 219 | 220 | # fix warm start 221 | x_ws = np.ones(n) 222 | y_ws = np.ones(m) 223 | s_ws = np.zeros(m) 224 | 225 | # solve in C 226 | c, b = q_mat[0, :n], q_mat[0, n:] 227 | P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 228 | c_np, b_np = np.array(c), np.array(b) 229 | c_data = dict(P=P_sparse, A=A_sparse, c=c_np, b=b_np) 230 | solver = scs.SCS(c_data, 231 | cones, 232 | normalize=False, 233 | scale=.1, 234 | adaptive_scale=False, 235 | rho_x=.01, 236 | alpha=1.6, 237 | acceleration_lookback=0, 238 | max_iters=max_iters) 239 | 240 | sol = solver.solve(warm_start=True, x=x_ws, y=y_ws, s=s_ws) 241 | x_c = jnp.array(sol['x']) 242 | y_c = jnp.array(sol['y']) 243 | s_c = jnp.array(sol['s']) 244 | 245 | # solve with our jax implementation 246 | data = dict(P=P, A=A, c=c, b=b, cones=cones, x=x_ws, y=y_ws, s=s_ws) 247 | sol_hsde = scs_jax(data, hsde=True, iters=max_iters, alpha=1.6, scale=.1, rho_x=.01) 248 | x_jax, y_jax, s_jax = sol_hsde['x'], sol_hsde['y'], sol_hsde['s'] 249 | # fp_res_hsde = sol_hsde['fixed_point_residuals'] 250 | 251 | # these should match to machine precision 252 | assert jnp.linalg.norm(x_jax - x_c) < 1e-6 253 | assert jnp.linalg.norm(y_jax - y_c) < 1e-6 254 | assert jnp.linalg.norm(s_jax - s_c) < 1e-6 255 | # assert jnp.all(jnp.diff(fp_res_hsde) < 0) 256 | 257 | 258 | def test_c_vs_jax_socp(): 259 | """ 260 | check iterate returned by x vs one returned by jax with war-start is the same 261 | """ 262 | # get a random robust least squares problem 263 | m_orig, n_orig = 30, 40 264 | rho = 1 265 | b_center, b_range = 1, 1 266 | P, A, c, b, cones = random_robust_ls(m_orig, n_orig, rho, b_center, b_range) 267 | m, n = A.shape 268 | 269 | max_iters = 10 270 | 271 | # fix warm start 272 | x_ws = np.ones(n) 273 | y_ws = np.ones(m) 274 | s_ws = np.zeros(m) 275 | 276 | # select hyperparameters 277 | scale = 10 278 | rho_x = 1e-3 279 | alpha = 1.8 280 | 281 | # solve in C 282 | P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 283 | c_np, b_np = np.array(c), np.array(b) 284 | c_data = dict(P=P_sparse, A=A_sparse, c=c_np, b=b_np) 285 | solver = scs.SCS(c_data, 286 | cones, 287 | normalize=False, 288 | scale=scale, 289 | adaptive_scale=False, 290 | rho_x=rho_x, 291 | alpha=alpha, 292 | acceleration_lookback=0, 293 | max_iters=max_iters) 294 | 295 | sol = solver.solve(warm_start=True, x=x_ws, y=y_ws, s=s_ws) 296 | x_c = jnp.array(sol['x']) 297 | y_c = jnp.array(sol['y']) 298 | s_c = jnp.array(sol['s']) 299 | 300 | # solve with our jax implementation 301 | data = dict(P=P, A=A, c=c, b=b, cones=cones, x=x_ws, y=y_ws, s=s_ws) 302 | sol_hsde = scs_jax(data, hsde=True, iters=max_iters, scale=scale, rho_x=rho_x, alpha=alpha) 303 | x_jax, y_jax, s_jax = sol_hsde['x'], sol_hsde['y'], sol_hsde['s'] 304 | sol_hsde['fixed_point_residuals'] 305 | 306 | # these should match to machine precision 307 | assert jnp.linalg.norm(x_jax - x_c) < 1e-10 308 | assert jnp.linalg.norm(y_jax - y_c) < 1e-10 309 | assert jnp.linalg.norm(s_jax - s_c) < 1e-10 310 | # assert jnp.all(jnp.diff(fp_res_hsde) < 0) 311 | 312 | 313 | def test_warm_start_from_opt(): 314 | """ 315 | this is the only test that uses a different warm-start from zero for s 316 | it's important for the non-identiy DR scaling 317 | """ 318 | m_orig, n_orig = 30, 40 319 | rho = 1 320 | b_center, b_range = 1, 1 321 | P, A, c, b, cones = random_robust_ls(m_orig, n_orig, rho, b_center, b_range) 322 | m, n = A.shape 323 | 324 | max_iters = 1 325 | 326 | # fix warm start 327 | x_ws = np.ones(n) 328 | y_ws = np.ones(m) 329 | s_ws = np.zeros(m) 330 | 331 | # solve in C to get close to opt 332 | P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 333 | c_np, b_np = np.array(c), np.array(b) 334 | c_data = dict(P=P_sparse, A=A_sparse, c=c_np, b=b_np) 335 | solver_opt = scs.SCS(c_data, 336 | cones, 337 | normalize=False, 338 | scale=1, 339 | adaptive_scale=False, 340 | rho_x=1, 341 | acceleration_lookback=0, 342 | max_iters=1000) 343 | 344 | sol = solver_opt.solve(warm_start=True, x=x_ws, y=y_ws, s=s_ws) 345 | x_opt = jnp.array(sol['x']) 346 | y_opt = jnp.array(sol['y']) 347 | s_opt = jnp.array(sol['s']) 348 | 349 | # set hyperparameters 350 | rho_x = .1 351 | alpha = 1.5 352 | scale = 1.01 353 | 354 | # warm start scs from opt 355 | solver = scs.SCS(c_data, 356 | cones, 357 | normalize=False, 358 | scale=scale, 359 | adaptive_scale=False, 360 | rho_x=rho_x, 361 | alpha=alpha, 362 | acceleration_lookback=0, 363 | max_iters=max_iters, 364 | eps_abs=1e-12, 365 | eps_rel=1e-12,) 366 | sol = solver.solve(warm_start=True, x=np.array(x_opt), y=np.array(y_opt), s=np.array(s_opt)) 367 | x_c = jnp.array(sol['x']) 368 | y_c = jnp.array(sol['y']) 369 | s_c = jnp.array(sol['s']) 370 | 371 | # warm start our implementation from opt 372 | data = dict(P=P, A=A, c=c, b=b, cones=cones, x=x_opt, y=y_opt, s=s_opt) 373 | sol_hsde = scs_jax(data, hsde=True, jit=False, iters=max_iters, 374 | rho_x=rho_x, scale=scale, alpha=alpha) 375 | x_jax, y_jax, s_jax = sol_hsde['x'], sol_hsde['y'], sol_hsde['s'] 376 | sol_hsde['fixed_point_residuals'] 377 | 378 | # these should match to machine precision 379 | assert jnp.linalg.norm(x_jax - x_c) < 1e-12 380 | assert jnp.linalg.norm(y_jax - y_c) < 1e-12 381 | assert jnp.linalg.norm(s_jax - s_c) < 1e-12 382 | 383 | -------------------------------------------------------------------------------- /l2ws/examples/solve_script.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import time 6 | import jax.numpy as jnp 7 | from l2ws.scs_problem import SCSinstance 8 | import pdb 9 | import cvxpy as cp 10 | from scipy.sparse import csc_matrix, save_npz, load_npz 11 | import osqp 12 | 13 | 14 | plt.rcParams.update( 15 | { 16 | "text.usetex": True, 17 | "font.family": "serif", 18 | "font.size": 16, 19 | } 20 | ) 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | def save_results_dynamic(output_filename, theta_mat, z_stars, q_mat, factors, ref_traj_tensor=None): 25 | """ 26 | saves the results from the setup phase 27 | saves q_mat in csc_matrix form to save space 28 | everything else is saved as a npz file 29 | also plots z_stars, q, thetas 30 | """ 31 | 32 | # save theta_mat, z_stars, factors 33 | # needs to save factors[0] and factors[1] separately 34 | t0 = time.time() 35 | if ref_traj_tensor is None: 36 | jnp.savez( 37 | output_filename, 38 | thetas=jnp.array(theta_mat), 39 | z_stars=z_stars#, 40 | # factors0=factors[0], 41 | # factors1=factors[1] 42 | ) 43 | else: 44 | jnp.savez( 45 | output_filename, 46 | thetas=jnp.array(theta_mat), 47 | z_stars=z_stars, 48 | ref_traj_tensor=ref_traj_tensor 49 | #, 50 | # factors0=factors[0], 51 | # factors1=factors[1], 52 | # ref_traj_tensor=ref_traj_tensor 53 | ) 54 | # ref_traj_tensor has shape (num_rollouts, num_goals, goal_length) 55 | t1 = time.time() 56 | print('time to save non-sparse', t1 - t0) 57 | 58 | # save the q_mat but as a sparse object 59 | t2 = time.time() 60 | q_mat_sparse = csc_matrix(q_mat) 61 | save_npz(f"{output_filename}_q", q_mat_sparse) 62 | t3 = time.time() 63 | print('time to save non-sparse', t3 - t2) 64 | 65 | # save plot of first 5 solutions 66 | for i in range(5): 67 | plt.plot(z_stars[i, :]) 68 | plt.savefig("z_stars.pdf") 69 | plt.clf() 70 | 71 | # save plot of first 5 q 72 | for i in range(5): 73 | plt.plot(q_mat[i, :]) 74 | plt.savefig("q.pdf") 75 | plt.clf() 76 | 77 | # save plot of first 5 parameters 78 | for i in range(5): 79 | plt.plot(theta_mat[i, :]) 80 | plt.savefig("thetas.pdf") 81 | plt.clf() 82 | 83 | 84 | def load_results_dynamic(output_filename): 85 | """ 86 | returns the saved results from the corresponding save_results_dynamic function 87 | """ 88 | q_mat_sparse = load_npz(f"{output_filename}_q") 89 | loaded_obj = jnp.load(output_filename) 90 | theta_mat, z_stars = loaded_obj['thetas'], loaded_obj['z_stars'] 91 | factors0, factors1 = loaded_obj['factors0'], loaded_obj['factors1'] 92 | factors = (factors0, factors1) 93 | return theta_mat, z_stars, q_mat_sparse, factors 94 | 95 | 96 | 97 | def direct_osqp_setup_script(theta_mat, q_mat, P, A, output_filename, z_stars=None): 98 | # def solve_many_probs_cvxpy(A, b_mat, lambd): 99 | """ 100 | solves many lasso problems where each problem has a different b vector 101 | """ 102 | # import pdb 103 | # pdb.set_trace() 104 | m, n = A.shape 105 | N = q_mat.shape[0] 106 | 107 | # P, A 108 | osqp_solver = osqp.OSQP() 109 | P_sparse, A_sparse = csc_matrix(np.array(P)), csc_matrix(np.array(A)) 110 | c, l, u = np.zeros(n), np.zeros(m), np.zeros(m) 111 | osqp_solver.setup(P=P_sparse, q=c, A=A_sparse, l=l, u=u, 112 | max_iter=2000, verbose=True, eps_abs=1e-5, eps_rel=1e-5) 113 | 114 | solve_times = np.zeros(N) 115 | if z_stars is None: 116 | z_stars = jnp.zeros((N, n + 2 * m)) 117 | objvals = jnp.zeros((N)) 118 | x_stars = [] 119 | y_stars = [] 120 | w_stars = [] 121 | for i in range(N): 122 | log.info(f"solving problem number {i}") 123 | 124 | # setup c, l, u 125 | c, l, u = q_mat[i, :n], q_mat[i, n:n + m], q_mat[i, n + m:] 126 | osqp_solver.update(q=np.array(c)) 127 | osqp_solver.update(l=np.array(l), u=np.array(u)) 128 | 129 | # solve with osqp 130 | results = osqp_solver.solve() 131 | 132 | # set the solve time in seconds 133 | solve_times[i] = results.info.solve_time 134 | # solve_iters[i] = results.info.iter 135 | 136 | # set the results 137 | # x_sols = x_sols.at[i, :].set(results.x) 138 | # y_sols = y_sols.at[i, :].set(results.y) 139 | 140 | x_stars.append(results.x) 141 | y_stars.append(results.y) 142 | w_stars.append(A @ results.x) 143 | # z_stars = z_stars.at[i, :n].set(results.x) 144 | # z_stars = z_stars.at[i, n:n + m].set(results.y) 145 | # z_stars = z_stars.at[i, n + m:].set(A @ results.x) 146 | # import pdb 147 | # pdb.set_trace() 148 | 149 | 150 | # objvals = objvals.at[i].set(prob.value) 151 | 152 | # x_star = jnp.array(x.value) 153 | # y_star = jnp.array(constraints[0].dual_value) 154 | # z_star = jnp.concatenate([x_star, y_star]) 155 | # z_stars = z_stars.at[i, :].set(z_star) 156 | # solve_times[i] = prob.solver_stats.solve_time 157 | if i % 1000 == 0: 158 | # save the data 159 | log.info("saving data...") 160 | t0 = time.time() 161 | jnp.savez( 162 | output_filename, 163 | thetas=jnp.array(theta_mat), 164 | z_stars=z_stars, 165 | q_mat=q_mat 166 | ) 167 | z_stars = jnp.hstack([jnp.stack(x_stars), jnp.stack(y_stars), jnp.stack(w_stars)]) 168 | 169 | 170 | 171 | # save the data 172 | log.info("final saving final data...") 173 | t0 = time.time() 174 | jnp.savez( 175 | output_filename, 176 | thetas=jnp.array(theta_mat), 177 | z_stars=z_stars, 178 | q_mat=q_mat 179 | ) 180 | 181 | # save solve times 182 | df_solve_times = pd.DataFrame(solve_times, columns=['solve_times']) 183 | df_solve_times.to_csv('solve_times.csv') 184 | 185 | save_time = time.time() 186 | log.info(f"finished saving final data... took {save_time-t0}'") 187 | 188 | # save plot of first 5 solutions 189 | for i in range(5): 190 | plt.plot(z_stars[i, :]) 191 | plt.savefig("z_stars.pdf") 192 | plt.clf() 193 | 194 | # save plot of first 5 q 195 | for i in range(5): 196 | plt.plot(q_mat[i, :]) 197 | plt.savefig("q.pdf") 198 | plt.clf() 199 | 200 | 201 | # save plot of first 5 parameters 202 | for i in range(5): 203 | plt.plot(theta_mat[i, :]) 204 | plt.savefig("thetas.pdf") 205 | plt.clf() 206 | 207 | return z_stars 208 | 209 | 210 | def osqp_setup_script(theta_mat, q_mat, P, A, output_filename, z_stars=None): 211 | # def solve_many_probs_cvxpy(A, b_mat, lambd): 212 | """ 213 | solves many lasso problems where each problem has a different b vector 214 | """ 215 | m, n = A.shape 216 | N = q_mat.shape[0] 217 | 218 | # setup cvxpy 219 | x, w = cp.Variable(n), cp.Variable(m) 220 | c_param, l_param, u_param = cp.Parameter(n), cp.Parameter(m), cp.Parameter(m) 221 | constraints = [A @ x == w, l_param <= w, w <= u_param] 222 | prob = cp.Problem(cp.Minimize(.5 * cp.quad_form(x, P) + c_param @ x), constraints) 223 | 224 | solve_times = np.zeros(N) 225 | if z_stars is None: 226 | z_stars = jnp.zeros((N, n + m)) 227 | objvals = jnp.zeros((N)) 228 | for i in range(N): 229 | log.info(f"solving problem number {i}") 230 | 231 | # solve with cvxpy 232 | c_param.value = np.array(q_mat[i, :n]) 233 | l_param.value = np.array(q_mat[i, n:n + m]) 234 | u_param.value = np.array(q_mat[i, n + m:]) 235 | prob.solve(verbose=True, solver=cp.OSQP, eps_abs=1e-03, eps_rel=1e-03) 236 | objvals = objvals.at[i].set(prob.value) 237 | 238 | x_star = jnp.array(x.value) 239 | y_star = jnp.array(constraints[0].dual_value) 240 | z_star = jnp.concatenate([x_star, y_star]) 241 | z_stars = z_stars.at[i, :].set(z_star) 242 | solve_times[i] = prob.solver_stats.solve_time 243 | 244 | # save the data 245 | log.info("final saving final data...") 246 | t0 = time.time() 247 | jnp.savez( 248 | output_filename, 249 | thetas=jnp.array(theta_mat), 250 | z_stars=z_stars, 251 | q_mat=q_mat 252 | ) 253 | 254 | # save solve times 255 | df_solve_times = pd.DataFrame(solve_times, columns=['solve_times']) 256 | df_solve_times.to_csv('solve_times.csv') 257 | 258 | save_time = time.time() 259 | log.info(f"finished saving final data... took {save_time-t0}'") 260 | 261 | # save plot of first 5 solutions 262 | for i in range(5): 263 | plt.plot(z_stars[i, :]) 264 | plt.savefig("z_stars.pdf") 265 | plt.clf() 266 | 267 | # save plot of first 5 q 268 | for i in range(5): 269 | plt.plot(q_mat[i, :]) 270 | plt.savefig("q.pdf") 271 | plt.clf() 272 | 273 | 274 | # save plot of first 5 parameters 275 | for i in range(5): 276 | plt.plot(theta_mat[i, :]) 277 | plt.savefig("thetas.pdf") 278 | plt.clf() 279 | 280 | 281 | def ista_setup_script(b_mat, A, lambd, output_filename): 282 | # def solve_many_probs_cvxpy(A, b_mat, lambd): 283 | """ 284 | solves many lasso problems where each problem has a different b vector 285 | """ 286 | m, n = A.shape 287 | N = b_mat.shape[0] 288 | z, b_param = cp.Variable(n), cp.Parameter(m) 289 | prob = cp.Problem(cp.Minimize(.5 * cp.sum_squares(np.array(A) @ z - b_param) + lambd * cp.norm(z, p=1))) 290 | z_stars = jnp.zeros((N, n)) 291 | objvals = jnp.zeros((N)) 292 | solve_times = np.zeros(N) 293 | for i in range(N): 294 | print('solving problem', i) 295 | b_param.value = np.array(b_mat[i, :]) 296 | prob.solve(verbose=True) 297 | objvals = objvals.at[i].set(prob.value) 298 | z_stars = z_stars.at[i, :].set(jnp.array(z.value)) 299 | solve_times[i] = prob.solver_stats.solve_time 300 | 301 | # save the data 302 | log.info("final saving final data...") 303 | t0 = time.time() 304 | jnp.savez( 305 | output_filename, 306 | thetas=jnp.array(b_mat), 307 | z_stars=z_stars, 308 | ) 309 | 310 | # save solve times 311 | df_solve_times = pd.DataFrame(solve_times, columns=['solve_times']) 312 | df_solve_times.to_csv('solve_times.csv') 313 | 314 | save_time = time.time() 315 | log.info(f"finished saving final data... took {save_time-t0}'") 316 | 317 | # save plot of first 5 solutions 318 | for i in range(5): 319 | plt.plot(z_stars[i, :]) 320 | plt.savefig("z_stars.pdf") 321 | plt.clf() 322 | 323 | 324 | # save plot of first 5 parameters 325 | for i in range(5): 326 | plt.plot(b_mat[i, :]) 327 | plt.savefig("thetas.pdf") 328 | plt.clf() 329 | 330 | 331 | def gd_setup_script(c_mat, P, output_filename): 332 | """ 333 | solves many gd problems where each problem has a different b vector 334 | """ 335 | # m, n = A.shape 336 | # N = b_mat.shape[0] 337 | # z, b_param = cp.Variable(n), cp.Parameter(m) 338 | # prob = cp.Problem(cp.Minimize(.5 * cp.sum_squares(np.array(A) @ z - b_param) + lambd * cp.norm(z, p=1))) 339 | # z_stars = jnp.zeros((N, n)) 340 | # objvals = jnp.zeros((N)) 341 | # solve_times = np.zeros(N) 342 | # for i in range(N): 343 | # print('solving problem', i) 344 | # b_param.value = np.array(b_mat[i, :]) 345 | # prob.solve(verbose=True) 346 | # objvals = objvals.at[i].set(prob.value) 347 | # z_stars = z_stars.at[i, :].set(jnp.array(z.value)) 348 | # solve_times[i] = prob.solver_stats.solve_time 349 | P_inv = jnp.linalg.inv(P) 350 | z_stars = (-P_inv @ c_mat.T).T 351 | 352 | # save the data 353 | log.info("final saving final data...") 354 | t0 = time.time() 355 | jnp.savez( 356 | output_filename, 357 | thetas=jnp.array(c_mat), 358 | z_stars=z_stars, 359 | ) 360 | 361 | # save solve times 362 | # df_solve_times = pd.DataFrame(solve_times, columns=['solve_times']) 363 | # df_solve_times.to_csv('solve_times.csv') 364 | 365 | save_time = time.time() 366 | log.info(f"finished saving final data... took {save_time-t0}'") 367 | 368 | # save plot of first 5 solutions 369 | for i in range(5): 370 | plt.plot(z_stars[i, :]) 371 | plt.savefig("z_stars.pdf") 372 | plt.clf() 373 | 374 | 375 | # save plot of first 5 parameters 376 | for i in range(5): 377 | plt.plot(c_mat[i, :]) 378 | plt.savefig("thetas.pdf") 379 | plt.clf() 380 | 381 | 382 | def setup_script(q_mat, theta_mat, solver, data, cones_dict, output_filename, solve=True): 383 | N = q_mat.shape[0] 384 | m, n = data['A'].shape 385 | 386 | solve_times = np.zeros(N) 387 | x_stars = jnp.zeros((N, n)) 388 | y_stars = jnp.zeros((N, m)) 389 | s_stars = jnp.zeros((N, m)) 390 | # q_mat = jnp.zeros((N, m + n)) 391 | # scs_instances = [] 392 | 393 | P_sparse, A_sparse = data['P'], data['A'] 394 | if solve: 395 | for i in range(N): 396 | log.info(f"solving problem number {i}") 397 | 398 | # update 399 | b = np.array(q_mat[i, n:]) 400 | c = np.array(q_mat[i, :n]) 401 | 402 | # manual canon 403 | manual_canon_dict = { 404 | "P": P_sparse, 405 | "A": A_sparse, 406 | "b": b, 407 | "c": c, 408 | "cones": cones_dict, 409 | } 410 | 411 | scs_instance = SCSinstance(manual_canon_dict, solver, manual_canon=True) 412 | 413 | # scs_instances.append(scs_instance) 414 | x_stars = x_stars.at[i, :].set(scs_instance.x_star) 415 | y_stars = y_stars.at[i, :].set(scs_instance.y_star) 416 | s_stars = s_stars.at[i, :].set(scs_instance.s_star) 417 | q_mat = q_mat.at[i, :].set(scs_instance.q) 418 | solve_times[i] = scs_instance.solve_time 419 | 420 | if i % 1000 == 0: 421 | log.info(f"saving final data... after solving problem number {i}") 422 | jnp.savez( 423 | output_filename, 424 | thetas=theta_mat, 425 | x_stars=x_stars, 426 | y_stars=y_stars, 427 | s_stars=s_stars, 428 | q_mat=q_mat 429 | ) 430 | # save the data 431 | log.info("final saving final data...") 432 | t0 = time.time() 433 | jnp.savez( 434 | output_filename, 435 | thetas=theta_mat, 436 | x_stars=x_stars, 437 | y_stars=y_stars, 438 | s_stars=s_stars, 439 | q_mat=q_mat 440 | ) 441 | 442 | # save solve times 443 | df_solve_times = pd.DataFrame(solve_times, columns=['solve_times']) 444 | df_solve_times.to_csv('solve_times.csv') 445 | 446 | save_time = time.time() 447 | log.info(f"finished saving final data... took {save_time-t0}'") 448 | 449 | # save plot of first 5 solutions 450 | for i in range(5): 451 | plt.plot(x_stars[i, :]) 452 | plt.savefig("x_stars.pdf") 453 | plt.clf() 454 | 455 | for i in range(5): 456 | plt.plot(y_stars[i, :]) 457 | plt.savefig("y_stars.pdf") 458 | plt.clf() 459 | 460 | # save plot of first 5 parameters 461 | for i in range(5): 462 | plt.plot(theta_mat[i, :]) 463 | plt.savefig("thetas.pdf") 464 | plt.clf() 465 | 466 | return x_stars, y_stars, s_stars 467 | --------------------------------------------------------------------------------