├── README.md ├── aux_data └── rad_figure.png ├── diffusionsim_experiments ├── README_diffusionsim.md ├── diffusionsim_plotcurves.py └── diffusionsimsimple.py ├── nn_experiments ├── README_ff.md ├── README_rnn.md ├── cifar_launch.py ├── cifarffcommands.txt ├── data.py ├── layers.py ├── mnist_launch.py ├── mnistffcommands.txt ├── models.py ├── plot_cifar.py ├── plot_mnist.py ├── rnn_mnist_launch.py ├── rnn_models.py ├── rnn_plottingscripts.py ├── rnn_plottingscripts_appendix.py ├── rnn_train_and_eval.py ├── train_and_eval.py └── utils.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Randomized Automatic Differentiation 2 | 3 | 4 | 5 | Paper: https://arxiv.org/abs/2007.10412 6 | 7 | This repository contains code for all three types of experiments used in the paper: 8 | 1. Feedforward network experiments described in ``nn_experiments/README_ff.md`` 9 | 2. Recurrent network experiments described in ``nn_experiments/README_rnn.md`` 10 | 3. PDE experiments described in ``diffusionsim_experiments/README_diffusionsim.md`` 11 | 12 | For the neural network experiments, the layers using randomized autodiff are in ``nn_experiments/layers.py``. 13 | 14 | ## Citation 15 | To cite this work, please use 16 | ``` 17 | @misc{oktay2020randomized, 18 | title={Randomized Automatic Differentiation}, 19 | author={Deniz Oktay and Nick McGreivy and Joshua Aduol and Alex Beatson and Ryan P. Adams}, 20 | year={2020}, 21 | eprint={2007.10412}, 22 | archivePrefix={arXiv}, 23 | primaryClass={cs.LG} 24 | } 25 | ``` 26 | 27 | ## Authors: 28 | * [Deniz Oktay](http://www.cs.princeton.edu/~doktay/) 29 | * [Nick McGreivy](https://scholar.princeton.edu/nickmcgreivy/home) 30 | * [Joshua Aduol]() 31 | * [Alex Beatson](https://www.cs.princeton.edu/~abeatson/) 32 | * [Ryan P. Adams](https://www.cs.princeton.edu/~rpa/) 33 | -------------------------------------------------------------------------------- /aux_data/rad_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrincetonLIPS/RandomizedAutomaticDifferentiation/bff4483434d72612eda680c5a33312ab5f834f30/aux_data/rad_figure.png -------------------------------------------------------------------------------- /diffusionsim_experiments/README_diffusionsim.md: -------------------------------------------------------------------------------- 1 | # Running diffusion sim code 2 | 3 | ## Requirements 4 | 5 | See requirements.txt file for full pip requirements.\ 6 | Major requirements: jax, jaxlib 7 | 8 | ## Reproducing results in paper 9 | 10 | ``mkdir diff_data``\ 11 | ``cp plot_curves.py diff_data/``\ 12 | ``python diffusionsimsimple.py --keep_frac 1.0 --filename 1.0 --num_opt 800``\ 13 | ``python diffusionsimsimple.py --keep_frac 0.1 --filename 0.1 --num_opt 800``\ 14 | ``python diffusionsimsimple.py --keep_frac 0.01 --filename 0.01 --num_opt 800``\ 15 | ``python diffusionsimsimple.py --keep_frac 0.005 --filename 0.005 --num_opt 800``\ 16 | ``python diffusionsimsimple.py --keep_frac 0.002 --filename 0.002 --num_opt 800``\ 17 | ``python diffusionsimsimple.py --keep_frac 0.001 --filename 0.001 --num_opt 800``\ 18 | ``cd diff_data``\ 19 | ``python plot_curves.py`` 20 | -------------------------------------------------------------------------------- /diffusionsim_experiments/diffusionsim_plotcurves.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import pickle 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from IPython.display import display, HTML 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import csv 10 | from numpy import genfromtxt 11 | 12 | params = { 13 | 'axes.labelsize': 12, 14 | 'font.size': 12, 15 | 'legend.fontsize': 12, 16 | 'xtick.labelsize': 12, 17 | 'ytick.labelsize': 12, 18 | 'text.usetex': True, 19 | 'figure.figsize': [6, 4], 20 | 'text.latex.preamble': r'\usepackage{amsmath} \usepackage{amssymb}', 21 | } 22 | plt.rcParams.update(params) 23 | 24 | fig = plt.figure(figsize=(10,40)) 25 | plt.axes(frameon=0) # turn off frames 26 | plt.grid(axis='y', color='0.9', linestyle='-', linewidth=1) 27 | 28 | ax = plt.subplot(511) 29 | plt.title('Training Loss vs Iterations for Reaction-Diffusion Equation') 30 | ax.set_yscale('log') 31 | 32 | marker_size = 10 33 | 34 | prefixes = [("0.001"),("0.002"), ("0.005"), ("0.01"), ("0.1"), ("1.0")] 35 | 36 | 37 | for pref in prefixes: 38 | my_data = genfromtxt('{}_loss.csv'.format(pref), delimiter=',') 39 | ax.plot(my_data, label=pref, ms=marker_size) 40 | 41 | ax.legend(title='Memory fraction') 42 | fig.savefig('diffusion_sim_curves.pdf') 43 | -------------------------------------------------------------------------------- /diffusionsim_experiments/diffusionsimsimple.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | import numpy as npo 3 | from jax.ops import index_add, index, index_update 4 | import jax 5 | from jax import jit 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.mplot3d import Axes3D 8 | from jax.experimental import optimizers 9 | from jax import value_and_grad 10 | import argparse 11 | import time 12 | import matplotlib.animation as animation 13 | from numpy import genfromtxt 14 | import csv 15 | from functools import partial 16 | import pdb 17 | import os 18 | 19 | PI = np.pi 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--keep_frac", type=float, default=0.5, help="fraction of phi we store" 24 | ) 25 | parser.add_argument( 26 | "--filename", 27 | type=str, 28 | default="baseline", 29 | help="prefix for the .mp4 files and .csv file.", 30 | ) 31 | parser.add_argument( 32 | "--num_opt", 33 | type=int, 34 | default=800, 35 | help="number of optimization iterations", 36 | ) 37 | 38 | assert os.path.isdir("diff_data") 39 | 40 | args = parser.parse_args() 41 | anim_init_filename = "diff_data/{}_initial.mp4".format(args.filename) 42 | anim_final_filename = "diff_data/{}_final.mp4".format(args.filename) 43 | csvfile = "diff_data/{}_loss.csv".format(args.filename) 44 | 45 | keep_frac = args.keep_frac 46 | 47 | assert keep_frac <= 1.0 48 | assert keep_frac > 0.0 49 | 50 | # Need 4 D delta_t / (delta_x)^2 < 1 for numerical stability 51 | 52 | # Here we set 4 D delta_t / (delta_x)^2 = 0.25 53 | 54 | # Simulate for 10 units of time 55 | T_f = 10 56 | # 2D box with size 1 x 1 57 | nx = 32 58 | dx = 1 / nx 59 | # Small timestep for numerical stability 60 | dt = 1.0 / 4096 61 | nt = int(T_f * 4096) 62 | # 1/4 cancels factor of 4 in mean-squared distance 63 | D = 0.25 64 | 65 | N_optim = args.num_opt 66 | 67 | rows, cols = nx + 1, nx + 1 68 | 69 | x = np.linspace(0, nx * dx, nx + 1) 70 | xg, yg = np.meshgrid(x, x) 71 | 72 | 73 | sin2pix = np.sin(2 * PI * xg) 74 | cos2pix = np.cos(2 * PI * xg) 75 | 76 | 77 | def get_source(source_params, t): 78 | return ( 79 | source_params[0] 80 | + source_params[1] * np.sin(PI * t) 81 | + source_params[2] * np.cos(PI * t) 82 | + source_params[3] * sin2pix * np.sin(PI * t) 83 | + source_params[4] * sin2pix * np.cos(PI * t) 84 | + source_params[5] * cos2pix * np.sin(PI * t) 85 | + source_params[6] * cos2pix * np.cos(PI * t) 86 | ) 87 | 88 | 89 | def get_ic(): 90 | return np.sin(PI * xg) * np.sin(PI * yg) 91 | 92 | 93 | ic = get_ic() 94 | 95 | 96 | def get_perturbation(): 97 | return 0.25 * np.sin(2 * PI * xg) * np.sin(PI * yg) 98 | 99 | 100 | perturbation = get_perturbation() 101 | 102 | 103 | def get_target(t): 104 | return ic + perturbation * np.sin(PI * t) 105 | 106 | 107 | def step_phi(phi, S): 108 | """ 109 | 110 | takes phi and returns phi at t+dt, using the update formula 111 | 112 | phi^{t+1}_{i,j} = phi^t_{i,j} + 113 | (D dt / dx^2) * (phi^t_{i+1,j} + phi^t_{i,j+1} - 4 phi^t_{i,j} + 114 | phi^t_{i-1,j} + phi^t_{i,j-1}) 115 | + dt * S^t_{ij} * phi^i_{ij} 116 | 117 | it all needs to be in the same index_add step otherwise we're not updating phi 118 | using the previous timestep 119 | 120 | """ 121 | phi = index_add( 122 | phi, 123 | index[1:-1, 1:-1], 124 | (D * dt / dx ** 2) 125 | * ( 126 | np.roll(phi, 1, axis=0) 127 | + np.roll(phi, 1, axis=1) 128 | + np.roll(phi, -1, axis=0) 129 | + np.roll(phi, -1, axis=1) 130 | - 4.0 * phi 131 | )[1:-1, 1:-1] 132 | + dt * S[1:-1, 1:-1] * phi[1:-1, 1:-1], 133 | ) 134 | return phi 135 | 136 | 137 | def loss_fn(phi, target): 138 | return np.mean((phi - target) ** 2) 139 | 140 | 141 | def sim_step(phi, source_params, t): 142 | source = get_source(source_params, t) 143 | target = get_target(t) 144 | phi = step_phi(phi, source) 145 | dl = loss_fn(phi, target) * dt 146 | return phi, dl 147 | 148 | 149 | def sim_step_target_sampled(phi, source_params, t, target): 150 | source = get_source(source_params, t) 151 | phi = step_phi(phi, source) 152 | dl = loss_fn(phi, target) * dt 153 | return phi, dl 154 | 155 | 156 | @jax.custom_vjp 157 | def sim_step_subsampled(phi, source_params, t, key): 158 | phi, dl = sim_step(phi, source_params, t) 159 | return phi, dl 160 | 161 | 162 | def sim_step_sub_fwd(phi, source_params, t, key): 163 | n_ixs = rows * cols 164 | all_ixs = np.arange(n_ixs, dtype=np.int32) 165 | keep_ixs = jax.random.shuffle(key, all_ixs)[: int(n_ixs * keep_frac)] 166 | 167 | phi_sub = phi.ravel()[keep_ixs] 168 | 169 | primals = sim_step_subsampled(phi, source_params, t, key) 170 | 171 | return primals, (phi_sub, source_params, t, key) 172 | 173 | 174 | def sim_step_sub_rev(res, g): 175 | 176 | phi_sub, source_params, t, key = res 177 | n_ixs = rows * cols 178 | all_ixs = np.arange(n_ixs, dtype=np.int32) 179 | keep_ixs = jax.random.shuffle(key, all_ixs)[: int(n_ixs * keep_frac)] 180 | 181 | phi = np.zeros((rows, cols)).ravel() 182 | phi = jax.ops.index_add(phi, keep_ixs, phi_sub) 183 | phi = phi.reshape(rows, cols) 184 | new_target = np.zeros((rows, cols)).ravel() 185 | new_target = jax.ops.index_add( 186 | new_target, keep_ixs, get_target(t).ravel()[keep_ixs] 187 | ) 188 | new_target = new_target.reshape(rows, cols) 189 | 190 | def forward(phi, source_params): 191 | return sim_step_target_sampled(phi, source_params, t, new_target) 192 | 193 | primals, vjp = jax.vjp(forward, phi, source_params) 194 | grad_phi, grad_source = vjp(g) 195 | return (grad_phi, grad_source, np.zeros_like(t), np.zeros_like(key)) 196 | 197 | 198 | sim_step_subsampled.defvjp(sim_step_sub_fwd, sim_step_sub_rev) 199 | 200 | 201 | def simulate(source_params, ic, key=jax.random.PRNGKey(0)): 202 | scan_step = lambda phi, t_and_key: sim_step_subsampled( 203 | phi, source_params, t_and_key[0], t_and_key[1] 204 | ) 205 | scan_inputs = (np.linspace(0, T_f, nt + 1), jax.random.split(key, nt + 1)) 206 | 207 | phi, losses = jax.lax.scan(scan_step, ic, scan_inputs) 208 | return np.sum(losses) 209 | 210 | 211 | def print_loss(loss, step): 212 | print("step is : {} loss is: {}".format(step, loss)) 213 | 214 | 215 | def write_csv(losses, csvfile): 216 | with open(csvfile, "w") as f: 217 | writer = csv.writer(f, delimiter=",") 218 | writer.writerow(losses) 219 | 220 | 221 | def animate(source_params, ic, filename): 222 | scan_step = lambda phi, t: anim_step(phi, source_params, t) 223 | _, phis = jax.lax.scan(scan_step, ic, np.linspace(0, T_f, nt + 1)) 224 | fig = plt.figure() 225 | ax1 = fig.add_subplot(1, 2, 1, projection="3d") 226 | ax2 = fig.add_subplot(1, 2, 2, projection="3d") 227 | ax = ax1, ax2 228 | ax1.set_zlim(0, 1) 229 | ax2.set_zlim(0, 1) 230 | ani = animation.FuncAnimation( 231 | fig, 232 | lambda i: anim_callback(int(i), phis, ax), 233 | blit=False, 234 | frames=np.arange(0, 40000, 1000), 235 | ) 236 | ani.save(filename) 237 | 238 | 239 | def anim_callback(i, phis, ax): 240 | ax1, ax2 = ax 241 | t = i * dt 242 | target = get_target(t) 243 | ax1.clear() 244 | ax2.clear() 245 | ax1.set_zlim(0, 1) 246 | ax2.set_zlim(0, 1) 247 | ax1.set_title("Neutron Density") 248 | ax2.set_title("Target Neutron Density") 249 | surface = ax1.plot_surface(xg, yg, phis[i, :, :]) 250 | surface = ax2.plot_surface(xg, yg, target) 251 | return surface 252 | 253 | 254 | def anim_step(phi, source_params, t): 255 | source = get_source(source_params, t) 256 | phi = step_phi(phi, source) 257 | return phi, phi 258 | 259 | 260 | key = jax.random.PRNGKey(0) 261 | key, k1, k2 = jax.random.split(key, 3) 262 | source_params = jax.random.uniform(key, minval=-0.1, maxval=0.1, shape=(7,)) 263 | source_params = index_update( 264 | source_params, index[0], 4.932 265 | ) # 4.932 will give steady state 266 | ic = get_ic() 267 | 268 | objective_with_grad = jit( 269 | value_and_grad(lambda source_params: simulate(source_params, ic)) 270 | ) 271 | 272 | opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2) 273 | 274 | opt_state = opt_init(source_params) 275 | 276 | animate(source_params, ic, anim_init_filename) 277 | 278 | t1 = time.time() 279 | losses = npo.zeros(N_optim) 280 | for step in range(N_optim): 281 | l, g = objective_with_grad(get_params(opt_state)) 282 | opt_state = opt_update(step, g, opt_state) 283 | print_loss(l, step) 284 | losses[step] = l 285 | 286 | write_csv(losses, csvfile) 287 | 288 | t2 = time.time() 289 | print("Time to run: {}".format(t2 - t1)) 290 | 291 | animate(get_params(opt_state), ic, anim_final_filename) 292 | -------------------------------------------------------------------------------- /nn_experiments/README_ff.md: -------------------------------------------------------------------------------- 1 | # Running Feedforward Experiments 2 | 3 | ## Requirements 4 | 5 | See requirements.txt file for full pip requirements.\ 6 | Major requirements: torch, torchvision, tensorboard, matplotlib 7 | 8 | NOTE: Only verified compatibility with torch version 1.3.1 + torchvision version 0.4.2\ 9 | torch version 1.5.0 is confirmed NOT to work. 10 | 11 | ## Reproducing results from paper 12 | 13 | ``cifarffcommands.txt`` and ``mnistffcommands.txt`` contains the commands to run the experiments for the feedforward experiments in the main text. The ``lr`` and ``weight_decay`` hyper-parameters have been tuned separately for each experiment, as described in the main text. 14 | 15 | ## Plotting results from paper 16 | 17 | ``plot_mnist.py`` and ``plot_cifar.py`` create the plots that result from the experiments. 18 | 19 | All the subdirectories will be created in the current directory. These can be changed using the ``exp_root`` argument for ``{cifar, mnist}_launch.py`` and ``EXP_ROOT`` global parameter in ``plot_{cifar, mnist}.py``. 20 | 21 | ## Datasets 22 | 23 | Data will be downloaded in the ``data`` subdirectory in the current directory, unless the ``data_root`` argument is provided to ``{cifar, mnist}_launch.py``. 24 | -------------------------------------------------------------------------------- /nn_experiments/README_rnn.md: -------------------------------------------------------------------------------- 1 | # Running RNN code 2 | 3 | ## Requirements 4 | 5 | See requirements.txt file for full pip requirements.\ 6 | Major requirements: torch, torchvision, tensorboard, matplotlib 7 | 8 | NOTE: Only verified compatibility with torch version 1.3.1 + torchvision version 0.4.2\ 9 | torch version 1.5.0 is confirmed NOT to work. 10 | 11 | ## Reproducing results in paper 12 | 13 | To reproduce the RNN results in the paper, run the following commands for each type of experiment. Replace `{exp_dir}` 14 | with the path to the directory where results will be stored and ``{data_dir}`` with the path to which the MNIST data will 15 | be downloaded and saved or already exists. Replace ``{exp_name}`` with the name of the experiment and ``{seed}`` with the 16 | given random seeds below. 17 | 18 | ##### Baseline ##### 19 | 20 | Seeds are ``1, 2, 28222136`` 21 | 22 | ``python rnn_mnist_launch.py --exp_name=irnn_baseline --seed={seed} --batch_size=150 --dataset=mnist --lr=1e-4 --keep_frac=1.0 --hidden_size=100 --weight_decay=0. --clip=1. --save_inter=10 --data_root={data_dir} --exp_root={exp_dir} --augment=False --use_writer=False --with_replace=True --simple=True --max_iterations=200000 --simple_test_eval_frequency=400`` 23 | 24 | ##### Small Batch ##### 25 | 26 | Seeds are ``1, 2, 3`` 27 | 28 | ``python rnn_mnist_launch.py --exp_name=irnn_small_batch --seed={seed} --batch_size=21 --dataset=mnist --lr=1e-4 --keep_frac=1.0 --hidden_size=100 --weight_decay=0. --clip=1. --save_inter=10 --data_root={data_dir} --exp_root={exp_dir} --augment=False --use_writer=False --with_replace=True --simple=True --max_iterations=200000 --simple_test_eval_frequency=400`` 29 | 30 | ##### Sparse ##### 31 | 32 | Seeds are ``1, 2, 3`` 33 | 34 | ``python rnn_mnist_launch.py --exp_name=rand_irnn_sparse --seed={seed} --batch_size=150 --dataset=mnist --lr=1e-4 --keep_frac=0.1 --hidden_size=100 --weight_decay=0. --clip=1. --save_inter=10 --data_root={data_dir} --exp_root={exp_dir} --augment=False --use_writer=False --with_replace=True --simple=True --max_iterations=200000 --simple_test_eval_frequency=400 --sparse=True --full_random=False`` 35 | 36 | ##### Full Sparse ##### 37 | 38 | Seeds are ``1, 2, 3`` 39 | 40 | ``python rnn_mnist_launch.py --exp_name=rand_irnn_sparse_full --seed={seed} --batch_size=150 --dataset=mnist --lr=1e-4 --keep_frac=0.1 --hidden_size=100 --weight_decay=0. --clip=1. --save_inter=10 --data_root={data_dir} --exp_root={exp_dir} --augment=False --use_writer=False --with_replace=True --simple=True --max_iterations=200000 --simple_test_eval_frequency=400 --sparse=True --full_random=True`` 41 | 42 | ##### RP ##### 43 | 44 | Seeds are ``1, 2, 20268186`` 45 | 46 | ``python rnn_mnist_launch.py --exp_name=rand_irnn_rp --seed={seed} --batch_size=150 --dataset=mnist --lr=1e-4 --keep_frac=0.1 --hidden_size=100 --weight_decay=0. --clip=1. --save_inter=10 --data_root={data_dir} --exp_root={exp_dir} --augment=False --use_writer=False --with_replace=True --simple=True --max_iterations=200000 --simple_test_eval_frequency=400 --sparse=False --full_random=False`` 47 | 48 | ##### Full RP ##### 49 | 50 | Seeds are ``1, 2, 3`` 51 | 52 | ``python rnn_mnist_launch.py --exp_name=rand_irnn_rp_full --seed={seed} --batch_size=150 --dataset=mnist --lr=1e-4 --keep_frac=0.1 --hidden_size=100 --weight_decay=0. --clip=1. --save_inter=10 --data_root={data_dir} --exp_root={exp_dir} --augment=False --use_writer=False --with_replace=True --simple=True --max_iterations=200000 --simple_test_eval_frequency=400 --sparse=False --full_random=True`` 53 | 54 | 55 | ## Plotting results in paper 56 | 57 | ``rnn_plottingscripts.py`` creates the plot from the main text, and ``rnn_plottingscripts_appendix.py`` creates the plot for the appendix. The ``EXP_BASE_DIR`` variable in the plotting scripts should be changed to the chosen ``{exp_dir}`` above, and ``exp_folders`` should contain the chosen ``--exp_name``s. 58 | -------------------------------------------------------------------------------- /nn_experiments/cifar_launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import random 5 | import signal 6 | import shutil 7 | import argparse 8 | import numpy as np 9 | import collections 10 | 11 | import data as rpdata 12 | import models as rpmodels 13 | import utils as rputils 14 | 15 | from train_and_eval import run_model 16 | 17 | import torch 18 | import torch.utils.tensorboard as tb 19 | 20 | 21 | ''' 22 | To override these parameters, specify in command line, such as: 23 | python mnist_launch.py --batch_size=100 24 | 25 | It is important to precede the argument with '--' and include an '=' sign. 26 | Do not use quotes around the value. 27 | ''' 28 | args_template = rputils.ParameterMap( 29 | # experiment name. prepended to pickle name 30 | exp_name='', 31 | 32 | # input batch size for training (default: 64) 33 | batch_size=128, 34 | 35 | # accepted dataset 36 | dataset='cifar10', 37 | 38 | # input batch size for testing (default: 1000) 39 | test_batch_size=1000, 40 | 41 | # number of epochs to train (default: 14) 42 | epochs=200, 43 | 44 | # learning rate (default: 1.0) 45 | lr=0.1, 46 | 47 | # Learning rate step gamma (default: 0.7) 48 | gamma=0.2, 49 | 50 | # disables CUDA training 51 | no_cuda=False, 52 | 53 | # random seed. do not set seed if 0. 54 | seed=0, 55 | 56 | # how many batches to wait before logging training status 57 | log_interval=10, 58 | 59 | # the fraction of activations to reduce to with RAD 60 | keep_frac=1.0, 61 | 62 | # hidden layer size for MNISTFCNet 63 | hidden_size=300, 64 | 65 | # number of epochs after which to drop the lr 66 | lr_drop_step=1, 67 | 68 | # l2 weight decay on the parameters 69 | weight_decay=5e-4, 70 | 71 | # For Saving the current Model 72 | save_model=True, 73 | 74 | # CIFAR lr schedule 75 | training_schedule='cifar', 76 | 77 | # Use adam optimizer in cifar 78 | cifar_adam=False, 79 | 80 | # If true, runs baseline without RP 81 | rp_layer='rpconv', 82 | 83 | # Number of epochs to test on the train set. 84 | train_test_interval=5, 85 | 86 | # Data Root 87 | data_root='./data', 88 | 89 | # Experiment Root 90 | exp_root='', 91 | 92 | # If true, randomly splits the training set into train/val (validation 5000). Ignores test set. 93 | validation=False, 94 | 95 | # Whether to sample with replacement or not while training. True gives real SGD. 96 | with_replace=False, 97 | 98 | # Whether to use augmentation in training dataset. 99 | augment=True, 100 | 101 | # Additive noise to add. 102 | rand_noise=0.0, 103 | 104 | # Wide-ResNet width multiplier. 105 | width_multiplier=1, 106 | 107 | # If experiment exists, overwrite 108 | override=False, 109 | 110 | # If true, generate random matrix independent across batch 111 | full_random=False, 112 | 113 | # If > 0, save this many intermediate checkpoints 114 | save_inter=0, 115 | 116 | # Whether to do simple iteration based training instead of epoch based. 117 | simple=False, 118 | 119 | # Following are only used when simple is True. 120 | max_iterations=-1, 121 | simple_log_frequency=-1, 122 | simple_test_eval_frequency=-1, 123 | simple_test_eval_per_train_test=-1, 124 | simple_scheduler_step_frequency=-1, 125 | simple_model_checkpoint_frequency=-1, 126 | 127 | # If true, samples training set with replacement. 128 | bootstrap_train=False, 129 | 130 | # If false, uses random projections. If true, uses sampling. 131 | sparse=False, 132 | 133 | # If true, also uses RAD on ReLU layers. 134 | rand_relu=False, 135 | ) 136 | 137 | 138 | def main(additional_args): 139 | args = args_template.clone() 140 | rputils.override_arguments(args, additional_args) 141 | 142 | # If simple is set, default to these arguments. 143 | # Note that we override again at the end, so specified 144 | # arguments take precedence over defaults. 145 | if args.simple: 146 | args.max_iterations = 100000 147 | args.simple_log_frequency = 10 148 | args.simple_test_eval_frequency = 400 149 | args.simple_test_eval_per_train_test = 10 150 | args.simple_scheduler_step_frequency = 10000 151 | args.simple_model_checkpoint_frequency = 10000 152 | args.save_inter = 1 153 | 154 | args.batch_size = 150 155 | args.gamma = 0.6 156 | args.training_schedule = 'epoch_step' 157 | args.cifar_adam = True 158 | args.lr = 0.002 159 | args.with_replace = True 160 | args.augment = False 161 | args.validation = False 162 | args.lr_drop_step = 1 163 | rputils.override_arguments(args, additional_args) 164 | 165 | if not os.path.exists(args.exp_root): 166 | print('Creating experiment root directory {}'.format(args.exp_root)) 167 | os.mkdir(args.exp_root) 168 | if not args.exp_name: 169 | args.exp_name = 'exp{}'.format(random.randint(100000, 999999)) 170 | 171 | if args.seed == 0: 172 | args.seed = random.randint(10000000, 99999999) 173 | 174 | args.exp_dir = os.path.join(args.exp_root, args.exp_name) 175 | os.environ['LAST_EXPERIMENT_DIR'] = args.exp_dir 176 | if args.override and os.path.exists(args.exp_dir): 177 | print("Overriding existing directory.") 178 | shutil.rmtree(args.exp_dir) 179 | assert not os.path.exists(args.exp_dir) 180 | print("Creating experiment with name {} in {}".format(args.exp_name, args.exp_dir)) 181 | os.mkdir(args.exp_dir) 182 | with open(os.path.join(args.exp_dir, 'experiment_args.txt'), 'w') as f: 183 | f.write(str(args)) 184 | 185 | if args.save_inter > 0: 186 | args.inter_dir = os.path.join(args.exp_dir, 'intermediate_checkpoints') 187 | if not os.path.exists(args.inter_dir): 188 | print('Creating directory for intermediate checkpoints.') 189 | os.mkdir(args.inter_dir) 190 | 191 | args.pickle_dir = os.path.join(args.exp_dir, 'pickles') 192 | if not os.path.exists(args.pickle_dir): 193 | print('Creating pickle directory in experiment directory.') 194 | os.mkdir(args.pickle_dir) 195 | 196 | use_cuda = not args.no_cuda and torch.cuda.is_available() 197 | print('Seed is {}'.format(args.seed)) 198 | torch.manual_seed(args.seed) 199 | torch.backends.cudnn.deterministic = True 200 | torch.backends.cudnn.benchmark = False 201 | np.random.seed(args.seed) 202 | 203 | device = torch.device("cuda" if use_cuda else "cpu") 204 | 205 | rp_args = {} 206 | rp_args['rp_layer'] = args.rp_layer 207 | rp_args['keep_frac'] = args.keep_frac 208 | rp_args['rand_noise'] = args.rand_noise 209 | rp_args['width_multiplier'] = args.width_multiplier 210 | rp_args['full_random'] = args.full_random 211 | rp_args['sparse'] = args.sparse 212 | 213 | models = [ 214 | (rpmodels.CIFARConvNet(rp_args=rp_args, rand_relu=args.rand_relu), args.exp_name + "cifarconvnet8", args.exp_name + "cifarconvnet8"), 215 | ] 216 | 217 | # Check if correct dataset is used for each model. 218 | for model, _, _ in models: 219 | if model.kCompatibleDataset != args.dataset: 220 | raise NotImplementedError( 221 | 'Unsupported dataset {} with model {}'.format(args.dataset, model.__class__.__name__) 222 | ) 223 | 224 | for model, pickle_string, model_string in models: 225 | run_model(model, args, device, None, pickle_string, model_string) 226 | 227 | 228 | if __name__ == '__main__': 229 | main(sys.argv[1:]) 230 | 231 | -------------------------------------------------------------------------------- /nn_experiments/cifarffcommands.txt: -------------------------------------------------------------------------------- 1 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0000-project --simple=True --validation=False --lr=0.000435 --weight_decay=5.485464e-05 --keep_frac=0.1 --bootstrap_train=True 2 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0001-project --simple=True --validation=False --lr=0.000435 --weight_decay=5.485464e-05 --keep_frac=0.1 --bootstrap_train=True 3 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0002-project --simple=True --validation=False --lr=0.000435 --weight_decay=5.485464e-05 --keep_frac=0.1 --bootstrap_train=True 4 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0003-project --simple=True --validation=False --lr=0.000435 --weight_decay=5.485464e-05 --keep_frac=0.1 --bootstrap_train=True 5 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0004-project --simple=True --validation=False --lr=0.000435 --weight_decay=5.485464e-05 --keep_frac=0.1 --bootstrap_train=True 6 | 7 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0005-smallbatch --simple=True --validation=False --lr=0.000430 --weight_decay=1.962288e-06 --batch_size=22 --bootstrap_train=True 8 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0006-smallbatch --simple=True --validation=False --lr=0.000430 --weight_decay=1.962288e-06 --batch_size=22 --bootstrap_train=True 9 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0007-smallbatch --simple=True --validation=False --lr=0.000430 --weight_decay=1.962288e-06 --batch_size=22 --bootstrap_train=True 10 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0008-smallbatch --simple=True --validation=False --lr=0.000430 --weight_decay=1.962288e-06 --batch_size=22 --bootstrap_train=True 11 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0009-smallbatch --simple=True --validation=False --lr=0.000430 --weight_decay=1.962288e-06 --batch_size=22 --bootstrap_train=True 12 | 13 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0010-baseline --simple=True --validation=False --lr=0.001240 --weight_decay=4.340514e-05 --bootstrap_train=True 14 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0011-baseline --simple=True --validation=False --lr=0.001240 --weight_decay=4.340514e-05 --bootstrap_train=True 15 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0012-baseline --simple=True --validation=False --lr=0.001240 --weight_decay=4.340514e-05 --bootstrap_train=True 16 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0013-baseline --simple=True --validation=False --lr=0.001240 --weight_decay=4.340514e-05 --bootstrap_train=True 17 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0014-baseline --simple=True --validation=False --lr=0.001240 --weight_decay=4.340514e-05 --bootstrap_train=True 18 | 19 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0015-samesample --simple=True --keep_frac=0.1 --sparse=True --validation=False --lr=0.001159 --weight_decay=9.990263e-04 --bootstrap_train=True 20 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0016-samesample --simple=True --keep_frac=0.1 --sparse=True --validation=False --lr=0.001159 --weight_decay=9.990263e-04 --bootstrap_train=True 21 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0017-samesample --simple=True --keep_frac=0.1 --sparse=True --validation=False --lr=0.001159 --weight_decay=9.990263e-04 --bootstrap_train=True 22 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0018-samesample --simple=True --keep_frac=0.1 --sparse=True --validation=False --lr=0.001159 --weight_decay=9.990263e-04 --bootstrap_train=True 23 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0019-samesample --simple=True --keep_frac=0.1 --sparse=True --validation=False --lr=0.001159 --weight_decay=9.990263e-04 --bootstrap_train=True 24 | 25 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0020-diffsample --simple=True --keep_frac=0.1 --sparse=True --lr=0.002139 --weight_decay=7.453841e-04 --full_random=True --bootstrap_train=True 26 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0021-diffsample --simple=True --keep_frac=0.1 --sparse=True --lr=0.002139 --weight_decay=7.453841e-04 --full_random=True --bootstrap_train=True 27 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0022-diffsample --simple=True --keep_frac=0.1 --sparse=True --lr=0.002139 --weight_decay=7.453841e-04 --full_random=True --bootstrap_train=True 28 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0023-diffsample --simple=True --keep_frac=0.1 --sparse=True --lr=0.002139 --weight_decay=7.453841e-04 --full_random=True --bootstrap_train=True 29 | python cifar_launch.py --exp_root=cifarexperiments --exp_name=0024-diffsample --simple=True --keep_frac=0.1 --sparse=True --lr=0.002139 --weight_decay=7.453841e-04 --full_random=True --bootstrap_train=True 30 | -------------------------------------------------------------------------------- /nn_experiments/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | 5 | class ApplyTransform(torch.utils.data.Dataset): 6 | """ 7 | Source: https://stackoverflow.com/a/56587747 -- Accessed: 04/05/2020 8 | 9 | Apply transformations to a Dataset 10 | 11 | Arguments: 12 | dataset (Dataset): A Dataset that returns (sample, target) 13 | transform (callable, optional): A function/transform to be applied on the sample 14 | target_transform (callable, optional): A function/transform to be applied on the target 15 | 16 | """ 17 | def __init__(self, dataset, transform=None, target_transform=None): 18 | self.dataset = dataset 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | # yes, you don't need these 2 lines below :( 22 | if transform is None and target_transform is None: 23 | print("Am I a joke to you? :)") 24 | 25 | def __getitem__(self, idx): 26 | sample, target = self.dataset[idx] 27 | if self.transform is not None: 28 | sample = self.transform(sample) 29 | if self.target_transform is not None: 30 | target = self.target_transform(target) 31 | return sample, target 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | 37 | def get_dataset(dataset, batch_size, test_batch_size=100, with_replace=False, 38 | num_workers=2, data_root='./data', validation=False, augment=True, bootstrap_loader=False, bootstrap_train=False, **kwargs): 39 | if dataset == 'mnist': 40 | train_transform = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.1307,), (0.3081,)) 43 | ]) 44 | 45 | test_transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.1307,), (0.3081,)) 48 | ]) 49 | 50 | dataset_cls = datasets.MNIST 51 | num_classes = 10 52 | input_size = (1, 28, 28) 53 | elif dataset == 'cifar10': 54 | train_transform = transforms.Compose([ 55 | transforms.RandomCrop(32, padding=4), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 59 | ]) 60 | 61 | test_transform = transforms.Compose([ 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 64 | ]) 65 | 66 | dataset_cls = datasets.CIFAR10 67 | num_classes = 10 68 | input_size = (3, 32, 32) 69 | elif dataset == 'cifar100': 70 | train_transform = transforms.Compose([ 71 | transforms.RandomCrop(32, padding=4), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 75 | ]) 76 | 77 | test_transform = transforms.Compose([ 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 80 | ]) 81 | 82 | dataset_cls = datasets.CIFAR100 83 | num_classes = 100 84 | input_size = (3, 32, 32) 85 | else: 86 | raise NotImplementedError('Unsupported Dataset {}'.format(dataset)) 87 | 88 | if validation: 89 | trainset_pre = dataset_cls(root=data_root, train=True, download=True) 90 | total_len = len(trainset_pre) 91 | val_set_size = 5000 92 | with torch.random.fork_rng(): 93 | torch.random.manual_seed(17) # So that validation dataset is deterministic 94 | trainset, valset = \ 95 | torch.utils.data.random_split(trainset_pre, [total_len - val_set_size, val_set_size]) 96 | if bootstrap_train: 97 | bootstrap_samples = torch.randint(low=0, high=len(trainset), size=(len(trainset),)).tolist() 98 | print('Bootstrapping training set with samples: {}'.format(bootstrap_samples[:100])) 99 | trainset = torch.utils.data.Subset(trainset, bootstrap_samples) 100 | 101 | if augment: 102 | trainset = ApplyTransform(trainset, train_transform) 103 | else: 104 | trainset = ApplyTransform(trainset, test_transform) 105 | 106 | train_sampler = torch.utils.data.RandomSampler(trainset, replacement=with_replace) 107 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, drop_last=with_replace, 108 | sampler=train_sampler, num_workers=num_workers) 109 | 110 | valset = ApplyTransform(valset, test_transform) 111 | val_loader = torch.utils.data.DataLoader(valset, batch_size=test_batch_size, 112 | shuffle=False, num_workers=num_workers) 113 | train_test_loader = torch.utils.data.DataLoader(trainset, batch_size=test_batch_size, 114 | shuffle=False, num_workers=num_workers) 115 | 116 | return train_loader, val_loader, train_test_loader, num_classes 117 | else: 118 | trainset = dataset_cls(root=data_root, train=True, download=True) 119 | if bootstrap_train: 120 | bootstrap_samples = torch.randint(low=0, high=len(trainset), size=(len(trainset),)).tolist() 121 | print('Bootstrapping training set with samples: {}'.format(bootstrap_samples[:100])) 122 | trainset = torch.utils.data.Subset(trainset, bootstrap_samples) 123 | 124 | if augment: 125 | trainset = ApplyTransform(trainset, train_transform) 126 | else: 127 | trainset = ApplyTransform(trainset, test_transform) 128 | 129 | train_sampler = torch.utils.data.RandomSampler(trainset, replacement=with_replace) 130 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, drop_last=with_replace, 131 | sampler=train_sampler, num_workers=num_workers) 132 | 133 | testset = dataset_cls(root=data_root, train=False, download=True, transform=test_transform) 134 | test_loader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, 135 | shuffle=False, num_workers=num_workers) 136 | train_test_loader = torch.utils.data.DataLoader(trainset, batch_size=test_batch_size, 137 | shuffle=False, num_workers=num_workers) 138 | 139 | return train_loader, test_loader, train_test_loader, num_classes -------------------------------------------------------------------------------- /nn_experiments/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class RandLinear(torch.nn.Linear): 10 | """ 11 | Linear layer with randomized automatic differentiation. Supports both 12 | random projections (sparse=False) and sampling (sparse=True). 13 | 14 | Arguments: 15 | *args, **kwargs: The regular arguments to torch.nn.Linear. 16 | keep_frac: The fraction of hidden units to keep after reduction with randomized autodiff. 17 | full_random: If true, different hidden units are sampled for each batch element. 18 | Only compatible with sparse=True, as it leads to extreme memory usage with random projections. 19 | sparse: Sampling if true, random projections if false. 20 | """ 21 | 22 | def __init__(self, *args, keep_frac=0.5, full_random=False, sparse=False, **kwargs): 23 | super(RandLinear, self).__init__(*args, **kwargs) 24 | self.keep_frac = keep_frac 25 | self.full_random = full_random 26 | self.random_seed = torch.randint(low=10000000000, high=99999999999, size=(1,)) 27 | self.sparse = sparse 28 | 29 | def forward(self, input, retain=False, skip_rand=False): 30 | """ 31 | If retain is True, uses the same random projection or sample as the last time this was called. 32 | This is achieved through reusing random seeds. 33 | 34 | If skip_rand is True, behaves like a regular torch.nn.Linear layer (sets keep_frac=1.0). 35 | """ 36 | if not retain: 37 | self.random_seed = torch.randint(low=10000000000, high=99999999999, size=(1,)) 38 | 39 | if skip_rand: 40 | keep_frac = 1.0 41 | else: 42 | keep_frac = self.keep_frac 43 | 44 | return RandMatMul.apply(input, self.weight, self.bias, keep_frac, self.full_random, self.random_seed, self.sparse) 45 | 46 | 47 | class RandConv2dLayer(torch.nn.Conv2d): 48 | """ 49 | Conv2d layer with randomized automatic differentiation. Supports both 50 | random projections (sparse=False) and sampling (sparse=True). 51 | 52 | Arguments: 53 | *args, **kwargs: The regular arguments to torch.nn.Conv2d. 54 | keep_frac: The fraction of hidden units to keep after reduction with randomized autodiff. 55 | full_random: If true, different hidden units are sampled for each batch element. 56 | Only compatible with sparse=True, as it leads to extreme memory usage with random projections. 57 | sparse: Sampling if true, random projections if false. 58 | """ 59 | 60 | def __init__(self, *args, keep_frac=0.5, full_random=False, sparse=False, **kwargs): 61 | super(RandConv2dLayer, self).__init__(*args,**kwargs) 62 | self.conv_params = { 63 | 'stride': self.stride, 64 | 'padding': self.padding, 65 | 'dilation': self.dilation, 66 | 'groups': self.groups, 67 | } 68 | self.keep_frac = keep_frac 69 | self.full_random = full_random 70 | self.random_seed = torch.randint(low=10000000000, high=99999999999, size=(1,)) 71 | self.sparse = sparse 72 | 73 | def forward(self, input, retain=False, skip_rand=False): 74 | """ 75 | If retain is True, uses the same random projection or sample as the last time this was called. 76 | This is achieved through reusing random seeds. 77 | 78 | If skip_rand is True, behaves like a regular torch.nn.Conv2d layer (sets keep_frac=1.0). 79 | """ 80 | if not retain: 81 | self.random_seed = torch.randint(low=10000000000, high=99999999999, size=(1,)) 82 | 83 | if skip_rand: 84 | keep_frac = 1.0 85 | else: 86 | keep_frac = self.keep_frac 87 | 88 | return RandConv2d.apply(input, self.weight, self.bias, \ 89 | self.conv_params, keep_frac, self.full_random, self.random_seed, self.sparse) 90 | 91 | 92 | class RandReLULayer(torch.nn.ReLU): 93 | """ 94 | ReLU layer with randomized automatic differentiation. Supports both 95 | random projections (sparse=False) and sampling (sparse=True). 96 | 97 | Not used in experiments as it leads to gradients with high variance. 98 | 99 | Arguments: 100 | *args, **kwargs: The regular arguments to torch.nn.ReLU. 101 | keep_frac: The fraction of hidden units to keep after reduction with randomized autodiff. 102 | full_random: If true, different hidden units are sampled for each batch element. 103 | Only compatible with sparse=True, as it leads to extreme memory usage with random projections. 104 | sparse: Sampling if true, random projections if false. 105 | """ 106 | 107 | def __init__(self, *args, keep_frac=0.5, full_random=False, sparse=False, **kwargs): 108 | super(RandReLULayer, self).__init__(*args, **kwargs) 109 | self.keep_frac = keep_frac 110 | self.full_random = full_random 111 | self.random_seed = torch.randint(low=10000000000, high=99999999999, size=(1,)) 112 | self.sparse = sparse 113 | 114 | def forward(self, input, retain=False, skip_rand=False): 115 | """ 116 | If retain is True, uses the same random projection or sample as the last time this was called. 117 | This is achieved through reusing random seeds. 118 | 119 | If skip_rand is True, behaves like a regular torch.nn.ReLU layer (sets keep_frac=1.0). 120 | """ 121 | if not retain: 122 | self.random_seed = torch.randint(low=10000000000, high=99999999999, size=(1,)) 123 | 124 | if skip_rand: 125 | keep_frac = 1.0 126 | else: 127 | keep_frac = self.keep_frac 128 | 129 | return RandReLU.apply(input, keep_frac, self.full_random, self.random_seed, self.sparse) 130 | 131 | 132 | 133 | ################################################################################################### 134 | ############################ BEGIN: Common Methods for all layers ################################# 135 | ################################################################################################### 136 | 137 | 138 | def input2rp(input, kept_feature_size, full_random=False, random_seed=None): 139 | """ 140 | Converts either a Linear layer or Conv2d layer input into a dimension reduced form 141 | using random projections. 142 | 143 | If the input is 2D, it is interpreted as (batch size) x (hidden size). The (hidden size) dimension 144 | will be reduced and the output will be (batch size) x (kept_feature_size). 145 | 146 | If the input is 4D, it is interpreted as (batch size) x (feature size) x (height) x (width). The 147 | (batch size) x (feature size) dims will be interpreted as a "effective batch size", and the array will 148 | be reduced along (height) x (width). The reduced tensor will be (batch size * feature size) x (kept_feature_size). 149 | 150 | Returns the reduced input and the random matrix used for the random projection. If using a random seed, 151 | the random matrix can be discarded, and the random seed can be used again in the rp2input method to regenerate 152 | the same random matrix. 153 | 154 | Arguments: 155 | input: Tensor of size (batch size) x (hidden size) or (batch size) x (feature size) x (height) x (width) 156 | kept_feature_size: The number to reduce the dimension to. 157 | full_random: If true, different hidden units are sampled for each effective batch element. 158 | WARNING: Will lead to extensive memory use if set to True in this method. 159 | random_seed: Use this random seed if not None. 160 | """ 161 | 162 | def shp(t): 163 | return tuple(t.size()) 164 | 165 | if len(shp(input)) == 4: 166 | batch_size = (shp(input)[0], shp(input)[1]) 167 | feature_len = shp(input)[2] * shp(input)[3] 168 | elif len(shp(input)) == 2: 169 | batch_size = (shp(input)[0], ) 170 | feature_len = shp(input)[1] 171 | 172 | if full_random: 173 | rand_matrix_size = (*batch_size, feature_len, kept_feature_size) 174 | matmul_view = input.view(*batch_size, 1, feature_len) 175 | else: 176 | rand_matrix_size = (feature_len, kept_feature_size) 177 | matmul_view = input.view(*batch_size, feature_len) 178 | 179 | # Create random matrix 180 | def gen_rad_mat(rm_size, feat_size, device): 181 | bern = torch.randint(2, size=rm_size, device=device, requires_grad=False) 182 | return (2.0 * bern - 1) / feat_size**0.5 183 | 184 | if random_seed: 185 | with torch.random.fork_rng(): 186 | torch.random.manual_seed(random_seed) 187 | rand_matrix = gen_rad_mat(rand_matrix_size, kept_feature_size, input.device) 188 | else: 189 | rand_matrix = gen_rad_mat(rand_matrix_size, kept_feature_size, input.device) 190 | 191 | with torch.autograd.grad_mode.no_grad(): 192 | dim_reduced_input = \ 193 | torch.matmul(matmul_view, rand_matrix) 194 | return dim_reduced_input, rand_matrix 195 | 196 | 197 | def rp2input(dim_reduced_input, input_shape, rand_matrix=None, random_seed=None, full_random=False): 198 | """ 199 | Inverse of input2rp. Accepts the outputted reduced tensor from input2rp along with 200 | the expected size of the input. 201 | 202 | One and only one of rand_matrix or random_seed must be provided. 203 | This method must take either the rand_matrix outputted by input2rp, or the random seed 204 | used by input2rp. If the random seed is provided, this method will reconstruct rand_matrix, which 205 | contains the random matrix used to project the input. 206 | 207 | Arguments: 208 | dim_reduced_input: The outputted reduced tensor from input2rp. 209 | input_shape: The shape of the input tensor fed into input2rp. 210 | rand_matrix: The random matrix generated by input2rp. 211 | random_seed: Set this random seed to the same one as input2rp to reconstruct the random indices. 212 | full_random: Must be set to the same value as when input2rp was called. 213 | """ 214 | 215 | def shp(t): 216 | return tuple(t.size()) 217 | 218 | if rand_matrix is None and random_seed is None: 219 | print("ERROR in rp2input: One of rand_matrix or random_seed must be provided.") 220 | return 221 | if rand_matrix is not None and random_seed is not None: 222 | print("ERROR in rp2input: Only one of rand_matrix or random_seed must be provided.") 223 | return 224 | 225 | if len(input_shape) == 4: 226 | batch_size = (input_shape[0], input_shape[1]) 227 | feature_len = input_shape[2] * input_shape[3] 228 | elif len(input_shape) == 2: 229 | batch_size = (input_shape[0], ) 230 | feature_len = input_shape[1] 231 | 232 | kept_feature_size = shp(dim_reduced_input)[-1] 233 | if full_random: 234 | rand_matrix_shape = (*batch_size, feature_len, kept_feature_size) 235 | else: 236 | rand_matrix_shape = (feature_len, kept_feature_size) 237 | 238 | # Create random matrix 239 | def gen_rad_mat(rm_size, feat_size, device): 240 | bern = torch.randint(2, size=rm_size, device=device, requires_grad=False) 241 | return (2.0 * bern - 1) / feat_size**0.5 242 | 243 | if random_seed is not None: 244 | with torch.random.fork_rng(): 245 | torch.random.manual_seed(random_seed) 246 | rand_matrix = gen_rad_mat(rand_matrix_shape, kept_feature_size, dim_reduced_input.device) 247 | 248 | with torch.autograd.grad_mode.no_grad(): 249 | input = torch.matmul(dim_reduced_input, torch.transpose(rand_matrix, -2, -1)) 250 | input = input.view(input_shape) 251 | 252 | return input 253 | 254 | 255 | def input2sparse(input, kept_feature_size, full_random=False, random_seed=None): 256 | """ 257 | Converts either a Linear layer or Conv2d layer input into a dimension reduced form 258 | using sampling. 259 | 260 | If the input is 2D, it is interpreted as (batch size) x (hidden size). The (hidden size) dimension 261 | will be reduced and the output will be (batch size) x (kept_feature_size). 262 | 263 | If the input is 4D, it is interpreted as (batch size) x (feature size) x (height) x (width). The 264 | (batch size) x (feature size) dims will be interpreted as a "effective batch size", and the input will 265 | be reduced along (height) x (width). The reduced tensor will be (batch size * feature size) x (kept_feature_size). 266 | 267 | Returns the reduced input and the random indices used for the sampling. If using a random seed, 268 | the random indices can be discarded, and the random seed can be used again in the sparse2input method to regenerate 269 | the same random indices. 270 | 271 | Arguments: 272 | input: Tensor of size (batch size) x (hidden size) or (batch size) x (feature size) x (height) x (width) 273 | kept_feature_size: The number to reduce the dimension to. 274 | full_random: If true, different hidden units are sampled for each effective batch element. 275 | random_seed: Use this random seed if not None. 276 | """ 277 | 278 | def shp(t): 279 | return tuple(t.size()) 280 | 281 | if len(shp(input)) == 4: 282 | batch_size = shp(input)[0] * shp(input)[1] 283 | feature_len = shp(input)[2] * shp(input)[3] 284 | elif len(shp(input)) == 2: 285 | batch_size = shp(input)[0] 286 | feature_len = shp(input)[1] 287 | 288 | if full_random: 289 | gather_index_shape = (batch_size, kept_feature_size) 290 | else: 291 | gather_index_shape = (1, kept_feature_size) 292 | 293 | # Create random matrix 294 | if random_seed is not None: 295 | with torch.random.fork_rng(): 296 | torch.random.manual_seed(random_seed) 297 | gather_index = torch.randint(feature_len, gather_index_shape, device=input.device, dtype=torch.long) 298 | else: 299 | gather_index = torch.randint(feature_len, gather_index_shape, device=input.device, dtype=torch.long) 300 | 301 | with torch.autograd.grad_mode.no_grad(): 302 | gathered_input = \ 303 | torch.gather(input.view(batch_size, feature_len), 304 | index=gather_index.expand(batch_size, -1), dim=-1).clone() 305 | # Normalization to ensure unbiased. 306 | gathered_input *= feature_len / kept_feature_size 307 | 308 | return gathered_input, gather_index 309 | 310 | 311 | def sparse2input(gathered_input, input_shape, gather_index=None, random_seed=None, full_random=False): 312 | """ 313 | Inverse of input2sparse. Accepts the outputted reduced tensor from input2sparse along with 314 | the expected size of the input. 315 | 316 | One and only one of gather_index or random_seed must be provided. 317 | This method must take either the gather_index outputted by input2sparse, or the random seed 318 | used by input2sparse. If the random seed is provided, this method will reconstruct gather_index, which 319 | contains the random indices used to sample the input. 320 | 321 | Arguments: 322 | gathered_input: The outputted reduced tensor from input2sparse. 323 | input_shape: The shape of the input tensor fed into input2sparse. 324 | gather_index: The random indices generated by input2sparse. 325 | random_seed: Set this random seed to the same one as input2sparse to reconstruct the random indices. 326 | full_random: Must be set to the same value as when input2sparse was called. 327 | """ 328 | 329 | def shp(t): 330 | return tuple(t.size()) 331 | 332 | if gather_index is None and random_seed is None: 333 | print("ERROR in sparse2input: One of gather_index or random_seed must be provided.") 334 | return 335 | if gather_index is not None and random_seed is not None: 336 | print("ERROR in sparse2input: Only one of gather_index or random_seed must be provided.") 337 | return 338 | 339 | if len(input_shape) == 4: 340 | batch_size = input_shape[0] * input_shape[1] 341 | feature_len = input_shape[2] * input_shape[3] 342 | elif len(input_shape) == 2: 343 | batch_size = input_shape[0] 344 | feature_len = input_shape[1] 345 | 346 | kept_feature_size = shp(gathered_input)[-1] 347 | if full_random: 348 | gather_index_shape = (batch_size, kept_feature_size) 349 | else: 350 | gather_index_shape = (1, kept_feature_size) 351 | 352 | if random_seed is not None: 353 | with torch.random.fork_rng(): 354 | torch.random.manual_seed(random_seed) 355 | gather_index = torch.randint(feature_len, gather_index_shape, device=gathered_input.device, dtype=torch.long) 356 | 357 | with torch.autograd.grad_mode.no_grad(): 358 | input = torch.zeros(batch_size, feature_len, device=gathered_input.device) 359 | 360 | batch_index = torch.arange(batch_size).view(batch_size, 1) 361 | input.index_put_((batch_index, gather_index), gathered_input, accumulate=True) 362 | input = input.view(input_shape) 363 | 364 | return input 365 | 366 | 367 | ################################################################################################# 368 | ############################ END: Common Methods for all layers ################################# 369 | ################################################################################################# 370 | 371 | 372 | class RandReLU(torch.autograd.Function): 373 | @staticmethod 374 | def forward(ctx, input, keep_frac, full_random, random_seed, sparse): 375 | batch_size = input.size()[:-1] 376 | num_activations = input.size()[-1] 377 | 378 | ctx.input_shape = tuple(input.size()) 379 | ctx.num_activations = num_activations 380 | ctx.keep_frac = keep_frac 381 | ctx.full_random = full_random 382 | ctx.random_seed = random_seed 383 | ctx.sparse = sparse 384 | kept_activations = int(num_activations * keep_frac + 0.999) 385 | 386 | # If we don't need to project, just fast-track. 387 | if ctx.keep_frac == 1.0: 388 | ctx.save_for_backward(input) 389 | return F.relu(input) 390 | 391 | if sparse: 392 | dim_reduced_input, _ = input2sparse(input, kept_activations, random_seed=random_seed, full_random=full_random) 393 | else: 394 | dim_reduced_input, _ = input2rp(input, kept_activations, random_seed=random_seed, full_random=full_random) 395 | 396 | # Saved Tensors should be low rank 397 | ctx.save_for_backward(dim_reduced_input) 398 | 399 | with torch.autograd.grad_mode.no_grad(): 400 | return F.relu(input) 401 | 402 | @staticmethod 403 | def backward(ctx, grad_output): 404 | if ctx.keep_frac < 1.0: 405 | (dim_reduced_input,) = ctx.saved_tensors 406 | if ctx.sparse: 407 | input = sparse2input(dim_reduced_input, ctx.input_shape, random_seed=ctx.random_seed, full_random=ctx.full_random) 408 | else: 409 | input = rp2input(dim_reduced_input, ctx.input_shape, random_seed=ctx.random_seed, full_random=ctx.full_random) 410 | else: 411 | (input,) = ctx.saved_tensors 412 | 413 | def cln(t): 414 | if t is None: 415 | return None 416 | ct = t.clone().detach() 417 | ct.requires_grad_(True) 418 | return ct 419 | 420 | cinput = cln(input) 421 | 422 | with torch.autograd.grad_mode.enable_grad(): 423 | output = F.relu(cinput) 424 | input_grad_input = output.grad_fn(grad_output) 425 | 426 | return input_grad_input, None, None, None, None 427 | 428 | 429 | class RandMatMul(torch.autograd.Function): 430 | @staticmethod 431 | def forward(ctx, input, weight, bias, keep_frac, full_random, random_seed, sparse): 432 | # Calculate dimensions according to input and keep_frac 433 | batch_size = input.size()[:-1] 434 | num_activations = input.size()[-1] 435 | 436 | ctx.input_shape = tuple(input.size()) 437 | ctx.num_activations = num_activations 438 | ctx.keep_frac = keep_frac 439 | ctx.full_random = full_random 440 | ctx.random_seed = random_seed 441 | ctx.sparse = sparse 442 | kept_activations = int(num_activations * keep_frac + 0.999) 443 | 444 | # If we don't need to project, just fast-track. 445 | if ctx.keep_frac == 1.0: 446 | ctx.save_for_backward(input, weight, bias) 447 | linear_out = F.linear(input, weight, bias=bias) 448 | return linear_out 449 | 450 | if sparse: 451 | dim_reduced_input, _ = input2sparse(input, kept_activations, random_seed=random_seed, full_random=full_random) 452 | else: 453 | dim_reduced_input, _ = input2rp(input, kept_activations, random_seed=random_seed, full_random=full_random) 454 | 455 | # Saved Tensors should be low rank 456 | ctx.save_for_backward(dim_reduced_input, weight, bias) 457 | 458 | with torch.autograd.grad_mode.no_grad(): 459 | return F.linear(input, weight, bias=bias) 460 | 461 | @staticmethod 462 | def backward(ctx, grad_output): 463 | if ctx.keep_frac < 1.0: 464 | dim_reduced_input, weight, bias = ctx.saved_tensors 465 | if ctx.sparse: 466 | input = sparse2input(dim_reduced_input, ctx.input_shape, random_seed=ctx.random_seed, full_random=ctx.full_random) 467 | else: 468 | input = rp2input(dim_reduced_input, ctx.input_shape, random_seed=ctx.random_seed, full_random=ctx.full_random) 469 | else: 470 | input, weight, bias = ctx.saved_tensors 471 | 472 | def cln(t): 473 | if t is None: 474 | return None 475 | ct = t.clone().detach() 476 | ct.requires_grad_(True) 477 | return ct 478 | 479 | cinput = cln(input) 480 | cweight = cln(weight) 481 | cbias = cln(bias) 482 | 483 | with torch.autograd.grad_mode.enable_grad(): 484 | output = F.linear(cinput, cweight, bias=cbias) 485 | bias_grad_input, input_grad_input, weight_grad_input = output.grad_fn(grad_output) 486 | 487 | # Why are the gradients for F.linear like this??? 488 | return input_grad_input, weight_grad_input.T, bias_grad_input.sum(axis=0), None, None, None, None 489 | 490 | 491 | class RandConv2d(torch.autograd.Function): 492 | @staticmethod 493 | def forward(ctx, input, weight, bias, conv_params, keep_frac, full_random, random_seed, sparse): 494 | ctx.input_shape = tuple(input.size()) 495 | ctx.keep_frac = keep_frac 496 | ctx.conv_params = conv_params 497 | ctx.full_random = full_random 498 | ctx.random_seed = random_seed 499 | ctx.sparse = sparse 500 | 501 | # If we don't need to project, just fast-track. 502 | if keep_frac == 1.0: 503 | ctx.save_for_backward(input, weight, bias) 504 | conv_out = F.conv2d(input, weight, bias=bias, **ctx.conv_params) 505 | return conv_out 506 | 507 | kept_image_size = int(keep_frac * ctx.input_shape[2] * ctx.input_shape[3] + 0.999) 508 | if ctx.sparse: 509 | dim_reduced_input, _ = input2sparse(input, kept_image_size, full_random=full_random, random_seed=random_seed) 510 | else: 511 | dim_reduced_input, _ = input2rp(input, kept_image_size, full_random=full_random, random_seed=random_seed) 512 | 513 | with torch.autograd.grad_mode.no_grad(): 514 | conv_out = F.conv2d(input, weight, bias=bias, **ctx.conv_params) 515 | 516 | # Save appropriate for backward pass. 517 | ctx.save_for_backward(dim_reduced_input, weight, bias) 518 | 519 | with torch.autograd.grad_mode.no_grad(): 520 | return conv_out 521 | 522 | @staticmethod 523 | def backward(ctx, grad_output): 524 | if ctx.keep_frac < 1.0: 525 | dim_reduced_input, weight, bias = ctx.saved_tensors 526 | if ctx.sparse: 527 | input = sparse2input(dim_reduced_input, ctx.input_shape, random_seed=ctx.random_seed, full_random=ctx.full_random) 528 | else: 529 | input = rp2input(dim_reduced_input, ctx.input_shape, random_seed=ctx.random_seed, full_random=ctx.full_random) 530 | else: 531 | input, weight, bias = ctx.saved_tensors 532 | 533 | def cln(t): 534 | if t is None: 535 | return None 536 | ct = t.clone().detach() 537 | ct.requires_grad_(True) 538 | return ct 539 | 540 | cinput = cln(input) 541 | cweight = cln(weight) 542 | cbias = cln(bias) 543 | 544 | with torch.autograd.grad_mode.enable_grad(): 545 | output = F.conv2d(cinput, cweight, bias=cbias, **ctx.conv_params) 546 | 547 | input_grad_output = grad_output 548 | input_grad_input, weight_grad_input, bias_grad_input = output.grad_fn(input_grad_output) 549 | 550 | return input_grad_input, weight_grad_input, bias_grad_input, None, None, None, None, None 551 | -------------------------------------------------------------------------------- /nn_experiments/mnist_launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import random 5 | import signal 6 | import shutil 7 | import argparse 8 | import numpy as np 9 | import collections 10 | 11 | import data as rpdata 12 | import models as rpmodels 13 | import utils as rputils 14 | 15 | from train_and_eval import run_model 16 | 17 | import torch 18 | import torch.utils.tensorboard as tb 19 | 20 | ''' 21 | To override these parameters, specify in command line, such as: 22 | python mnist_launch.py --batch_size=100 23 | 24 | It is important to precede the argument with '--' and include an '=' sign. 25 | Do not use quotes around the value. 26 | ''' 27 | args_template = rputils.ParameterMap( 28 | # experiment name. prepended to pickle name 29 | exp_name='', 30 | 31 | # input batch size for training (default: 64) 32 | batch_size=64, 33 | 34 | # accepted dataset 35 | dataset='mnist', 36 | 37 | # input batch size for testing (default: 1000) 38 | test_batch_size=1000, 39 | 40 | # number of epochs to train (default: 14) 41 | epochs=20, 42 | 43 | # learning rate (default: 1.0) 44 | lr=0.1, 45 | 46 | # Learning rate step gamma (default: 0.7) 47 | gamma=0.2, 48 | 49 | # disables CUDA training 50 | no_cuda=False, 51 | 52 | # random seed. do not set seed if 0. 53 | seed=0, 54 | 55 | # how many batches to wait before logging training status 56 | log_interval=10, 57 | 58 | # the fraction of activations to reduce to with RAD 59 | keep_frac=1.0, 60 | 61 | # hidden layer size for MNISTFCNet 62 | hidden_size=300, 63 | 64 | # number of epochs after which to drop the lr 65 | lr_drop_step=1, 66 | 67 | # l2 weight decay on the parameters 68 | weight_decay=0.0001, 69 | 70 | # For Saving the current Model 71 | save_model=True, 72 | 73 | # CIFAR lr schedule 74 | training_schedule='cifar', 75 | 76 | # Use adam optimizer in cifar 77 | cifar_adam=False, 78 | 79 | # If true, runs baseline without RP 80 | rp_layer='rpconv', 81 | 82 | # Number of epochs to test on the train set. 83 | train_test_interval=5, 84 | 85 | # Data Root 86 | data_root='./data', 87 | 88 | # Experiment Root 89 | exp_root='', 90 | 91 | # If true, randomly splits the training set into train/val (validation 5000). Ignores test set. 92 | validation=False, 93 | 94 | # Whether to sample with replacement or not while training. True gives real SGD. 95 | with_replace=False, 96 | 97 | # Whether to use augmentation in training dataset. 98 | augment=True, 99 | 100 | # Additive noise to add. 101 | rand_noise=0.0, 102 | 103 | # Wide-ResNet width multiplier. 104 | width_multiplier=1, 105 | 106 | # If experiment exists, overwrite 107 | override=False, 108 | 109 | # If true, generate random matrix independent across batch 110 | full_random=False, 111 | 112 | # If > 0, save this many intermediate checkpoints 113 | save_inter=0, 114 | 115 | # Whether to do simple iteration based training instead of epoch based. 116 | simple=False, 117 | 118 | # Following are only used when simple is True. 119 | max_iterations=-1, 120 | simple_log_frequency=-1, 121 | simple_test_eval_frequency=-1, 122 | simple_test_eval_per_train_test=-1, 123 | simple_scheduler_step_frequency=-1, 124 | simple_model_checkpoint_frequency=-1, 125 | 126 | # If true, samples training set with replacement. 127 | bootstrap_train=False, 128 | 129 | # If false, uses random projections. If true, uses sampling. 130 | sparse=False, 131 | 132 | # If true, also uses RAD on ReLU layers. 133 | rand_relu=False, 134 | ) 135 | 136 | 137 | def main(additional_args): 138 | args = args_template.clone() 139 | rputils.override_arguments(args, additional_args) 140 | 141 | # If simple is set, default to these arguments. 142 | # Note that we override again at the end, so specified 143 | # arguments take precedence over defaults. 144 | if args.simple: 145 | args.max_iterations = 20000 146 | args.simple_log_frequency = 10 147 | args.simple_test_eval_frequency = 400 148 | args.simple_test_eval_per_train_test = 10 149 | args.simple_scheduler_step_frequency = 2000 150 | args.simple_model_checkpoint_frequency = 5000 151 | args.save_inter = 1 152 | 153 | args.batch_size = 150 154 | args.gamma = 0.6 155 | args.training_schedule = 'epoch_step' 156 | args.cifar_adam = True 157 | args.lr = 0.002 158 | args.with_replace = True 159 | args.augment = False 160 | args.validation = False 161 | args.lr_drop_step = 1 162 | rputils.override_arguments(args, additional_args) 163 | 164 | if not os.path.exists(args.exp_root): 165 | print('Creating experiment root directory {}'.format(args.exp_root)) 166 | os.mkdir(args.exp_root) 167 | if not args.exp_name: 168 | args.exp_name = 'exp{}'.format(random.randint(100000, 999999)) 169 | 170 | if args.seed == 0: 171 | args.seed = random.randint(10000000, 99999999) 172 | 173 | args.exp_dir = os.path.join(args.exp_root, args.exp_name) 174 | os.environ['LAST_EXPERIMENT_DIR'] = args.exp_dir 175 | if args.override and os.path.exists(args.exp_dir): 176 | print("Overriding existing directory.") 177 | shutil.rmtree(args.exp_dir) 178 | assert not os.path.exists(args.exp_dir) 179 | print("Creating experiment with name {} in {}".format(args.exp_name, args.exp_dir)) 180 | os.mkdir(args.exp_dir) 181 | with open(os.path.join(args.exp_dir, 'experiment_args.txt'), 'w') as f: 182 | f.write(str(args)) 183 | 184 | if args.save_inter > 0: 185 | args.inter_dir = os.path.join(args.exp_dir, 'intermediate_checkpoints') 186 | if not os.path.exists(args.inter_dir): 187 | print('Creating directory for intermediate checkpoints.') 188 | os.mkdir(args.inter_dir) 189 | 190 | args.pickle_dir = os.path.join(args.exp_dir, 'pickles') 191 | if not os.path.exists(args.pickle_dir): 192 | print('Creating pickle directory in experiment directory.') 193 | os.mkdir(args.pickle_dir) 194 | 195 | use_cuda = not args.no_cuda and torch.cuda.is_available() 196 | print('Seed is {}'.format(args.seed)) 197 | torch.manual_seed(args.seed) 198 | torch.backends.cudnn.deterministic = True 199 | torch.backends.cudnn.benchmark = False 200 | np.random.seed(args.seed) 201 | 202 | device = torch.device("cuda" if use_cuda else "cpu") 203 | 204 | rp_args = {} 205 | rp_args['rp_layer'] = args.rp_layer 206 | rp_args['keep_frac'] = args.keep_frac 207 | rp_args['rand_noise'] = args.rand_noise 208 | rp_args['width_multiplier'] = args.width_multiplier 209 | rp_args['full_random'] = args.full_random 210 | rp_args['sparse'] = args.sparse 211 | 212 | models = [ 213 | (rpmodels.MNISTFCNet(hidden_size=args.hidden_size, rp_args=rp_args, rand_relu=args.rand_relu), args.exp_name + "mnistfcnet8", args.exp_name + "mnistfcnet8"), 214 | ] 215 | 216 | # Check if correct dataset is used for each model. 217 | for model, _, _ in models: 218 | if model.kCompatibleDataset != args.dataset: 219 | raise NotImplementedError( 220 | 'Unsupported dataset {} with model {}'.format(args.dataset, model.__class__.__name__) 221 | ) 222 | 223 | for model, pickle_string, model_string in models: 224 | run_model(model, args, device, None, pickle_string, model_string) 225 | 226 | 227 | if __name__ == '__main__': 228 | main(sys.argv[1:]) 229 | 230 | -------------------------------------------------------------------------------- /nn_experiments/mnistffcommands.txt: -------------------------------------------------------------------------------- 1 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0000-project --simple=True --lr=0.000527 --weight_decay=1.009799e-03 --keep_frac=0.1 --bootstrap_train=True 2 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0001-project --simple=True --lr=0.000527 --weight_decay=1.009799e-03 --keep_frac=0.1 --bootstrap_train=True 3 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0002-project --simple=True --lr=0.000527 --weight_decay=1.009799e-03 --keep_frac=0.1 --bootstrap_train=True 4 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0003-project --simple=True --lr=0.000527 --weight_decay=1.009799e-03 --keep_frac=0.1 --bootstrap_train=True 5 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0004-project --simple=True --lr=0.000527 --weight_decay=1.009799e-03 --keep_frac=0.1 --bootstrap_train=True 6 | 7 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0005-smallbatch --simple=True --lr=0.000645 --weight_decay=5.687505e-05 --batch_size=20 --bootstrap_train=True 8 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0006-smallbatch --simple=True --lr=0.000645 --weight_decay=5.687505e-05 --batch_size=20 --bootstrap_train=True 9 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0007-smallbatch --simple=True --lr=0.000645 --weight_decay=5.687505e-05 --batch_size=20 --bootstrap_train=True 10 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0008-smallbatch --simple=True --lr=0.000645 --weight_decay=5.687505e-05 --batch_size=20 --bootstrap_train=True 11 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0009-smallbatch --simple=True --lr=0.000645 --weight_decay=5.687505e-05 --batch_size=20 --bootstrap_train=True 12 | 13 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0010-baseline --simple=True --lr=0.001350 --weight_decay=4.066478e-07 --bootstrap_train=True 14 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0011-baseline --simple=True --lr=0.001350 --weight_decay=4.066478e-07 --bootstrap_train=True 15 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0012-baseline --simple=True --lr=0.001350 --weight_decay=4.066478e-07 --bootstrap_train=True 16 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0013-baseline --simple=True --lr=0.001350 --weight_decay=4.066478e-07 --bootstrap_train=True 17 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0014-baseline --simple=True --lr=0.001350 --weight_decay=4.066478e-07 --bootstrap_train=True 18 | 19 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0015-samesample --simple=True --lr=0.000452 --weight_decay=4.058855e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True 20 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0016-samesample --simple=True --lr=0.000452 --weight_decay=4.058855e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True 21 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0017-samesample --simple=True --lr=0.000452 --weight_decay=4.058855e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True 22 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0018-samesample --simple=True --lr=0.000452 --weight_decay=4.058855e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True 23 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0019-samesample --simple=True --lr=0.000452 --weight_decay=4.058855e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True 24 | 25 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0020-diffsample --simple=True --lr=0.000934 --weight_decay=1.254031e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True --full_random=True 26 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0021-diffsample --simple=True --lr=0.000934 --weight_decay=1.254031e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True --full_random=True 27 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0022-diffsample --simple=True --lr=0.000934 --weight_decay=1.254031e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True --full_random=True 28 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0023-diffsample --simple=True --lr=0.000934 --weight_decay=1.254031e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True --full_random=True 29 | python mnist_launch.py --exp_root=mnistexperiments --exp_name=0024-diffsample --simple=True --lr=0.000934 --weight_decay=1.254031e-06 --keep_frac=0.1 --sparse=True --bootstrap_train=True --full_random=True 30 | -------------------------------------------------------------------------------- /nn_experiments/models.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import collections 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import layers as rpn 9 | 10 | 11 | class MNISTFCNet(torch.nn.Module): 12 | kCompatibleDataset = 'mnist' 13 | def __init__(self, hidden_size, rand_relu=False, rp_args={}): 14 | super(MNISTFCNet, self).__init__() 15 | self.rand_relu = rand_relu 16 | kept_keys = ['keep_frac', 'full_random', 'sparse'] 17 | kept_dict = {key: rp_args[key] for key in kept_keys} 18 | 19 | self.fc1 = rpn.RandLinear(784, hidden_size, **kept_dict) 20 | self.relu1 = rpn.RandReLULayer(**kept_dict) 21 | self.fc2 = rpn.RandLinear(hidden_size, hidden_size, **kept_dict) 22 | self.relu2 = rpn.RandReLULayer(**kept_dict) 23 | self.fc3 = rpn.RandLinear(hidden_size, hidden_size, **kept_dict) 24 | self.relu3 = rpn.RandReLULayer(**kept_dict) 25 | self.fc4 = rpn.RandLinear(hidden_size, 10, **kept_dict) 26 | 27 | def forward(self, x, retain=False, skip_rand=False): 28 | if self.rand_relu: 29 | skip_relu = False 30 | else: 31 | skip_relu = True 32 | 33 | x = nn.Flatten()(x) 34 | x = self.fc1(x, retain=retain, skip_rand=skip_rand) 35 | x = self.relu1(x, skip_rand=skip_relu) 36 | x = self.fc2(x, retain=retain, skip_rand=skip_rand) 37 | x = self.relu2(x, skip_rand=skip_relu) 38 | x = self.fc3(x, retain=retain, skip_rand=skip_rand) 39 | x = self.relu3(x, skip_rand=skip_relu) 40 | x = self.fc4(x, retain=retain, skip_rand=skip_rand) 41 | output = F.log_softmax(x, dim=1) 42 | return output 43 | 44 | 45 | class CIFARConvNet(torch.nn.Module): 46 | kCompatibleDataset = 'cifar10' 47 | 48 | def __init__(self, rand_relu=False, rp_args={}): 49 | super(CIFARConvNet, self).__init__() 50 | self.rand_relu = rand_relu 51 | kept_keys = ['keep_frac', 'full_random', 'sparse'] 52 | kept_dict = {key: rp_args[key] for key in kept_keys} 53 | 54 | self.conv1 = rpn.RandConv2dLayer(3, 16, 5, padding=2, **kept_dict) 55 | self.relu1 = rpn.RandReLULayer(**kept_dict) 56 | self.conv2 = rpn.RandConv2dLayer(16, 32, 5, padding=2, **kept_dict) 57 | self.relu2 = rpn.RandReLULayer(**kept_dict) 58 | self.conv3 = rpn.RandConv2dLayer(32, 32, 5, padding=2, **kept_dict) 59 | self.relu3 = rpn.RandReLULayer(**kept_dict) 60 | self.conv4 = rpn.RandConv2dLayer(32, 32, 5, padding=2, **kept_dict) 61 | self.relu4 = rpn.RandReLULayer(**kept_dict) 62 | 63 | self.fc5 = rpn.RandLinear(2048, 10, **kept_dict) 64 | 65 | def forward(self, x, retain=False, skip_rand=False): 66 | if self.rand_relu: 67 | skip_relu = False 68 | else: 69 | skip_relu = True 70 | 71 | x = self.conv1(x, retain=retain, skip_rand=skip_rand) 72 | x = self.relu1(x, skip_rand=skip_relu) 73 | 74 | x = self.conv2(x, retain=retain, skip_rand=skip_rand) 75 | x = self.relu2(x, skip_rand=skip_relu) 76 | 77 | x = F.avg_pool2d(x, 2) 78 | 79 | x = self.conv3(x, retain=retain, skip_rand=skip_rand) 80 | x = self.relu3(x, skip_rand=skip_relu) 81 | 82 | x = self.conv4(x, retain=retain, skip_rand=skip_rand) 83 | x = self.relu4(x, skip_rand=skip_relu) 84 | 85 | x = F.avg_pool2d(x, 2) 86 | 87 | x = torch.flatten(x, start_dim=1) 88 | x = self.fc5(x, retain=retain, skip_rand=skip_rand) 89 | output = F.log_softmax(x, dim=1) 90 | return output 91 | -------------------------------------------------------------------------------- /nn_experiments/plot_cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import pickle 4 | import numpy as np 5 | import pandas as pd 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | from IPython.display import display, HTML 9 | 10 | EXP_ROOT = './cifarexperiments' 11 | 12 | params = { 13 | 'axes.labelsize': 12, 14 | 'font.size': 12, 15 | 'legend.fontsize': 12, 16 | 'xtick.labelsize': 12, 17 | 'ytick.labelsize': 12, 18 | 'text.usetex': True, 19 | 'figure.figsize': [6, 4], 20 | 'text.latex.preamble': r'\usepackage{amsmath} \usepackage{amssymb}', 21 | } 22 | plt.rcParams.update(params) 23 | 24 | def plot_everything(workers): 25 | worker_dirs = [os.path.join(EXP_ROOT, f[0]) for f in workers] 26 | worker_names = [f[1] for f in workers] 27 | worker_colors = [f[2] for f in workers] 28 | worker_markers = [f[3] for f in workers] 29 | 30 | fig = plt.figure(figsize=(10,40)) 31 | plt.axes(frameon=0) # turn off frames 32 | plt.grid(axis='y', color='0.9', linestyle='-', linewidth=1) 33 | 34 | ax = plt.subplot(511) 35 | plt.title('Training Loss vs Iterations for SmallConvNet on CIFAR-10') 36 | ax.set_yscale('log') 37 | 38 | ax2 = plt.subplot(512) 39 | plt.title('Training Accuracy vs Iterations for SmallConvNet on CIFAR-10') 40 | ax2.set_ylim((90, 101)) 41 | 42 | ax3 = plt.subplot(513) 43 | plt.title('Test Loss vs Iterations for SmallConvNet on CIFAR-10') 44 | ax3.set_yscale('log') 45 | 46 | ax4 = plt.subplot(514) 47 | plt.title('Test Accuracy vs Iterations for SmallConvNet on CIFAR-10') 48 | ax4.set_ylim((65, 75)) 49 | ax4.grid(True) 50 | 51 | ax5 = plt.subplot(515) 52 | plt.title('Training time vs Iterations for SmallConvNet on CIFAR-10') 53 | 54 | final_results = [] 55 | 56 | labeled = [] 57 | for worker, worker_name, color, marker in zip(worker_dirs, worker_names, worker_colors, worker_markers): 58 | one_pickle_dir = os.path.join(worker, 'pickles') 59 | one_pickle = os.path.join(one_pickle_dir, os.listdir(one_pickle_dir)[0]) 60 | with open(one_pickle, 'rb') as f: 61 | struct = pickle.load(f) 62 | 63 | train_curve = [] 64 | test_curve = [] 65 | train_test_curve = [] 66 | for (iteration, s) in struct: 67 | if 'train' in s: 68 | train_curve.append((iteration, s['train'])) 69 | if 'train_test' in s: 70 | train_test_curve.append((iteration, s['train_test'])) 71 | if 'test' in s: 72 | test_curve.append((iteration, s['test'])) 73 | train_test_iterations = [t[0] for t in train_test_curve if t[0] != 'final'] 74 | train_iterations = [t[0] for t in train_curve if t[0] != 'final'] 75 | train_test_loss = [t[1]['loss'] for t in train_test_curve if t[0] != 'final'] 76 | train_test_accuracy = [t[1]['accuracy'] for t in train_test_curve if t[0] != 'final'] 77 | train_time = [t[1]['time'] for t in train_curve if t[0] != 'final'] 78 | if worker_name in labeled: 79 | worker_name = None 80 | else: 81 | labeled.append(worker_name) 82 | 83 | marker_size = 10 84 | 85 | ax.plot(train_test_iterations, train_test_loss, marker=marker, label=worker_name, c=color, ms=marker_size) 86 | ax2.plot(train_test_iterations, train_test_accuracy, marker=marker, label=worker_name, c=color, ms=marker_size) 87 | 88 | test_iterations = [t[0] for t in test_curve if t[0] != 'final'] 89 | test_loss = [t[1]['loss'] for t in test_curve if t[0] != 'final'] 90 | test_accuracy = [t[1]['accuracy'] for t in test_curve if t[0] != 'final'] 91 | ax3.plot(test_iterations, test_loss, marker=marker, label=worker_name, c=color, ms=marker_size, markevery=10) 92 | 93 | ax4.plot(test_iterations, test_accuracy, marker=marker, label=worker_name, c=color, ms=marker_size, markevery=10) 94 | 95 | ax5.plot(train_iterations, train_time, marker=marker, label=worker_name, c=color, ms=marker_size, markevery=10) 96 | 97 | 98 | final_results.append({ 99 | 'name': worker_name, 100 | 'train_loss': train_test_loss[-1], 101 | 'train_accuracy': train_test_accuracy[-1], 102 | 'test_loss': test_loss[-1], 103 | 'test_accuracy': test_accuracy[-1], 104 | }) 105 | 106 | display(pd.DataFrame(final_results)) 107 | ax.legend() 108 | ax2.legend() 109 | ax3.legend() 110 | ax4.legend() 111 | ax5.legend() 112 | 113 | fig.savefig('cifar_all_curves_full.pdf') 114 | 115 | workers = [ 116 | ('0000-project', 'Project', 'b', 'h'), 117 | ('0001-project', 'Project', 'b', 'h'), 118 | ('0002-project', 'Project', 'b', 'h'), 119 | ('0003-project', 'Project', 'b', 'h'), 120 | ('0004-project', 'Project', 'b', 'h'), 121 | ('0005-smallbatch', 'Reduced batch', 'r', '^'), 122 | ('0006-smallbatch', 'Reduced batch', 'r', '^'), 123 | ('0007-smallbatch', 'Reduced batch', 'r', '^'), 124 | ('0008-smallbatch', 'Reduced batch', 'r', '^'), 125 | ('0009-smallbatch', 'Reduced batch', 'r', '^'), 126 | ('0010-baseline', 'Baseline', 'pink', 'o'), 127 | ('0011-baseline', 'Baseline', 'pink', 'o'), 128 | ('0012-baseline', 'Baseline', 'pink', 'o'), 129 | ('0013-baseline', 'Baseline', 'pink', 'o'), 130 | ('0014-baseline', 'Baseline', 'pink', 'o'), 131 | ('0015-samesample', 'Same Sample', 'g', 'x'), 132 | ('0016-samesample', 'Same Sample', 'g', 'x'), 133 | ('0017-samesample', 'Same Sample', 'g', 'x'), 134 | ('0018-samesample', 'Same Sample', 'g', 'p'), 135 | ('0019-samesample', 'Same Sample', 'g', 'p'), 136 | ('0020-diffsample', 'Different Sample', 'black', '*'), 137 | ('0021-diffsample', 'Different Sample', 'black', '*'), 138 | ('0022-diffsample', 'Different Sample', 'black', '*'), 139 | ('0023-diffsample', 'Different Sample', 'black', '*'), 140 | ('0024-diffsample', 'Different Sample', 'black', '*'), 141 | ] 142 | 143 | plot_everything(workers) -------------------------------------------------------------------------------- /nn_experiments/plot_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import pickle 4 | import numpy as np 5 | import pandas as pd 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | from IPython.display import display, HTML 9 | 10 | EXP_ROOT = './mnistexperiments' 11 | 12 | params = { 13 | 'axes.labelsize': 12, 14 | 'font.size': 12, 15 | 'legend.fontsize': 12, 16 | 'xtick.labelsize': 12, 17 | 'ytick.labelsize': 12, 18 | 'text.usetex': True, 19 | 'figure.figsize': [6, 4], 20 | 'text.latex.preamble': r'\usepackage{amsmath} \usepackage{amssymb}', 21 | } 22 | plt.rcParams.update(params) 23 | 24 | def plot_everything(workers): 25 | worker_dirs = [os.path.join(EXP_ROOT, f[0]) for f in workers] 26 | worker_names = [f[1] for f in workers] 27 | worker_colors = [f[2] for f in workers] 28 | worker_markers = [f[3] for f in workers] 29 | 30 | fig = plt.figure(figsize=(10,40)) 31 | plt.axes(frameon=0) # turn off frames 32 | plt.grid(axis='y', color='0.9', linestyle='-', linewidth=1) 33 | 34 | ax = plt.subplot(511) 35 | plt.title('Training Loss vs Iterations for SmallFCNet on MNIST') 36 | ax.set_yscale('log') 37 | 38 | ax2 = plt.subplot(512) 39 | plt.title('Training Accuracy vs Iterations for SmallFCNet on MNIST') 40 | ax2.set_ylim((98, 100)) 41 | 42 | ax3 = plt.subplot(513) 43 | plt.title('Test Loss vs Iterations for SmallFCNet on MNIST') 44 | ax3.set_yscale('log') 45 | 46 | ax4 = plt.subplot(514) 47 | plt.title('Test Accuracy vs Iterations for SmallFCNet on MNIST') 48 | ax4.set_ylim((95, 100)) 49 | ax4.grid(True) 50 | 51 | ax5 = plt.subplot(515) 52 | plt.title('Training time vs Iterations for SmallFCNet on MNIST') 53 | 54 | final_results = [] 55 | 56 | labeled = [] 57 | for worker, worker_name, color, marker in zip(worker_dirs, worker_names, worker_colors, worker_markers): 58 | one_pickle_dir = os.path.join(worker, 'pickles') 59 | one_pickle = os.path.join(one_pickle_dir, os.listdir(one_pickle_dir)[0]) 60 | with open(one_pickle, 'rb') as f: 61 | struct = pickle.load(f) 62 | 63 | train_curve = [] 64 | test_curve = [] 65 | train_test_curve = [] 66 | for (iteration, s) in struct: 67 | if 'train' in s: 68 | train_curve.append((iteration, s['train'])) 69 | if 'train_test' in s: 70 | train_test_curve.append((iteration, s['train_test'])) 71 | if 'test' in s: 72 | test_curve.append((iteration, s['test'])) 73 | train_test_iterations = [t[0] for t in train_test_curve if t[0] != 'final'] 74 | train_iterations = [t[0] for t in train_curve if t[0] != 'final'] 75 | train_test_loss = [t[1]['loss'] for t in train_test_curve if t[0] != 'final'] 76 | train_test_accuracy = [t[1]['accuracy'] for t in train_test_curve if t[0] != 'final'] 77 | train_time = [t[1]['time'] for t in train_curve if t[0] != 'final'] 78 | if worker_name in labeled: 79 | worker_name = None 80 | else: 81 | labeled.append(worker_name) 82 | 83 | marker_size = 10 84 | 85 | ax.plot(train_test_iterations, train_test_loss, marker=marker, label=worker_name, c=color, ms=marker_size) 86 | ax2.plot(train_test_iterations, train_test_accuracy, marker=marker, label=worker_name, c=color, ms=marker_size) 87 | 88 | test_iterations = [t[0] for t in test_curve if t[0] != 'final'] 89 | test_loss = [t[1]['loss'] for t in test_curve if t[0] != 'final'] 90 | test_accuracy = [t[1]['accuracy'] for t in test_curve if t[0] != 'final'] 91 | ax3.plot(test_iterations, test_loss, marker=marker, label=worker_name, c=color, ms=marker_size, markevery=10) 92 | 93 | ax4.plot(test_iterations, test_accuracy, marker=marker, label=worker_name, c=color, ms=marker_size, markevery=10) 94 | 95 | ax5.plot(train_iterations, train_time, marker=marker, label=worker_name, c=color, ms=marker_size, markevery=10) 96 | 97 | 98 | final_results.append({ 99 | 'name': worker_name, 100 | 'train_loss': train_test_loss[-1], 101 | 'train_accuracy': train_test_accuracy[-1], 102 | 'test_loss': test_loss[-1], 103 | 'test_accuracy': test_accuracy[-1], 104 | }) 105 | 106 | display(pd.DataFrame(final_results)) 107 | ax.legend() 108 | ax2.legend() 109 | ax3.legend() 110 | ax4.legend() 111 | ax5.legend() 112 | 113 | fig.savefig('mnist_all_curves_full.pdf') 114 | 115 | workers = [ 116 | ('0000-project', 'Project', 'b', 'h'), 117 | ('0001-project', 'Project', 'b', 'h'), 118 | ('0002-project', 'Project', 'b', 'h'), 119 | ('0003-project', 'Project', 'b', 'h'), 120 | ('0004-project', 'Project', 'b', 'h'), 121 | ('0005-smallbatch', 'Reduced batch', 'r', '^'), 122 | ('0006-smallbatch', 'Reduced batch', 'r', '^'), 123 | ('0007-smallbatch', 'Reduced batch', 'r', '^'), 124 | ('0008-smallbatch', 'Reduced batch', 'r', '^'), 125 | ('0009-smallbatch', 'Reduced batch', 'r', '^'), 126 | ('0010-baseline', 'Baseline', 'pink', 'o'), 127 | ('0011-baseline', 'Baseline', 'pink', 'o'), 128 | ('0012-baseline', 'Baseline', 'pink', 'o'), 129 | ('0013-baseline', 'Baseline', 'pink', 'o'), 130 | ('0014-baseline', 'Baseline', 'pink', 'o'), 131 | ('0015-samesample', 'Same Sample', 'g', 'x'), 132 | ('0016-samesample', 'Same Sample', 'g', 'x'), 133 | ('0017-samesample', 'Same Sample', 'g', 'x'), 134 | ('0018-samesample', 'Same Sample', 'g', 'x'), 135 | ('0019-samesample', 'Same Sample', 'g', 'x'), 136 | ('0020-diffsample', 'Different Sample', 'black', '*'), 137 | ('0021-diffsample', 'Different Sample', 'black', '*'), 138 | ('0022-diffsample', 'Different Sample', 'black', '*'), 139 | ('0023-diffsample', 'Different Sample', 'black', '*'), 140 | ('0024-diffsample', 'Different Sample', 'black', '*'), 141 | ] 142 | 143 | plot_everything(workers) -------------------------------------------------------------------------------- /nn_experiments/rnn_mnist_launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import shutil 5 | from operator import itemgetter 6 | 7 | import numpy as np 8 | 9 | import rnn_models as rnn_rpmodels 10 | import utils as rputils 11 | 12 | from rnn_train_and_eval import run_rnn_model 13 | 14 | import torch 15 | import torch.utils.tensorboard as tb 16 | 17 | ''' 18 | To override these parameters, specify in command line, such as: 19 | python rnn_mnist_launch.py --batch_size=100 20 | 21 | It is important to precede the argument with '--' and include an '=' sign. 22 | Do not use quotes around the value. 23 | ''' 24 | args_template = rputils.ParameterMap( 25 | # experiment name. prepended to pickle name 26 | exp_name='', 27 | 28 | # input batch size for training (default: 64) 29 | batch_size=128, 30 | 31 | # accepted dataset 32 | dataset='mnist', 33 | 34 | # input batch size for testing (default: 1000) 35 | test_batch_size=1000, 36 | 37 | # number of epochs to train (default: 14) 38 | epochs=200, 39 | 40 | # learning rate (default: 1.0) 41 | lr=0.1, 42 | 43 | # disables CUDA training 44 | no_cuda=False, 45 | 46 | # random seed. do not set seed if 0. 47 | seed=0, 48 | 49 | # how many batches to wait before logging training status 50 | log_interval=10, 51 | 52 | # keep_prob for the RandLinear layer 53 | keep_frac=1.0, 54 | 55 | # hidden layer size 56 | hidden_size=300, 57 | 58 | # l2 weight decay on the parameters 59 | weight_decay=5e-4, 60 | 61 | # For Saving the current Model 62 | save_model=True, 63 | 64 | # If true, runs baseline without RP 65 | rp_layer='rpconv', 66 | 67 | # Number of epochs to test on the train set. 68 | train_test_interval=5, 69 | 70 | # Data Root 71 | data_root='./data', 72 | 73 | # Experiment Root 74 | exp_root='', 75 | 76 | # If true, randomly splits the training set into train/val (validation 5000). Ignores test set. 77 | validation=False, 78 | 79 | # Whether to sample with replacement or not while training. True gives real SGD. 80 | with_replace=False, 81 | 82 | # Whether to use augmentation in training dataset. 83 | augment=True, 84 | 85 | # Additive noise to add. 86 | rand_noise=0.0, 87 | 88 | # If experiment exists, overwrite 89 | override=False, 90 | 91 | # If true, generate random matrix independent across batch 92 | full_random=False, 93 | 94 | # If > 0, save this many intermediate checkpoints (Not used for simple training, only needed to create intermediates 95 | # directory) 96 | save_inter=0, 97 | 98 | # Whether to do simple iteration based training instead of epoch based. 99 | simple=False, 100 | 101 | # Following are only used when simple is True. 102 | max_iterations=-1, 103 | simple_log_frequency=-1, 104 | simple_test_eval_frequency=-1, 105 | simple_test_eval_per_train_test=-1, 106 | simple_model_checkpoint_frequency=-1, 107 | 108 | # If true, samples training set with replacement. 109 | bootstrap_train=False, 110 | 111 | sparse=False, 112 | 113 | # If true, use a TensorBoard writer 114 | use_writer=False, 115 | 116 | # Gradient norm clipping parameter 117 | clip=0.25, 118 | 119 | # No. of workers for data loader 120 | num_workers=2, 121 | 122 | # Whether to resume training from the most recent checkpoint (if it exists) 123 | resume=False, 124 | ) 125 | 126 | 127 | 128 | def main(additional_args): 129 | #signal.signal(signal.SIGINT, receiveSignal) 130 | args = args_template.clone() 131 | rputils.override_arguments(args, additional_args) 132 | 133 | # If simple is set, default to these arguments. 134 | # Note that we override again at the end, so specified 135 | # arguments take precedence over defaults. 136 | if args.simple: 137 | args.max_iterations = 100000 138 | args.simple_log_frequency = 10 139 | args.simple_test_eval_frequency = 400 140 | args.simple_test_eval_per_train_test = 10 141 | args.simple_model_checkpoint_frequency = 10000 142 | args.save_inter = 1 143 | 144 | args.batch_size = 150 145 | args.lr = 0.002 146 | args.with_replace = True 147 | args.augment = False 148 | args.validation = False 149 | rputils.override_arguments(args, additional_args) 150 | 151 | if not os.path.exists(args.exp_root): 152 | print('Creating experiment root directory {}'.format(args.exp_root)) 153 | os.mkdir(args.exp_root) 154 | if not args.exp_name: 155 | args.exp_name = 'exp{}'.format(random.randint(100000, 999999)) 156 | 157 | if args.seed == 0: 158 | args.seed = random.randint(10000000, 99999999) 159 | 160 | args.exp_dir = os.path.join(args.exp_root, args.exp_name) 161 | os.environ['LAST_EXPERIMENT_DIR'] = args.exp_dir 162 | if args.override and os.path.exists(args.exp_dir): 163 | print("Overriding existing directory.") 164 | shutil.rmtree(args.exp_dir) 165 | if not args.resume or not os.path.exists(args.exp_dir): 166 | assert not os.path.exists(args.exp_dir) 167 | print("Creating experiment with name {} in {}".format(args.exp_name, args.exp_dir)) 168 | os.mkdir(args.exp_dir) 169 | with open(os.path.join(args.exp_dir, 'experiment_args.txt'), 'w') as f: 170 | f.write(str(args)) 171 | assert os.path.exists(args.exp_dir) 172 | 173 | 174 | if args.save_inter > 0: 175 | args.inter_dir = os.path.join(args.exp_dir, 'intermediate_checkpoints') 176 | if not os.path.exists(args.inter_dir): 177 | print('Creating directory for intermediate checkpoints.') 178 | os.mkdir(args.inter_dir) 179 | 180 | args.pickle_dir = os.path.join(args.exp_dir, 'pickles') 181 | if not os.path.exists(args.pickle_dir): 182 | print('Creating pickle directory in experiment directory.') 183 | os.mkdir(args.pickle_dir) 184 | 185 | use_cuda = not args.no_cuda and torch.cuda.is_available() 186 | print('Seed is {}'.format(args.seed)) 187 | torch.manual_seed(args.seed) 188 | torch.backends.cudnn.deterministic = True 189 | torch.backends.cudnn.benchmark = False 190 | np.random.seed(args.seed) 191 | 192 | device = torch.device("cuda" if use_cuda else "cpu") 193 | 194 | # Tensorboard SummaryWriter 195 | writer = tb.SummaryWriter(log_dir=args.exp_dir) if args.use_writer else None 196 | 197 | rp_args = {} 198 | rp_args['rp_layer'] = args.rp_layer 199 | rp_args['keep_frac'] = args.keep_frac 200 | rp_args['rand_noise'] = args.rand_noise 201 | rp_args['full_random'] = args.full_random 202 | rp_args['sparse'] = args.sparse 203 | 204 | model = rnn_rpmodels.MNISTIRNN(hidden_size=args.hidden_size, rp_args=rp_args) 205 | pickle_string, model_string = args.exp_name + "mnistirnn", args.exp_name + "mnistirnn" 206 | 207 | optimizer_state_dict, iteration = None, 0 208 | if args.resume and args.save_inter > 0: 209 | # Get most recent checkpoint if it exists 210 | ckpt_filenames = os.listdir(args.inter_dir) 211 | if len(ckpt_filenames) > 0: 212 | most_recent_ckpt = sorted(ckpt_filenames, key=lambda f: int(f.split('_')[1].split('.')[0]))[-1] 213 | print("Using checkpoint {}".format(most_recent_ckpt)) 214 | iteration, model_state_dict, optimizer_state_dict = \ 215 | itemgetter('iteration', 'model_state_dict', 'optimizer_state_dict')( 216 | torch.load(os.path.join(args.inter_dir, most_recent_ckpt), map_location=device) 217 | ) 218 | print("Reloading model state.") 219 | model.load_state_dict(model_state_dict) 220 | 221 | 222 | # Check if correct dataset is used for model. 223 | if model.kCompatibleDataset != args.dataset: 224 | raise NotImplementedError( 225 | 'Unsupported dataset {} with model {}'.format(args.dataset, model.__class__.__name__) 226 | ) 227 | 228 | # Run training 229 | run_rnn_model( 230 | model, args, device, writer, pickle_string, model_string, num_workers=args.num_workers, iteration=iteration, 231 | optimizer_state_dict=optimizer_state_dict 232 | ) 233 | 234 | 235 | if __name__ == '__main__': 236 | main(sys.argv[1:]) 237 | 238 | -------------------------------------------------------------------------------- /nn_experiments/rnn_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from layers import RandLinear 6 | 7 | 8 | def irnn_initializer(m: nn.Module): 9 | nn.init.eye_(m.weight.data) 10 | nn.init.zeros_(m.bias.data) 11 | 12 | def gaussian_initializer(m: nn.Module): 13 | nn.init.normal_(m.weight.data, 0, 0.001) 14 | nn.init.zeros_(m.bias.data) 15 | 16 | 17 | class MNISTIRNN(nn.Module): 18 | """ 19 | IRNN where the only output that is returned is the logits for a distribution over classes. 20 | """ 21 | kCompatibleDataset = 'mnist' 22 | 23 | def __init__(self, hidden_size, num_classes=10, rp_args=None): 24 | super(MNISTIRNN, self).__init__() 25 | rp_args = {} if rp_args is None else rp_args 26 | kept_keys = ['keep_frac', 'full_random', 'sparse'] 27 | kept_dict = {key: rp_args[key] for key in kept_keys} 28 | 29 | self._hidden_size = hidden_size 30 | 31 | if kept_dict['keep_frac'] == 1.: 32 | self.i2h = nn.Linear(1, hidden_size) 33 | self.h2h = nn.Linear(hidden_size, hidden_size) 34 | self.h2o = nn.Linear(hidden_size, num_classes) 35 | print("Using nn.Linear layers") 36 | else: 37 | self.i2h = RandLinear(1, hidden_size, **kept_dict) 38 | self.h2h = RandLinear(hidden_size, hidden_size, **kept_dict) 39 | self.h2o = RandLinear(hidden_size, num_classes, **kept_dict) 40 | print("Using RandLinear layers") 41 | 42 | self.h2h.apply(irnn_initializer) 43 | self.i2h.apply(gaussian_initializer) 44 | self.h2o.apply(gaussian_initializer) 45 | 46 | def forward(self, inputs): 47 | batch_size = inputs.shape[0] 48 | hidden = self._init_hidden(batch_size) 49 | inputs = inputs.view(batch_size, -1, 1).permute(1, 0, 2) # (781, batch_size, 1) 50 | for t in range(inputs.shape[0]): 51 | hidden_part = self.h2h(hidden) 52 | input_part = self.i2h(inputs[t]) 53 | hidden = F.relu(hidden_part + input_part) 54 | output_logits = F.log_softmax(self.h2o(hidden), dim=1) 55 | return output_logits 56 | 57 | def _init_hidden(self, batch_size): 58 | weight = next(self.parameters()).data 59 | return weight.new(batch_size, self._hidden_size).zero_() 60 | 61 | 62 | class MNISTIRNNPyTorch(nn.RNN, nn.Module): 63 | """ 64 | Implementation of MNISTIRNN that uses the torch.nn.RNN implementation. Can't be used for Rand experiments. 65 | """ 66 | 67 | kCompatibleDataset = 'mnist' 68 | 69 | def __init__(self, hidden_size, *args, **kwargs): 70 | super().__init__(input_size=1, hidden_size=hidden_size, num_layers=1, nonlinearity='relu', *args, **kwargs) 71 | 72 | self._hidden_size = hidden_size 73 | 74 | self.h2o = nn.Linear(hidden_size, 10) 75 | self.h2o.apply(gaussian_initializer) 76 | 77 | # Initialise parameters. We assume single layer for now 78 | nn.init.normal_(self.weight_ih_l0, 0., 0.001) 79 | nn.init.zeros_(self.bias_ih_l0) 80 | 81 | nn.init.eye_(self.weight_hh_l0) 82 | nn.init.zeros_(self.bias_hh_l0) 83 | 84 | def forward(self, inputs, hx=None): 85 | batch_size = inputs.shape[0] 86 | hidden = self._init_hidden(batch_size) 87 | inputs = inputs.view(batch_size, -1, 1).permute(1, 0, 2) # (781, batch_size, 1) 88 | outputs, hidden = super().forward(inputs, hidden) 89 | output_logits = F.log_softmax(self.h2o(outputs[-1]), dim=1) 90 | return output_logits 91 | 92 | def _init_hidden(self, batch_size): 93 | weight = next(self.parameters()).data 94 | return weight.new(1, batch_size, self._hidden_size).zero_() 95 | 96 | 97 | class IRNN(nn.Module): 98 | """ 99 | Implementation of IRNN from "A Simple Way to Initialize Recurrent Networks of Rectified Linear Units" by Le. et al, 100 | 2015. 101 | 102 | It works just like a basic RNN but with the weights initialised to identity and biases to zero in addition to using 103 | a ReLU activation function instead of the typical Tanh. 104 | """ 105 | 106 | def __init__(self, input_size, hidden_size): 107 | super(IRNN, self).__init__() 108 | 109 | self.i2h = nn.Linear(input_size, hidden_size) 110 | self.h2h = nn.Linear(hidden_size, hidden_size) 111 | self.h2h.apply(irnn_initializer) 112 | 113 | #self.apply(irnn_initializer) 114 | 115 | def forward(self, inputs, hidden): 116 | hidden = hidden[0] 117 | outputs = [] 118 | for i in range(inputs.shape[0]): 119 | #combined = torch.cat((inputs[i], hidden), dim=1) 120 | hidden_part = self.h2h(hidden) 121 | input_part = self.i2h(inputs[i]) 122 | hidden = F.relu(hidden_part + input_part) 123 | # output = self.i2o(combined) 124 | output = hidden 125 | outputs.append(output) 126 | outputs = torch.stack(outputs) 127 | return outputs, hidden.view(1, hidden.size(0), hidden.size(1)) 128 | 129 | 130 | class RandIRNN(nn.Module): 131 | """ 132 | Implementation of random projection version of IRNN 133 | """ 134 | 135 | def __init__(self, input_size, hidden_size, keep_frac=0.9, full_random=False, sparse=False): 136 | super(RandIRNN, self).__init__() 137 | 138 | self.i2h = RandLinear( 139 | input_size, hidden_size, keep_frac=keep_frac, full_random=full_random, sparse=sparse 140 | ) 141 | self.h2h = RandLinear( 142 | hidden_size, hidden_size, keep_frac=keep_frac, full_random=full_random, sparse=sparse 143 | ) 144 | self.h2h.apply(irnn_initializer) 145 | 146 | #self.apply(irnn_initializer) 147 | 148 | def forward(self, inputs, hidden): 149 | 150 | hidden = hidden[0] 151 | outputs = [] 152 | for i in range(inputs.shape[0]): 153 | #combined = torch.cat((inputs[i], hidden), dim=1) 154 | hidden_part = self.h2h(hidden) 155 | input_part = self.i2h(inputs[i]) 156 | hidden = F.relu(hidden_part + input_part) 157 | # output = self.i2o(combined) 158 | output = hidden 159 | outputs.append(output) 160 | outputs = torch.stack(outputs) 161 | return outputs, hidden.view(1, hidden.size(0), hidden.size(1)) 162 | 163 | -------------------------------------------------------------------------------- /nn_experiments/rnn_plottingscripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | from operator import itemgetter 3 | import re 4 | import pickle 5 | import matplotlib.pyplot as plt 6 | 7 | PKL_FILE_REGEX = re.compile(".+mnistirnn\.pkl") 8 | 9 | def extract_plot_data(fname): 10 | with open(fname, 'rb') as f: 11 | all_ckpts = pickle.load(f) 12 | 13 | iterations, train_losses, test_losses, test_accuracies = [], [], [], [] 14 | train_test_iterations, train_test_losses, train_test_accuracies = [], [], [] 15 | for ckpt in all_ckpts: 16 | iteration = ckpt[0] 17 | if iteration == 'final': # Skip for now 18 | continue 19 | train_ckpt, test_ckpt = itemgetter('train', 'test')(ckpt[1]) 20 | 21 | iterations.append(iteration) 22 | train_losses.append(train_ckpt['loss']) 23 | test_losses.append(test_ckpt['loss']) 24 | test_accuracies.append(test_ckpt['accuracy']) 25 | 26 | if 'train_test' in ckpt[1]: 27 | train_test_ckpt = ckpt[1]['train_test'] 28 | train_test_iterations.append(iteration) 29 | train_test_losses.append(train_test_ckpt['loss']) 30 | train_test_accuracies.append(train_test_ckpt['accuracy']) 31 | 32 | return iterations, train_losses, test_losses, test_accuracies, train_test_iterations, train_test_losses, \ 33 | train_test_accuracies 34 | 35 | 36 | 37 | params = { 38 | 'axes.labelsize': 12, 39 | 'font.size': 12, 40 | 'legend.fontsize': 12, 41 | 'xtick.labelsize': 12, 42 | 'ytick.labelsize': 12, 43 | 'text.usetex': True, 44 | 'figure.figsize': [6, 4], 45 | 'text.latex.preamble': r'\usepackage{amsmath} \usepackage{amssymb}', 46 | } 47 | plt.rcParams.update(params) 48 | 49 | fig = plt.figure(figsize=(10,40)) 50 | plt.axes(frameon=0) # turn off frames 51 | plt.grid(axis='y', color='0.9', linestyle='-', linewidth=1) 52 | 53 | ax = plt.subplot(511) 54 | plt.title('Training Loss vs Iterations for IRNN on Sequential-MNIST') 55 | ax.set_yscale('log') 56 | 57 | marker_size = 10 58 | 59 | EXP_BASE_DIR = '../rnn_stuff/mnist_rnn_exps' # Replace with path to folder containing experiments 60 | exp_folders = ['irnn_baseline', 'irnn_small_batch', 'rand_irnn_sparse', 'rand_irnn_sparse_full', 'rand_irnn_rp', 'rand_irnn_rp_full'] 61 | colors = ['pink', 'r', 'g', 'k', 'b', 'y'] 62 | markers = ['o', '^', 'x', '*', 'o', 'd'] 63 | labels = ['Baseline', 'Reduced batch', 'Same Sample', 'Different Sample', 'Project', 'Different Project'] 64 | 65 | for i, exp_folder in enumerate(exp_folders): 66 | exp_folder = os.path.join(EXP_BASE_DIR, exp_folder) 67 | for j, seed in enumerate(os.listdir(exp_folder)): 68 | data_dir = os.path.join(exp_folder, seed, "pickles") 69 | data_filename = os.path.join(data_dir, list(filter(PKL_FILE_REGEX.match, os.listdir(data_dir)))[0]) 70 | 71 | _, _, _, _, train_test_iterations, train_test_losses, _ = extract_plot_data(data_filename) 72 | 73 | label = labels[i] if j == 0 else None 74 | 75 | iterations = train_test_iterations 76 | loss = train_test_losses 77 | marker = markers[i] 78 | exp_name = label 79 | color = colors[i] 80 | 81 | ax.plot(iterations, loss, marker=marker, label=exp_name, c=color, ms=marker_size, markevery=5) 82 | 83 | 84 | ax.legend() 85 | plt.show() 86 | #fig.savefig('../Plots/IRNN_PaperPlots_With_RP.pdf') 87 | 88 | -------------------------------------------------------------------------------- /nn_experiments/rnn_plottingscripts_appendix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import pickle 4 | from operator import itemgetter 5 | import matplotlib.pyplot as plt 6 | 7 | PKL_FILE_REGEX = re.compile(".+mnistirnn\.pkl") 8 | 9 | def extract_plot_data(fname): 10 | with open(fname, 'rb') as f: 11 | all_ckpts = pickle.load(f) 12 | 13 | iterations, train_losses, test_losses, test_accuracies = [], [], [], [] 14 | train_test_iterations, train_test_losses, train_test_accuracies = [], [], [] 15 | epoch_times = [] 16 | for ckpt in all_ckpts: 17 | iteration = ckpt[0] 18 | if iteration == 'final': # Skip for now 19 | continue 20 | train_ckpt, test_ckpt = itemgetter('train', 'test')(ckpt[1]) 21 | 22 | iterations.append(iteration) 23 | train_losses.append(train_ckpt['loss']) 24 | test_losses.append(test_ckpt['loss']) 25 | test_accuracies.append(test_ckpt['accuracy']) 26 | epoch_times.append(float(train_ckpt['time'])) 27 | 28 | if 'train_test' in ckpt[1]: 29 | train_test_ckpt = ckpt[1]['train_test'] 30 | train_test_iterations.append(iteration) 31 | train_test_losses.append(train_test_ckpt['loss']) 32 | train_test_accuracies.append(train_test_ckpt['accuracy']) 33 | 34 | return iterations, train_losses, test_losses, test_accuracies, train_test_iterations, train_test_losses, \ 35 | train_test_accuracies, epoch_times 36 | 37 | params = { 38 | 'axes.labelsize': 12, 39 | 'font.size': 12, 40 | 'legend.fontsize': 12, 41 | 'xtick.labelsize': 12, 42 | 'ytick.labelsize': 12, 43 | 'text.usetex': True, 44 | 'figure.figsize': [6, 4], 45 | 'text.latex.preamble': r'\usepackage{amsmath} \usepackage{amssymb}', 46 | } 47 | plt.rcParams.update(params) 48 | 49 | fig = plt.figure(figsize=(10,40)) 50 | plt.axes(frameon=0) # turn off frames 51 | plt.grid(axis='y', color='0.9', linestyle='-', linewidth=1) 52 | 53 | ax = plt.subplot(511) 54 | plt.title('Training Loss vs Iterations for IRNN on Sequential-MNIST') 55 | ax.set_yscale('log') 56 | 57 | ax2 = plt.subplot(512) 58 | plt.title('Training Accuracy vs Iterations for IRNN on Sequential-MNIST') 59 | ax2.set_ylim((15, 80)) 60 | 61 | ax3 = plt.subplot(513) 62 | plt.title('Test Loss vs Iterations for IRNN on Sequential-MNIST') 63 | ax3.set_yscale('log') 64 | 65 | ax4 = plt.subplot(514) 66 | plt.title('Test Accuracy vs Iterations for IRNN on Sequential-MNIST') 67 | ax4.set_ylim((15, 80)) 68 | #ax4.hlines(y=92.0, xmin=0, xmax=1000) 69 | ax4.grid(True) 70 | 71 | ax5 = plt.subplot(515) 72 | plt.title('Training time vs Iterations for IRNN on Sequential-MNIST') 73 | 74 | marker_size = 10 75 | 76 | EXP_BASE_DIR = '../rnn_stuff/mnist_rnn_exps' # Replace with path to folder containing experiments 77 | exp_folders = ['irnn_baseline', 'irnn_small_batch', 'rand_irnn_sparse', 'rand_irnn_sparse_full', 'rand_irnn_rp', 'rand_irnn_rp_full'] 78 | colors = ['pink', 'r', 'g', 'k', 'b', 'y'] 79 | markers = ['o', '^', 'x', '*', 'o', 'd'] 80 | labels = ['Baseline', 'Reduced batch', 'Same Sample', 'Different Sample', 'Project', 'Different Project'] 81 | 82 | for i, exp_folder in enumerate(exp_folders): 83 | exp_folder = os.path.join(EXP_BASE_DIR, exp_folder) 84 | for j, seed in enumerate(os.listdir(exp_folder)): 85 | data_dir = os.path.join(exp_folder, seed, "pickles") 86 | data_filename = os.path.join(data_dir, list(filter(PKL_FILE_REGEX.match, os.listdir(data_dir)))[0]) 87 | 88 | iterations, train_losses, test_losses, test_accuracies, train_test_iterations, train_test_losses, \ 89 | train_test_accuracies, epoch_times = extract_plot_data(data_filename) 90 | 91 | label = labels[i] if j == 0 else None 92 | 93 | marker = markers[i] 94 | exp_name = label 95 | color = colors[i] 96 | 97 | ax.plot(train_test_iterations, train_test_losses, marker=marker, label=exp_name, c=color, ms=marker_size, markevery=5) 98 | ax2.plot(train_test_iterations, train_test_accuracies, marker=marker, label=exp_name, c=color, ms=marker_size, markevery=5) 99 | ax3.plot(iterations, test_losses, marker=marker, label=exp_name, c=color, ms=marker_size, markevery=50) 100 | ax4.plot(iterations, test_accuracies, marker=marker, label=exp_name, c=color, ms=marker_size, markevery=50) 101 | ax5.plot(iterations, epoch_times, marker=marker, label=exp_name, c=color, ms=marker_size, markevery=50) 102 | 103 | 104 | ax.legend() 105 | ax2.legend() 106 | ax3.legend() 107 | ax4.legend() 108 | ax5.legend(loc='center right') 109 | 110 | #plt.show() 111 | fig.savefig('../Plots/IRNN_all_curves.pdf') 112 | 113 | -------------------------------------------------------------------------------- /nn_experiments/rnn_train_and_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | 5 | import torch 6 | import torchvision 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | import data as rpdata 12 | from train_and_eval import test, test_list 13 | 14 | def rnn_simple_train(args, model, device, train_loader, optimizer, test_loader, train_test_loader, iteration=0): 15 | all_checkpoints = [] 16 | before_epoch = time.time() 17 | 18 | params = list(model.parameters()) 19 | while iteration < args.max_iterations: 20 | for data, target in train_loader: 21 | iteration += 1 22 | if iteration > args.max_iterations: 23 | break 24 | 25 | random_seed = torch.randint(low=10000000000, high=99999999999, size=(1,)) 26 | model.train() 27 | optimizer.zero_grad() 28 | data, target = data.to(device), target.to(device) 29 | 30 | with torch.random.fork_rng(): 31 | torch.random.manual_seed(random_seed) 32 | output = model(data) 33 | 34 | loss = F.nll_loss(output, target) 35 | loss.backward() 36 | 37 | if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip) 38 | 39 | optimizer.step() 40 | 41 | if iteration % args.simple_log_frequency == 0: 42 | # Logging every 10 iterations. 43 | print('Train Iteration: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 44 | iteration, args.max_iterations, 45 | 100. * iteration / args.max_iterations, loss.item())) 46 | 47 | if iteration % args.simple_test_eval_frequency == 0: 48 | # Evaluate every 300 iterations. 49 | after_epoch = time.time() 50 | print('{} Iteration time: {}'.format(args.simple_test_eval_frequency, after_epoch - before_epoch)) 51 | train_ckpt = {} 52 | train_ckpt['loss'] = loss.item() 53 | train_ckpt['time'] = after_epoch - before_epoch 54 | train_ckpt['iteration'] = iteration 55 | 56 | start_testing_time = time.time() 57 | test_ckpt = test(args, model, device, test_loader) 58 | if (iteration // args.simple_test_eval_frequency) % args.simple_test_eval_per_train_test == 0: 59 | # Every few epochs calculate total train loss and accuracy. 60 | train_test_ckpt = test(args, model, device, train_test_loader, split='Train') 61 | all_checkpoints.append((iteration, {'train': train_ckpt, 'test': test_ckpt, 'train_test': train_test_ckpt})) 62 | else: 63 | all_checkpoints.append((iteration, {'train': train_ckpt, 'test': test_ckpt})) 64 | print("Test time at iteration {}: {}".format(iteration, time.time() - start_testing_time)) 65 | 66 | before_epoch = time.time() 67 | 68 | if iteration % args.simple_model_checkpoint_frequency == 0: 69 | if 'inter_dir' not in args: 70 | print("ERROR: Intermediate directory not created. This is a mistake.") 71 | else: 72 | torch.save({ 73 | 'iteration': iteration, 74 | 'model_state_dict': model.state_dict(), 75 | 'optimizer_state_dict': optimizer.state_dict(), 76 | 'loss': loss, 77 | }, os.path.join(args.inter_dir, 'iteration_{0:09}.pt'.format(iteration))) 78 | 79 | with open(os.path.join(args.pickle_dir, 'ckpts_iteration_{0:09}.pkl'.format(iteration)), 'wb') as f: 80 | pickle.dump(all_checkpoints, f) 81 | 82 | test_ckpt = test(args, model, device, test_loader) 83 | train_test_ckpt = test(args, model, device, train_test_loader, split='Train') 84 | all_checkpoints.append(('final', {'test': test_ckpt, 'train_test': train_test_ckpt})) 85 | 86 | return all_checkpoints 87 | 88 | def run_rnn_model(model, args, device, writer, pickle_string, model_string, num_workers=2, iteration=0, 89 | optimizer_state_dict=None): 90 | train_loader, test_loader, train_test_loader, num_classes = \ 91 | rpdata.get_dataset(args.dataset, batch_size=args.batch_size, test_batch_size=args.test_batch_size, 92 | validation=args.validation, with_replace=args.with_replace, 93 | data_root=args.data_root, augment=args.augment, 94 | bootstrap_train=args.bootstrap_train, num_workers=num_workers) 95 | # Calculate hash of test dataset to make sure test set is fixed. 96 | with torch.autograd.grad_mode.no_grad(): 97 | data_sum = 0.0 98 | for batch_idx, (data, target) in enumerate(test_loader): 99 | data_sum += torch.sum(data) 100 | print('Testing dataset has a sum of: {}'.format(data_sum)) 101 | if args.dataset == 'cifar10': 102 | if args.validation: 103 | assert data_sum == 116150.875 104 | else: 105 | assert data_sum == 492130.3125 106 | print('Make sure this is different: {}'.format(torch.randn(3))) 107 | print('Validation is {}'.format(args.validation)) 108 | 109 | data, _ = next(iter(train_loader)) 110 | if writer: 111 | grid = torchvision.utils.make_grid(data) 112 | writer.add_image('images', grid, 0) 113 | writer.add_graph(model, data) 114 | 115 | model = model.to(device) 116 | 117 | optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 118 | if optimizer_state_dict: 119 | print("Reloading optimizer state.") 120 | optimizer.load_state_dict(optimizer_state_dict) 121 | 122 | all_checkpoints = rnn_simple_train(args, model, device, train_loader, optimizer, test_loader, train_test_loader, iteration=iteration) 123 | 124 | with open(os.path.join(args.pickle_dir, '{}.pkl'.format(pickle_string)), 'wb') as f: 125 | pickle.dump(all_checkpoints, f) 126 | 127 | if args.save_model: 128 | torch.save(model.state_dict(), os.path.join(args.exp_dir, "{}.pt".format(model_string))) 129 | 130 | if writer: 131 | writer.close() 132 | -------------------------------------------------------------------------------- /nn_experiments/train_and_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import sys 4 | import time 5 | import pickle 6 | import argparse 7 | import numpy as np 8 | 9 | import data as rpdata 10 | import models as rpmodels 11 | 12 | import torch 13 | import torchvision 14 | import torch.optim as optim 15 | import torch.utils.tensorboard as tb 16 | import torch.optim.lr_scheduler as lrs 17 | import torch.nn.functional as F 18 | 19 | 20 | def simple_train(args, model, device, train_loader, optimizer, scheduler, test_loader, train_test_loader): 21 | all_checkpoints = [] 22 | before_epoch = time.time() 23 | iteration = 0 24 | 25 | while iteration < args.max_iterations: 26 | for data, target in train_loader: 27 | iteration += 1 28 | if iteration > args.max_iterations: 29 | break 30 | 31 | random_seed = torch.randint(low=10000000000, high=99999999999, size=(1,)) 32 | model.train() 33 | optimizer.zero_grad() 34 | data, target = data.to(device), target.to(device) 35 | 36 | with torch.random.fork_rng(): 37 | torch.random.manual_seed(random_seed) 38 | output = model(data) 39 | 40 | loss = F.nll_loss(output, target) 41 | loss.backward() 42 | 43 | optimizer.step() 44 | 45 | if iteration % args.simple_log_frequency == 0: 46 | # Logging every 10 iterations. 47 | print('Train Iteration: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 48 | iteration, args.max_iterations, 49 | 100. * iteration / args.max_iterations, loss.item())) 50 | 51 | if iteration % args.simple_test_eval_frequency == 0: 52 | # Evaluate every 300 iterations. 53 | after_epoch = time.time() 54 | print('{} Iteration time: {}'.format(args.simple_test_eval_frequency, after_epoch - before_epoch)) 55 | train_ckpt = {} 56 | train_ckpt['loss'] = loss.item() 57 | train_ckpt['time'] = after_epoch - before_epoch 58 | train_ckpt['iteration'] = iteration 59 | 60 | test_ckpt = test(args, model, device, test_loader) 61 | if (iteration // args.simple_test_eval_frequency) % args.simple_test_eval_per_train_test == 0: 62 | # Every few epochs calculate total train loss and accuracy. 63 | train_test_ckpt = test(args, model, device, train_test_loader, split='Train') 64 | all_checkpoints.append((iteration, {'train': train_ckpt, 'test': test_ckpt, 'train_test': train_test_ckpt})) 65 | else: 66 | all_checkpoints.append((iteration, {'train': train_ckpt, 'test': test_ckpt})) 67 | before_epoch = time.time() 68 | 69 | if iteration % args.simple_scheduler_step_frequency == 0: 70 | scheduler.step() 71 | print('Learning rate is now decreased by {}'.format(args.gamma)) 72 | 73 | if iteration % args.simple_model_checkpoint_frequency == 0: 74 | if 'inter_dir' not in args: 75 | print("ERROR: Intermediate directory not created. This is a mistake.") 76 | else: 77 | torch.save({ 78 | 'iteration': iteration, 79 | 'model_state_dict': model.state_dict(), 80 | 'optimizer_state_dict': optimizer.state_dict(), 81 | 'loss': loss, 82 | }, os.path.join(args.inter_dir, 'iteration_{0:09}.pt'.format(iteration))) 83 | 84 | test_ckpt = test(args, model, device, test_loader) 85 | train_test_ckpt = test(args, model, device, train_test_loader, split='Train') 86 | all_checkpoints.append(('final', {'test': test_ckpt, 'train_test': train_test_ckpt})) 87 | 88 | return all_checkpoints 89 | 90 | 91 | def test(args, model, device, test_loader, writer=None, split='Test'): 92 | model.eval() 93 | test_loss = 0 94 | correct = 0 95 | before_test = time.time() 96 | with torch.no_grad(): 97 | for data, target in test_loader: 98 | data, target = data.to(device), target.to(device) 99 | output = model(data) 100 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 101 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 102 | correct += pred.eq(target.view_as(pred)).sum().item() 103 | 104 | test_loss /= len(test_loader.dataset) 105 | after_test = time.time() 106 | 107 | print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 108 | split, test_loss, correct, len(test_loader.dataset), 109 | 100. * correct / len(test_loader.dataset))) 110 | checkpoint = {} 111 | checkpoint['loss'] = test_loss 112 | checkpoint['accuracy'] = 100. * correct / len(test_loader.dataset) 113 | checkpoint['time'] = after_test - before_test 114 | 115 | return checkpoint 116 | 117 | 118 | def run_model(model, args, device, writer, pickle_string, model_string): 119 | train_loader, test_loader, train_test_loader, num_classes = \ 120 | rpdata.get_dataset(args.dataset, batch_size=args.batch_size, test_batch_size=args.test_batch_size, 121 | validation=args.validation, with_replace=args.with_replace, 122 | data_root=args.data_root, augment=args.augment, 123 | bootstrap_train=args.bootstrap_train) 124 | 125 | data, _ = next(iter(train_loader)) 126 | if writer: 127 | grid = torchvision.utils.make_grid(data) 128 | writer.add_image('images', grid, 0) 129 | writer.add_graph(model, data) 130 | 131 | model = model.to(device) 132 | 133 | if 'cifar_adam' in args and args.cifar_adam: 134 | optimizer = optim.Adam(model.parameters(), lr=args.lr, 135 | weight_decay=args.weight_decay) 136 | else: 137 | optimizer = optim.SGD( 138 | model.parameters(), 139 | momentum=0.9, lr=args.lr, 140 | weight_decay=args.weight_decay) 141 | 142 | if 'training_schedule' not in args or args.training_schedule == 'cifar': 143 | print('Using CIFAR schedule.') 144 | scheduler = lrs.MultiStepLR(optimizer, milestones=[int(args.epochs * 0.3), 145 | int(args.epochs * 0.6), 146 | int(args.epochs * 0.8)], gamma=0.2) 147 | elif args.training_schedule == 'extended': 148 | print('Using extended CIFAR schedule.') 149 | scheduler = lrs.MultiStepLR(optimizer, milestones=[int(args.epochs * 0.1), 150 | int(args.epochs * 0.2), 151 | int(args.epochs * 0.3), 152 | int(args.epochs * 0.4), 153 | int(args.epochs * 0.5), 154 | int(args.epochs * 0.6), 155 | int(args.epochs * 0.7), 156 | int(args.epochs * 0.8), 157 | int(args.epochs * 0.9)], gamma=0.4) 158 | elif args.training_schedule == 'epoch_step': 159 | scheduler = lrs.StepLR(optimizer, step_size=args.lr_drop_step, gamma=args.gamma) 160 | else: 161 | raise NotImplementedError('Invalid training schedule.') 162 | 163 | all_checkpoints = simple_train(args, model, device, train_loader, optimizer, scheduler, test_loader, train_test_loader) 164 | 165 | with open(os.path.join(args.pickle_dir,'{}.pkl'.format(pickle_string)), 'wb') as f: 166 | pickle.dump(all_checkpoints, f) 167 | 168 | if args.save_model: 169 | torch.save(model.state_dict(), os.path.join(args.exp_dir, "{}.pt".format(model_string))) 170 | if writer: 171 | writer.close() 172 | -------------------------------------------------------------------------------- /nn_experiments/utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://stackoverflow.com/a/32107024 - Accessed 02/05/2020 2 | class ParameterMap(dict): 3 | """ 4 | Example: 5 | m = ParameterMap({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer']) 6 | """ 7 | def __init__(self, *args, **kwargs): 8 | super(ParameterMap, self).__init__(*args, **kwargs) 9 | for arg in args: 10 | if isinstance(arg, dict): 11 | for k, v in arg.items(): 12 | self[k] = v 13 | 14 | if kwargs: 15 | for k, v in kwargs.items(): 16 | self[k] = v 17 | 18 | def __getattr__(self, attr): 19 | return self.get(attr) 20 | 21 | def __setattr__(self, key, value): 22 | self.__setitem__(key, value) 23 | 24 | def set_with_cast(self, key, value): 25 | if self[key] is None: 26 | return False 27 | else: 28 | t = type(self[key]) 29 | if t is bool: 30 | val = value == 'True' 31 | elif t is int: 32 | val = int(value) 33 | elif t is float: 34 | val = float(value) 35 | else: 36 | val = str(value) 37 | self[key] = val 38 | return True 39 | 40 | def __setitem__(self, key, value): 41 | super(ParameterMap, self).__setitem__(key, value) 42 | self.__dict__.update({key: value}) 43 | 44 | def __delattr__(self, item): 45 | self.__delitem__(item) 46 | 47 | def __delitem__(self, key): 48 | super(ParameterMap, self).__delitem__(key) 49 | del self.__dict__[key] 50 | 51 | def clone(self): 52 | new_map = ParameterMap() 53 | for k in self.keys(): 54 | new_map[k] = self[k] 55 | return new_map 56 | 57 | def from_file(self, file_path): 58 | with open(file_path, 'r') as f: 59 | lines = f.readlines() 60 | for line in lines: 61 | arg_split = line.split(': ') 62 | assert len(arg_split) == 2 63 | k = arg_split[0].strip() 64 | v = arg_split[1].strip() 65 | if not self.set_with_cast(k, v): 66 | raise AttributeError('Unknown parameter [{}]'.format(k)) 67 | 68 | def __str__(self): 69 | rep = '' 70 | for key in self: 71 | rep += '{}: {}\n'.format(key, self[key]) 72 | return rep 73 | 74 | 75 | def override_arguments(args, additional_args): 76 | for arg in additional_args: 77 | assert arg[:2] == '--' 78 | arg_split = arg.split('=') 79 | assert len(arg_split) == 2 80 | k = arg_split[0][2:] 81 | v = arg_split[1] 82 | if args.set_with_cast(k, v): 83 | print('Overriding argument [{}] with [{}]'.format(k, v)) 84 | else: 85 | raise AttributeError('Unknown parameter [{}]'.format(k)) 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | cachetools==4.1.0 3 | certifi==2020.4.5.2 4 | chardet==3.0.4 5 | cycler==0.10.0 6 | future==0.18.2 7 | google-auth==1.17.0 8 | google-auth-oauthlib==0.4.1 9 | grpcio==1.29.0 10 | idna==2.9 11 | importlib-metadata==1.6.1 12 | jax==0.1.70 13 | jaxlib==0.1.47 14 | kiwisolver==1.2.0 15 | Markdown==3.2.2 16 | matplotlib==3.2.1 17 | numpy==1.18.5 18 | oauthlib==3.1.0 19 | opt-einsum==3.2.1 20 | pandas==1.0.4 21 | Pillow==7.1.2 22 | protobuf==3.12.2 23 | pyasn1==0.4.8 24 | pyasn1-modules==0.2.8 25 | pyparsing==2.4.7 26 | python-dateutil==2.8.1 27 | pytz==2020.1 28 | requests==2.23.0 29 | requests-oauthlib==1.3.0 30 | rsa==4.1 31 | scipy==1.4.1 32 | six==1.15.0 33 | tensorboard==2.2.2 34 | tensorboard-plugin-wit==1.6.0.post3 35 | torch==1.3.1 36 | torchvision==0.4.2 37 | urllib3==1.25.9 38 | Werkzeug==1.0.1 39 | zipp==3.1.0 40 | --------------------------------------------------------------------------------