├── README.md ├── notebooks ├── runs.npz ├── results_optim.npz ├── results_shower.npz └── 03_VisualizeSimulator.ipynb ├── utils.py ├── shower_sim_baseline.py ├── plots_optimization.py ├── plots_gradients.py ├── shower_sim_optimize.py ├── plot_loss_landscape.py ├── LossLandscape.ipynb ├── shower_sim_redone.py ├── GradientEstimates.ipynb └── VisualizeDesign.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # branches_of_a_tree 2 | -------------------------------------------------------------------------------- /notebooks/runs.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makagan/branches_of_a_tree/main/notebooks/runs.npz -------------------------------------------------------------------------------- /notebooks/results_optim.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makagan/branches_of_a_tree/main/notebooks/results_optim.npz -------------------------------------------------------------------------------- /notebooks/results_shower.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makagan/branches_of_a_tree/main/notebooks/results_shower.npz -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | COLORS = { 5 | 'stad': 'maroon', 6 | 'scorebase': 'darkorange', 7 | 'numeric': 'forestgreen', 8 | 'score': 'steelblue' 9 | } 10 | 11 | def array_mean_and_quantiles(array, window = 5): 12 | mean = pd.DataFrame(np.mean(array, axis=1)).rolling(window).mean().to_numpy()[:,0] 13 | pup,pmn,pdn = [ 14 | pd.DataFrame(np.quantile(array,q, axis=1)).rolling(window).mean().to_numpy()[:,0] 15 | for q in [.10,.50,.90] 16 | ] 17 | return mean,(pup,pmn,pdn) 18 | 19 | -------------------------------------------------------------------------------- /shower_sim_baseline.py: -------------------------------------------------------------------------------- 1 | # shower_sim_baseline 2 | 3 | import numpy as np 4 | import jax 5 | 6 | def propagate_state(state): 7 | 8 | if state is None: 9 | return None 10 | 11 | if state['alive']==False: 12 | return state 13 | 14 | E,x,y,px,py = state['E'],state['x'],state['y'],state['px'],state['py'] 15 | pmag = np.sqrt(px**2 + py**2) 16 | time = 0.02 17 | next_x = x + time*px 18 | next_y = y + time*py 19 | next_E = E 20 | return { 21 | 'E': next_E, 22 | 'x': next_x, 23 | 'y': next_y, 24 | 'px': state['px'], 25 | 'py': state['py'], 26 | 'alive':True, 27 | } 28 | 29 | def sample_stop_prob(score, state, sim_parameters): 30 | r = np.sqrt(state['x']**2 + state['y']**2) 31 | E = state['E'] 32 | par_thresh_E = sim_parameters['thresh_E'] 33 | stop = False 34 | if (E < par_thresh_E): 35 | stop = True 36 | if (r > 20.): 37 | stop = True 38 | return stop 39 | 40 | def interact_prob(x,y,par): 41 | par_radial = 10 42 | par_azimutal = 10 43 | r = jax.numpy.sqrt(x**2+y**2) 44 | 45 | alpha = jax.numpy.arctan2(x,y) 46 | 47 | 48 | sampling1 = 1/(1+jax.numpy.exp(10*jax.numpy.sin(par_radial*(alpha+2*r)))) 49 | sampling2 = 1/(1+jax.numpy.exp(10*jax.numpy.cos(par_azimutal*(r-2)))) 50 | start = 1/(1+jax.numpy.exp(-10*(r-par))) 51 | end = 1/(1+jax.numpy.exp(10*(r-(par+10.0)))) 52 | 53 | return 0.5*start*sampling1*sampling2*end 54 | 55 | 56 | 57 | ########### Summary / Objective ########### 58 | 59 | def per_hit_summary(hits): 60 | return np.sqrt(hits[:,0]**2+hits[:,1]**2) 61 | 62 | def summary(generation): 63 | hits,active,*_ = generation 64 | return (np.mean(per_hit_summary(active)) - 2.0)**2 65 | 66 | def summary_metric(active): 67 | return (np.mean(per_hit_summary(active)) - 2.0)**2 68 | -------------------------------------------------------------------------------- /plots_optimization.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import utils 3 | import numpy as np 4 | 5 | def moving_avg(array, window): 6 | return pd.DataFrame(array).rolling(window).mean() 7 | 8 | def plot_single_opt_comparison(ax, l_st, l_s, l_sb, l_n): 9 | window = 10 10 | ax.plot(moving_avg(l_st, window) , label = 'STAD', color = utils.COLORS['stad']) 11 | ax.plot(moving_avg(l_s, window), label = 'Score', color = utils.COLORS['score']) 12 | ax.plot(moving_avg(l_sb, window), label = 'Score Baseline', color = utils.COLORS['scorebase']) 13 | ax.plot(moving_avg(l_n, window), label = 'Numeric', color = utils.COLORS['numeric']) 14 | 15 | ax.set_ylabel('Loss') 16 | ax.set_xlabel('Epoch') 17 | ax.set_ylim(0,5) 18 | ax.legend() 19 | 20 | 21 | def plot_optimization_comparison( 22 | ax, 23 | l_st_list, 24 | l_s_list, 25 | l_n_list, 26 | l_sb_list 27 | 28 | ): 29 | 30 | st = utils.array_mean_and_quantiles(l_st_list.T) 31 | s = utils.array_mean_and_quantiles(l_s_list.T) 32 | n = utils.array_mean_and_quantiles(l_n_list.T) 33 | sb = utils.array_mean_and_quantiles(l_sb_list.T) 34 | 35 | xrange = np.arange(500) 36 | ax.plot(xrange,n[0],c = utils.COLORS['numeric'], label = 'numeric') 37 | ax.fill_between(xrange,n[1][0],n[1][2], facecolor = utils.COLORS['numeric'], alpha = 0.2) 38 | 39 | ax.plot(xrange,s[0],c = utils.COLORS['score'], label = 'SCORE') 40 | ax.fill_between(xrange,s[1][0],s[1][2], facecolor = utils.COLORS['score'], alpha = 0.2) 41 | 42 | ax.plot(xrange,sb[0],c = utils.COLORS['scorebase'], label = 'SCORB') 43 | ax.fill_between(xrange,sb[1][0],sb[1][2], facecolor = utils.COLORS['scorebase'], alpha = 0.2) 44 | 45 | ax.plot(xrange,st[0],c = utils.COLORS['stad'], label = 'STAD') 46 | ax.fill_between(xrange,st[1][0],st[1][2], facecolor = utils.COLORS['stad'], alpha = 0.2) 47 | 48 | 49 | ax.set_ylabel('Loss') 50 | ax.set_xlabel('Steps') 51 | ax.set_title('Design Optimization') 52 | ax.legend(loc = 'lower left') 53 | -------------------------------------------------------------------------------- /plots_gradients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | import utils 4 | 5 | def analyse_at_point(ax,runs1, eps=0.05, nsig = 50, nbins = 101, legend = True): 6 | values1 = np.array([r1['primal'] for r1 in runs1]) 7 | 8 | stad_grads1 = np.array([r['grad_dict']['stad'] for r in runs1]) 9 | score_grads1 = np.array([r['grad_dict']['score'] for r in runs1]) 10 | dlogp = np.array([r['dlogp'] for r in runs1]) 11 | numeric = np.array([r['grad_dict']['numeric'] for r in runs1]) 12 | score_baseline = score_grads1 - dlogp*values1.mean() 13 | 14 | 15 | print(f'stad {stad_grads1.mean():.2f},{stad_grads1.std():.2f}') 16 | print(f'scorb {score_baseline.mean():.2f},{score_baseline.std():.2f}') 17 | print(f'score {score_grads1.mean():.2f},{score_grads1.std():.2f}') 18 | print(f'numeric {numeric.mean():.2f},{numeric.std():.2f}') 19 | print('----') 20 | ax = sns.boxplot(ax = ax, data = np.column_stack( 21 | [stad_grads1,score_baseline, score_grads1, numeric]), orient = 'v', fliersize=1, meanline=True, showmeans=True, 22 | meanprops = {'c': 'k'}, 23 | palette = [ 24 | utils.COLORS['stad'], 25 | utils.COLORS['scorebase'], 26 | utils.COLORS['score'], 27 | utils.COLORS['numeric'], 28 | ] 29 | ) 30 | # ax.axhline(stad_grads1.mean(), c = 'k') 31 | 32 | 33 | 34 | return {'stad_m':stad_grads1.mean(), 'stad_s':stad_grads1.std(), 35 | 'score_m':score_grads1.mean(), 'score_s':score_grads1.std(), 36 | 'score_baseline_m':score_baseline.mean(), 'score_baseline_s':score_baseline.std(), 37 | 'numeric_m':numeric.mean(), 'numeric_s':numeric.std(), 38 | } 39 | 40 | 41 | def plot_variance_with_inset(axarr,runs): 42 | ax = axarr 43 | _ = analyse_at_point(ax,runs, nsig = 50, nbins = 101, legend = True) 44 | ax.set_ylim(-15,15) 45 | ax.set_xticklabels(['StochAD','Score w/ Baseline','Score','Numeric'], rotation = 20) 46 | ax.set_xlim(-0.5,2.5) 47 | ax.set_ylabel(r'$g \sim \partial_\theta\mathbb{E}[X(\theta)]}$') 48 | 49 | 50 | iax = ax.inset_axes([0.3,0.1,0.3,0.2]) 51 | _ = analyse_at_point(iax,runs, nsig = 50, nbins = 101, legend = True) 52 | iax.set_xticklabels(['ST','SB','Score','Numeric']) 53 | iax.set_xlim(1.5,3.5) 54 | iax.set_ylim(-60,60) 55 | -------------------------------------------------------------------------------- /shower_sim_optimize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import optax 4 | import jax.numpy as jnp 5 | 6 | def program_to_optimize(simulator, objective, sim_kwargs): 7 | def program_for_optimizer(theta, grad_type = "stad", eps=0.01, keep_all_grads=False): 8 | 9 | if grad_type not in ["stad","score","numeric"]: 10 | print("grad_type=",grad_type,"not recognized") 11 | return None 12 | 13 | hits,active,history,scores,out_st = simulator(theta, **sim_kwargs) 14 | primal = objective(active) 15 | 16 | grad_dict={} 17 | 18 | if grad_type == "stad" or keep_all_grads: 19 | alt = objective(out_st['y']['active']) 20 | grad_dict["stad"] = out_st['d'] + out_st['w']*(alt - primal) 21 | 22 | if grad_type == "score" or keep_all_grads: 23 | grad_dict["score"] = scores*primal 24 | 25 | if grad_type == "numeric" or keep_all_grads: 26 | _,active2,_,_,_ = simulator(theta+eps, **sim_kwargs) 27 | primal2 = objective(active2) 28 | grad_dict["numeric"] = (primal2 - primal) / eps 29 | 30 | grad_val = grad_dict[grad_type] 31 | 32 | return {"primal":primal, "grad":grad_val, "grad_type":grad_type, "grad_dict":grad_dict, "dlogp":scores} 33 | return program_for_optimizer 34 | 35 | def minibatch_primal_and_grad(program, theta, Nmini, grad_type = "stad", dobaseline=True): 36 | 37 | runs = [program(theta, grad_type) for _ in range(Nmini)] 38 | 39 | primal = np.mean([r["primal"] for r in runs]) 40 | 41 | if grad_type=="score" and dobaseline and Nmini > 1: 42 | grad = np.mean([ (r["grad"]-r["dlogp"]*primal) for r in runs]) 43 | else: 44 | grad = np.mean([r["grad"] for r in runs]) 45 | 46 | 47 | return primal, grad 48 | 49 | def optimize(program, init, LR, Nepoch, Nmini, grad_type, dobaseline=True, doprint=True): 50 | traj_theta = [] 51 | traj_v = [] 52 | traj_g = [] 53 | theta = jnp.array(init) 54 | 55 | optimizer = optax.adam(learning_rate=LR) 56 | adam_state = optimizer.init(theta) 57 | 58 | trainsteps = tqdm.tqdm(range(Nepoch)) 59 | for i in trainsteps: 60 | traj_theta.append(theta) 61 | v, g = minibatch_primal_and_grad(program, theta, Nmini, grad_type, dobaseline) 62 | updates, adam_state = optimizer.update(g, adam_state, theta) 63 | theta = optax.apply_updates(theta, updates) 64 | if theta < 0.: 65 | theta = 0. 66 | 67 | traj_v.append(v) 68 | traj_g.append(g) 69 | return theta, traj_theta, traj_v, traj_g 70 | -------------------------------------------------------------------------------- /plot_loss_landscape.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pandas as pd 4 | import utils 5 | import numpy as np 6 | 7 | 8 | def plot_loss_landscape_primal(ax, numeric_fit, par_vals, primal_list): 9 | Nscan,NMC = primal_list.shape 10 | print(primal_list.shape) 11 | 12 | ax.scatter(np.tile(np.transpose(par_vals.reshape(1,par_vals.shape[0])), (1,20)), 13 | primal_list[:,:20],# + np.random.normal(0,0.01, size = score_list.shape), 14 | alpha = 0.1, label="per event primal", c = 'k') 15 | 16 | # ax.scatter(par_vals, primal_list_m, label = 'primal mean') 17 | 18 | mean, mean_q = utils.array_mean_and_quantiles(primal_list, window = 1) 19 | 20 | ax.plot(par_vals,np.array([mean_q[0],mean_q[2]]).T, color = 'k', alpha = 0.5) 21 | ax.plot(par_vals,mean, label = 'primal median', c = 'k') 22 | 23 | ax.plot(par_vals, numeric_fit(par_vals), color='maroon', label = 'poly. fit', linestyle = 'dashed') 24 | ax.axhline(0.0, c = 'k') 25 | ax.set_xlim(0.0,4.1) 26 | ax.set_ylim(-0.2,10) 27 | ax.set_ylabel('Loss') 28 | ax.set_xlabel('parameter') 29 | ax.legend() 30 | 31 | def plot_loss_landscape_gradients( 32 | ax, 33 | numeric_fit, 34 | par_vals, 35 | numeric_list, 36 | score_list, 37 | score_baseline_list, 38 | stad_list): 39 | window = 3 40 | 41 | grad_from_fit = numeric_fit.deriv() 42 | 43 | w = 1 44 | stad_mean, stad_q = utils.array_mean_and_quantiles(stad_list, window = w) 45 | scob_mean, scob_q = utils.array_mean_and_quantiles(score_baseline_list, window = w) 46 | scor_mean, scor_q = utils.array_mean_and_quantiles(score_list, window = w) 47 | numr_mean, numr_q = utils.array_mean_and_quantiles(numeric_list, window = w) 48 | 49 | ax.plot(par_vals, numr_mean, label = 'Numeric', color = utils.COLORS['numeric'], linestyle = 'dashed') 50 | ax.plot(par_vals,np.array([numr_q[0],numr_q[2]]).T, alpha = 1.0, color = utils.COLORS['numeric']) 51 | 52 | ax.plot(par_vals, stad_mean, label = 'StochAD', color = utils.COLORS['stad'], linestyle = 'dashed') 53 | ax.plot(par_vals,np.array([stad_q[0],stad_q[2]]).T, alpha = 1.0, color = utils.COLORS['stad']) 54 | 55 | ax.plot(par_vals, stad_mean, label = 'SCORB', color = utils.COLORS['scorebase'], linestyle = 'dashed') 56 | ax.plot(par_vals,np.array([scob_q[0],scob_q[2]]).T, alpha = 1.0, color = utils.COLORS['scorebase']) 57 | 58 | ax.plot(par_vals, scor_mean, label = 'SCORE', color = utils.COLORS['score'], linestyle = 'dashed') 59 | ax.plot(par_vals,np.array([scor_q[0],scor_q[2]]).T, alpha = 1.0, color = utils.COLORS['score']) 60 | 61 | 62 | ax.plot( 63 | par_vals, grad_from_fit(par_vals), color='black', linestyle = 'dashed', 64 | label = 'grad from fit' 65 | ) 66 | 67 | 68 | ax.set_xlim(0.0,4.0) 69 | ax.set_ylim(-15.0, 20.0) 70 | 71 | ax.set_ylabel('Grad') 72 | ax.set_xlabel('parameter') 73 | ax.legend(loc='upper left') 74 | -------------------------------------------------------------------------------- /LossLandscape.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import jax\n", 12 | "import copy\n", 13 | "import queue\n", 14 | "from shower_sim_instrumented import simulator\n", 15 | "from shower_sim_baseline import summary_metric\n", 16 | "\n", 17 | "%load_ext autoreload\n", 18 | "%autoreload 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "0.87279123 48.394210412632674 2736 2848\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "hits,active,history,scores,out_st = simulator(3.5)\n", 36 | "print(scores, out_st['w'], hits.size, out_st['y']['hits'].size)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "############ par value= 0.25 ##################\n", 49 | "############ 0 ##################\n", 50 | "############ 100 ##################\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "par_vals = np.arange(0.25, 4.0, 0.1)\n", 56 | "N=200\n", 57 | "eps = 0.01\n", 58 | "\n", 59 | "primal_list = []\n", 60 | "primal_st_list = []\n", 61 | "\n", 62 | "score_list = []\n", 63 | "stad_list = []\n", 64 | "numeric_list = []\n", 65 | "dlogp_list = []\n", 66 | "\n", 67 | "\n", 68 | "for par_v in par_vals:\n", 69 | " print(\"############ par value=\", par_v, \"##################\")\n", 70 | " \n", 71 | " primal = []\n", 72 | " primal_st = []\n", 73 | "\n", 74 | " score_val = []\n", 75 | " stad_val = []\n", 76 | " numeric_val = []\n", 77 | " dlogp_val = []\n", 78 | " \n", 79 | " for i in range(N):\n", 80 | " if i%100 == 0: print(\"############\", i, \"##################\")\n", 81 | " hits,active,history,scores,out_st = simulator(par_v)\n", 82 | " _, active2, _, _, _ = simulator(par_v+eps)\n", 83 | " \n", 84 | " _val = summary_metric(active)\n", 85 | " _val2 = summary_metric(active2)\n", 86 | " \n", 87 | " primal.append(_val)\n", 88 | " primal_st.append(summary_metric(out_st['y']['active']))\n", 89 | " \n", 90 | " numeric_val.append( (_val2 - _val)/eps )\n", 91 | " \n", 92 | " score_val.append(scores*primal[i])\n", 93 | " dlogp_val.append(scores)\n", 94 | " stad_val.append(out_st['d'] + out_st['w']*(primal_st[i] - primal[i]))\n", 95 | " \n", 96 | " primal_list.append(primal)\n", 97 | " primal_st_list.append(primal_st)\n", 98 | " \n", 99 | " score_list.append(score_val)\n", 100 | " stad_list.append(stad_val)\n", 101 | " numeric_list.append(numeric_val)\n", 102 | " dlogp_list.append(dlogp_val)\n", 103 | " \n", 104 | "\n", 105 | "primal_list = np.array(primal_list)\n", 106 | "primal_list_m = primal_list.mean(axis=1)\n", 107 | "primal_list_s = primal_list.std(axis=1)\n", 108 | "\n", 109 | "numeric_list = np.array(numeric_list)\n", 110 | "numeric_m = numeric_list.mean(axis=1)\n", 111 | "numeric_s = numeric_list.std(axis=1)\n", 112 | "\n", 113 | "score_list = np.array(score_list)\n", 114 | "score_m = score_list.mean(axis=1)\n", 115 | "score_s = score_list.std(axis=1)\n", 116 | "\n", 117 | "dlogp_list = np.array(dlogp_list)\n", 118 | "score_baseline_list = score_list - dlogp_list*primal_list_m.reshape(-1,1)\n", 119 | "score_baseline_m = score_baseline_list.mean(axis=1)\n", 120 | "score_baseline_s = score_baseline_list.std(axis=1)\n", 121 | "\n", 122 | "stad_list = np.array(stad_list)\n", 123 | "stad_m = stad_list.mean(axis=1)\n", 124 | "stad_s = stad_list.std(axis=1)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": "Python 3.11.4 64-bit ('stochad_env')", 138 | "language": "python", 139 | "name": "python3" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.11.4" 152 | }, 153 | "orig_nbformat": 4, 154 | "vscode": { 155 | "interpreter": { 156 | "hash": "29b777788fad9121f9b4a41d949494280ba66b13a17f3ccd9e5dfa0de3270b9d" 157 | } 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 2 162 | } 163 | -------------------------------------------------------------------------------- /shower_sim_redone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import jax 4 | import copy 5 | import queue 6 | 7 | def make_simulator(): 8 | def stochasticTriple(d=0., y=None, w=0.): 9 | return { 10 | "d": d, 11 | "y": y, 12 | "w": w, 13 | } 14 | 15 | def bernoulli_basic(p, get_omega=False, u_input=None): 16 | def result_on_omega(omega,p): 17 | if omega > (1.-p): 18 | return 1 19 | else: 20 | return 0 21 | 22 | if u_input is not None: 23 | u = u_input 24 | else: 25 | u = np.random.uniform() 26 | 27 | b = result_on_omega(u,p) 28 | 29 | if get_omega: 30 | return b, u 31 | else: 32 | return b 33 | 34 | 35 | def compose_derivs(func,st1,st2): 36 | d1,y1,w1 = st1["d"], st1["y"], st1["w"] 37 | d2,y2,w2 = st2["d"], st2["y"], st2["w"] 38 | 39 | w1_iszero = (w1==0 or w1==0.) 40 | w2_iszero = (w2==0 or w2==0.) 41 | 42 | u = None 43 | 44 | if w1_iszero and w2_iszero: 45 | y=y2 46 | else: 47 | prob = 0 if w1_iszero else np.fabs(w1)/(np.fabs(w1)+np.fabs(d1)*np.fabs(w2)) 48 | option = bernoulli_basic(prob, get_omega=False) 49 | 50 | y = func(y1) if option == 1 else y2 51 | 52 | d = d1*d2 53 | w = np.fabs(w1) + np.fabs(d1)*np.fabs(w2) 54 | 55 | #print("C", st1, st2, (d,y,w)) 56 | 57 | return stochasticTriple(d,y,w) 58 | 59 | 60 | def do_prune_away_old(st_new, st_old): 61 | 62 | w_new = st_new["w"] 63 | w_old = st_old["w"] 64 | 65 | w_new_iszero = (w_new==0 or w_new==0.) 66 | w_old_iszero = (w_old==0 or w_old==0.) 67 | 68 | 69 | if w_new_iszero and w_old_iszero: 70 | keep_new_state = False 71 | else: 72 | prob = 0 if w_new_iszero else np.fabs(w_new)/(np.fabs(w_new)+np.fabs(w_old)) 73 | keep_new_state = bernoulli_basic(prob, get_omega=False) 74 | 75 | return keep_new_state 76 | 77 | def bernoulli(p, p_st, get_omega=False): 78 | 79 | def _fwd(p, get_omega=False): 80 | def result_on_omega(omega,p): 81 | if omega > (1.-p): 82 | return 1 83 | else: 84 | return 0 85 | 86 | u = np.random.uniform() 87 | b = result_on_omega(u,p) 88 | 89 | if get_omega: 90 | return b, u 91 | else: 92 | return b 93 | 94 | 95 | def _deriv(x,p, direction=1.0): 96 | if int(direction)==1: 97 | if x==0: 98 | #Right deriv 99 | st = stochasticTriple(0., 1, 1./(1.-p)) 100 | else: 101 | st = stochasticTriple(0., 0, 0.) 102 | 103 | else: 104 | if x == 1: 105 | #Left deriv 106 | st = stochasticTriple(0., 0, 1./p) 107 | else: 108 | st = stochasticTriple(0., 0, 0.) 109 | 110 | return st 111 | 112 | if get_omega: 113 | out, u = _fwd(p, get_omega=True) 114 | else: 115 | out = _fwd(p, get_omega=False) 116 | 117 | direction=1.0 #by default, choose right deriv, but switch if p deriv is negative 118 | if p_st is not None: 119 | if p_st['d'] < 0: 120 | direction = -1.0 121 | 122 | out_st = _deriv(out, p, direction) 123 | 124 | if p_st is not None: 125 | out_st = compose_derivs(_fwd, p_st, out_st) 126 | 127 | if get_omega: 128 | return out, out_st, u 129 | else: 130 | return out, out_st 131 | 132 | 133 | def propagate_state(state): 134 | 135 | if state is None: 136 | return None 137 | 138 | if state['alive']==False: 139 | return state 140 | 141 | E,x,y,px,py = state['E'],state['x'],state['y'],state['px'],state['py'] 142 | pmag = np.sqrt(px**2 + py**2) 143 | time = 0.02 144 | next_x = x + time*px 145 | next_y = y + time*py 146 | next_E = E 147 | return { 148 | 'E': next_E, 149 | 'x': next_x, 150 | 'y': next_y, 151 | 'px': state['px'], 152 | 'py': state['py'], 153 | 'alive':True, 154 | } 155 | 156 | 157 | 158 | def sample_stop_prob(score, state, sim_parameters): 159 | r = np.sqrt(state['x']**2 + state['y']**2) 160 | E = state['E'] 161 | par_thresh_E = sim_parameters['thresh_E'] 162 | #next_stop_prob = 1/(1+jax.numpy.exp((E - par_thresh_E)/0.5)) 163 | #stop = np.random.binomial(1,next_stop_prob) 164 | #return stop 165 | 166 | stop = False 167 | if (E < par_thresh_E): 168 | stop = True 169 | if (r > 20.): 170 | stop = True 171 | 172 | return stop 173 | 174 | score_bernoulli = jax.jit(jax.grad(jax.scipy.stats.bernoulli.logpmf, argnums=1)) 175 | score_bernoulli(1,0.5) 176 | 177 | def interact_prob(x,y,par): 178 | par_radial = 10 179 | par_azimutal = 10 180 | r = jax.numpy.sqrt(x**2+y**2) 181 | 182 | alpha = jax.numpy.arctan2(x,y) 183 | 184 | 185 | sampling1 = 1/(1+jax.numpy.exp(10*jax.numpy.sin(par_radial*(alpha+2*r)))) 186 | sampling2 = 1/(1+jax.numpy.exp(10*jax.numpy.cos(par_azimutal*(r-2)))) 187 | start = 1/(1+jax.numpy.exp(-10*(r-par))) 188 | end = 1/(1+jax.numpy.exp(10*(r-(par+10.0)))) 189 | 190 | return 0.5*start*sampling1*sampling2*end 191 | 192 | 193 | interact_prob_and_g = jax.jit(jax.value_and_grad(interact_prob, argnums = 2)) 194 | interact_prob_and_g(0.5,0.5,0.5) 195 | 196 | def interact_prob_and_grad(x,y,par): 197 | p,g = interact_prob_and_g(x,y,par) 198 | 199 | g = jax.lax.cond(jax.numpy.isnan(g), lambda x: 0., lambda x: x, g) 200 | #if jax.numpy.isnan(g): 201 | # g=0. 202 | 203 | return p, g 204 | 205 | interact_prob_and_grad = jax.jit(interact_prob_and_grad) 206 | interact_prob_and_grad(0.5,0.5,0.5) 207 | 208 | def sample_interact(score, state, sim_parameters, keep_derivs=True, fifos=None): 209 | x,y = state['x'], state['y'] 210 | par_thresh_x = sim_parameters['thresh_x'] 211 | 212 | interact_prob, interact_prob_grad = interact_prob_and_grad(x,y,par_thresh_x) 213 | 214 | if not keep_derivs: 215 | if fifos is not None and not fifos['interact'].empty(): 216 | u_input = fifos['interact'].get() 217 | interact = bernoulli_basic(interact_prob, get_omega=False, u_input=u_input) 218 | else: 219 | interact = bernoulli_basic(interact_prob, get_omega=False, u_input=None) 220 | 221 | return interact, None 222 | 223 | interact_prob_st = stochasticTriple(interact_prob_grad, 0., 0.) 224 | 225 | if fifos is not None: 226 | interact, interact_st, u_int = bernoulli(interact_prob, interact_prob_st, get_omega=True) 227 | fifos['interact'].put(u_int) 228 | else: 229 | interact, interact_st = bernoulli(interact_prob, interact_prob_st, get_omega=False) 230 | 231 | score['thresh_x'] += score_bernoulli(interact,interact_prob)*interact_prob_grad 232 | 233 | return interact, interact_st 234 | 235 | def sample_fate(score, state, program_st, sim_parameters, check_alts=True, fifos=None): 236 | stop = sample_stop_prob(score, state,sim_parameters) 237 | if stop: 238 | state['alive']=False 239 | return None 240 | 241 | if check_alts==False: 242 | interact, _ = sample_interact(None, state, sim_parameters, keep_derivs=False, fifos=fifos) 243 | interact_st = None 244 | keep_new_alt = False 245 | else: 246 | interact, interact_st = sample_interact(score, state, sim_parameters, keep_derivs=True, fifos=fifos) 247 | 248 | keep_new_alt = False 249 | if program_st['y']==None: 250 | keep_new_alt = True 251 | else: 252 | keep_new_alt = do_prune_away_old(interact_st, program_st) 253 | 254 | 255 | #keep_new_alt = True if program_st['y']==None else do_prune_away_old(interact_st, program_st) 256 | 257 | program_st['w'] += np.fabs(interact_st['w']) 258 | if keep_new_alt: 259 | program_st['d'] = interact_st['d'] 260 | 261 | if fifos is not None and not fifos['interact'].empty(): 262 | #_ = fifos['interact'].get() #remove this rv, as it was used to creat alternative 263 | fifos['interact'].queue.clear() 264 | 265 | 266 | bumpx = 0.05 267 | bumpy = 0.05#np.random.normal(0,.1) 268 | 269 | split = np.random.binomial(1, sim_parameters['split_prob']) 270 | 271 | if not interact: 272 | return { 273 | 'interact': False, 274 | 'interact_st': interact_st, 275 | 'keep_new_alt': keep_new_alt, 276 | 'split': split, 277 | 'eloss': 2.0, 278 | 'bumpx': bumpx, 279 | 'bumpy': bumpy, 280 | } 281 | 282 | 283 | 284 | if split: 285 | return { 286 | 'interact': True, 287 | 'interact_st': interact_st, 288 | 'keep_new_alt': keep_new_alt, 289 | 'split': True, 290 | 'eloss': 2.0, 291 | 'bumpx': bumpx, 292 | 'bumpy': bumpy 293 | } 294 | else: 295 | return { 296 | 'interact': True, 297 | 'interact_st': interact_st, 298 | 'keep_new_alt': keep_new_alt, 299 | 'split': False, 300 | 'eloss': 2.0, 301 | 'bumpx': bumpx, 302 | 'bumpy': bumpy 303 | }#np.random.uniform(0,0.1)} 304 | 305 | def fate2state(fate, state): 306 | 307 | def _update_stop(state): 308 | state1 = { 309 | 'E': state['E'], 310 | 'x': state['x'], 311 | 'y': state['y'], 312 | 'px': state['px'] + fate['bumpx'], 313 | 'py': state['py'] + fate['bumpy'], 314 | 'alive':False, 315 | } 316 | return state1 317 | 318 | def _update_split(state): 319 | 320 | norm1 = np.sqrt( (state['px'] + fate['bumpx'])**2 + (state['py'] + fate['bumpy'])**2 ) 321 | norm2 = np.sqrt( (state['px'] - fate['bumpx'])**2 + (state['py'] - fate['bumpy'])**2 ) 322 | 323 | 324 | state1 = { 325 | 'E': state['E']/2, 326 | 'x': state['x'], 327 | 'y': state['y'], 328 | 'px': (state['px'] + fate['bumpx']) / norm1, 329 | 'py': (state['py'] + fate['bumpy']) / norm1, 330 | 'alive':True, 331 | } 332 | state2 = { 333 | 'E': state['E']/2, 334 | 'x': state['x'], 335 | 'y': state['y'], 336 | 'px': (state['px'] - fate['bumpx']) / norm2, 337 | 'py': (state['py'] - fate['bumpy']) / norm2, 338 | 'alive':True, 339 | } 340 | return state1, state2 341 | 342 | def _update_eloss(state): 343 | new_E = state['E'] - fate['eloss'] 344 | if new_E < 0.: 345 | new_E = 0. 346 | 347 | bump_px = state['px'] + fate['bumpx']*(1.0 if np.random.binomial(1,0.5) else -1.0) 348 | bump_py = state['py'] + fate['bumpy']*(1.0 if np.random.binomial(1,0.5) else -1.0) 349 | 350 | #renomalize 351 | new_px = bump_px / np.sqrt(bump_px**2 + bump_py**2) 352 | new_py = bump_py / np.sqrt(bump_px**2 + bump_py**2) 353 | 354 | state1 = { 355 | 'E': new_E, 356 | 'x': state['x'], 357 | 'y': state['y'], 358 | 'px': new_px, 359 | 'py': new_py, 360 | 'alive':True, 361 | } 362 | return state1 363 | 364 | 365 | if fate is None: 366 | state1 = _update_stop(state) 367 | return state1, None, None, None 368 | 369 | if not fate['interact']: 370 | state1, state2 = state, None 371 | 372 | stateY1, stateY2 = None, None 373 | if fate['keep_new_alt']: 374 | if fate['split']: 375 | stateY1, stateY2 = _update_split(state) 376 | else: 377 | stateY1 = _update_eloss(state) 378 | stateY2 = None 379 | 380 | return state1,state2,stateY1,stateY2 381 | 382 | 383 | if fate['split']: 384 | state1, state2 = _update_split(state) 385 | 386 | stateY1, stateY2 = None, None 387 | if fate['keep_new_alt']: 388 | #stateY1 = _update_eloss(state) 389 | stateY1, stateY2 = state, None 390 | 391 | return state1,state2,stateY1,stateY2 392 | 393 | else: 394 | state1 = _update_eloss(state) 395 | state2 = None 396 | 397 | stateY1, stateY2 = None, None 398 | if fate['keep_new_alt']: 399 | #stateY1, stateY2 = _update_split(state) 400 | stateY1, stateY2 = state, None 401 | 402 | 403 | return state1,state2,stateY1,stateY2 404 | 405 | 406 | def run(score, history, hits, alive_states, program_st, sim_parameters, step_count, fifos = None): 407 | next_alive_states = [] 408 | next_alive_states_st = [] 409 | found_alt = False 410 | alt_index = None 411 | 412 | 413 | for state in alive_states: 414 | new_state = propagate_state(state) 415 | fate = sample_fate(score, new_state, program_st, sim_parameters, fifos = fifos) 416 | 417 | if fate is None: 418 | hits.append([state['x'],state['y'],state['E'],1]) 419 | else: 420 | #if fate is not None: 421 | hits.append([state['x'],state['y'],state['E'],fate['interact']]) 422 | history.append([[state['x'],state['y']],[new_state['x'],new_state['y']]]) 423 | next1, next2, nextY1, nextY2 = fate2state(fate, new_state) 424 | 425 | #print(step_count['n'], fate['keep_new_alt']) 426 | 427 | if fate['keep_new_alt']==True: 428 | found_alt = True 429 | alt_index = len(hits)-1 430 | next_alive_states_st = copy.deepcopy(next_alive_states) 431 | if nextY1 is not None: 432 | next_alive_states_st.append(nextY1) 433 | if nextY2 is not None: 434 | next_alive_states_st.append(nextY2) 435 | 436 | 437 | if next1 is not None: 438 | next_alive_states.append(next1) 439 | if found_alt and fate['keep_new_alt']==False: 440 | next_alive_states_st.append(next1) 441 | 442 | if next2 is not None: 443 | next_alive_states.append(next2) 444 | if found_alt and fate['keep_new_alt']==False: 445 | next_alive_states_st.append(next2) 446 | 447 | if found_alt: 448 | program_st['y'] = { 449 | 'history': copy.deepcopy(history), 450 | 'hits': copy.deepcopy(hits), 451 | 'alive_states': next_alive_states_st, 452 | } 453 | interact_primal = program_st['y']['hits'][alt_index][3] 454 | program_st['y']['hits'][alt_index][3] = (not interact_primal) 455 | 456 | 457 | else: 458 | if program_st['y'] is not None: 459 | next_alive_states_st = [] 460 | 461 | for state in program_st['y']['alive_states']: 462 | new_state = propagate_state(state) 463 | fate = sample_fate(None, new_state, None, sim_parameters, check_alts=False, fifos = fifos) 464 | 465 | if fate is None: 466 | program_st['y']['hits'].append([state['x'],state['y'],state['E'],1]) 467 | else: 468 | #if fate is not None: 469 | program_st['y']['hits'].append([state['x'],state['y'],state['E'],fate['interact']]) 470 | program_st['y']['history'].append([[state['x'],state['y']],[new_state['x'],new_state['y']]]) 471 | next1, next2, _, _ = fate2state(fate, new_state) 472 | 473 | 474 | if next1 is not None: 475 | next_alive_states_st.append(next1) 476 | if next2 is not None: 477 | next_alive_states_st.append(next2) 478 | 479 | program_st['y']['alive_states'] = next_alive_states_st 480 | 481 | run_again = (len(next_alive_states)>0) 482 | run_again_st = False 483 | if (program_st['y'] is None): 484 | run_again_st = False 485 | else: 486 | if(len(program_st['y']['alive_states'])>0): 487 | run_again_st = True 488 | 489 | step_count['n'] += 1 490 | 491 | if run_again or run_again_st: 492 | #print(len(next_alive_states)) 493 | try: 494 | run(score, history, hits, next_alive_states, program_st, sim_parameters, step_count, fifos) 495 | except RecursionError as e: 496 | print("####### Caught Recursion Error #######") 497 | print(len(next_alive_states), len(program_st['y']['alive_states']), step_count['n']) 498 | print("next_alive_states:") 499 | for state in next_alive_states: 500 | print(state) 501 | 502 | print("next_alive_states_st:") 503 | for state in program_st['y']['alive_states']: 504 | print(state) 505 | 506 | return 507 | 508 | 509 | def generate_random_init(): 510 | phi = np.random.uniform(-np.pi,np.pi) 511 | py = np.sin(phi) 512 | px = np.cos(phi) 513 | state = {'x': 0, 'y': 0, 'r': 0, 'px': px, 'py': py, 'E': 25, 'alive':True} 514 | return state 515 | 516 | 517 | def generate(init, program_st, sim_parameters, reuse_rvs=False): 518 | history, hits, score = [], [], {'thresh_x': 0.0, 'thresh_E': 0.0} 519 | step_count = {'n':0} 520 | 521 | fifos = None 522 | if reuse_rvs: 523 | fifos = { "interact":queue.Queue() } 524 | 525 | run(score, history, hits, alive_states = [init], 526 | program_st = program_st, sim_parameters=sim_parameters, 527 | step_count=step_count, fifos=fifos) 528 | 529 | hits = np.array(hits) 530 | act = hits[hits[:,3]==1] 531 | hits_st = np.array(program_st['y']['hits']) 532 | out_st = stochasticTriple(program_st['d'], 533 | {'hits':hits_st, 534 | 'active':hits_st[hits_st[:,3]==1], 535 | 'history':np.array(program_st['y']['history'])}, 536 | program_st['w']) 537 | return np.array(hits), act, np.array(history), {k:np.array(v) for k,v in score.items()}, out_st 538 | 539 | def simulator(par, split_prob, reuse_rvs = True): 540 | init = generate_random_init() 541 | program_st = stochasticTriple(0., None, 0.) 542 | 543 | sim_parameters = {'thresh_E': 0.5, 'split_prob':split_prob, 'thresh_x': par} 544 | 545 | hits,active,history,scores,out_st = generate(init, program_st, sim_parameters, reuse_rvs) 546 | scores = jax.lax.stop_gradient(scores['thresh_x']) 547 | return hits,active,history,scores,out_st 548 | 549 | 550 | return simulator -------------------------------------------------------------------------------- /GradientEstimates.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib.pyplot as plt\n", 10 | "import seaborn as sns\n", 11 | "from shower_sim_baseline import summary_metric\n", 12 | "from shower_sim_instrumented import simulator\n", 13 | "from tqdm import tqdm\n", 14 | "import numpy as np\n", 15 | "from shower_sim_optimize import optimize, program_to_optimize\n", 16 | "\n", 17 | "%load_ext autoreload\n", 18 | "%autoreload 2\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "the_program = program_to_optimize(simulator, summary_metric)\n", 28 | "\n", 29 | "def the_program_all_grads(theta):\n", 30 | " return the_program(theta, keep_all_grads=True)\n", 31 | "\n", 32 | "runs1 = [the_program_all_grads(2.5) for _ in tqdm(range(500))]" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 18, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "image/png": "", 43 | "text/plain": [ 44 | "
" 45 | ] 46 | }, 47 | "metadata": {}, 48 | "output_type": "display_data" 49 | } 50 | ], 51 | "source": [ 52 | "from plots_gradients import plot_variance_with_inset\n", 53 | "f,axarr = plt.subplots(1,1)\n", 54 | "f.set_size_inches(6,6)\n", 55 | "plot_variance_with_inset(axarr,runs1)\n", 56 | "f.savefig('gradient_variance.pdf')" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [] 65 | } 66 | ], 67 | "metadata": { 68 | "kernelspec": { 69 | "display_name": "Python 3.11.4 64-bit ('stochad_env')", 70 | "language": "python", 71 | "name": "python3" 72 | }, 73 | "language_info": { 74 | "codemirror_mode": { 75 | "name": "ipython", 76 | "version": 3 77 | }, 78 | "file_extension": ".py", 79 | "mimetype": "text/x-python", 80 | "name": "python", 81 | "nbconvert_exporter": "python", 82 | "pygments_lexer": "ipython3", 83 | "version": "3.11.4" 84 | }, 85 | "orig_nbformat": 4, 86 | "vscode": { 87 | "interpreter": { 88 | "hash": "29b777788fad9121f9b4a41d949494280ba66b13a17f3ccd9e5dfa0de3270b9d" 89 | } 90 | } 91 | }, 92 | "nbformat": 4, 93 | "nbformat_minor": 2 94 | } 95 | -------------------------------------------------------------------------------- /notebooks/03_VisualizeSimulator.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "id": "726e1c72-9de8-4334-b8c9-c203da688e57", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "The autoreload extension is already loaded. To reload it, use:\n", 14 | " %reload_ext autoreload\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import sys\n", 20 | "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "import jax\n", 25 | "import copy\n", 26 | "import queue\n", 27 | "from shower_sim_redone import make_simulator\n", 28 | "from shower_sim_baseline import summary_metric, interact_prob\n", 29 | "\n", 30 | "simulator = make_simulator()\n", 31 | "%load_ext autoreload\n", 32 | "%autoreload 2" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "12874c6c", 38 | "metadata": {}, 39 | "source": [ 40 | "# Check it Runs" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 7, 46 | "id": "ac9ffd10-1148-4e73-82a6-bdad052d94df", 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "-6.4722357 26.9953502063654 2824 3088\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "\n", 59 | "\n", 60 | "hits,active,history,scores,out_st = simulator(3.5, split_prob = 1.0, reuse_rvs=True)\n", 61 | "print(scores, out_st['w'], hits.size, out_st['y']['hits'].size)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 8, 67 | "id": "45d0d360-cd91-4155-bc0b-520f87059b8e", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "from shower_sim_optimize import optimize, program_to_optimize\n", 72 | "from shower_sim_baseline import per_hit_summary\n", 73 | "the_program = program_to_optimize(simulator, summary_metric, sim_kwargs=dict(split_prob = 1.0, reuse_rvs=True))" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 58, 79 | "id": "92181b2b-db93-40dc-9265-f4c959e56dcc", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "########### Plotting ###########\n", 84 | "\n", 85 | "def _plot_bkg(ax,par,min = -5,max = 5):\n", 86 | " grid = np.mgrid[-min:min:701j,-min:min:701j]\n", 87 | " points = np.swapaxes(grid,0,-1).reshape(-1,2)\n", 88 | " vals = jax.vmap(interact_prob,in_axes = (0,0,None))(points[:,0],points[:,1],par)\n", 89 | " vals = vals.reshape(701,701).T\n", 90 | " ax.contourf(grid[0],grid[1],vals, cmap = 'Greys', vmin = 0,vmax = 1, alpha = 1.0) \n", 91 | " ax.set_xlabel('x')\n", 92 | " ax.set_ylabel('y')\n", 93 | " ax.set_ylim(-6,6)\n", 94 | " ax.set_xlim(-6,6)\n", 95 | "\n", 96 | "def _plot_event(ax,generation):\n", 97 | " hits,active,history,scores,out_st = generation\n", 98 | "\n", 99 | " alt_active = generation[-1]['y']['active']\n", 100 | " ax.plot(history[:,:,0].T,history[:,:,1].T, c = 'k', alpha = 0.2);\n", 101 | " # ax.scatter(hits[:,0],hits[:,1], c = hits[:,3], alpha = 0.4)\n", 102 | " ax.scatter(active[:,0],active[:,1], c = active[:,3], alpha = 0.4)\n", 103 | "\n", 104 | "def _plot_event_summary(ax,generation):\n", 105 | " hits,active,history,scores,out_st = generation\n", 106 | " ax.hist(per_hit_summary(active),bins = np.linspace(0,5,301), density=True)\n", 107 | " ax.set_ylim(0,3)\n", 108 | " ax.set_xlim(0,5)\n", 109 | " ax.set_xlabel('r')\n", 110 | " ax.set_ylabel('p_hits(r)')\n", 111 | "\n" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 59, 117 | "id": "3d0df2f6", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "np.random.seed(8)\n", 122 | "generation = simulator(2.5, **dict(split_prob = 1.0))" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 60, 128 | "id": "514cd715", 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "image/png": "", 134 | "text/plain": [ 135 | "
" 136 | ] 137 | }, 138 | "metadata": {}, 139 | "output_type": "display_data" 140 | } 141 | ], 142 | "source": [ 143 | "f,ax = plt.subplots(1,1)\n", 144 | "par = 2.5\n", 145 | "_plot_bkg(ax,par,-6,6)\n", 146 | "_plot_event(ax, generation)\n", 147 | "ax.set_xlim(-3,-1)\n", 148 | "ax.set_ylim(1.,3.)\n", 149 | "f.set_size_inches(4,4)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "47b08e4d", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "Python 3.11.4 64-bit ('stochad_env')", 164 | "language": "python", 165 | "name": "python3" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.11.4" 178 | }, 179 | "vscode": { 180 | "interpreter": { 181 | "hash": "29b777788fad9121f9b4a41d949494280ba66b13a17f3ccd9e5dfa0de3270b9d" 182 | } 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 5 187 | } 188 | -------------------------------------------------------------------------------- /VisualizeDesign.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "id": "726e1c72-9de8-4334-b8c9-c203da688e57", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "The autoreload extension is already loaded. To reload it, use:\n", 14 | " %reload_ext autoreload\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "import jax\n", 22 | "import copy\n", 23 | "import queue\n", 24 | "from shower_sim_instrumented import simulator\n", 25 | "from shower_sim_baseline import summary_metric, interact_prob\n", 26 | "\n", 27 | "%load_ext autoreload\n", 28 | "%autoreload 2" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "12874c6c", 34 | "metadata": {}, 35 | "source": [ 36 | "# Check it Runs" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 11, 42 | "id": "ac9ffd10-1148-4e73-82a6-bdad052d94df", 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "-9.576843 49.74283491693495 3440 3100\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "hits,active,history,scores,out_st = simulator(3.5)\n", 55 | "print(scores, out_st['w'], hits.size, out_st['y']['hits'].size)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 12, 61 | "id": "45d0d360-cd91-4155-bc0b-520f87059b8e", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from shower_sim_optimize import optimize, program_to_optimize\n", 66 | "from shower_sim_baseline import per_hit_summary\n", 67 | "the_program = program_to_optimize(simulator, summary_metric)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 15, 73 | "id": "92181b2b-db93-40dc-9265-f4c959e56dcc", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "########### Plotting ###########\n", 78 | "\n", 79 | "def _plot_bkg(ax,par,min = -5,max = 5):\n", 80 | " grid = np.mgrid[-min:min:701j,-min:min:701j]\n", 81 | " points = np.swapaxes(grid,0,-1).reshape(-1,2)\n", 82 | " vals = jax.vmap(interact_prob,in_axes = (0,0,None))(points[:,0],points[:,1],par)\n", 83 | " vals = vals.reshape(701,701).T\n", 84 | " ax.contourf(grid[0],grid[1],vals, cmap = 'Greys', vmin = 0,vmax = 1, alpha = 1.0) \n", 85 | " ax.set_xlabel('x')\n", 86 | " ax.set_ylabel('y')\n", 87 | " ax.set_ylim(-6,6)\n", 88 | " ax.set_xlim(-6,6)\n", 89 | "\n", 90 | "def _plot_event(ax,generation):\n", 91 | " hits,active,history,scores,out_st = generation\n", 92 | " ax.plot(history[:,:,0].T,history[:,:,1].T, c = 'k', alpha = 0.2);\n", 93 | " # ax.scatter(hits[:,0],hits[:,1], c = hits[:,3], alpha = 0.4)\n", 94 | " ax.scatter(active[:,0],active[:,1], c = active[:,3], alpha = 0.4)\n", 95 | "\n", 96 | "def _plot_event_summary(ax,generation):\n", 97 | " hits,active,history,scores,out_st = generation\n", 98 | " ax.hist(per_hit_summary(active),bins = np.linspace(0,5,301), density=True)\n", 99 | " ax.set_ylim(0,3)\n", 100 | " ax.set_xlim(0,5)\n", 101 | " ax.set_xlabel('r')\n", 102 | " ax.set_ylabel('p_hits(r)')\n", 103 | "\n", 104 | "def plot_config_and_summary_onax(axarr,par):\n", 105 | " ax = axarr[0]\n", 106 | "\n", 107 | " generation = simulator(par)\n", 108 | " _plot_bkg(ax,par,-6,6)\n", 109 | " _plot_event(ax, generation)\n", 110 | "\n", 111 | " ax = axarr[1]\n", 112 | " _plot_event_summary(ax,generation)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 16, 118 | "id": "514cd715", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "image/png": "", 124 | "text/plain": [ 125 | "
" 126 | ] 127 | }, 128 | "metadata": {}, 129 | "output_type": "display_data" 130 | } 131 | ], 132 | "source": [ 133 | "f,axarr = plt.subplots(1,2)\n", 134 | "f.set_size_inches(10,5)\n", 135 | "plot_config_and_summary_onax(axarr,2.5)" 136 | ] 137 | } 138 | ], 139 | "metadata": { 140 | "kernelspec": { 141 | "display_name": "Python 3.11.4 64-bit ('stochad_env')", 142 | "language": "python", 143 | "name": "python3" 144 | }, 145 | "language_info": { 146 | "codemirror_mode": { 147 | "name": "ipython", 148 | "version": 3 149 | }, 150 | "file_extension": ".py", 151 | "mimetype": "text/x-python", 152 | "name": "python", 153 | "nbconvert_exporter": "python", 154 | "pygments_lexer": "ipython3", 155 | "version": "3.11.4" 156 | }, 157 | "vscode": { 158 | "interpreter": { 159 | "hash": "29b777788fad9121f9b4a41d949494280ba66b13a17f3ccd9e5dfa0de3270b9d" 160 | } 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 5 165 | } 166 | --------------------------------------------------------------------------------