├── README.md ├── cvar ├── common │ ├── cvar_computation.py │ └── util.py ├── dqn │ ├── core │ │ ├── __init__.py │ │ ├── build_graph.py │ │ ├── models.py │ │ ├── plots.py │ │ ├── replay_buffer.py │ │ ├── simple.py │ │ └── static.py │ ├── ice_lake │ │ ├── README.md │ │ ├── __init__.py │ │ ├── icelake.png │ │ ├── icelake.py │ │ └── ple_env.py │ └── scripts │ │ ├── README.md │ │ ├── enjoy_atari.py │ │ ├── enjoy_ice.py │ │ ├── enjoy_pong.py │ │ ├── enjoy_simple.py │ │ ├── test_atari.py │ │ ├── train_atari.py │ │ ├── train_ice.py │ │ ├── train_pong.py │ │ └── train_simple.py └── gridworld │ ├── README.md │ ├── algorithms │ ├── q_learning.py │ └── value_iteration.py │ ├── cliffwalker.py │ ├── core │ ├── __init__.py │ ├── constants.py │ ├── cvar_computation.py │ ├── models.py │ ├── policies.py │ └── runs.py │ ├── exp_model.py │ ├── interactive.png │ ├── plots │ ├── __init__.py │ ├── grid.py │ ├── info_plots.py │ ├── other.py │ ├── thesis_plots.py │ └── vi_compare_plots.py │ ├── run_q.py │ └── run_vi.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # Risk-Averse Distributional Reinforcement Learning 2 | 3 | This package contains code for [my thesis](https://github.com/Silvicek/cvar-rl-thesis) including CVaR Value Iteration, CVaR Q-learning and Deep CVaR Q-learning. 4 | 5 | 6 | ## Abstract 7 | 8 | Conditional Value-at-Risk (CVaR) is a well-known measure of risk that has been 9 | used for decades in the financial sector and has been directly equated to robustness, 10 | an important component of Artificial Intelligence (AI) safety. In this thesis we 11 | focus on optimizing CVaR in the context of Reinforcement Learning, a branch of 12 | Machine Learning that has brought significant attention to AI due to its generality 13 | and potential. 14 | 15 | As a first original contribution, we extend the CVaR Value Iteration algorithm 16 | (Chow et al. [20]) by utilizing the distributional nature of the CVaR objective. The 17 | proposed extension reduces computational complexity of the original algorithm from 18 | polynomial to linear and we prove it is equivalent to the said algorithm for continuous 19 | distributions. 20 | 21 | Secondly, based on the improved procedure, we propose a sampling version of 22 | CVaR Value Iteration we call CVaR Q-learning. We also derive a distributional 23 | policy improvement algorithm, prove its validity, and later use it as a heuristic for 24 | extracting the optimal policy from the converged CVaR Q-learning algorithm. 25 | 26 | Finally, to show the scalability of our method, we propose an approximate Q-learning 27 | algorithm by reformulating the CVaR Temporal Difference update rule as 28 | a loss function which we later use in a deep learning context. 29 | 30 | All proposed methods are experimentally analyzed, using a risk-sensitive gridworld 31 | environment for CVaR Value Iteration and Q-learning and a challenging visual environment 32 | for the approximate CVaR Q-learning algorithm. All trained agents are 33 | able to learn risk-sensitive policies, including the Deep CVaR Q-learning agent which 34 | learns how to avoid risk from raw pixels. 35 | 36 | ## Installation 37 | 38 | Install tensorflow by following instructions in https://www.tensorflow.org/install/ 39 | Using GPU during training is highly recommended (but not required) 40 | 41 | pip3 install tensorflow-gpu 42 | 43 | Next install [OpenAI baselines](https://github.com/Silvicek/baselines) 44 | 45 | git clone https://github.com/Silvicek/baselines.git 46 | cd PyGame-Learning-Environment/ 47 | pip3 install -e . 48 | 49 | Next install the [Pygame Learning Environment](https://github.com/ntasfi/PyGame-Learning-Environment) 50 | 51 | git clone https://github.com/ntasfi/PyGame-Learning-Environment.git 52 | cd PyGame-Learning-Environment/ 53 | pip3 install -e . 54 | 55 | Lastly, install the cvar package (from cvar-algorithms) 56 | 57 | pip3 install -e . 58 | 59 | ### CVaR Value Iteration, CVaR Q-learning 60 | 61 | See readme in [`cvar/gridworld`](https://github.com/Silvicek/cvar-algorithms/tree/master/cvar/gridworld) 62 | 63 | 64 | ### Deep CVaR Q-learning 65 | 66 | See readme in [`cvar/dqn/scripts`](https://github.com/Silvicek/cvar-algorithms/tree/master/cvar/dqn/scripts) 67 | -------------------------------------------------------------------------------- /cvar/common/cvar_computation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def var_cvar_from_samples(samples, alpha): 5 | samples = np.sort(samples) 6 | alpha_ix = int(np.round(alpha * len(samples))) 7 | var = samples[alpha_ix - 1] 8 | cvar = np.mean(samples[:alpha_ix]) 9 | return var, cvar 10 | -------------------------------------------------------------------------------- /cvar/common/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def print_output(info_string): 5 | """ Decorator that prints the output of a function, together with an information string. """ 6 | def wrap_output(func): 7 | def func_wrapper(*args, **kwargs): 8 | output = func(*args, **kwargs) 9 | print(info_string, output) 10 | return output 11 | return func_wrapper 12 | return wrap_output 13 | 14 | 15 | def timed(func): 16 | """ Decorator that reports the runtime of a function. """ 17 | def func_wrapper(*args, **kwargs): 18 | import time 19 | start = time.time() 20 | output = func(*args, **kwargs) 21 | print("Running {} took {:.1f}s.".format(func.__name__, time.time()-start)) 22 | return output 23 | return func_wrapper 24 | 25 | 26 | @print_output('ATOMS:') 27 | def spaced_atoms(nb_atoms, spacing, log_atoms, log_threshold): 28 | assert log_atoms <= nb_atoms 29 | assert spacing > 1 30 | 31 | if log_atoms != 0: 32 | lin = np.linspace(log_threshold, 1, nb_atoms - log_atoms) 33 | log_only = int(log_atoms == nb_atoms) 34 | if spacing < 2: 35 | log = np.array([0, log_threshold * 0.5 / spacing ** log_atoms] + [log_threshold / spacing ** (log_atoms - i) 36 | for i in range(log_only, log_atoms - 1 + log_only)]) 37 | else: 38 | log = np.array([0] + [log_threshold / spacing ** (log_atoms - i) 39 | for i in range(log_only, log_atoms + log_only)]) 40 | 41 | atoms = np.hstack((log, lin)) 42 | else: 43 | atoms = np.linspace(0, 1, nb_atoms+1) 44 | 45 | assert np.all(atoms == np.array(sorted(atoms))) 46 | assert atoms[0] == 0 47 | assert atoms[-1] == 1 48 | 49 | return atoms 50 | 51 | 52 | def softmax(x): 53 | exp = np.exp(x) 54 | if len(x.shape) > 1: 55 | return exp / np.sum(exp, axis=0) 56 | else: 57 | return exp / np.sum(exp) 58 | 59 | 60 | time_start = 0 61 | def tick(): 62 | import time 63 | global time_start 64 | time_start = time.time() 65 | 66 | def tock(): 67 | import time 68 | print("t = {:.2f}".format(time.time() - time_start)) 69 | -------------------------------------------------------------------------------- /cvar/dqn/core/__init__.py: -------------------------------------------------------------------------------- 1 | import cvar.dqn.core.models 2 | # 3 | from cvar.dqn.core.build_graph import build_act, build_train 4 | from cvar.dqn.core.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer 5 | from cvar.dqn.core.simple import learn, load, make_session 6 | from cvar.dqn.core.static import * 7 | -------------------------------------------------------------------------------- /cvar/dqn/core/build_graph.py: -------------------------------------------------------------------------------- 1 | """Deep Q learning graph 2 | 3 | The functions in this file can are used to create the following functions: 4 | 5 | ======= act ======== 6 | 7 | Function to chose an action given an observation 8 | 9 | Parameters 10 | ---------- 11 | observation: object 12 | Observation that can be feed into the output of make_obs_ph 13 | alpha: float 14 | Action is picked to maximize CVaR_alpha 15 | stochastic: bool 16 | if set to False all the actions are always deterministic (default False) 17 | update_eps_ph: float 18 | update epsilon a new value, if negative no update happens 19 | (default: no update) 20 | 21 | Returns 22 | ------- 23 | Tensor of dtype tf.int64 and shape (BATCH_SIZE,) with an action to be performed for 24 | every element of the batch. 25 | 26 | 27 | ======= train ======= 28 | 29 | Function that takes a transition (s,a,r,s') and optimizes CVaR TD error: 30 | 31 | TODO: describe error 32 | 33 | Parameters 34 | ---------- 35 | obs_t: object 36 | a batch of observations 37 | action: np.array 38 | actions that were selected upon seeing obs_t. 39 | dtype must be int32 and shape must be (batch_size,) 40 | reward: np.array 41 | immediate reward attained after executing those actions 42 | dtype must be float32 and shape must be (batch_size,) 43 | obs_tp1: object 44 | observations that followed obs_t 45 | done: np.array 46 | 1 if obs_t was the last observation in the episode and 0 otherwise 47 | obs_tp1 gets ignored, but must be of the valid shape. 48 | dtype must be float32 and shape must be (batch_size,) 49 | weight: np.array 50 | imporance weights for every element of the batch (gradient is multiplied 51 | by the importance weight) dtype must be float32 and shape must be (batch_size,) 52 | 53 | Returns 54 | ------- 55 | td_error: np.array 56 | a list of differences between Q(s,a) and the target in CVaR Bellman's equation. 57 | dtype is float32 and shape is (batch_size,) 58 | 59 | ======= update_target ======== 60 | 61 | copy the parameters from optimized P function to the target P function. 62 | 63 | """ 64 | import tensorflow as tf 65 | import baselines.common.tf_util as U 66 | 67 | 68 | def pick_actions(cvar_values): 69 | """ 70 | Select actions based on optimal CVaR value for each atom. 71 | Parameters 72 | ---------- 73 | cvar_values: (?, actions, nb_atoms) 74 | 75 | Returns 76 | ------- 77 | (?, nb_atoms) 78 | """ 79 | deterministic_actions = tf.argmax(cvar_values, axis=-1, output_type=tf.int32) 80 | return deterministic_actions 81 | 82 | 83 | debug_expressions = [] 84 | 85 | 86 | def pick_action(cvar_values, alpha, nb_atoms): 87 | """ 88 | Pick a single action based on CVaR_alpha. 89 | Assumes linearly spaced atoms. 90 | 91 | Parameters 92 | ---------- 93 | cvar_values: (?, actions, nb_atoms) 94 | 95 | Returns 96 | ------- 97 | (?,) 98 | 99 | """ 100 | 101 | ix_f = alpha*nb_atoms - 1 102 | ix_int = tf.cast(tf.floor(ix_f), tf.int32) 103 | portion = ix_f - tf.cast(ix_int, tf.float32) 104 | # special case if alpha=1 105 | ix_next = tf.cond(tf.equal(alpha, tf.constant(1, tf.float32)), lambda: ix_int, lambda: ix_int+1, name='ix_next') 106 | 107 | cvar_alpha_std = cvar_values[:, :, ix_int] * (1-portion) + cvar_values[:, :, ix_next] * portion 108 | 109 | # if alpha is before first atom 110 | cvar_alpha_zero = cvar_values[:, :, ix_next] * portion 111 | 112 | cvar_alpha = tf.cond(tf.less(alpha, 1/nb_atoms), lambda: cvar_alpha_zero, lambda: cvar_alpha_std) 113 | 114 | return tf.argmax(cvar_alpha, axis=-1, output_type=tf.int32) 115 | 116 | 117 | def build_act(make_obs_ph, cvar_func, var_func, num_actions, nb_atoms, scope="cvar_dqn", reuse=None): 118 | """Creates the act function: 119 | 120 | Parameters 121 | ---------- 122 | make_obs_ph: str -> tf.placeholder or TfInput 123 | a function that take a name and creates a placeholder of input with that name 124 | cvar_func: (tf.Variable, int, str, bool) -> tf.Variable 125 | the model that takes the following inputs: 126 | observation_in: object 127 | the output of observation placeholder 128 | num_actions: int 129 | number of actions 130 | scope: str 131 | reuse: bool 132 | should be passed to outer variable scope 133 | and returns a tensor of shape (batch_size, num_actions) with values of every action. 134 | num_actions: int 135 | number of actions. 136 | nb_atoms: int 137 | number of linearly-spaced atoms 138 | scope: str or VariableScope 139 | optional scope for variable_scope. 140 | reuse: bool or None 141 | whether or not the variables should be reused. To be able to reuse the scope must be given. 142 | 143 | Returns 144 | ------- 145 | act: (tf.Variable, bool, float, float) -> tf.Variable 146 | function to select and action given observation. 147 | ` See the top of the file for details. 148 | """ 149 | with tf.variable_scope(scope, reuse=reuse): 150 | observations_ph = U.ensure_tf_input(make_obs_ph("observation")) 151 | # alpha in cvar_alpha 152 | alpha_ph = U.ensure_tf_input(tf.placeholder(tf.float32, (), name="alpha")) 153 | 154 | stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic") 155 | update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps") 156 | 157 | # eps in epsilon-greedy 158 | eps = tf.get_variable("eps", (), initializer=tf.constant_initializer(0)) 159 | 160 | cvar_values = cvar_func(observations_ph.get(), num_actions, nb_atoms, scope="out_func") 161 | # keep here for plotting 162 | var_values = var_func(observations_ph.get(), num_actions, nb_atoms, scope="out_func", 163 | reuse_main=True, reuse_last=False) 164 | 165 | deterministic_actions = pick_action(cvar_values, alpha_ph.get(), nb_atoms) 166 | 167 | batch_size = tf.shape(observations_ph.get())[0] 168 | random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int32) 169 | chose_random = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps 170 | stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions) 171 | 172 | output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions) 173 | update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps)) 174 | act = U.function(inputs=[observations_ph, alpha_ph, stochastic_ph, update_eps_ph], 175 | outputs=output_actions, 176 | givens={update_eps_ph: -1.0, stochastic_ph: True}, 177 | updates=[update_eps_expr]) 178 | 179 | return act 180 | 181 | 182 | def extract_distribution(y_cvar, nb_atoms): 183 | """ Convert yC -> underlying distribution. 184 | y_cvar: (?, nb_atoms) 185 | """ 186 | 187 | dist_cropped = y_cvar[:, 1:] - y_cvar[:, :-1] 188 | dist = tf.concat((y_cvar[:, 0, None], dist_cropped), axis=1) * nb_atoms 189 | return dist 190 | 191 | 192 | def build_train(make_obs_ph, var_func, cvar_func, num_actions, nb_atoms, optimizer, grad_norm_clipping=None, gamma=1.0, 193 | scope="cvar_dqn", reuse=None): 194 | """Creates the train function: 195 | 196 | Parameters 197 | ---------- 198 | make_obs_ph: str -> tf.placeholder or TfInput 199 | a function that takes a name and creates a placeholder of input with that name 200 | var_func: (tf.Variable, int, int, str, bool) -> tf.Variable 201 | the model that takes the following inputs: 202 | observation_in: object 203 | the output of observation placeholder 204 | num_actions: int 205 | number of actions 206 | nb_atoms: int 207 | number of atoms 208 | scope: str 209 | reuse: bool 210 | should be passed to outer variable scope 211 | and returns a tensor of shape (batch_size, num_actions) with values of every action. 212 | cvar_func: (tf.Variable, int, str, bool) -> tf.Variable 213 | see var_func 214 | num_actions: int 215 | number of actions 216 | reuse: bool 217 | whether or not to reuse the graph variables 218 | optimizer: tf.train.Optimizer 219 | optimizer to use for the Q-learning objective. 220 | grad_norm_clipping: float or None 221 | clip gradient norms to this value. If None no clipping is performed. 222 | gamma: float 223 | discount rate. 224 | scope: str or VariableScope 225 | optional scope for variable_scope. 226 | reuse: bool or None 227 | whether or not the variables should be reused. To be able to reuse the scope must be given. 228 | 229 | Returns 230 | ------- 231 | act: (tf.Variable, bool, float) -> tf.Variable 232 | function to select and action given observation. 233 | ` See the top of the file for details. 234 | train: (object, np.array, np.array, object, np.array, np.array) -> np.array 235 | optimize the error in Bellman's equation. 236 | ` See the top of the file for details. 237 | update_target: () -> () 238 | copy the parameters from optimized Q function to the target Q function. 239 | ` See the top of the file for details. 240 | debug: {str: function} 241 | a bunch of functions to print debug data like q_values. 242 | """ 243 | act_f = build_act(make_obs_ph, cvar_func, var_func, num_actions, nb_atoms, scope=scope, reuse=reuse) 244 | 245 | with tf.variable_scope(scope, reuse=reuse): 246 | # set up placeholders 247 | obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t")) 248 | act_t_ph = tf.placeholder(tf.int32, [None], name="action") 249 | rew_t_ph = tf.placeholder(tf.float32, [None], name="reward") 250 | obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1")) 251 | done_mask_ph = tf.placeholder(tf.float32, [None], name="done") 252 | importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight") 253 | # atoms 254 | y = tf.range(1, nb_atoms + 1, dtype=tf.float32, name='y') * 1. / nb_atoms 255 | 256 | # ------------------------------- Core networks --------------------------------- 257 | # var network 258 | var_t = var_func(obs_t_input.get(), num_actions, nb_atoms, scope="out_func", 259 | reuse_main=True, reuse_last=True) # reuse from act 260 | 261 | # vars for actions which we know were selected in the given state. 262 | var_t_selected = gather_along_second_axis(var_t, act_t_ph) 263 | var_t_selected.set_shape([None, nb_atoms]) 264 | 265 | # cvar network 266 | cvar_t = cvar_func(obs_t_input.get(), num_actions, nb_atoms, scope="out_func", 267 | reuse_main=True, reuse_last=True) # reuse from act 268 | 269 | # cvars for actions which we know were selected in the given state. 270 | cvar_t_selected = gather_along_second_axis(cvar_t, act_t_ph) 271 | cvar_t_selected.set_shape([None, nb_atoms]) 272 | 273 | # target cvar network 274 | cvar_tp1 = cvar_func(obs_tp1_input.get(), num_actions, nb_atoms, scope="target_cvar_func") 275 | 276 | # extract variables 277 | joint_variables = U.scope_vars(U.absolute_scope_name("out_func/net")) 278 | var_variables = U.scope_vars(U.absolute_scope_name("out_func/var")) 279 | cvar_variables = U.scope_vars(U.absolute_scope_name("out_func/cvar")) 280 | target_cvar_func_variables = U.scope_vars(U.absolute_scope_name("target_cvar_func")) 281 | 282 | # ------------------------------------------------------------------------------- 283 | 284 | # ----------------------------- Extract distribution ---------------------------- 285 | # construct a new cvar with different actions for each atom 286 | cvar_tp1_star = tf.reduce_max(cvar_tp1, axis=1) 287 | cvar_tp1_star.set_shape([None, nb_atoms]) 288 | # construct a distribution from the new cvar 289 | ycvar_tp1_star = cvar_tp1_star * y 290 | dist_tp1_star_ = extract_distribution(ycvar_tp1_star, nb_atoms) 291 | 292 | # apply done mask 293 | dist_tp1_star = tf.einsum('ij,i->ij', dist_tp1_star_, 1. - done_mask_ph) 294 | 295 | # Td = r + gamma * dist 296 | dist_target = tf.identity(rew_t_ph[:, None] + gamma * dist_tp1_star, name='dist_target') 297 | # dist is always non-differentiable 298 | dist_target = tf.stop_gradient(dist_target) 299 | 300 | # ------------------------------------------------------------------------------- 301 | 302 | # ---------------------------------- VaR loss ----------------------------------- 303 | 304 | td_error = dist_target[:, :, None] - var_t_selected[:, None, :] 305 | # td_error[0]= 306 | # [[Td1-v1 Td1-v2 ... Td1-vn] 307 | # [Td2-v1 Td2-v2 ... Td2-vn] 308 | # [... ] 309 | # [Tdn-v1 Tdn-v2 ... Tdn-vn]] 310 | 311 | negative_indicator = tf.cast(td_error < 0, tf.float32) 312 | 313 | var_weights = tf.stop_gradient(y - negative_indicator) # XXX: stop gradient? 314 | quantile_loss = var_weights * td_error 315 | 316 | var_error = tf.reduce_mean(quantile_loss) 317 | # ------------------------------------------------------------------------------- 318 | 319 | # ---------------------------------- CVaR loss ---------------------------------- 320 | # Minimizing the MSE of: 321 | # V_i + 1/y_i(Td_j - V_i)^- - C_i 322 | 323 | min_target_diff = negative_indicator / y * tf.stop_gradient(td_error) 324 | cvar_loss = tf.stop_gradient(var_t_selected)[:, None, :] + min_target_diff - cvar_t_selected[:, None, :] 325 | 326 | cvar_error = tf.reduce_mean(tf.square(cvar_loss)) 327 | 328 | # ------------------------------------------------------------------------------- 329 | 330 | # ------------------------------- Finalizing ------------------------------------ 331 | 332 | error = var_error + cvar_error 333 | # compute optimization op (potentially with gradient clipping) 334 | var_list = [joint_variables, var_variables, cvar_variables] 335 | if grad_norm_clipping is not None: 336 | optimize_expr = U.minimize_and_clip(optimizer, error, var_list, clip_val=grad_norm_clipping) 337 | else: 338 | optimize_expr = optimizer.minimize(error, var_list=var_list) 339 | 340 | # update_target_fn will be called periodically to copy cvar network to target cvar network 341 | # Note: var has no target 342 | update_target_expr = [] 343 | for cvar_variable, target_cvar_variable in zip(sorted(joint_variables+cvar_variables, key=lambda v: v.name), 344 | sorted(target_cvar_func_variables, key=lambda v: v.name)): 345 | update_target_expr.append(target_cvar_variable.assign(cvar_variable)) 346 | update_target_expr = tf.group(*update_target_expr) 347 | 348 | # Create callable functions 349 | train = U.function( 350 | inputs=[ 351 | obs_t_input, 352 | act_t_ph, 353 | rew_t_ph, 354 | obs_tp1_input, 355 | done_mask_ph, 356 | importance_weights_ph 357 | ], 358 | outputs=error, 359 | updates=[optimize_expr] 360 | ) 361 | update_target = U.function([], [], updates=[update_target_expr]) 362 | 363 | # ------------------------------------------------------------------------------- 364 | 365 | # --------------------------------- Debug --------------------------------------- 366 | # a = U.function([obs_t_input, act_t_ph, rew_t_ph, obs_tp1_input, done_mask_ph], var_t_selected) 367 | # b = U.function([obs_t_input, act_t_ph, rew_t_ph, obs_tp1_input, done_mask_ph], cvar_t_selected) 368 | # c = U.function([obs_t_input, act_t_ph, rew_t_ph, obs_tp1_input, done_mask_ph], big_dist_target*y) 369 | # b = U.function([obs_t_input, act_t_ph, rew_t_ph, obs_tp1_input, done_mask_ph], var_t) 370 | # c = U.function([obs_t_input, act_t_ph, rew_t_ph, obs_tp1_input, done_mask_ph], negative_indicator) 371 | # d = U.function([obs_t_input, act_t_ph, rew_t_ph, obs_tp1_input, done_mask_ph], big_yc_target) 372 | # e = U.function([obs_t_input, act_t_ph, rew_t_ph, obs_tp1_input, done_mask_ph], cvar_t) 373 | # f = U.function([obs_t_input, act_t_ph, rew_t_ph, obs_tp1_input, done_mask_ph], cvar_loss) 374 | # atoms = U.function([obs_tp1_input], atoms) 375 | # ------------------------------------------------------------------------------- 376 | 377 | return act_f, train, update_target, [] 378 | 379 | 380 | def gather_along_second_axis(data, indices): 381 | batch_offset = tf.range(0, tf.shape(data)[0]) 382 | flat_indices = tf.stack([batch_offset, indices], axis=1) 383 | return tf.gather_nd(data, flat_indices) 384 | -------------------------------------------------------------------------------- /cvar/dqn/core/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as layers 3 | 4 | 5 | def atari_model(): 6 | model = cnn_to_mlp( 7 | convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 8 | hiddens=[512]) 9 | return model 10 | 11 | 12 | def _mlp(hiddens, inpt, scope, reuse, layer_norm): 13 | with tf.variable_scope(scope, reuse=reuse): 14 | out = inpt 15 | for hidden in hiddens: 16 | out = layers.fully_connected(out, num_outputs=hidden, activation_fn=None) 17 | if layer_norm: 18 | out = layers.layer_norm(out, center=True, scale=True) 19 | out = tf.nn.relu(out) 20 | 21 | return out 22 | 23 | 24 | def mlp(hiddens, layer_norm=False): 25 | """This model takes as input an observation and returns values of all actions. 26 | 27 | Parameters 28 | ---------- 29 | hiddens: [int] 30 | list of sizes of hidden layers 31 | 32 | Returns 33 | ------- 34 | var_func: function 35 | representing the VaR of the CVaR DQN algorithm 36 | cvar_func: function 37 | representing the CVaRy of the CVaR DQN algorithm 38 | """ 39 | 40 | def last_layer(name, inpt, num_actions, nb_atoms, scope, reuse_main=False, reuse_last=False): 41 | out = _mlp(hiddens, inpt, scope + '/net', reuse_main, layer_norm) 42 | with tf.variable_scope('{}/{}'.format(scope, name), reuse=reuse_last): 43 | out = layers.fully_connected(out, num_outputs=num_actions * nb_atoms, activation_fn=None) 44 | out = tf.reshape(out, shape=[-1, num_actions, nb_atoms], name='out') 45 | return out 46 | 47 | var_func = lambda *args, **kwargs: last_layer('var', *args, **kwargs) 48 | cvar_func = lambda *args, **kwargs: last_layer('cvar', *args, **kwargs) 49 | 50 | return var_func, cvar_func 51 | 52 | 53 | def _cnn_to_mlp(convs, hiddens, inpt, scope, reuse=False, layer_norm=False): 54 | with tf.variable_scope(scope, reuse=reuse): 55 | out = inpt 56 | with tf.variable_scope("convnet"): 57 | for num_outputs, kernel_size, stride in convs: 58 | out = layers.convolution2d(out, 59 | num_outputs=num_outputs, 60 | kernel_size=kernel_size, 61 | stride=stride, 62 | activation_fn=tf.nn.relu) 63 | conv_out = layers.flatten(out) 64 | with tf.variable_scope("action_value"): 65 | action_out = conv_out 66 | for hidden in hiddens: 67 | action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None) 68 | if layer_norm: 69 | action_out = layers.layer_norm(action_out, center=True, scale=True) 70 | out = tf.nn.relu(action_out) 71 | 72 | return out 73 | 74 | 75 | def cnn_to_mlp(convs, hiddens, layer_norm=False): 76 | """This model takes as input an observation and returns values of all actions. 77 | 78 | Parameters 79 | ---------- 80 | convs: [(int, int int)] 81 | list of convolutional layers in form of 82 | (num_outputs, kernel_size, stride) 83 | hiddens: [int] 84 | list of sizes of hidden layers 85 | 86 | Returns 87 | ------- 88 | var_func: function 89 | representing the VaR of the CVaR DQN algorithm 90 | cvar_func: function 91 | representing the CVaRy of the CVaR DQN algorithm 92 | """ 93 | 94 | def last_layer(name, inpt, num_actions, nb_atoms, scope, reuse_main=False, reuse_last=False): 95 | out = _cnn_to_mlp(convs, hiddens, inpt, scope + '/net', reuse_main, layer_norm) 96 | with tf.variable_scope('{}/{}'.format(scope, name), reuse=reuse_last): 97 | out = layers.fully_connected(out, num_outputs=num_actions * nb_atoms, activation_fn=None) 98 | out = tf.reshape(out, shape=[-1, num_actions, nb_atoms], name='out') 99 | return out 100 | 101 | var_func = lambda *args, **kwargs: last_layer('var', *args, **kwargs) 102 | cvar_func = lambda *args, **kwargs: last_layer('cvar', *args, **kwargs) 103 | 104 | return var_func, cvar_func -------------------------------------------------------------------------------- /cvar/dqn/core/plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import tensorflow as tf 4 | from cvar.gridworld.core.cvar_computation import yc_to_var 5 | 6 | 7 | class PlotMachine: 8 | 9 | def __init__(self, nb_atoms, nb_actions, action_set=None): 10 | 11 | # TODO: unify linear atoms 12 | self.atoms = np.arange(0, nb_atoms + 1) / nb_atoms 13 | 14 | self.limits = None 15 | self.yc_limits = None 16 | 17 | self.fig, self.ax = plt.subplots(1, 3, figsize=(8,4)) 18 | self.fig.canvas.draw() 19 | 20 | self.yc_plot = [self.ax[0].plot(self.atoms, np.zeros(nb_atoms+1))[0] for _ in range(nb_actions)] 21 | self.dist_plot = [self.ax[1].step(np.insert(self.atoms, -1, 1), np.zeros(nb_atoms+2))[0] for _ in range(nb_actions)] 22 | self.var_plot = [self.ax[2].step(np.insert(self.atoms, -1, 1), np.zeros(nb_atoms+2))[0] for _ in range(nb_actions)] 23 | 24 | if action_set is not None: 25 | plt.legend(action_set, loc='upper left') 26 | 27 | self.sess = tf.get_default_session() 28 | self.act_cvar = tf.get_default_graph().get_tensor_by_name("cvar_dqn/out_func/cvar/out:0") 29 | self.act_var = tf.get_default_graph().get_tensor_by_name("cvar_dqn/out_func/var/out:0") 30 | 31 | # titles 32 | self.ax[0].set_title('yCVaR') 33 | self.ax[1].set_title('Extracted Distribution') 34 | self.ax[2].set_title('VaR') 35 | 36 | # grids 37 | self.ax[0].grid() 38 | self.ax[1].grid() 39 | self.ax[2].grid() 40 | 41 | def plot_distribution(self, obs): 42 | yc_out = self.sess.run(self.act_cvar, {"cvar_dqn/observation:0": obs})[0]*self.atoms[1:] 43 | var_out = self.sess.run(self.act_var, {"cvar_dqn/observation:0": obs})[0] 44 | dist_out = [yc_to_var(self.atoms, yc_out[a]) for a in range(len(yc_out))] 45 | 46 | if self.limits is None: 47 | self.limits = [np.min(dist_out), np.max(dist_out)] 48 | self.yc_limits = [np.min(yc_out), np.max(yc_out)] 49 | else: 50 | self.limits = [min(np.min(dist_out), self.limits[0]), max(np.max(dist_out), self.limits[1])] 51 | self.yc_limits = [min(np.min(yc_out), self.yc_limits[0]), max(np.max(yc_out), self.yc_limits[1])] 52 | 53 | self.ax[0].set_ylim(self.yc_limits) 54 | self.ax[1].set_ylim(self.limits) 55 | self.ax[2].set_ylim(self.limits) 56 | 57 | # -------- yCVaR -------- 58 | plot = self.yc_plot 59 | values = yc_out 60 | for line, data in zip(plot, values): 61 | y_data = np.zeros(len(data)+1) 62 | y_data[1:] = data 63 | line.set_ydata(y_data) 64 | 65 | # ------ Dist + VaR ------ 66 | for plot, values in zip([self.dist_plot, self.var_plot], [dist_out, var_out]): 67 | for line, data in zip(plot, values): 68 | y_data = np.zeros(len(data)+2) 69 | y_data[1:-1] = data 70 | y_data[0] = self.limits[0] 71 | y_data[-1] = self.limits[-1] 72 | line.set_ydata(y_data) 73 | 74 | self.fig.canvas.draw() 75 | self.fig.canvas.flush_events() 76 | # plt.savefig('test.pdf') 77 | plt.pause(1e-10) 78 | 79 | -------------------------------------------------------------------------------- /cvar/dqn/core/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | from baselines.common.segment_tree import SumSegmentTree, MinSegmentTree 5 | 6 | 7 | class ReplayBuffer(object): 8 | def __init__(self, size): 9 | """Create basic Replay buffer. 10 | 11 | Parameters 12 | ---------- 13 | size: int 14 | Max number of transitions to store in the buffer. When the buffer 15 | overflows the old memories are dropped. 16 | """ 17 | self._storage = [] 18 | self._maxsize = size 19 | self._next_idx = 0 20 | 21 | def __len__(self): 22 | return len(self._storage) 23 | 24 | def add(self, obs_t, action, reward, obs_tp1, done): 25 | data = (obs_t, action, reward, obs_tp1, done) 26 | 27 | if self._next_idx >= len(self._storage): 28 | self._storage.append(data) 29 | else: 30 | self._storage[self._next_idx] = data 31 | self._next_idx = (self._next_idx + 1) % self._maxsize 32 | 33 | def _encode_sample(self, idxes): 34 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 35 | for i in idxes: 36 | data = self._storage[i] 37 | obs_t, action, reward, obs_tp1, done = data 38 | obses_t.append(np.array(obs_t, copy=False)) 39 | actions.append(np.array(action, copy=False)) 40 | rewards.append(reward) 41 | obses_tp1.append(np.array(obs_tp1, copy=False)) 42 | dones.append(done) 43 | return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) 44 | 45 | def sample(self, batch_size): 46 | """Sample a batch of experiences. 47 | 48 | Parameters 49 | ---------- 50 | batch_size: int 51 | How many transitions to sample. 52 | 53 | Returns 54 | ------- 55 | obs_batch: np.array 56 | batch of observations 57 | act_batch: np.array 58 | batch of actions executed given obs_batch 59 | rew_batch: np.array 60 | rewards received as results of executing act_batch 61 | next_obs_batch: np.array 62 | next set of observations seen after executing act_batch 63 | done_mask: np.array 64 | done_mask[i] = 1 if executing act_batch[i] resulted in 65 | the end of an episode and 0 otherwise. 66 | """ 67 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 68 | return self._encode_sample(idxes) 69 | 70 | 71 | class PrioritizedReplayBuffer(ReplayBuffer): 72 | def __init__(self, size, alpha): 73 | """Create Prioritized Replay buffer. 74 | 75 | Parameters 76 | ---------- 77 | size: int 78 | Max number of transitions to store in the buffer. When the buffer 79 | overflows the old memories are dropped. 80 | alpha: float 81 | how much prioritization is used 82 | (0 - no prioritization, 1 - full prioritization) 83 | 84 | See Also 85 | -------- 86 | ReplayBuffer.__init__ 87 | """ 88 | super(PrioritizedReplayBuffer, self).__init__(size) 89 | assert alpha > 0 90 | self._alpha = alpha 91 | 92 | it_capacity = 1 93 | while it_capacity < size: 94 | it_capacity *= 2 95 | 96 | self._it_sum = SumSegmentTree(it_capacity) 97 | self._it_min = MinSegmentTree(it_capacity) 98 | self._max_priority = 1.0 99 | 100 | def add(self, *args, **kwargs): 101 | """See ReplayBuffer.store_effect""" 102 | idx = self._next_idx 103 | super().add(*args, **kwargs) 104 | self._it_sum[idx] = self._max_priority ** self._alpha 105 | self._it_min[idx] = self._max_priority ** self._alpha 106 | 107 | def _sample_proportional(self, batch_size): 108 | res = [] 109 | for _ in range(batch_size): 110 | # TODO(szymon): should we ensure no repeats? 111 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 112 | idx = self._it_sum.find_prefixsum_idx(mass) 113 | res.append(idx) 114 | return res 115 | 116 | def sample(self, batch_size, beta): 117 | """Sample a batch of experiences. 118 | 119 | compared to ReplayBuffer.sample 120 | it also returns importance weights and idxes 121 | of sampled experiences. 122 | 123 | 124 | Parameters 125 | ---------- 126 | batch_size: int 127 | How many transitions to sample. 128 | beta: float 129 | To what degree to use importance weights 130 | (0 - no corrections, 1 - full correction) 131 | 132 | Returns 133 | ------- 134 | obs_batch: np.array 135 | batch of observations 136 | act_batch: np.array 137 | batch of actions executed given obs_batch 138 | rew_batch: np.array 139 | rewards received as results of executing act_batch 140 | next_obs_batch: np.array 141 | next set of observations seen after executing act_batch 142 | done_mask: np.array 143 | done_mask[i] = 1 if executing act_batch[i] resulted in 144 | the end of an episode and 0 otherwise. 145 | weights: np.array 146 | Array of shape (batch_size,) and dtype np.float32 147 | denoting importance weight of each sampled transition 148 | idxes: np.array 149 | Array of shape (batch_size,) and dtype np.int32 150 | idexes in buffer of sampled experiences 151 | """ 152 | assert beta > 0 153 | 154 | idxes = self._sample_proportional(batch_size) 155 | 156 | weights = [] 157 | p_min = self._it_min.min() / self._it_sum.sum() 158 | max_weight = (p_min * len(self._storage)) ** (-beta) 159 | 160 | for idx in idxes: 161 | p_sample = self._it_sum[idx] / self._it_sum.sum() 162 | weight = (p_sample * len(self._storage)) ** (-beta) 163 | weights.append(weight / max_weight) 164 | weights = np.array(weights) 165 | encoded_sample = self._encode_sample(idxes) 166 | return tuple(list(encoded_sample) + [weights, idxes]) 167 | 168 | def update_priorities(self, idxes, priorities): 169 | """Update priorities of sampled transitions. 170 | 171 | sets priority of transition at index idxes[i] in buffer 172 | to priorities[i]. 173 | 174 | Parameters 175 | ---------- 176 | idxes: [int] 177 | List of idxes of sampled transitions 178 | priorities: [float] 179 | List of updated priorities corresponding to 180 | transitions at the sampled idxes denoted by 181 | variable `idxes`. 182 | """ 183 | assert len(idxes) == len(priorities) 184 | for idx, priority in zip(idxes, priorities): 185 | assert priority > 0 186 | assert 0 <= idx < len(self._storage) 187 | self._it_sum[idx] = priority ** self._alpha 188 | self._it_min[idx] = priority ** self._alpha 189 | 190 | self._max_priority = max(self._max_priority, priority) 191 | -------------------------------------------------------------------------------- /cvar/dqn/core/simple.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import dill 4 | import tempfile 5 | import tensorflow as tf 6 | import zipfile 7 | 8 | import baselines.common.tf_util as U 9 | 10 | from baselines import logger 11 | from baselines.common.schedules import LinearSchedule 12 | from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer 13 | from .build_graph import build_act, build_train 14 | from cvar.common.util import timed 15 | from .static import make_session 16 | 17 | 18 | class ActWrapper(object): 19 | def __init__(self, act, act_params): 20 | self._act = act 21 | self._act_params = act_params 22 | 23 | @staticmethod 24 | def load(path, num_cpu=4): 25 | with open(path, "rb") as f: 26 | model_data, act_params = dill.load(f) 27 | act = build_act(**act_params) 28 | sess = make_session(num_cpu=num_cpu) 29 | sess.__enter__() 30 | with tempfile.TemporaryDirectory() as td: 31 | arc_path = os.path.join(td, "packed.zip") 32 | with open(arc_path, "wb") as f: 33 | f.write(model_data) 34 | 35 | zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) 36 | U.load_state(os.path.join(td, "model")) 37 | 38 | return ActWrapper(act, act_params) 39 | 40 | @staticmethod 41 | def reload(path): 42 | with open(path, "rb") as f: 43 | model_data, act_params = dill.load(f) 44 | 45 | with tempfile.TemporaryDirectory() as td: 46 | arc_path = os.path.join(td, "packed.zip") 47 | with open(arc_path, "wb") as f: 48 | f.write(model_data) 49 | 50 | zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) 51 | U.load_state(os.path.join(td, "model")) 52 | 53 | def __call__(self, *args, **kwargs): 54 | return self._act(*args, **kwargs) 55 | 56 | def save(self, path): 57 | """Save model to a pickle located at `path`""" 58 | with tempfile.TemporaryDirectory() as td: 59 | U.save_state(os.path.join(td, "model")) 60 | arc_name = os.path.join(td, "packed.zip") 61 | with zipfile.ZipFile(arc_name, 'w') as zipf: 62 | for root, dirs, files in os.walk(td): 63 | for fname in files: 64 | file_path = os.path.join(root, fname) 65 | if file_path != arc_name: 66 | zipf.write(file_path, os.path.relpath(file_path, td)) 67 | with open(arc_name, "rb") as f: 68 | model_data = f.read() 69 | with open(path, "wb") as f: 70 | dill.dump((model_data, self._act_params), f) 71 | 72 | def get_nb_atoms(self): 73 | return self._act_params['nb_atoms'] 74 | 75 | 76 | def load(path, num_cpu=16): 77 | """Load act function that was returned by learn function. 78 | 79 | Parameters 80 | ---------- 81 | path: str 82 | path to the act function pickle 83 | num_cpu: int 84 | number of cpus to use for executing the policy 85 | 86 | Returns 87 | ------- 88 | act: ActWrapper 89 | function that takes a batch of observations 90 | and returns actions. 91 | """ 92 | return ActWrapper.load(path, num_cpu=num_cpu) 93 | 94 | 95 | @timed 96 | def learn(env, 97 | var_func, 98 | cvar_func, 99 | nb_atoms, 100 | run_alpha=None, 101 | lr=5e-4, 102 | max_timesteps=100000, 103 | buffer_size=50000, 104 | exploration_fraction=0.1, 105 | exploration_final_eps=0.01, 106 | train_freq=1, 107 | batch_size=32, 108 | print_freq=1, 109 | checkpoint_freq=10000, 110 | learning_starts=1000, 111 | gamma=0.95, 112 | target_network_update_freq=500, 113 | num_cpu=4, 114 | callback=None, 115 | periodic_save_freq=1000000, 116 | periodic_save_path=None, 117 | grad_norm_clip=None, 118 | ): 119 | """Train a CVaR DQN model. 120 | 121 | Parameters 122 | ------- 123 | env: gym.Env 124 | environment to train on 125 | var_func: (tf.Variable, int, str, bool) -> tf.Variable 126 | the model that takes the following inputs: 127 | observation_in: object 128 | the output of observation placeholder 129 | num_actions: int 130 | number of actions 131 | scope: str 132 | reuse: bool 133 | should be passed to outer variable scope 134 | and returns a tensor of shape (batch_size, num_actions) with values of every action. 135 | cvar_func: function 136 | same as var_func 137 | nb_atoms: int 138 | number of atoms used in CVaR discretization 139 | run_alpha: float 140 | optimize CVaR_alpha while running. None if you want random alpha each episode. 141 | lr: float 142 | learning rate for adam optimizer 143 | max_timesteps: int 144 | number of env steps to optimizer for 145 | buffer_size: int 146 | size of the replay buffer 147 | exploration_fraction: float 148 | fraction of entire training period over which the exploration rate is annealed 149 | exploration_final_eps: float 150 | final value of random action probability 151 | train_freq: int 152 | update the model every `train_freq` steps. 153 | set to None to disable printing 154 | batch_size: int 155 | size of a batched sampled from replay buffer for training 156 | print_freq: int 157 | how often to print out training progress 158 | set to None to disable printing 159 | checkpoint_freq: int 160 | how often to save the best model. This is so that the best version is restored 161 | at the end of the training. If you do not wish to restore the best version at 162 | the end of the training set this variable to None. 163 | learning_starts: int 164 | how many steps of the model to collect transitions for before learning starts 165 | gamma: float 166 | discount factor 167 | target_network_update_freq: int 168 | update the target network every `target_network_update_freq` steps. 169 | num_cpu: int 170 | number of cpus to use for training 171 | callback: (locals, globals) -> None 172 | function called at every steps with state of the algorithm. 173 | If callback returns true training stops. 174 | periodic_save_freq: int 175 | How often do we save the model - periodically 176 | periodic_save_path: str 177 | Where do we save the model - periodically 178 | grad_norm_clip: float 179 | Clip gradient to this value. No clipping if None 180 | Returns 181 | ------- 182 | act: ActWrapper 183 | Wrapper over act function. Adds ability to save it and load it. 184 | See header of baselines/distdeepq/categorical.py for details on the act function. 185 | """ 186 | # Create all the functions necessary to train the model 187 | 188 | sess = make_session(num_cpu=num_cpu) 189 | sess.__enter__() 190 | 191 | obs_space_shape = env.observation_space.shape 192 | 193 | def make_obs_ph(name): 194 | return U.BatchInput(obs_space_shape, name=name) 195 | 196 | act, train, update_target, debug = build_train( 197 | make_obs_ph=make_obs_ph, 198 | var_func=var_func, 199 | cvar_func=cvar_func, 200 | num_actions=env.action_space.n, 201 | optimizer=tf.train.AdamOptimizer(learning_rate=lr), 202 | gamma=gamma, 203 | nb_atoms=nb_atoms, 204 | grad_norm_clipping=grad_norm_clip 205 | ) 206 | 207 | act_params = { 208 | 'make_obs_ph': make_obs_ph, 209 | 'cvar_func': cvar_func, 210 | 'var_func': var_func, 211 | 'num_actions': env.action_space.n, 212 | 'nb_atoms': nb_atoms 213 | } 214 | 215 | # Create the replay buffer 216 | replay_buffer = ReplayBuffer(buffer_size) 217 | beta_schedule = None 218 | # Create the schedule for exploration starting from 1. 219 | exploration = LinearSchedule(schedule_timesteps=int(exploration_fraction * max_timesteps), 220 | initial_p=1.0, 221 | final_p=exploration_final_eps) 222 | 223 | # Initialize the parameters and copy them to the target network. 224 | U.initialize() 225 | update_target() 226 | 227 | episode_rewards = [0.0] 228 | saved_mean_reward = None 229 | obs = env.reset() 230 | reset = True 231 | episode = 0 232 | alpha = 1. 233 | 234 | # --------------------------------- RUN --------------------------------- 235 | with tempfile.TemporaryDirectory() as td: 236 | model_saved = False 237 | model_file = os.path.join(td, "model") 238 | for t in range(max_timesteps): 239 | if callback is not None: 240 | if callback(locals(), globals()): 241 | print('Target reached') 242 | model_saved = False 243 | break 244 | # Take action and update exploration to the newest value 245 | update_eps = exploration.value(t) 246 | 247 | update_param_noise_threshold = 0. 248 | 249 | action = act(np.array(obs)[None], alpha, update_eps=update_eps)[0] 250 | reset = False 251 | new_obs, rew, done, _ = env.step(action) 252 | 253 | # ===== DEBUG ===== 254 | 255 | # s = np.ones_like(np.array(obs)[None]) 256 | # a = np.ones_like(act(np.array(obs)[None], run_alpha, update_eps=update_eps)) 257 | # r = np.array([0]) 258 | # s_ = np.ones_like(np.array(obs)[None]) 259 | # d = np.array([False]) 260 | # s = obs[None] 261 | # a = np.array([action]) 262 | # r = np.array([rew]) 263 | # s_ = new_obs[None] 264 | # d = np.array([done]) 265 | # if t % 100 == 0: 266 | # for f in debug: 267 | # print(f(s, a, r, s_, d)) 268 | # print('-------------') 269 | # 270 | # # print([sess.run(v) for v in tf.global_variables('cvar_dqn/cvar_func')]) 271 | # # print([sess.run(v) for v in tf.global_variables('cvar_dqn/var_func')]) 272 | 273 | # ================= 274 | 275 | # Store transition in the replay buffer. 276 | replay_buffer.add(obs, action, rew, new_obs, float(done)) 277 | obs = new_obs 278 | 279 | episode_rewards[-1] += rew 280 | if done: 281 | obs = env.reset() 282 | episode_rewards.append(0.0) 283 | reset = True 284 | if run_alpha is None: 285 | alpha = np.random.random() 286 | 287 | if t > learning_starts and t % train_freq == 0: 288 | # Minimize the error in Bellman's equation on a batch sampled from replay buffer. 289 | 290 | obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size) 291 | weights, batch_idxes = np.ones_like(rewards), None 292 | 293 | errors = train(obses_t, actions, rewards, obses_tp1, dones, weights) 294 | 295 | if t > learning_starts and t % target_network_update_freq == 0: 296 | # Update target network periodically. 297 | update_target() 298 | 299 | # Log results and periodically save the model 300 | mean_100ep_reward = round(float(np.mean(episode_rewards[-101:-1])), 1) 301 | num_episodes = len(episode_rewards) 302 | if done and print_freq is not None and len(episode_rewards) % print_freq == 0: 303 | logger.record_tabular("steps", t) 304 | logger.record_tabular("episodes", num_episodes) 305 | logger.record_tabular("mean 100 episode reward", mean_100ep_reward) 306 | logger.record_tabular("% time spent exploring", int(100 * exploration.value(t))) 307 | logger.record_tabular("(current alpha)", "%.2f" % alpha) 308 | logger.dump_tabular() 309 | 310 | # save and report best model 311 | if (checkpoint_freq is not None and t > learning_starts and 312 | num_episodes > 100 and t % checkpoint_freq == 0): 313 | if saved_mean_reward is None or mean_100ep_reward > saved_mean_reward: 314 | if print_freq is not None: 315 | logger.log("Saving model due to mean reward increase: {} -> {}".format( 316 | saved_mean_reward, mean_100ep_reward)) 317 | U.save_state(model_file) 318 | model_saved = True 319 | saved_mean_reward = mean_100ep_reward 320 | 321 | # save periodically 322 | if periodic_save_freq is not None and periodic_save_path is not None and t > learning_starts: 323 | if t % periodic_save_freq == 0: 324 | ActWrapper(act, act_params).save("{}-{}.pkl".format(periodic_save_path, int(t/periodic_save_freq))) 325 | 326 | if model_saved: 327 | if print_freq is not None: 328 | logger.log("Restored model with mean reward: {}".format(saved_mean_reward)) 329 | U.load_state(model_file) 330 | 331 | return ActWrapper(act, act_params) 332 | -------------------------------------------------------------------------------- /cvar/dqn/core/static.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def parent_path(path): 8 | if path.endswith('/'): 9 | path = path[:-1] 10 | return os.path.join(*os.path.split(path)[:-1]) 11 | 12 | 13 | atari_actions = ['noop', 'fire', 'up', 'right', 'left', 14 | 'down', 'up-right', 'up-left', 'down-right', 'down-left', 15 | 'up-fire', 'right-fire', 'left-fire', 'down-fire', 'up-right-fire', 16 | 'up-left-fire', 'down-right-fire', 'down-left-fire'] 17 | 18 | 19 | def actions_from_env(env): 20 | """ Propagate through all wrappers to get action indices. """ 21 | while True: 22 | if isinstance(env, gym.Wrapper): 23 | env = env.env 24 | else: 25 | break 26 | if isinstance(env, gym.Env): 27 | if hasattr(env, 'ale'): 28 | actions = env.ale.getMinimalActionSet() 29 | return [atari_actions[i] for i in actions] 30 | 31 | 32 | def make_env_atari(game_name, random_action_eps=0.): 33 | from baselines.common.atari_wrappers import wrap_deepmind, make_atari 34 | env = make_atari(game_name + "NoFrameskip-v4") 35 | if random_action_eps > 0: 36 | env = ActionRandomizer(env, random_action_eps) 37 | monitored_env = SimpleMonitor(env) 38 | env = wrap_deepmind(monitored_env, frame_stack=True, scale=True) 39 | return env, monitored_env 40 | 41 | 42 | def make_env_ice(game_name): 43 | from baselines.common.atari_wrappers import FrameStack, WarpFrame, MaxAndSkipEnv, ScaledFloatFrame 44 | import gym 45 | import cvar.dqn.ice_lake 46 | 47 | env = gym.make(game_name) 48 | # env = MaxAndSkipEnv(env, skip=4) 49 | env = WarpFrame(env) 50 | env = ScaledFloatFrame(env) 51 | env = FrameStack(env, 4) 52 | return env 53 | 54 | 55 | def make_session(num_cpu): 56 | tf_config = tf.ConfigProto( 57 | inter_op_parallelism_threads=num_cpu, 58 | intra_op_parallelism_threads=num_cpu) 59 | gpu_frac = 0.25 60 | tf_config.gpu_options.per_process_gpu_memory_fraction = gpu_frac 61 | import warnings 62 | warnings.warn("GPU is using a fixed fraction of memory: %.2f" % gpu_frac) 63 | 64 | return tf.Session(config=tf_config) 65 | 66 | 67 | class ActionRandomizer(gym.ActionWrapper): 68 | 69 | def __init__(self, env, eps): 70 | super().__init__(env) 71 | self.eps = eps 72 | 73 | def _action(self, action): 74 | if np.random.random() < self.eps: 75 | # pick action with uniform probability 76 | return self.action_space.sample() 77 | else: 78 | return action 79 | 80 | def _reverse_action(self, action): 81 | pass 82 | 83 | 84 | # hard copy from old baselines.common.misc_util 85 | # TODO: remove? 86 | import time 87 | class SimpleMonitor(gym.Wrapper): 88 | def __init__(self, env): 89 | """Adds two qunatities to info returned by every step: 90 | num_steps: int 91 | Number of steps takes so far 92 | rewards: [float] 93 | All the cumulative rewards for the episodes completed so far. 94 | """ 95 | super().__init__(env) 96 | # current episode state 97 | self._current_reward = None 98 | self._num_steps = None 99 | # temporary monitor state that we do not save 100 | self._time_offset = None 101 | self._total_steps = None 102 | # monitor state 103 | self._episode_rewards = [] 104 | self._episode_lengths = [] 105 | self._episode_end_times = [] 106 | 107 | def _reset(self): 108 | obs = self.env.reset() 109 | # recompute temporary state if needed 110 | if self._time_offset is None: 111 | self._time_offset = time.time() 112 | if len(self._episode_end_times) > 0: 113 | self._time_offset -= self._episode_end_times[-1] 114 | if self._total_steps is None: 115 | self._total_steps = sum(self._episode_lengths) 116 | # update monitor state 117 | if self._current_reward is not None: 118 | self._episode_rewards.append(self._current_reward) 119 | self._episode_lengths.append(self._num_steps) 120 | self._episode_end_times.append(time.time() - self._time_offset) 121 | # reset episode state 122 | self._current_reward = 0 123 | self._num_steps = 0 124 | 125 | return obs 126 | 127 | def _step(self, action): 128 | obs, rew, done, info = self.env.step(action) 129 | self._current_reward += rew 130 | self._num_steps += 1 131 | self._total_steps += 1 132 | info['steps'] = self._total_steps 133 | info['rewards'] = self._episode_rewards 134 | return (obs, rew, done, info) 135 | 136 | def get_state(self): 137 | return { 138 | 'env_id': self.env.unwrapped.spec.id, 139 | 'episode_data': { 140 | 'episode_rewards': self._episode_rewards, 141 | 'episode_lengths': self._episode_lengths, 142 | 'episode_end_times': self._episode_end_times, 143 | 'initial_reset_time': 0, 144 | } 145 | } 146 | 147 | def set_state(self, state): 148 | assert state['env_id'] == self.env.unwrapped.spec.id 149 | ed = state['episode_data'] 150 | self._episode_rewards = ed['episode_rewards'] 151 | self._episode_lengths = ed['episode_lengths'] 152 | self._episode_end_times = ed['episode_end_times'] 153 | 154 | -------------------------------------------------------------------------------- /cvar/dqn/ice_lake/README.md: -------------------------------------------------------------------------------- 1 | # Ice Lake 2 | 3 | Ice Lake is a visual environment specifically designed for risk-sensitive decision mak- 4 | ing. Imagine you are standing on an ice lake and you want to travel fast to a point 5 | on the lake. Will you take the a shortcut and risk falling into the cold water or 6 | will you be more patient and go around? 7 | 8 | The agent has five discrete actions, namely go Left, Right, Up, Down and Noop. 9 | These correspond to moving in the respective directions or no operation. Since the 10 | agent is on ice, there is a sliding element in the movement - this is mainly done to 11 | introduce time dependency and makes the environment a little harder. The environ- 12 | ment is updated thirty times per second. 13 | 14 | ![icelake](icelake.png "Logo Title Text 1") 15 | 16 | Try it yourself by running 17 | 18 | python3 icelake.py 19 | 20 | Controls: W-S-A-D 21 | -------------------------------------------------------------------------------- /cvar/dqn/ice_lake/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import registry, register, make, spec 2 | from cvar.dqn.ice_lake.ple_env import DiscreteStateEnv, DiscreteVisualEnv, DummyStateEnv 3 | # headless 4 | import os 5 | os.putenv('SDL_VIDEODRIVER', 'fbcon') 6 | os.environ["SDL_VIDEODRIVER"] = "dummy" 7 | 8 | register( 9 | id='IceLake-v0', 10 | entry_point='cvar.dqn.ice_lake:DiscreteStateEnv', 11 | kwargs={'game_name': 'IceLake', 'display_screen': False}, 12 | tags={'wrapper_config.TimeLimit.max_episode_steps': 500}, 13 | nondeterministic=False, 14 | ) 15 | 16 | register( 17 | id='IceLakeRGB-v0', 18 | entry_point='cvar.dqn.ice_lake:DiscreteVisualEnv', 19 | kwargs={'game_name': 'IceLake', 'display_screen': False}, 20 | tags={'wrapper_config.TimeLimit.max_episode_steps': 500}, 21 | nondeterministic=False, 22 | ) 23 | -------------------------------------------------------------------------------- /cvar/dqn/ice_lake/icelake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Silvicek/cvar-algorithms/ec60696d269857f78213b8cbde506bb5e94f34a3/cvar/dqn/ice_lake/icelake.png -------------------------------------------------------------------------------- /cvar/dqn/ice_lake/icelake.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import sys 3 | import numpy as np 4 | 5 | from ple.games.base.pygamewrapper import PyGameWrapper 6 | 7 | from pygame.constants import K_w, K_a, K_s, K_d 8 | from ple.games.utils import percent_round_int 9 | 10 | BG_COLOR = (255, 255, 255) 11 | 12 | 13 | class GameObject(pygame.sprite.Sprite): 14 | 15 | def __init__(self, position, radius, color): 16 | super().__init__() 17 | self.position = position 18 | self.velocity = np.zeros(position.shape) 19 | self.radius = radius 20 | self.color = color 21 | 22 | image = pygame.Surface([radius * 2, radius * 2]) 23 | image.set_colorkey((0, 0, 0)) 24 | 25 | pygame.draw.circle( 26 | image, 27 | color, 28 | (radius, radius), 29 | radius, 30 | 0 31 | ) 32 | 33 | self.image = image.convert() 34 | self.rect = self.image.get_rect() 35 | 36 | def draw(self, screen): 37 | self.rect = pygame.Rect(self.position[0]-self.radius, self.position[1]-self.radius, self.radius, self.radius) 38 | screen.blit(self.image, self.rect) 39 | 40 | def update(self, velocity, dt): 41 | self.velocity += velocity 42 | 43 | self.position = self.position + self.velocity * dt 44 | 45 | self.velocity *= 0.975 46 | 47 | @staticmethod 48 | def distance(a, b): 49 | return np.sqrt(np.sum(np.square(a.position - b.position))) 50 | 51 | 52 | class IceLake(PyGameWrapper): 53 | """ 54 | 55 | Parameters 56 | ---------- 57 | width : int 58 | Screen width. 59 | 60 | height : int 61 | Screen height, recommended to be same dimension as width. 62 | 63 | """ 64 | actions = { 65 | "up": K_w, 66 | "left": K_a, 67 | "right": K_d, 68 | "down": K_s 69 | } 70 | 71 | rewards = { 72 | "tick": -40. / 30, 73 | "ice": -50.0, 74 | "win": 100.0, 75 | "wall": 0., 76 | } 77 | 78 | def __init__(self, width=84, height=84): 79 | 80 | PyGameWrapper.__init__(self, width, height, actions=self.actions) 81 | 82 | self.dx = 0. 83 | self.dy = 0. 84 | self.ticks = 0 85 | 86 | self._game_ended = False 87 | 88 | def _handle_player_events(self): 89 | self.dx = 0.0 90 | self.dy = 0.0 91 | 92 | agent_speed = 0.1 * self.width 93 | for event in pygame.event.get(): 94 | if event.type == pygame.QUIT: 95 | pygame.quit() 96 | sys.exit() 97 | 98 | if event.type == pygame.KEYDOWN: 99 | key = event.key 100 | 101 | if key == self.actions["left"]: 102 | self.dx -= agent_speed 103 | 104 | if key == self.actions["right"]: 105 | self.dx += agent_speed 106 | 107 | if key == self.actions["up"]: 108 | self.dy -= agent_speed 109 | 110 | if key == self.actions["down"]: 111 | self.dy += agent_speed 112 | 113 | def getGameState(self): 114 | """ 115 | Gets a non-visual state representation of the game. 116 | XXX: should be a dict, to np in preprocess 117 | """ 118 | return np.hstack((self.player.position, self.player.velocity)) 119 | 120 | def getGameStateDims(self): 121 | """ 122 | Gets the games non-visual state dimensions. 123 | 124 | Returns 125 | ------- 126 | list of tuples (min, max, discrete_steps) corresponding to each observation index 127 | TODO: constants 128 | """ 129 | return np.array([(0, self.width, 50), (0, self.height, 50), (-200, 200, 2), (-200, 200, 2)], dtype=object) 130 | 131 | def getScore(self): 132 | return self._score 133 | 134 | def game_over(self): 135 | """ 136 | Return bool if the game has 'finished' 137 | """ 138 | return self._game_ended 139 | 140 | def init(self): 141 | """ 142 | Starts/Resets the game to its initial state 143 | """ 144 | target_radius = percent_round_int(self.width, 0.047) 145 | self.target = GameObject(np.array([self.width-target_radius, self.height-target_radius]), 146 | target_radius, (40, 140, 40)) 147 | 148 | ice_radius = percent_round_int(self.width, 0.3) 149 | self.ice = GameObject(np.array([self.width/2, self.height-ice_radius/2]), 150 | ice_radius, (0, 110, 255)) 151 | 152 | player_radius = percent_round_int(self.width, 0.047) 153 | self.player = GameObject(np.array([1+player_radius, self.height-1-player_radius]), 154 | player_radius, (1, 1, 1)) 155 | self.playerGroup = pygame.sprite.GroupSingle(self.player) 156 | 157 | self.creeps = pygame.sprite.Group() 158 | self.creeps.add(self.target) 159 | self.creeps.add(self.ice) 160 | 161 | self._score = 0. 162 | self.ticks = 0 163 | self.lives = -1 164 | 165 | self._game_ended = False 166 | 167 | def step(self, dt): 168 | """ 169 | Perform one step of game emulation. 170 | """ 171 | dt /= 1000.0 172 | self.ticks += 1 173 | 174 | self._score += IceLake.rewards["tick"] 175 | 176 | self._handle_player_events() 177 | self.player.update(np.array([self.dx, self.dy]), dt) 178 | 179 | if GameObject.distance(self.target, self.player) < self.target.radius: 180 | self._game_ended = True 181 | self._score += IceLake.rewards['win'] 182 | elif self.wall_collide(): 183 | self._score += IceLake.rewards['wall'] 184 | if GameObject.distance(self.ice, self.player) < self.ice.radius: 185 | if np.random.random() < 0.02: 186 | self._game_ended = True 187 | self._score += IceLake.rewards['ice'] 188 | 189 | def draw(self): 190 | self.screen.fill(BG_COLOR) 191 | self.target.draw(self.screen) 192 | self.ice.draw(self.screen) 193 | self.player.draw(self.screen) 194 | 195 | def wall_collide(self): 196 | x = self.player.position[0] 197 | y = self.player.position[1] 198 | r = self.player.radius 199 | collision = False 200 | if x <= r: 201 | self.player.position[0] = r 202 | self.player.velocity[0] = 0 203 | collision = True 204 | elif x >= self.width - r: 205 | self.player.position[0] = self.width - r 206 | self.player.velocity[0] = 0 207 | collision = True 208 | if y <= r: 209 | self.player.position[1] = r 210 | self.player.velocity[1] = 0 211 | collision = True 212 | elif y >= self.height - r: 213 | self.player.position[1] = self.height - r 214 | self.player.velocity[1] = 0 215 | collision = True 216 | 217 | return collision 218 | 219 | 220 | if __name__ == "__main__": 221 | 222 | pygame.init() 223 | game = IceLake(width=256, height=256) 224 | game.screen = pygame.display.set_mode(game.getScreenDims(), 0, 32) 225 | game.clock = pygame.time.Clock() 226 | game.rng = np.random.RandomState(24) 227 | 228 | while True: 229 | game.init() 230 | while not game.game_over(): 231 | dt = game.clock.tick_busy_loop(30) 232 | game.step(dt) 233 | game.draw() 234 | pygame.display.update() 235 | print("Episode reward", game.getScore()) 236 | 237 | 238 | -------------------------------------------------------------------------------- /cvar/dqn/ice_lake/ple_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | from ple import PLE 4 | import numpy as np 5 | 6 | 7 | class LazyDrawPLE(PLE): 8 | 9 | def __init__(self, draw_function, args, **kwargs): 10 | super().__init__(args, **kwargs) 11 | self.draw_function = draw_function 12 | 13 | def getScreenRGB(self): 14 | self.draw_function() 15 | return super().getScreenRGB() 16 | 17 | 18 | class Env(gym.Env): 19 | metadata = {'render.modes': ['human', 'rgb_array']} 20 | 21 | def __init__(self, game_name, display_screen=True): 22 | # open up a game state to communicate with emulator 23 | import importlib 24 | game_module_name = ('cvar.dqn.ice_lake.%s' % game_name).lower() 25 | game_module = importlib.import_module(game_module_name) 26 | game = getattr(game_module, game_name)() 27 | self.game_state = LazyDrawPLE(game.draw, game, fps=30, display_screen=display_screen, 28 | state_preprocessor=state_preprocessor) 29 | self.game_state.init() 30 | self._action_set = sorted(self.game_state.getActionSet(), key=lambda x: (x is None, x)) 31 | self.screen_width, self.screen_height = self.game_state.getScreenDims() 32 | 33 | self.action_space = None 34 | self.observation_space = None 35 | self.viewer = None 36 | 37 | def _step(self, a): 38 | raise NotImplementedError 39 | 40 | def _get_image(self): 41 | image_rotated = np.fliplr(np.rot90(self.game_state.getScreenRGB(),3)) # Hack to fix the rotated image returned by ple 42 | return image_rotated 43 | 44 | @property 45 | def _n_actions(self): 46 | return len(self._action_set) 47 | 48 | # return: (states, observations) 49 | def _reset(self): 50 | raise NotImplementedError 51 | 52 | def _render(self, mode='human', close=False): 53 | if close: 54 | if self.viewer is not None: 55 | self.viewer.close() 56 | self.viewer = None 57 | return 58 | img = self._get_image() 59 | if mode == 'rgb_array': 60 | return img 61 | elif mode == 'human': 62 | from gym.envs.classic_control import rendering 63 | if self.viewer is None: 64 | self.viewer = rendering.SimpleImageViewer() 65 | self.viewer.imshow(img) 66 | 67 | def _seed(self, seed=0): 68 | rng = np.random.seed(seed) 69 | self.game_state.rng = rng 70 | self.game_state.game.rng = self.game_state.rng 71 | 72 | self.game_state.init() 73 | 74 | 75 | class DiscreteVisualEnv(Env): 76 | 77 | def __init__(self, game_name, display_screen=True): 78 | super().__init__(game_name, display_screen) 79 | 80 | self.action_space = spaces.Discrete(len(self._action_set)) 81 | self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3)) 82 | 83 | def _step(self, a): 84 | reward = self.game_state.act(self._action_set[a]) 85 | state = self._get_image() 86 | terminal = self.game_state.game_over() 87 | return state, reward, terminal, {} 88 | 89 | # return: (states, observations) 90 | def _reset(self): 91 | self.game_state.reset_game() 92 | state = self._get_image() 93 | return state 94 | 95 | 96 | class DiscreteStateEnv(Env): 97 | 98 | def __init__(self, game_name, display_screen=True): 99 | super().__init__(game_name, display_screen) 100 | 101 | self.action_space = spaces.Discrete(len(self._action_set)) 102 | self.observation_space = spaces.Box(low=-1e3, high=1e3, shape=self.game_state.getGameStateDims()) 103 | 104 | def _step(self, a): 105 | reward = self.game_state.act(self._action_set[a]) 106 | state = self.game_state.getGameState() 107 | terminal = self.game_state.game_over() 108 | return state, reward, terminal, {} 109 | 110 | def _reset(self): 111 | self.game_state.reset_game() 112 | return self.game_state.getGameState() 113 | 114 | 115 | class DummyStateEnv(Env): 116 | 117 | def __init__(self, game_name, display_screen=True): 118 | super().__init__(game_name, display_screen) 119 | 120 | self.action_space = spaces.Discrete(len(self._action_set)) 121 | self.observation_space = spaces.Box(low=0, high=1, shape=(100,)) 122 | 123 | def _step(self, a): 124 | reward = self.game_state.act(self._action_set[a]) 125 | state = self.game_state.getGameState() 126 | terminal = self.game_state.game_over() 127 | return state, reward, terminal, {} 128 | 129 | def _reset(self): 130 | self.game_state.reset_game() 131 | return self.game_state.getGameState() 132 | 133 | 134 | def state_preprocessor(s): 135 | return s 136 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/README.md: -------------------------------------------------------------------------------- 1 | # CVaR DQN Scripts 2 | 3 | Run your training and testing from here. 4 | 5 | For baseline IceLake benchmarks run 6 | 7 | python3 train_simple.py --env IceLake --nb-atoms 100 --run-alpha -1 --num-steps 2000000 --buffer-size 200000 8 | 9 | For IceLakeRGB run 10 | 11 | python3 train_ice.py 12 | 13 | Also try Atari with 14 | 15 | python3 train_atari.py --help 16 | 17 | or faster Atari benchmark 18 | 19 | python3 train_pong.py 20 | 21 | 22 | After learning, run 23 | 24 | python3 enjoy_[{simple, ice, pong, atari}].py 25 | 26 | for visualizations. 27 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/enjoy_atari.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import baselines.common.tf_util as U 6 | import numpy as np 7 | from baselines.common.misc_util import boolean_flag 8 | from gym.monitoring import VideoRecorder 9 | 10 | import cvar.dqn.core as dqn_core 11 | from cvar.dqn.core.plots import PlotMachine 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser("Run an already learned DQN model.") 16 | # Environment 17 | parser.add_argument("--env", type=str, required=True, help="name of the game") 18 | parser.add_argument("--model-dir", type=str, default=None, help="load model from this directory. ") 19 | parser.add_argument("--video", type=str, default=None, help="Path to mp4 file where the video of first episode will be recorded.") 20 | boolean_flag(parser, "stochastic", default=False, help="whether or not to use stochastic actions according to models eps value") 21 | boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model") 22 | boolean_flag(parser, "visual", default=False, help="whether or not to show the distribution output") 23 | 24 | parser.add_argument("--alpha", type=str, default=1.0, help="alpha in CVaR_alpha(x_0)") 25 | parser.add_argument("--random-action", type=float, default=0., 26 | help="probability of selecting a random action (for more risk sensitivity)") 27 | 28 | return parser.parse_args() 29 | 30 | 31 | def play(env, act, stochastic, video_path, nb_atoms): 32 | num_episodes = 0 33 | video_recorder = VideoRecorder( 34 | env, video_path, enabled=video_path is not None) 35 | obs = env.reset() 36 | if args.visual: 37 | action_names = dqn_core.actions_from_env(env) 38 | plot_machine = PlotMachine(nb_atoms, env.action_space.n, action_names) 39 | while True: 40 | env.unwrapped.render() 41 | video_recorder.capture_frame() 42 | action = act(np.array(obs)[None], args.alpha, stochastic=stochastic)[0] 43 | obs, rew, done, info = env.step(action) 44 | if args.visual: 45 | plot_machine.plot_distribution(np.array(obs)[None]) 46 | 47 | if done: 48 | obs = env.reset() 49 | if len(info["rewards"]) > num_episodes: 50 | if len(info["rewards"]) == 1 and video_recorder.enabled: 51 | # save video of first episode 52 | print("Saved video.") 53 | video_recorder.close() 54 | video_recorder.enabled = False 55 | print(info["rewards"][-1]) 56 | num_episodes = len(info["rewards"]) 57 | # input() 58 | 59 | 60 | if __name__ == '__main__': 61 | with U.make_session(4) as sess: 62 | args = parse_args() 63 | env, _ = dqn_core.make_env_atari(args.env) 64 | 65 | if args.random_action > 0: 66 | env = dqn_core.ActionRandomizer(env, args.random_action) 67 | 68 | model_parent_path = dqn_core.parent_path(args.model_dir) 69 | old_args = json.load(open(model_parent_path + '/args.json')) 70 | 71 | var_func, cvar_func = dqn_core.models.atari_model() 72 | act = dqn_core.build_act( 73 | make_obs_ph=lambda name: U.BatchInput(env.observation_space.shape, name=name), 74 | var_func=var_func, 75 | cvar_func=cvar_func, 76 | num_actions=env.action_space.n, 77 | nb_atoms=old_args['nb_atoms']) 78 | U.load_state(os.path.join(args.model_dir, "saved")) 79 | play(env, act, args.stochastic, args.video, old_args['nb_atoms']) 80 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/enjoy_ice.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | import cvar.dqn.core as dqn_core 5 | from cvar.dqn.core.plots import PlotMachine 6 | 7 | 8 | def main(): 9 | env = dqn_core.make_env_ice("IceLakeRGB-v0") 10 | act = dqn_core.load("../models/ice_rgb_model.pkl") 11 | 12 | action_set = ['Left', 'Right', 'Down', 'Up', '-'] 13 | plot_machine = PlotMachine(act.get_nb_atoms(), env.action_space.n, action_set) 14 | 15 | while True: 16 | obs, done = env.reset(), False 17 | episode_rew = 0 18 | while not done: 19 | env.render() 20 | obs, rew, done, _ = env.step(act(np.array(obs)[None], 1.0, stochastic=False)[0]) 21 | plot_machine.plot_distribution(np.array(obs)[None]) 22 | episode_rew += rew 23 | print("Episode reward", episode_rew) 24 | 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/enjoy_pong.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import cvar.dqn.core as dqn_core 4 | from cvar.dqn.core.plots import PlotMachine 5 | 6 | 7 | def main(): 8 | env, _ = dqn_core.make_env_atari("Pong") 9 | act = dqn_core.load("../models/pong_model.pkl") 10 | print(act) 11 | action_set = dqn_core.actions_from_env(env) 12 | plot_machine = PlotMachine(act.get_nb_atoms(), env.action_space.n, action_set) 13 | 14 | while True: 15 | obs, done = env.reset(), False 16 | episode_rew = 0 17 | while not done: 18 | env.render() 19 | obs, rew, done, _ = env.step(act(np.array(obs)[None], 1.0, stochastic=False)[0]) 20 | plot_machine.plot_distribution(np.array(obs)[None]) 21 | episode_rew += rew 22 | print("Episode reward", episode_rew) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/enjoy_simple.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | from baselines.common.misc_util import boolean_flag 5 | 6 | import cvar.dqn.core as dqn_core 7 | from cvar.dqn.core.plots import PlotMachine 8 | import cvar.dqn.ice_lake 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser("CVaR DQN experiments for simple environments") 13 | 14 | parser.add_argument("--env", type=str, default="IceLake", help="name of the game") 15 | parser.add_argument("--random-action", type=float, default=0., help="probability of selecting a random action (for more risk sensitivity)") 16 | parser.add_argument("--num-steps", type=int, default=50000, help="total number of steps to run the environment for") 17 | 18 | boolean_flag(parser, "visual", default=True, help="whether or not to show the distribution plots") 19 | 20 | # CVaR 21 | parser.add_argument("--nb-atoms", type=int, default=10, help="number of cvar and quantile atoms (linearly spaced)") 22 | parser.add_argument("--run-alpha", type=float, default=1., help="alpha for policy used during training") 23 | 24 | return parser.parse_args() 25 | 26 | 27 | def main(): 28 | args = parse_args() 29 | env = gym.make(args.env+"-v0") 30 | if args.random_action > 0: 31 | env = dqn_core.ActionRandomizer(env, args.random_action) 32 | 33 | act = dqn_core.load("../models/"+args.env.lower()+"_model.pkl") 34 | action_set = dqn_core.actions_from_env(env) 35 | plot_machine = PlotMachine(act.get_nb_atoms(), env.action_space.n, action_set) 36 | while True: 37 | obs, done = env.reset(), False 38 | episode_rew = 0 39 | while not done: 40 | env.render() 41 | obs, rew, done, _ = env.step(act(obs[None], args.run_alpha, stochastic=False)[0]) 42 | if args.visual: 43 | plot_machine.plot_distribution(obs[None]) 44 | episode_rew += rew 45 | print("Episode reward", episode_rew) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/test_atari.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import os 4 | import numpy as np 5 | import json 6 | 7 | from gym.monitoring import VideoRecorder 8 | 9 | import baselines.common.tf_util as U 10 | 11 | import cvar.dqn.core as dqn_core 12 | from cvar.common.cvar_computation import var_cvar_from_samples 13 | from baselines.common.misc_util import boolean_flag 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser("Run an already learned DQN model.") 18 | # Environment 19 | parser.add_argument("--env", type=str, required=True, help="name of the game") 20 | parser.add_argument("--model-dir", type=str, required=True, help="load model from this directory. ") 21 | boolean_flag(parser, "stochastic", default=False, help="whether or not to use stochastic actions according to models eps value") 22 | 23 | parser.add_argument("--alpha", type=str, default=1.0, help="alpha in CVaR_alpha(x_0)") 24 | parser.add_argument("--random-action", type=float, default=0., 25 | help="probability of selecting a random action (for more risk sensitivity)") 26 | 27 | parser.add_argument("--nb-episodes", type=int, default=1000, help="run how many episodes") 28 | 29 | return parser.parse_args() 30 | 31 | 32 | def run(env, act, stochastic, nb_episodes): 33 | episode = 0 34 | info = {} 35 | 36 | obs = env.reset() 37 | 38 | while episode < nb_episodes: 39 | 40 | action = act(np.array(obs)[None], args.alpha, stochastic=stochastic)[0] 41 | obs, rew, done, info = env.step(action) 42 | 43 | if done: 44 | obs = env.reset() 45 | if len(info["rewards"]) > episode: 46 | episode = len(info["rewards"]) 47 | print('{}: {}'.format(episode, info["rewards"][-1])) 48 | 49 | return info['rewards'] 50 | 51 | 52 | if __name__ == '__main__': 53 | with U.make_session(4) as sess: 54 | args = parse_args() 55 | env, _ = dqn_core.make_env_atari(args.env) 56 | 57 | if args.random_action > 0: 58 | env = dqn_core.ActionRandomizer(env, args.random_action) 59 | 60 | model_parent_path = dqn_core.parent_path(args.model_dir) 61 | old_args = json.load(open(model_parent_path + '/args.json')) 62 | 63 | var_func, cvar_func = dqn_core.models.atari_model() 64 | act = dqn_core.build_act( 65 | make_obs_ph=lambda name: U.BatchInput(env.observation_space.shape, name=name), 66 | var_func=var_func, 67 | cvar_func=cvar_func, 68 | num_actions=env.action_space.n, 69 | nb_atoms=old_args['nb_atoms']) 70 | U.load_state(os.path.join(args.model_dir, "saved")) 71 | 72 | rewards = run(env, act, args.stochastic, args.nb_episodes) 73 | 74 | print('---------------------') 75 | for alpha in np.arange(0.05, 1.05, 0.05): 76 | v, cv = var_cvar_from_samples(rewards, alpha) 77 | print('CVaR_{:.2f} = {}'.format(alpha, cv)) 78 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/train_atari.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import numpy as np 4 | import os 5 | import tensorflow as tf 6 | import tempfile 7 | import time 8 | import json 9 | 10 | import baselines.common.tf_util as U 11 | 12 | from baselines import logger 13 | import cvar.dqn.core as dqn_core 14 | from cvar.dqn.core.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer 15 | from cvar.dqn.core.simple import make_session 16 | from baselines.common.misc_util import ( 17 | boolean_flag, 18 | pickle_load, 19 | pretty_eta, 20 | relatively_safe_pickle_dump, 21 | set_global_seeds, 22 | RunningAvg, 23 | ) 24 | from baselines.common.schedules import LinearSchedule, PiecewiseSchedule 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser("DQN experiments for Atari games") 29 | # Environment 30 | parser.add_argument("--env", type=str, default="Pong", help="name of the game") 31 | parser.add_argument("--seed", type=int, default=42, help="which seed to use") 32 | parser.add_argument("--random-action", type=float, default=0., help="probability of selecting a random action (for more risk sensitivity)") 33 | # Core DQN parameters 34 | parser.add_argument("--replay-buffer-size", type=int, default=int(1e6), help="replay buffer size") 35 | parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for Adam optimizer") 36 | parser.add_argument("--num-steps", type=int, default=int(4e7), help="total number of steps to run the environment for") 37 | parser.add_argument("--batch-size", type=int, default=32, help="number of transitions to optimize at the same time") 38 | parser.add_argument("--learning-freq", type=int, default=4, help="number of iterations between every optimization step") 39 | parser.add_argument("--target-update-freq", type=int, default=10000, help="number of iterations between every target network update") 40 | # Bells and whistles 41 | boolean_flag(parser, "layer-norm", default=False, help="whether or not to use layer norm (should be True if param_noise is used)") 42 | boolean_flag(parser, "gym-monitor", default=False, help="whether or not to use a OpenAI Gym monitor (results in slower training due to video recording)") 43 | # CVaR 44 | parser.add_argument("--nb-atoms", type=int, default=10, help="number of cvar and quantile atoms (linearly spaced)") 45 | parser.add_argument("--run-alpha", type=float, default=1., help="alpha for policy used during training") 46 | # Checkpointing 47 | parser.add_argument("--save-dir", type=str, default=None, help="directory in which training state and model should be saved.") 48 | parser.add_argument("--save-freq", type=int, default=1e6, help="save model once every time this many iterations are completed") 49 | boolean_flag(parser, "load-on-start", default=True, help="if true and model was previously saved then training will be resumed") 50 | return parser.parse_args() 51 | 52 | 53 | def maybe_save_model(savedir, state): 54 | """This function checkpoints the model and state of the training algorithm.""" 55 | if savedir is None: 56 | return 57 | start_time = time.time() 58 | model_dir = "model-{}".format(state["num_iters"]) 59 | U.save_state(os.path.join(savedir, model_dir, "saved")) 60 | 61 | # requires 32+gb of memory 62 | relatively_safe_pickle_dump(state, os.path.join(savedir, 'training_state.pkl.zip'), compression=True) 63 | relatively_safe_pickle_dump(state["monitor_state"], os.path.join(savedir, 'monitor_state.pkl')) 64 | logger.log("Saved model in {} seconds\n".format(time.time() - start_time)) 65 | 66 | 67 | def maybe_load_model(savedir): 68 | """Load model if present at the specified path.""" 69 | if savedir is None: 70 | return 71 | 72 | state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) 73 | found_model = os.path.exists(state_path) 74 | if found_model: 75 | state = pickle_load(state_path, compression=True) 76 | model_dir = "model-{}".format(state["num_iters"]) 77 | U.load_state(os.path.join(savedir, model_dir, "saved")) 78 | logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) 79 | return state 80 | 81 | 82 | if __name__ == '__main__': 83 | args = parse_args() 84 | 85 | # Parse savedir 86 | savedir = args.save_dir 87 | if savedir is None: 88 | savedir = os.getenv('OPENAI_LOGDIR', None) 89 | 90 | # Create and seed the env. 91 | env, monitored_env = dqn_core.make_env_atari(args.env) 92 | if args.random_action > 0: 93 | env = dqn_core.ActionRandomizer(env, args.random_action) 94 | if args.seed > 0: 95 | set_global_seeds(args.seed) 96 | env.unwrapped.seed(args.seed) 97 | 98 | if args.gym_monitor and savedir: 99 | env = gym.wrappers.Monitor(env, os.path.join(savedir, 'gym_monitor'), force=True) 100 | 101 | if savedir: 102 | with open(os.path.join(savedir, 'args.json'), 'w') as f: 103 | json.dump(vars(args), f) 104 | 105 | var_func, cvar_func = dqn_core.models.atari_model() 106 | 107 | sess = make_session(num_cpu=4) 108 | sess.__enter__() 109 | 110 | # Create training graph 111 | act, train, update_target, debug = dqn_core.build_train( 112 | make_obs_ph=lambda name: U.BatchInput(env.observation_space.shape, name=name), 113 | var_func=var_func, 114 | cvar_func=cvar_func, 115 | num_actions=env.action_space.n, 116 | optimizer=tf.train.AdamOptimizer(learning_rate=args.lr), 117 | gamma=0.99, 118 | nb_atoms=args.nb_atoms 119 | ) 120 | 121 | # Create the schedule for exploration starting from 1. 122 | final_p = 0 if args.random_action > 0 else 0.01 123 | exploration = LinearSchedule(schedule_timesteps=int(0.1 * args.num_steps), 124 | initial_p=1.0, 125 | final_p=final_p) 126 | # approximate_num_iters = args.num_steps / 4 127 | # exploration = PiecewiseSchedule([ 128 | # (0, 1.0), 129 | # (approximate_num_iters / 50, 0.1), 130 | # (approximate_num_iters / 5, 0.01) 131 | # ], outside_value=0.01) 132 | 133 | replay_buffer = ReplayBuffer(args.replay_buffer_size) 134 | 135 | U.initialize() 136 | update_target() 137 | num_iters = 0 138 | 139 | # Load the model 140 | state = maybe_load_model(savedir) 141 | if state is not None: 142 | num_iters, replay_buffer = state["num_iters"], state["replay_buffer"], 143 | monitored_env.set_state(state["monitor_state"]) 144 | 145 | start_time, start_steps = None, None 146 | steps_per_iter = RunningAvg(0.999) 147 | iteration_time_est = RunningAvg(0.999) 148 | obs = env.reset() 149 | num_iters_since_reset = 0 150 | reset = True 151 | 152 | # Main training loop 153 | while True: 154 | num_iters += 1 155 | num_iters_since_reset += 1 156 | 157 | # Take action and store transition in the replay buffer. 158 | kwargs = {} 159 | 160 | update_eps = exploration.value(num_iters) 161 | update_param_noise_threshold = 0. 162 | 163 | action = act(np.array(obs)[None], args.run_alpha, update_eps=update_eps, **kwargs)[0] 164 | reset = False 165 | new_obs, rew, done, info = env.step(action) 166 | replay_buffer.add(obs, action, rew, new_obs, float(done)) 167 | obs = new_obs 168 | if done: 169 | num_iters_since_reset = 0 170 | obs = env.reset() 171 | reset = True 172 | 173 | if (num_iters > max(5 * args.batch_size, args.replay_buffer_size // 20) and 174 | num_iters % args.learning_freq == 0): 175 | # Sample a bunch of transitions from replay buffer 176 | obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(args.batch_size) 177 | weights = np.ones_like(rewards) 178 | # Minimize the error in Bellman's equation and compute TD-error 179 | td_errors = train(obses_t, actions, rewards, obses_tp1, dones, weights) 180 | 181 | # Update target network. 182 | if num_iters % args.target_update_freq == 0: 183 | update_target() 184 | 185 | if start_time is not None: 186 | steps_per_iter.update(info['steps'] - start_steps) 187 | iteration_time_est.update(time.time() - start_time) 188 | start_time, start_steps = time.time(), info["steps"] 189 | 190 | # Save the model and training state. 191 | if num_iters > 0 and (num_iters % args.save_freq == 0 or info["steps"] > args.num_steps): 192 | maybe_save_model(savedir, { 193 | 'replay_buffer': replay_buffer, 194 | 'num_iters': num_iters, 195 | 'monitor_state': monitored_env.get_state(), 196 | }) 197 | 198 | if info["steps"] > args.num_steps: 199 | break 200 | 201 | if done: 202 | steps_left = args.num_steps - info["steps"] 203 | completion = np.round(100*info["steps"] / args.num_steps, 2) 204 | 205 | logger.record_tabular("% completion", completion) 206 | logger.record_tabular("steps", info["steps"]) 207 | logger.record_tabular("iters", num_iters) 208 | logger.record_tabular("episodes", len(info["rewards"])) 209 | logger.record_tabular("reward (100 epi mean)", np.mean(info["rewards"][-100:])) 210 | logger.record_tabular("exploration", exploration.value(num_iters)) 211 | 212 | fps_estimate = (float(steps_per_iter) / (float(iteration_time_est) + 1e-6) 213 | if steps_per_iter._value is not None else "calculating...") 214 | logger.dump_tabular() 215 | logger.log() 216 | logger.log("ETA: " + pretty_eta(int(steps_left / fps_estimate))) 217 | logger.log() 218 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/train_ice.py: -------------------------------------------------------------------------------- 1 | import cvar.dqn.core as dqn_core 2 | from baselines.common import set_global_seeds 3 | from cvar.dqn.core.static import make_env_ice 4 | 5 | 6 | def main(): 7 | # set_global_seeds(1337) 8 | env = make_env_ice("IceLakeRGB-v0") 9 | 10 | var_func, cvar_func = dqn_core.models.cnn_to_mlp( 11 | convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 12 | hiddens=[256], 13 | ) 14 | act = dqn_core.learn( 15 | env, 16 | var_func=var_func, 17 | cvar_func=cvar_func, 18 | lr=1e-4, 19 | max_timesteps=10000000 + 1, 20 | buffer_size=500000, 21 | exploration_fraction=0.2, 22 | exploration_final_eps=0.3, 23 | train_freq=4, 24 | learning_starts=10000, 25 | target_network_update_freq=1000, 26 | gamma=0.99, 27 | batch_size=32, 28 | nb_atoms=100, 29 | print_freq=25, 30 | periodic_save_path="../models/ice_rgb", 31 | periodic_save_freq=100000, 32 | grad_norm_clip=10. 33 | ) 34 | act.save("../models/ice_rgb_model.pkl") 35 | env.close() 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/train_pong.py: -------------------------------------------------------------------------------- 1 | import cvar.dqn.core as dqn_core 2 | from baselines.common import set_global_seeds 3 | 4 | 5 | def main(): 6 | set_global_seeds(1337) 7 | env, _ = dqn_core.make_env_atari("Pong") 8 | 9 | var_func, cvar_func = dqn_core.models.cnn_to_mlp( 10 | convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 11 | hiddens=[256], 12 | ) 13 | act = dqn_core.learn( 14 | env, 15 | run_alpha=1.0, 16 | var_func=var_func, 17 | cvar_func=cvar_func, 18 | lr=1e-4, 19 | max_timesteps=2000000, 20 | buffer_size=100000, 21 | exploration_fraction=0.1, 22 | exploration_final_eps=0.01, 23 | train_freq=4, 24 | learning_starts=10000, 25 | target_network_update_freq=1000, 26 | gamma=0.99, 27 | batch_size=32, 28 | nb_atoms=50 29 | ) 30 | act.save("../models/pong_model.pkl") 31 | env.close() 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /cvar/dqn/scripts/train_simple.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from baselines.common import set_global_seeds 3 | import cvar.dqn.core as dqn_core 4 | import argparse 5 | import cvar.dqn.ice_lake 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser("CVaR DQN experiments for simple environments") 10 | 11 | parser.add_argument("--env", type=str, default="IceLake", help="name of the game") 12 | parser.add_argument("--random-action", type=float, default=0., help="probability of selecting a random action (for more risk sensitivity)") 13 | parser.add_argument("--num-steps", type=int, default=50000, help="total number of steps to run the environment for") 14 | parser.add_argument("--buffer-size", type=int, default=50000, help="size of replay memory") 15 | 16 | # CVaR 17 | parser.add_argument("--nb-atoms", type=int, default=10, help="number of cvar and quantile atoms (linearly spaced)") 18 | parser.add_argument("--run-alpha", type=float, default=1., help="alpha for policy used during training. -1 " 19 | "means") 20 | 21 | return parser.parse_args() 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | env = gym.make(args.env+"-v0") 27 | 28 | if args.random_action > 0: 29 | env = dqn_core.ActionRandomizer(env, args.random_action) 30 | exploration_final_eps = 0 31 | else: 32 | exploration_final_eps = 0.3 33 | 34 | set_global_seeds(1337) 35 | 36 | var_func, cvar_func = dqn_core.models.mlp([64]) 37 | act = dqn_core.learn( 38 | env, 39 | var_func, 40 | cvar_func, 41 | nb_atoms=args.nb_atoms, 42 | run_alpha=args.run_alpha if args.run_alpha > 0 else None, 43 | lr=1e-4, 44 | max_timesteps=args.num_steps+1, 45 | buffer_size=args.buffer_size, 46 | exploration_fraction=0.2, 47 | exploration_final_eps=exploration_final_eps, 48 | print_freq=10, 49 | batch_size=32, 50 | periodic_save_path="../models/"+args.env.lower() 51 | ) 52 | act.save("../models/"+args.env.lower()+"_model.pkl") 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /cvar/gridworld/README.md: -------------------------------------------------------------------------------- 1 | # Gridworld 2 | 3 | This subpackage contains a risk-sensitive gridworld environment (`cliffwalker.py`) and algorithms that solve it. 4 | 5 | 6 | Tweak the file `core/constants` if you want to test with different atoms. Note that Q-learning (empirically) prefers uniform atoms while VI performs better with log-spaced atoms. 7 | 8 | ---------------------------------------------------------------------- 9 | 10 | For CVaR Value Iteration run 11 | 12 | python3 run_vi.py 13 | 14 | If the shown deterministic path stops before target, increase the number of atoms around the critical point. 15 | 16 | ---------------------------------------------------------------------- 17 | 18 | For CVaR Q-learning run 19 | 20 | python3 run_q.py 21 | 22 | 23 | ---------------------------------------------------------------------- 24 | 25 | You can also check how standard RL behaves on the gridworld by running 26 | 27 | python3 exp_model.py 28 | 29 | See the main function for different algorithms. 30 | 31 | ---------------------------------------------------------------------- 32 | 33 | Both CVaR VI and Q-learning have interactive plots - click somewhere on the learned grid and a distribution plot will pop out. 34 | 35 | 36 | ![interactive](interactive.png) -------------------------------------------------------------------------------- /cvar/gridworld/algorithms/q_learning.py: -------------------------------------------------------------------------------- 1 | from cvar.gridworld.cliffwalker import * 2 | from cvar.gridworld.core.constants import * 3 | from cvar.gridworld.core import cvar_computation 4 | import numpy as np 5 | from cvar.gridworld.plots.grid import InteractivePlotMachine 6 | from cvar.common.util import timed, spaced_atoms 7 | 8 | 9 | class ActionValueFunction: 10 | 11 | def __init__(self, world, atoms): 12 | self.world = world 13 | self.atoms = atoms 14 | self.atom_p = self.atoms[1:] - self.atoms[:-1] 15 | 16 | self.Q = np.empty((world.height, world.width, len(world.ACTIONS)), dtype=object) 17 | for ix in np.ndindex(self.Q.shape): 18 | self.Q[ix] = MarkovQState(self.atoms) 19 | 20 | def update_safe(self, x, a, x_, r, beta, id=None): 21 | """ Naive TD update that ensures yCVaR convexity. """ 22 | V_x = self.joint_action_dist(x_) 23 | 24 | for v in V_x: 25 | for i, atom in enumerate(self.atoms[1:]): 26 | V = self.Q[x.y, x.x, a].V[i] 27 | yC = self.Q[x.y, x.x, a].yc[i] 28 | 29 | # learning rates 30 | lr_v = beta * self.atom_p[i] # p mirrors magnitude (for log-spaced) 31 | lr_yc = beta * self.atom_p[i] 32 | # lr_yc = beta * atom_p[i] / atom # /atom for using the same beta when estimating cvar (not yc) 33 | 34 | if self.Q[x.y, x.x, a].V[i] >= r + gamma*v: 35 | update = lr_v*(1-1/atom) 36 | else: 37 | update = lr_v 38 | 39 | # UPDATE VaR 40 | if i == 0: 41 | self.Q[x.y, x.x, a].V[i] = min(self.Q[x.y, x.x, a].V[i] + update, self.Q[x.y, x.x, a].V[i+1]) 42 | elif i == (len(self.atoms)-2): 43 | self.Q[x.y, x.x, a].V[i] = max(self.Q[x.y, x.x, a].V[i] + update, self.Q[x.y, x.x, a].V[i-1]) 44 | else: 45 | self.Q[x.y, x.x, a].V[i] = min(max(self.Q[x.y, x.x, a].V[i] + update, self.Q[x.y, x.x, a].V[i-1]), 46 | self.Q[x.y, x.x, a].V[i+1]) 47 | 48 | # UPDATE CVaR 49 | yCn = (1 - lr_yc) * yC + lr_yc * (atom*V + min(0, r+gamma*v - V)) 50 | if i == 0: 51 | self.Q[x.y, x.x, a].yc[i] = yCn 52 | elif i == 1: 53 | ddy = self.Q[x.y, x.x, a].yc[0] / self.atom_p[0] # TODO: check 54 | self.Q[x.y, x.x, a].yc[i] = max(yCn, self.Q[x.y, x.x, a].yc[i - 1] + ddy * self.atom_p[i]) 55 | else: 56 | ddy = (self.Q[x.y, x.x, a].yc[i-1] - self.Q[x.y, x.x, a].yc[i-2]) / self.atom_p[i-1] # TODO: check 57 | self.Q[x.y, x.x, a].yc[i] = max(yCn, self.Q[x.y, x.x, a].yc[i-1] + ddy*self.atom_p[i]) 58 | 59 | def update_naive(self, x, a, x_, r, beta, id=None): 60 | """ Naive (read slow) CVaR TD update. """ 61 | V_x = self.joint_action_dist(x_) 62 | print('standard') 63 | # TODO: vectorize/cythonize 64 | for iv, v in enumerate(V_x): 65 | for i, atom in enumerate(self.atoms[1:]): 66 | V = self.Q[x.y, x.x, a].V[i] 67 | yC = self.Q[x.y, x.x, a].yc[i] 68 | 69 | # learning rates 70 | lr_v = beta * self.atom_p[iv] # p mirrors magnitude (for log-spaced) 71 | # lr_yc = beta * atom_p[iv] / atom # /atom for using the same beta when estimating cvar (not yc) 72 | lr_yc = beta * self.atom_p[iv] 73 | 74 | if self.Q[x.y, x.x, a].V[i] >= r + gamma * v: 75 | update = lr_v * (1 - 1 / atom) 76 | else: 77 | update = lr_v 78 | 79 | # UPDATE VaR 80 | self.Q[x.y, x.x, a].V[i] += update 81 | 82 | # UPDATE CVaR 83 | yCn = (1 - lr_yc) * yC + lr_yc * (atom*V + min(0, r+gamma*v - V)) 84 | self.Q[x.y, x.x, a].yc[i] = yCn 85 | 86 | def update(self, x, a, x_, r, beta, id=None): 87 | """ Vectorized CVaR TD update. """ 88 | d = self.joint_action_dist(x_) 89 | 90 | V = np.array(self.Q[x.y, x.x, a].V) 91 | C = np.array(self.Q[x.y, x.x, a].yc) / self.atoms[1:] 92 | 93 | # row is a single atom update 94 | # shape=(n, n) 95 | indicator_mask = self.Q[x.y, x.x, a].V[:, None] >= r + gamma * d 96 | 97 | V_update = 1 - indicator_mask / self.atoms[1:, None] 98 | 99 | self.Q[x.y, x.x, a].V += beta * np.average(V_update, axis=1, weights=self.atom_p) 100 | 101 | C_update = V[:, None] + np.clip(r + gamma * d - V[:, None], a_min=None, a_max=0) / self.atoms[1:, None] 102 | 103 | C_new = (1 - beta) * C + beta * np.average(C_update, axis=1, weights=self.atom_p) 104 | self.Q[x.y, x.x, a].yc = C_new * self.atoms[1:] 105 | 106 | 107 | def next_action_alpha(self, x, alpha): 108 | yc = [self.Q[x.y, x.x, a].yc_alpha(alpha) for a in self.world.ACTIONS] 109 | return np.argmax(yc) 110 | 111 | def next_action_s(self, x, s): 112 | """ 113 | Select best action according to E[(Z-s)^-]. 114 | If multiple 0's, use yCVaR_0. 115 | """ 116 | return max(self.world.ACTIONS, key=lambda a: (self.Q[x.y, x.x, a].cvar_pre_s(s), self.Q[x.y, x.x, a].yc[0])) 117 | 118 | def joint_action_dist(self, x, return_yc=False): 119 | """ 120 | Returns a distribution representing the value function at state x. 121 | Constructed by taking a supremum of yC over actions for each atom. 122 | """ 123 | yc = [np.max([self.Q[x.y, x.x, a].yc[i] for a in self.world.ACTIONS]) for i in range(NB_ATOMS)] 124 | 125 | if return_yc: 126 | return yc 127 | else: 128 | return cvar_computation.yc_to_var(self.atoms, yc) 129 | 130 | def joint_action_dist_var(self, x): 131 | """ 132 | Returns VaR estimates of the joint distribution. 133 | Constructed by taking a supremum of yC over actions for each atom. 134 | """ 135 | info = [max([(self.Q[x.y, x.x, a].yc[i], self.Q[x.y, x.x, a].V[i]) for a in self.world.ACTIONS]) for i in range(NB_ATOMS)] 136 | 137 | return [ycv[1] for ycv in info] 138 | 139 | def var_alpha(self, x, a, alpha): 140 | """ 141 | Get VaR_alpha using interpolation 142 | """ 143 | i = 0 144 | for i in range(len(self.atoms)): 145 | if alpha < self.atoms[i]: 146 | break 147 | v_low = self.Q[x.y, x.x, a].V[i-2] 148 | v_high = self.Q[x.y, x.x, a].V[i-1] 149 | 150 | p_low = self.atoms[i-1] 151 | p_high = self.atoms[i] 152 | 153 | return v_low + (alpha - p_low) / (p_high - p_low) * (v_high - v_low) 154 | 155 | def alpha_from_var(self, x, s): 156 | """ 157 | Get alpha from joint VaRs using interpolation 158 | """ 159 | var = self.joint_action_dist_var(x) 160 | for i in range(len(var)): 161 | if s < var[i]: 162 | break 163 | 164 | # clip alpha to lowest atom (less won't make a difference) 165 | if i == 0: 166 | return self.atoms[1] 167 | # 1 is max 168 | elif s > var[-1]: 169 | return 1. 170 | 171 | v_low = var[i-1] 172 | v_high = var[i] 173 | 174 | p_low = self.atoms[i] 175 | p_high = self.atoms[i+1] 176 | 177 | return p_low + (s - v_low) / (v_high - v_low) * (p_high - p_low) 178 | 179 | def optimal_path(self, alpha): 180 | """ Optimal deterministic path. """ 181 | from cvar.gridworld.core.policies import VarBasedQPolicy, XiBasedQPolicy, NaiveQPolicy, VarXiQPolicy 182 | from cvar.gridworld.core.runs import optimal_path 183 | # policy = VarBasedQPolicy(self, alpha) 184 | policy = VarXiQPolicy(self, alpha) 185 | # policy = XiBasedQPolicy(self, alpha) 186 | # policy = NaiveQPolicy(self, alpha) 187 | return optimal_path(self.world, policy) 188 | 189 | 190 | def is_ordered(v): 191 | for i in range(1, len(v)): 192 | if v[i-1] - v[i] > 1e-6: 193 | return False 194 | return True 195 | 196 | 197 | def is_convex(yc, atoms): 198 | assert LOG_NB_ATOMS == 0 199 | return is_ordered(cvar_computation.yc_to_var(atoms, yc)) 200 | 201 | 202 | class MarkovQState: 203 | 204 | def __init__(self, atoms): 205 | self.atoms = atoms 206 | self.V = np.zeros(NB_ATOMS) # VaR estimate 207 | self.yc = np.zeros(NB_ATOMS) # CVaR estimate 208 | 209 | def plot(self, show=True, ax=None): 210 | import matplotlib.pyplot as plt 211 | if ax is None: 212 | _, ax = plt.subplots(1, 3) 213 | 214 | # yC 215 | ax[0].plot(self.atoms, np.insert(self.yc, 0, 0), '-') 216 | 217 | # yC-> var 218 | v = self.dist_from_yc() 219 | ax[1].step(self.atoms, list(v) + [v[-1]], '-', where='post') 220 | 221 | # var 222 | ax[2].step(self.atoms, list(self.V) + [self.V[-1]], '-', where='post') 223 | 224 | # titles 225 | ax[0].set_title('yCVaR') 226 | ax[1].set_title('Extracted Distribution') 227 | ax[2].set_title('VaR') 228 | 229 | if show: 230 | plt.show() 231 | 232 | def expected_value(self): 233 | return self.yc[-1] 234 | 235 | def yc_alpha(self, alpha): 236 | """ linear interpolation: yC(alpha)""" 237 | i = 0 238 | for i in range(1, len(self.atoms)): 239 | if alpha < self.atoms[i]: 240 | break 241 | alpha_portion = (alpha - self.atoms[i-1]) / (self.atoms[i] - self.atoms[i-1]) 242 | if i == 1: # between 0 and first atom 243 | return alpha_portion * self.yc[i-1] 244 | else: 245 | return self.yc[i-2] + alpha_portion * (self.yc[i-1] - self.yc[i-2]) 246 | 247 | def var_alpha(self, alpha): 248 | """ VaR estimate of alpha. """ 249 | # TODO: check 250 | last_v = self.V[0] 251 | for p, v in zip(self.atoms[1:], self.V): 252 | if p > alpha: 253 | break 254 | last_v = v 255 | return last_v 256 | 257 | def cvar_pre_s(self, s): 258 | """ E[(V-s)^-] + ys. 259 | 260 | Uses the actual VaR for th cutoff and yC->VaR for the expectation. 261 | """ 262 | yc = 0 263 | 264 | for ix, v_yc in enumerate(self.dist_from_yc()): 265 | v = self.V[ix] 266 | p = self.atoms[ix+1] - self.atoms[ix] 267 | if v < s: 268 | yc += p * v_yc 269 | else: 270 | break 271 | 272 | return yc 273 | 274 | def dist_from_yc(self): 275 | return cvar_computation.yc_to_var(self.atoms, self.yc) 276 | 277 | 278 | @timed 279 | def q_learning(world, alpha, max_episodes=2e3, max_episode_length=100): 280 | Q = ActionValueFunction(world, spaced_atoms(NB_ATOMS, SPACING, LOG_NB_ATOMS, LOG_THRESHOLD)) 281 | 282 | # learning parameters 283 | eps = 0.5 284 | beta = 0.4 285 | 286 | # count visits for debugging purposes 287 | counter = np.zeros((world.height, world.width), dtype=int) 288 | 289 | e = 0 290 | while e < max_episodes: 291 | if e % 10 == 0: 292 | print("e:{}, beta:{}".format(e, beta)) 293 | beta = max(beta*0.995, 0.01) 294 | x = world.initial_state 295 | 296 | i = 0 297 | while x not in world.goal_states and i < max_episode_length: 298 | 299 | counter[x.y, x.x] += 1 300 | 301 | a = eps_greedy(Q.next_action_alpha(x, alpha), eps, world.ACTIONS) 302 | t = world.sample_transition(x, a) 303 | x_, r = t.state, t.reward 304 | 305 | Q.update(x, a, x_, r, beta, id=(e, i)) 306 | 307 | x = x_ 308 | 309 | i += 1 310 | e += 1 311 | 312 | # # show visit counts 313 | # import matplotlib.pyplot as plt 314 | # fig, ax = plt.subplots() 315 | # ax.imshow(counter) 316 | # for (j, i), label in np.ndenumerate(counter): 317 | # ax.text(i, j, label, ha='center', va='center', color='white') 318 | # ax.set_title('Run alpha={}'.format(alpha)) 319 | 320 | return Q 321 | 322 | 323 | def pseudo_q_learning(world, max_episodes): 324 | Q = ActionValueFunction(world) 325 | 326 | e = 0 327 | beta = 0.01 / NB_ATOMS 328 | while e < max_episodes: 329 | if e % 10 == 0: 330 | print(e, beta) 331 | for x in world.states(): 332 | if x in world.goal_states or x in world.cliff_states: 333 | continue 334 | a = np.random.randint(0, 4) 335 | 336 | t = world.sample_transition(x, a) 337 | x_, r = t.state, t.reward 338 | 339 | # Q.update(x, a, x_, r, beta) 340 | Q.update_safe(x, a, x_, r, beta) 341 | 342 | e += 1 343 | 344 | return Q 345 | 346 | 347 | def eps_greedy(a, eps, action_space): 348 | if np.random.random() < eps: 349 | return np.random.choice(action_space) 350 | else: 351 | return a 352 | 353 | 354 | def q_to_v_exp(Q): 355 | return np.max(np.array([Q.Q[ix].expected_value() for ix in np.ndindex(Q.Q.shape)]).reshape(Q.Q.shape), axis=-1) 356 | -------------------------------------------------------------------------------- /cvar/gridworld/algorithms/value_iteration.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from cvar.gridworld.cliffwalker import * 3 | from cvar.gridworld.core import cvar_computation 4 | from cvar.gridworld.core.constants import * 5 | from cvar.common.util import timed, spaced_atoms 6 | 7 | # use LP when computing CVaRs 8 | # TAMAR_LP = True 9 | TAMAR_LP = False 10 | 11 | 12 | class ValueFunction: 13 | 14 | def __init__(self, world): 15 | self.world = world 16 | 17 | self.V = np.empty((world.height, world.width), dtype=object) 18 | for ix in np.ndindex(self.V.shape): 19 | self.V[ix] = MarkovState() 20 | 21 | print('ATOMS:', list(self.V[0, 0].atoms)) 22 | 23 | def update(self, y, x, check_bound=False): 24 | 25 | v_a, yc_a = self.action_v_yc(y, x) 26 | 27 | best_args = np.argmax(yc_a, axis=0) 28 | 29 | self.V[y, x].yc = np.array([yc_a[best_args[i], i] for i in range(len(self.V[y, x].yc))]) 30 | 31 | self.V[y, x].c_0 = max([cvar_computation.v_0_from_transitions(self.V, list(self.transitions(y, x, a)), gamma) 32 | for a in self.world.ACTIONS]) 33 | 34 | # check for error bound 35 | if check_bound: 36 | eps = 1. 37 | c_0 = v_a[best_args[0], 0] 38 | if c_0 - self.V[y, x].c_0 > eps: 39 | # if deep and self.V[y, x].nb_atoms < 100: 40 | self.V[y, x].increase_precision(eps) 41 | 42 | def action_v_yc(self, y, x): 43 | """ Extract transition distributions for each action. """ 44 | yc_a = [] 45 | v_a = [] 46 | 47 | for a in self.world.ACTIONS: 48 | t = list(self.transitions(y, x, a)) 49 | 50 | if TAMAR_LP: 51 | v, yc = self.V[y, x].compute_cvar_by_lp([t_.prob for t_ in t], self.transition_ycs(y, x, a), 52 | [self.V[tr.state.y, tr.state.x].atoms for tr in t]) 53 | else: 54 | v, yc = self.V[y, x].compute_cvar_by_sort([t_.prob for t_ in t], self.transition_vars(y, x, a), 55 | [self.V[tr.state.y, tr.state.x].atoms for tr in t]) 56 | yc_a.append(yc) 57 | v_a.append(v) 58 | 59 | return np.array(v_a), np.array(yc_a) 60 | 61 | def next_action(self, y, x, alpha): 62 | if alpha == 0: 63 | print('alpha=0') 64 | a_best = max(self.world.ACTIONS, 65 | key=lambda a:cvar_computation.v_0_from_transitions(self.V, list(self.transitions(y, x, a)), gamma)) 66 | return a_best, np.zeros(len(list(self.transitions(y, x, a_best)))) 67 | 68 | assert alpha != 0 69 | # self.plot_full_actions(31, 59) 70 | best = (-1e6, 0, 0) 71 | for a in self.world.ACTIONS: 72 | 73 | if TAMAR_LP: 74 | cv, xis = self.single_yc_xis_lp(y, x, a, alpha) 75 | else: 76 | _, cv, xis = self.single_var_yc_xis(y, x, a, alpha) 77 | 78 | if cv > best[0]: 79 | best = (cv, xis, a) 80 | 81 | _, xis, a = best 82 | return a, xis 83 | 84 | def single_yc_xis_lp(self, y, x, a, alpha): 85 | transition_p = [t.prob for t in self.transitions(y, x, a)] 86 | atom_values = [self.V[t.state.y, t.state.x].atoms for t in self.transitions(y, x, a)] 87 | yc = self.transition_ycs(y, x, a) 88 | return cvar_computation.single_yc_lp_from_t(transition_p, atom_values, yc, alpha, xis=True) 89 | 90 | def single_var_yc_xis(self, y, x, a, alpha): 91 | """ 92 | Compute VaR, CVaR and xi values in O(nlogn) 93 | """ 94 | 95 | transitions = list(self.transitions(y, x, a)) 96 | var_values = self.transition_vars(y, x, a) 97 | transition_p = [t.prob for t in transitions] 98 | t_atoms = [self.V[t.state.y, t.state.x].atoms for t in transitions] 99 | 100 | return cvar_computation.single_var_yc_xis_from_t(transition_p, t_atoms, var_values, alpha) 101 | 102 | def plot_full_actions(self, y, x): 103 | """ 104 | Plot actions without value approximation - used for debugging 105 | """ 106 | import matplotlib.pyplot as plt 107 | fig, ax = plt.subplots() 108 | 109 | for a in self.world.ACTIONS: 110 | transitions = list(self.transitions(y, x, a)) 111 | var_values = self.transition_vars(y, x, a) 112 | transition_p = [t.prob for t in transitions] 113 | t_atoms = [self.V[t.state.y, t.state.x].atoms for t in transitions] 114 | 115 | info = cvar_computation.extract_distribution(transition_p, t_atoms, var_values) 116 | t_atoms = np.cumsum([0]+[p for p, ix, v in info]) 117 | t_vars = [v for p, ix, v in info] 118 | t_yc = cvar_computation.var_to_ycvar([p for p, ix, v in info], t_vars) 119 | print(a, cvar_computation.single_alpha_to_yc([p for p, ix, v in info], t_vars, 0.036)) 120 | 121 | ax.plot(t_atoms, [0]+list(t_yc), 'o-') 122 | 123 | ax.legend([self.world.ACTION_NAMES[a] for a in self.world.ACTIONS]) 124 | 125 | plt.show() 126 | 127 | def y_var(self, y, x, a, var): 128 | """ E[(Z-var)^-] + yvar""" 129 | 130 | transitions = list(self.transitions(y, x, a)) 131 | var_values = self.transition_vars(y, x, a) 132 | 133 | info = cvar_computation.extract_distribution(transitions, var_values, 134 | [self.V[tr.state.y, tr.state.x].atom_p for tr in transitions]) 135 | 136 | yv = 0. 137 | p = 0 138 | for p_, _, v_ in info: 139 | if v_ >= var: # TODO: solve for discrete distributions 140 | break 141 | else: 142 | yv += p_ * v_ 143 | p += p_ 144 | 145 | return p, yv 146 | 147 | def transitions(self, y, x, a): 148 | for t in self.world.transitions(State(y, x))[a]: 149 | yield t 150 | 151 | def transition_vars(self, y, x, a): 152 | return np.array([t.reward + gamma * self.V[t.state.y, t.state.x].var for t in self.transitions(y, x, a)]) 153 | 154 | def transition_ycs(self, y, x, a): 155 | return np.array([t.reward*self.V[t.state.y, t.state.x].atoms[1:] + gamma * self.V[t.state.y, t.state.x].yc for t in self.transitions(y, x, a)]) 156 | 157 | def optimal_path(self, alpha): 158 | """ Optimal deterministic path. """ 159 | from cvar.gridworld.core.policies import XiBasedPolicy 160 | from cvar.gridworld.core.runs import optimal_path 161 | policy = XiBasedPolicy(self, alpha) 162 | return optimal_path(self.world, policy) 163 | 164 | def plot(self, y, x, a, show=False, ax=None): 165 | import matplotlib.pyplot as plt 166 | if ax is None: 167 | fig, ax = plt.subplots(1, 2) 168 | 169 | var_a, yc_a = self.action_v_yc(y, x) 170 | var = var_a[a] 171 | yc = yc_a[a] 172 | # var 173 | ax[1].step(self.V[y, x].atoms, list(var) + [var[-1]], 'o-', where='post') 174 | 175 | # yC 176 | ax[0].plot(self.V[y, x].atoms, np.insert(yc, 0, 0), 'o-') 177 | 178 | ax[0].set_title('yCVaR') 179 | ax[1].set_title('Extracted Distribution') 180 | 181 | if show: 182 | plt.show() 183 | 184 | 185 | class MarkovState: 186 | 187 | def __init__(self): 188 | self.yc = np.zeros(NB_ATOMS) 189 | self.atoms = spaced_atoms(NB_ATOMS, SPACING, LOG_NB_ATOMS, LOG_THRESHOLD) # e.g. [0, 0.25, 0.5, 1] 190 | self.atom_p = self.atoms[1:] - self.atoms[:-1] # [0.25, 0.25, 0.5] 191 | 192 | self.c_0 = 0 # separate estimate for CVaR_0 193 | 194 | def plot(self, show=True, figax=None): 195 | import matplotlib.pyplot as plt 196 | if figax is None: 197 | fig, ax = plt.subplots(1, 2) 198 | else: 199 | fig, ax = figax 200 | 201 | # var 202 | ax[1].step(self.atoms, list(self.var) + [self.var[-1]], 'o-', where='post') 203 | 204 | # yC 205 | ax[0].plot(self.atoms, np.insert(self.yc, 0, 0), 'o-') 206 | 207 | if show: 208 | plt.show() 209 | 210 | @property 211 | def var(self): 212 | return cvar_computation.yc_to_var(self.atoms, self.yc) 213 | 214 | @property 215 | def nb_atoms(self): 216 | return len(self.yc) 217 | 218 | def increase_precision(self, eps): 219 | """ Bound error by adding atoms. Follows the adaptive procedure from RSRDM. """ 220 | new_atoms = [] 221 | v_0 = self.yc[0] 222 | y = (eps*self.atom_p[0])/(np.abs(v_0-self.c_0)) 223 | if y < 1e-15: 224 | print('SMALL') 225 | 226 | while y < self.atom_p[0]: 227 | new_atoms.append(y) 228 | y *= SPACING 229 | 230 | self.atoms = np.hstack((np.array([0]), new_atoms, self.atoms[1:])) 231 | self.atom_p = self.atoms[1:] - self.atoms[:-1] 232 | 233 | self.yc = np.hstack((v_0*np.array(new_atoms), self.yc)) 234 | 235 | def cvar_alpha(self, alpha): 236 | return cvar_computation.single_alpha_to_cvar(self.atom_p, self.var, alpha) 237 | 238 | def expected_value(self): 239 | return np.dot(self.atom_p, self.var) 240 | 241 | def compute_cvar_by_sort(self, transition_p, var_values, t_atoms): 242 | return cvar_computation.v_yc_from_t(self.atoms, transition_p, var_values, t_atoms) 243 | 244 | def compute_cvar_by_lp(self, transition_p, t_ycs, t_atoms): 245 | return cvar_computation.v_yc_from_t_lp(self.atoms, transition_p, t_ycs, t_atoms) 246 | 247 | 248 | def value_update(world, V, id=0, figax=None): 249 | 250 | V_ = copy.deepcopy(V) 251 | for s in world.states(): 252 | V_.update(s.y, s.x) 253 | 254 | return V_ 255 | 256 | 257 | def value_difference(V, V_, world): 258 | max_val = -1 259 | max_state = None 260 | for s in world.states(): 261 | # dist = np.max(np.abs(V.V[s.y, s.x].var-V_.V[s.y, s.x].var)) 262 | cvars = np.array([V.V[s.y, s.x].cvar_alpha(alpha) for alpha in V.V[s.y, s.x].atoms[1:]]) 263 | cvars_ = np.array([V_.V[s.y, s.x].cvar_alpha(alpha) for alpha in V_.V[s.y, s.x].atoms[1:]]) 264 | if cvars.shape != cvars_.shape: 265 | return float('inf'), None 266 | dist = np.max(np.abs(cvars - cvars_)) 267 | if dist > max_val: 268 | max_state = s 269 | max_val = dist 270 | 271 | return max_val, max_state 272 | 273 | @timed 274 | def value_iteration(world, V=None, max_iters=1e3, eps_convergence=1e-3): 275 | if V is None: 276 | V = ValueFunction(world) 277 | i = 0 278 | figax = None 279 | while True: 280 | if i == 28: 281 | import matplotlib.pyplot as plt 282 | figax = plt.subplots(1, 2) 283 | V_ = value_update(world, V, i, figax) 284 | 285 | error, worst_state = value_difference(V, V_, world) 286 | if error < eps_convergence: 287 | print("value fully learned after %d iterations" % (i,)) 288 | break 289 | elif i > max_iters: 290 | print("value finished without convergence after %d iterations" % (i,)) 291 | break 292 | V = V_ 293 | i += 1 294 | 295 | print('Iteration:{}, error={} ({})'.format(i, error, worst_state)) 296 | 297 | return V 298 | 299 | -------------------------------------------------------------------------------- /cvar/gridworld/cliffwalker.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | 4 | 5 | # helper data structures: 6 | # a state is given by row and column positions designated (y, x) 7 | State = namedtuple('State', ['y', 'x']) 8 | 9 | # encapsulates a transition to state and its probability 10 | Transition = namedtuple('Transition', ['state', 'prob', 'reward']) # transition to state with probability prob 11 | 12 | 13 | class GridWorld: 14 | """ Cliffwalker. """ 15 | 16 | ACTION_LEFT = 0 17 | ACTION_RIGHT = 1 18 | ACTION_UP = 2 19 | ACTION_DOWN = 3 20 | ACTIONS = [ACTION_LEFT, ACTION_RIGHT, ACTION_UP, ACTION_DOWN] 21 | FALL_REWARD = -40 22 | ACTION_NAMES = {ACTION_LEFT: "Left", ACTION_RIGHT: "Right", ACTION_UP: "Up", ACTION_DOWN: "Down"} 23 | 24 | def __init__(self, height, width, random_action_p=0.1, risky_p_loss=0.15): 25 | 26 | self.height, self.width = height, width 27 | self.risky_p_loss = risky_p_loss 28 | self.random_action_p = random_action_p 29 | 30 | # self.risky_goal_states = {State(0, 5)} 31 | self.risky_goal_states = {} 32 | 33 | self.initial_state = State(self.height - 1, 0) 34 | self.goal_states = {State(self.height - 1, self.width - 1)} 35 | 36 | self.cliff_states = set() 37 | if height != 1: 38 | for x in range(width): 39 | for y in range(height): 40 | s = State(y, x) 41 | p_cliff = 0.1 * (y / height)**2 * bool(x > 1 and y > 0 and x < width-2 and y < height-1) 42 | if s == self.initial_state or s in self.goal_states: 43 | continue 44 | 45 | if np.random.random() < p_cliff: 46 | self.cliff_states.add(s) 47 | 48 | def states(self): 49 | """ iterator over all possible states """ 50 | for y in range(self.height): 51 | for x in range(self.width): 52 | s = State(y, x) 53 | if s in self.cliff_states: 54 | continue 55 | yield s 56 | 57 | def target_state(self, s, a): 58 | """ Return the next deterministic state """ 59 | x, y = s.x, s.y 60 | if a == self.ACTION_LEFT: 61 | return State(y, max(x - 1, 0)) 62 | if a == self.ACTION_RIGHT: 63 | return State(y, min(x + 1, self.width - 1)) 64 | if a == self.ACTION_UP: 65 | return State(max(y - 1, 0), x) 66 | if a == self.ACTION_DOWN: 67 | return State(min(y + 1, self.height - 1), x) 68 | 69 | def transitions(self, s): 70 | """ 71 | returns a list of Transitions from the state s for each action, only non zero probabilities are given 72 | serves the lists for all actions at once 73 | """ 74 | if s in self.goal_states: 75 | return [[Transition(state=s, prob=1.0, reward=0)] for a in self.ACTIONS] 76 | 77 | if s in self.risky_goal_states: 78 | goal = next(iter(self.goal_states)) 79 | return [[Transition(state=goal, prob=self.risky_p_loss, reward=-50), 80 | Transition(state=goal, prob=1-self.risky_p_loss, reward=100)] for a in self.ACTIONS] 81 | 82 | transitions_full = [] 83 | for a in self.ACTIONS: 84 | transitions_actions = [] 85 | 86 | # over all *random* actions 87 | for a_ in self.ACTIONS: 88 | s_ = self.target_state(s, a_) 89 | if s_ in self.cliff_states: 90 | r = self.FALL_REWARD 91 | # s_ = self.initial_state 92 | s_ = next(iter(self.goal_states)) 93 | else: 94 | r = -1 95 | p = 1.0 - self.random_action_p if a_ == a else self.random_action_p / 3 96 | if p != 0: 97 | transitions_actions.append(Transition(s_, p, r)) 98 | transitions_full.append(transitions_actions) 99 | 100 | return transitions_full 101 | 102 | def sample_transition(self, s, a): 103 | """ Sample a single transition, duh. """ 104 | trans = self.transitions(s)[a] 105 | state_probs = [tran.prob for tran in trans] 106 | return trans[np.random.choice(len(trans), p=state_probs)] 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | from cvar.gridworld.core.constants import * 112 | from cvar.gridworld.plots.grid import grid_plot 113 | import matplotlib.pyplot as plt 114 | 115 | world = GridWorld(40, 60) 116 | grid_plot(world) 117 | plt.show() 118 | for i in range(20): 119 | print('seed=', i) 120 | np.random.seed(i) 121 | world = GridWorld(40, 60) 122 | grid_plot(world) 123 | plt.show() -------------------------------------------------------------------------------- /cvar/gridworld/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Silvicek/cvar-algorithms/ec60696d269857f78213b8cbde506bb5e94f34a3/cvar/gridworld/core/__init__.py -------------------------------------------------------------------------------- /cvar/gridworld/core/constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | np.set_printoptions(8) 4 | 5 | gamma = 0.95 6 | 7 | 8 | # atom spacing 9 | NB_ATOMS = 50 10 | LOG_NB_ATOMS = 50 # number of log atoms 11 | LOG_THRESHOLD = 1. # where does the log start (1 for full log) 12 | SPACING = 2 13 | 14 | -------------------------------------------------------------------------------- /cvar/gridworld/core/cvar_computation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Different CVaR computations and conversions. 3 | Naming conventions: 4 | 5 | v=VaR, c=CVaR, yc=yCVaR, t=transition 6 | 7 | single_: extract a single value from a distribution 8 | 9 | _from_t: extract the desired value from weighted list of distributions 10 | 11 | """ 12 | import numpy as np 13 | 14 | 15 | # =================================================================== 16 | # Single: 17 | # gets the desired values from a single distribution 18 | # =================================================================== 19 | 20 | def single_var_to_alpha(p_sorted, v_sorted, s): 21 | """ """ 22 | alpha = 0 23 | for v, p in zip(v_sorted, p_sorted): 24 | if v > s: 25 | # TODO: middle/linear interp 26 | break 27 | alpha += p 28 | return alpha 29 | 30 | 31 | def single_alpha_to_var(p_sorted, v_sorted, alpha): 32 | p = 0. 33 | for p_, v_ in zip(p_sorted, v_sorted): 34 | p += p_ 35 | if p >= alpha: 36 | return v_ 37 | # numerical 1 != 1 38 | return v_sorted[-1] 39 | 40 | 41 | def single_alpha_to_cvar(p_sorted, v_sorted, alpha): 42 | if alpha == 0: 43 | return v_sorted[0] 44 | return single_alpha_to_yc(p_sorted, v_sorted, alpha) / alpha 45 | 46 | 47 | def single_alpha_to_yc(p_sorted, v_sorted, alpha): 48 | p = 0. 49 | yc = 0. 50 | for p_, v_ in zip(p_sorted, v_sorted): 51 | if p + p_ >= alpha: 52 | yc += (alpha - p)*v_ 53 | break 54 | else: 55 | p += p_ 56 | yc += p_*v_ 57 | 58 | return yc 59 | 60 | 61 | # =================================================================== 62 | # Single from transitions: 63 | # gets the desired values from transition distributions 64 | # =================================================================== 65 | 66 | def single_var_yc_xis_from_t(transition_p, t_atoms, var_values, alpha): 67 | """ 68 | Compute VaR, CVaR and xi values, using uniform last probabilities. 69 | 70 | """ 71 | 72 | info = extract_distribution(transition_p, t_atoms, var_values) 73 | 74 | xis = np.zeros(len(transition_p)) 75 | p = 0. 76 | yc = 0. 77 | 78 | v_alpha = single_alpha_to_var([p_ for p_, i_t, v in info], [v for p_, i_t, v in info], alpha) 79 | ix = 0 80 | for ix, (p_, t_i, v) in enumerate(info): 81 | if v >= v_alpha: 82 | yc += (alpha - p) * v 83 | break 84 | else: 85 | xis[t_i] += p_ 86 | yc += p_ * v 87 | p += p_ 88 | 89 | # same last atom -> weight uniformly 90 | last_v_info = [] 91 | while ix < len(info): 92 | p_, t_i, v = info[ix] 93 | if v != v_alpha: 94 | break 95 | last_v_info.append(info[ix]) 96 | ix += 1 97 | last_v_p = np.array([p_ for p_, t_i, v in last_v_info]) 98 | fractions = last_v_p/np.sum(last_v_p) 99 | 100 | for fr, (p_, t_i, v) in zip(fractions, last_v_info): 101 | xis[t_i] += (alpha - p) * fr 102 | 103 | return v_alpha, yc, xis / transition_p 104 | 105 | 106 | def single_yc_lp_from_t(transition_p, t_atoms, yc_values, alpha, xis=False): 107 | """ 108 | Create LP: 109 | min Sum p_t * I 110 | 111 | 0 <= xi <= 1/alpha 112 | Sum p_t * xi == 1 113 | 114 | I = max{y_cvar} 115 | 116 | return y_cvar[alpha] 117 | """ 118 | from pulp import LpVariable, LpProblem, value 119 | 120 | if alpha == 0: 121 | return 0. 122 | 123 | nb_transitions = len(transition_p) 124 | 125 | Xi = [LpVariable('xi_' + str(i)) for i in range(nb_transitions)] 126 | I = [LpVariable('I_' + str(i)) for i in range(nb_transitions)] 127 | 128 | prob = LpProblem(name='tamar') 129 | 130 | for xi in Xi: 131 | prob.addConstraint(0 <= xi) 132 | prob.addConstraint(xi <= 1./alpha) 133 | prob.addConstraint(sum([xi*p for xi, p in zip(Xi, transition_p)]) == 1) 134 | 135 | for xi, i, yc, atoms in zip(Xi, I, yc_values, t_atoms): 136 | last_yc = 0. 137 | f_params = [] 138 | atom_p = atoms[1:] - atoms[:-1] 139 | for ix in range(len(yc)): 140 | # linear interpolation as a solution to 'y = kx + q' 141 | k = (yc[ix]-last_yc)/atom_p[ix] 142 | 143 | q = last_yc - k * atoms[ix] 144 | prob.addConstraint(i >= k * xi * alpha + q) 145 | f_params.append((k, q)) 146 | last_yc = yc[ix] 147 | 148 | # opt criterion 149 | prob.setObjective(sum([i * p for i, p in zip(I, transition_p)])) 150 | 151 | prob.solve() 152 | 153 | if xis: 154 | return value(prob.objective), [value(xi)*alpha for xi in Xi] 155 | else: 156 | return value(prob.objective) 157 | 158 | 159 | # =================================================================== 160 | # Distribution <=> vector 161 | # =================================================================== 162 | 163 | def yc_to_var(atoms, y_cvar): 164 | """ yCVaR -> distribution. Outputs same atoms as input. """ 165 | last = 0. 166 | var = np.zeros_like(y_cvar) 167 | 168 | for i in range(len(atoms) - 1): 169 | p = atoms[i + 1] - atoms[i] 170 | ddy = (y_cvar[i] - last) / p 171 | var[i] = ddy 172 | last = y_cvar[i] 173 | 174 | return var 175 | 176 | 177 | def var_to_ycvar(p_sorted, v_sorted): # TODO: name 178 | yc = np.zeros_like(p_sorted) 179 | yc_last = 0 180 | for i in range(len(yc)): 181 | yc[i] = yc_last + p_sorted[i] * v_sorted[i] 182 | yc_last = yc[i] 183 | return yc 184 | 185 | 186 | def var_vector(atoms, p_sorted, v_sorted): 187 | """ 188 | :param atoms: full atoms, shape=[n] 189 | :param p_sorted: probability mass, shape=[m] 190 | :param v_sorted: quantile values, shape=[m] 191 | :return: VaR at atoms[1:], shape=[n-1] 192 | """ 193 | v = np.zeros(len(atoms)-1) 194 | p = 0 195 | ix = 0 # index in p,v 196 | atom_ix = 1 197 | p_ = 0 198 | while p < 1 and ix < len(p_sorted): 199 | if p_ == 0: 200 | p_ = p_sorted[ix] 201 | 202 | if p + p_ >= atoms[atom_ix]: 203 | p_difference = atoms[atom_ix] - p 204 | p = atoms[atom_ix] 205 | v[atom_ix-1] = v_sorted[ix] 206 | atom_ix += 1 207 | p_ -= p_difference 208 | else: 209 | p += p_ 210 | ix += 1 211 | p_ = 0 212 | assert abs(p - 1) < 1e-5 213 | 214 | return v 215 | 216 | 217 | def ycvar_vector(atoms, p_sorted, v_sorted): 218 | """ 219 | Compute yCvAR at desired atom locations. 220 | len(p) == len(var) 221 | :param atoms: desired atom locations. 222 | """ 223 | y_cvar = np.zeros(len(atoms)-1) 224 | p = 0 225 | ycv = 0 226 | ix = 0 # index in p,v 227 | atom_ix = 1 228 | p_ = 0 229 | while p < 1 and ix < len(p_sorted): 230 | if p_ == 0: 231 | p_ = p_sorted[ix] 232 | v_ = v_sorted[ix] 233 | 234 | if p + p_ >= atoms[atom_ix]: 235 | p_difference = atoms[atom_ix] - p 236 | ycv += p_difference * v_ 237 | p = atoms[atom_ix] 238 | y_cvar[atom_ix-1] = ycv 239 | atom_ix += 1 240 | p_ -= p_difference 241 | if p_ == 0: 242 | ix += 1 243 | else: 244 | ycv += p_ * v_ 245 | p += p_ 246 | ix += 1 247 | p_ = 0 248 | 249 | # numerical errors 250 | if p != 1: 251 | y_cvar[-1] = ycv 252 | 253 | assert abs(p-1) < 1e-5 254 | assert abs(y_cvar[-1] - np.dot(p_sorted, v_sorted)) < 1e-5 255 | 256 | return y_cvar 257 | 258 | 259 | # =================================================================== 260 | # Transitions => vector 261 | # =================================================================== 262 | 263 | 264 | def v_yc_from_t_lp(atoms, transition_p, t_yc, t_atoms): 265 | """ CVaR computation by dual decomposition LP. """ 266 | y_cvar = [single_yc_lp_from_t(transition_p, t_atoms, t_yc, alpha) for alpha in atoms[1:]] 267 | # extract vars: 268 | var = yc_to_var(atoms, y_cvar) 269 | 270 | return var, y_cvar 271 | 272 | 273 | def v_yc_from_t(atoms, transition_p, var_values, t_atoms): 274 | """ 275 | CVaR computation by using underlying distributions. 276 | :param atoms: points of interest 277 | :param transition_p: 278 | :param var_values: (transitions, nb_atoms) 279 | :param t_atoms: (transitions, nb_atoms+1) e.g. [0, 0.25, 0.5, 1] 280 | :return: 281 | """ 282 | # 0) weight by transition probs 283 | p = np.concatenate([transition_p[i]*(t_atoms[i][1:] - t_atoms[i][:-1]) for i in range(len(transition_p))]) 284 | 285 | # 1) sort 286 | sortargs = np.concatenate(var_values).argsort() 287 | var_sorted = np.concatenate(var_values)[sortargs] 288 | p_sorted = p[sortargs] 289 | 290 | # 2) compute y_cvar for each atom 291 | y_cvar = ycvar_vector(atoms, p_sorted, var_sorted) 292 | 293 | # 3) get vars from y_cvar 294 | var = yc_to_var(atoms, y_cvar) 295 | 296 | return var, y_cvar 297 | 298 | 299 | def v_0_from_transitions(V, transitions, gamma): 300 | return min([t.reward + gamma*V[t.state.y, t.state.x].c_0 for t in transitions]) 301 | 302 | 303 | def extract_distribution(transition_p, t_atoms, var_values): 304 | """ 305 | :return: sorted list of tuples (probability, index, var) 306 | """ 307 | info = [] 308 | for i_t, t_p in enumerate(transition_p): 309 | for v, p_ in zip(var_values[i_t], t_atoms[i_t][1:]-t_atoms[i_t][:-1]): 310 | info.append((p_ * t_p, i_t, v)) 311 | 312 | info.sort(key=lambda x: x[-1]) 313 | return info 314 | 315 | 316 | # =================================================================== 317 | # Other 318 | # =================================================================== 319 | # TODO: move to common 320 | def var_cvar_from_samples(samples, alpha): 321 | samples = np.sort(samples) 322 | alpha_ix = int(np.round(alpha * len(samples))) 323 | var = samples[alpha_ix - 1] 324 | cvar = np.mean(samples[:alpha_ix]) 325 | return var, cvar 326 | -------------------------------------------------------------------------------- /cvar/gridworld/core/models.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | # class Model: 5 | # """ Container for safe saving and retrieval of VI/Q-learning models.""" 6 | # 7 | # def __init__(self, world, model, **kwargs): 8 | # self.world = world 9 | # self.model = model 10 | # self.info = kwargs 11 | 12 | def save(path, world, model, **kwargs): 13 | pickle.dump((world, model, kwargs), open(path, mode='wb')) 14 | 15 | 16 | -------------------------------------------------------------------------------- /cvar/gridworld/core/policies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cvar.gridworld.core.constants import gamma 3 | from cvar.gridworld.core import cvar_computation 4 | 5 | 6 | class Policy: 7 | """ Abstract class representing different policies. """ 8 | 9 | __name__ = 'Policy' 10 | 11 | def next_action(self, t): 12 | raise NotImplementedError() 13 | 14 | def reset(self): 15 | pass 16 | 17 | 18 | class FixedPolicy(Policy): 19 | __name__ = 'Fixed' 20 | 21 | def __init__(self, P, alpha=None): 22 | self.P = P 23 | 24 | def next_action(self, t): 25 | return self.P[t.state.y, t.state.x] 26 | 27 | 28 | class GreedyPolicy(Policy): 29 | __name__ = 'Greedy' 30 | 31 | def __init__(self, Q, alpha=None): 32 | self.Q = Q 33 | 34 | def next_action(self, t): 35 | s = t.state 36 | return np.argmax(self.Q[:, s.y, s.x]) 37 | 38 | 39 | class NaiveCvarPolicy(Policy): 40 | __name__ = 'Naive CVaR' 41 | 42 | def __init__(self, Q, alpha): 43 | self.Q = Q 44 | self.alpha = alpha 45 | 46 | def next_action(self, t): 47 | s = t.state 48 | action_distributions = self.Q[:, s.y, s.x] 49 | a = np.argmax([d.cvar(self.alpha) for d in action_distributions]) 50 | return a 51 | 52 | 53 | class AlphaBasedPolicy(Policy): 54 | """ Deprecated, diverging policy. """ 55 | __name__ = 'alpha-based CVaR' 56 | 57 | def __init__(self, Q, alpha): 58 | raise DeprecationWarning('counterexample found') 59 | self.Q = Q 60 | self.init_alpha = alpha 61 | self.alpha = alpha 62 | self.s_old = None 63 | self.a_old = None 64 | 65 | def next_action(self, t): 66 | s = t.state 67 | 68 | action_distributions = self.Q[:, s.y, s.x] 69 | old_action = np.argmax(expected_value(action_distributions)) 70 | 71 | if self.alpha > 0.999: 72 | return old_action 73 | 74 | if self.s_old is not None: 75 | self.update_alpha(self.s_old, self.a_old, t) 76 | a = np.argmax([d.cvar(self.alpha) for d in action_distributions]) 77 | self.s_old = s 78 | self.a_old = a 79 | 80 | return a 81 | 82 | def update_alpha(self, s, a, t): 83 | """ 84 | Correctly updates next alpha with discrete variables. 85 | :param s: state we came from 86 | :param a: action we took 87 | :param t: transition we sampled 88 | """ 89 | s_ = t.state 90 | s_dist = self.Q[a, s.y, s.x] 91 | 92 | a_ = np.argmax(expected_value(self.Q[:, s_.y, s_.x])) 93 | 94 | s__dist = self.Q[a_, s_.y, s_.x] 95 | 96 | var_ix = s_dist.var_index(self.alpha) 97 | 98 | var__ix = clip(var_ix - t.reward) 99 | 100 | # prob before var 101 | p_pre = np.sum(s_dist.p[:var_ix]) 102 | # prob at var 103 | p_var = s_dist.p[var_ix] 104 | # prob before next var 105 | p__pre = np.sum(s__dist.p[:var__ix]) 106 | # prob at next var 107 | p__var = s__dist.p[var__ix] 108 | 109 | # how much does t add to the full var 110 | # p_portion = (t.prob * p__var) / self.p_portion_sum(s, a, var_ix) 111 | p_portion = 1 112 | 113 | # we care about this portion of var 114 | p_active = (self.alpha - p_pre) / p_var 115 | 116 | self.alpha = p__pre + p_active * p__var * p_portion 117 | 118 | # def p_portion_sum(self, s, a, var_ix): 119 | # 120 | # p_portion = 0. 121 | # 122 | # for t_ in transitions(s)[a]: 123 | # action_distributions = self.Q[:, t_.state.y, t_.state.x] 124 | # a_ = np.argmax(expected_value(action_distributions)) 125 | # p_portion += t_.prob*action_distributions[a_].p[clip(var_ix - t_.reward)] 126 | # 127 | # return p_portion 128 | 129 | def reset(self): 130 | self.alpha = self.init_alpha 131 | self.s_old = None 132 | self.a_old = None 133 | 134 | 135 | class XiBasedPolicy(Policy): 136 | __name__ = 'Tamar-like' 137 | 138 | def __init__(self, V, alpha): 139 | self.V = V 140 | self.alpha = alpha 141 | self.orig_alpha = alpha 142 | self.last_state = None 143 | self.last_action = None 144 | self.last_xis = None 145 | 146 | def next_action(self, transition): 147 | 148 | if self.last_state is not None: 149 | t_ix = list(self.V.transitions(self.last_state.y, self.last_state.x, self.last_action)).index(transition) 150 | self.alpha = self.last_xis[t_ix] 151 | 152 | self.last_action, self.last_xis = self.V.next_action(transition.state.y, transition.state.x, self.alpha) 153 | self.last_state = transition.state 154 | 155 | # print('alpha:', self.alpha) 156 | 157 | return self.last_action 158 | 159 | def reset(self): 160 | self.alpha = self.orig_alpha 161 | self.last_state = None 162 | self.last_action = None 163 | self.last_xis = None 164 | 165 | 166 | class TamarVarBasedPolicy(Policy): # TODO: whats this? delete? X transform to Q-like var based 167 | __name__ = 'Tamar VaR-based CVaR' 168 | 169 | def __init__(self, V, alpha): 170 | self.V = V 171 | self.alpha = alpha 172 | self.var = None 173 | 174 | def next_action(self, t): 175 | if self.var is None: 176 | best = (0, -1e6, 0) # (var, cvar, action) 177 | for a in self.V.world.ACTIONS: 178 | v, cv, _ = self.V.single_var_yc_xis(t.state.y, t.state.x, a, self.alpha) 179 | if cv > best[1]: 180 | best = v, cv, a 181 | v, _, a = best 182 | self.var = v 183 | return a 184 | else: 185 | self.var = (self.var - t.reward)/gamma 186 | 187 | a = np.argmax([self.V.y_var(t.state.y, t.state.x, a, self.var)[1] for a in self.V.world.ACTIONS]) 188 | 189 | # print('alpha:', self.V.y_var(t.state.y, t.state.x, a, self.var)[0]) 190 | return a 191 | 192 | def reset(self): 193 | self.var = None 194 | 195 | # ========================================== 196 | # Q-learning 197 | # ========================================== 198 | 199 | 200 | class VarBasedQPolicy(Policy): 201 | """ For Q-learning with CVaR. """ 202 | __name__ = 'VaR-based CVaR' 203 | 204 | def __init__(self, Q, alpha): 205 | self.Q = Q 206 | self.alpha = alpha 207 | self.s = None 208 | 209 | def next_action(self, t): 210 | x, r = t.state, t.reward 211 | 212 | if self.s is None: 213 | a = self.Q.next_action_alpha(x, self.alpha) 214 | self.s = self.Q.var_alpha(x, a, self.alpha) 215 | else: 216 | self.s = (self.s - t.reward) / gamma 217 | a = self.Q.next_action_s(x, self.s) 218 | return a 219 | 220 | def reset(self): 221 | self.s = None 222 | 223 | 224 | class VarXiQPolicy(Policy): 225 | """ For Q-learning with CVaR. """ 226 | __name__ = 'VaRXi-based CVaR' 227 | 228 | def __init__(self, Q, alpha): 229 | self.Q = Q 230 | self.alpha = alpha 231 | self.orig_alpha = alpha 232 | self.s = None 233 | 234 | def next_action(self, t): 235 | x, r = t.state, t.reward 236 | 237 | if self.s is not None: 238 | s = (self.s - t.reward) / gamma 239 | self.alpha = self.Q.alpha_from_var(x, s) 240 | 241 | a = self.Q.next_action_alpha(x, self.alpha) 242 | self.s = self.Q.var_alpha(x, a, self.alpha) 243 | 244 | return a 245 | 246 | def reset(self): 247 | self.s = None 248 | self.alpha = self.orig_alpha 249 | 250 | 251 | class XiBasedQPolicy(Policy): 252 | """ For Q-learning with CVaR. """ 253 | __name__ = 'VaR-based CVaR' 254 | 255 | def __init__(self, Q, alpha): 256 | self.Q = Q 257 | self.alpha = alpha 258 | self.orig_alpha = alpha 259 | 260 | self.last_t = None 261 | self.last_a = None 262 | 263 | def next_action(self, t): 264 | x, r = t.state, t.reward 265 | 266 | if self.last_t is not None: 267 | last_s = self.Q.var_alpha(self.last_t.state, self.last_a, self.alpha) 268 | s = (last_s - t.reward) / gamma 269 | var_dist = self.Q.joint_action_dist_var(x) 270 | self.alpha = cvar_computation.single_var_to_alpha(self.Q.atom_p, var_dist, s) 271 | 272 | a = self.Q.next_action_alpha(x, self.alpha) 273 | 274 | self.last_t = t 275 | self.last_a = a 276 | return a 277 | 278 | def reset(self): 279 | self.alpha = self.orig_alpha 280 | self.last_t = None 281 | 282 | 283 | class NaiveQPolicy(Policy): 284 | """ For Q-learning with CVaR. """ 285 | __name__ = 'VaR-based CVaR' 286 | 287 | def __init__(self, Q, alpha): 288 | self.Q = Q 289 | self.alpha = alpha 290 | 291 | def next_action(self, t): 292 | x, r = t.state, t.reward 293 | 294 | a = self.Q.next_action_alpha(x, self.alpha) 295 | 296 | return a 297 | -------------------------------------------------------------------------------- /cvar/gridworld/core/runs.py: -------------------------------------------------------------------------------- 1 | import time 2 | from cvar.gridworld.cliffwalker import * 3 | 4 | 5 | def epoch(world, policy, max_iters=100, plot_machine=None): 6 | """ 7 | Evaluates a single epoch starting at start_state, using a given policy. 8 | :param start_state: 9 | :param policy: Policy instance 10 | :param max_iters: end the epoch after this #steps 11 | :return: States, Actions, Rewards 12 | """ 13 | s = world.initial_state 14 | S = [s] 15 | A = [] 16 | R = [] 17 | i = 0 18 | t = Transition(s, 0, 0) 19 | while s not in world.goal_states and i < max_iters: 20 | a = policy.next_action(t) 21 | A.append(a) 22 | 23 | if plot_machine is not None: 24 | plot_machine.step(s, a) 25 | time.sleep(0.5) 26 | 27 | t = world.sample_transition(s, a) 28 | 29 | r = t.reward 30 | s = t.state 31 | 32 | R.append(r) 33 | S.append(s) 34 | i += 1 35 | 36 | return S, A, R 37 | 38 | 39 | def optimal_path(world, policy): 40 | """ Optimal deterministic path. """ 41 | s = world.initial_state 42 | states = [s] 43 | t = Transition(s, 0, 0) 44 | while s not in world.goal_states: 45 | a = policy.next_action(t) 46 | t = max(world.transitions(s)[a], key=lambda t: t.prob) 47 | s = t.state 48 | if s in states: 49 | print("ERROR: path repeats {}, last action={}".format(s, world.ACTION_NAMES[a])) 50 | return states 51 | states.append(s) 52 | return states -------------------------------------------------------------------------------- /cvar/gridworld/exp_model.py: -------------------------------------------------------------------------------- 1 | """ Standard RL methods stored here - VI, PI, Q-learning. 2 | Is not fully compatible with distributional setting. 3 | """ 4 | 5 | from cvar.gridworld.cliffwalker import GridWorld 6 | from cvar.gridworld.plots.grid import show_fixed 7 | from cvar.gridworld.core.constants import gamma 8 | import numpy as np 9 | 10 | # ======================================== 11 | # algorithms 12 | # ======================================== 13 | 14 | 15 | def value_iteration(world): 16 | Q = np.zeros((len(world.ACTIONS), world.height, world.width)) 17 | i = 0 18 | while True: 19 | Q_ = value_update(world, Q, np.argmax(Q, axis=0)) 20 | if converged(Q, Q_) and i != 0: 21 | print("value fully learned after %d iterations" % (i,)) 22 | break 23 | Q = Q_ 24 | i += 1 25 | return Q 26 | 27 | 28 | def policy_iteration(world): 29 | Q = np.zeros((len(world.ACTIONS), world.height, world.width)) 30 | i = 0 31 | while True: 32 | Q_ = eval_fixed_policy(np.argmax(Q, axis=0)) 33 | print(i) 34 | if np.all(np.argmax(Q, axis=0) == np.argmax(Q_, axis=0)) and i != 0: 35 | print("policy fully learned after %d iterations" % (i,)) 36 | break 37 | i += 1 38 | Q = Q_ 39 | 40 | return Q 41 | 42 | 43 | def q_learning(world, max_episodes=1e3, max_iters=100): 44 | Q = np.zeros((len(world.ACTIONS), world.height, world.width)) 45 | 46 | beta = 0.4 # learning rate 47 | eps = 0.5 48 | 49 | iter = 0 50 | while True: 51 | if iter % 10 == 0: 52 | beta *= 0.995 53 | print("{}: beta={}".format(iter, beta)) 54 | # ========================== 55 | s = world.initial_state 56 | 57 | i = 0 58 | while i < max_iters: 59 | # sample next action 60 | a = policy_sample(epsilon_greedy_policy(eps), s, Q) 61 | 62 | # sample next transition 63 | t = world.sample_transition(s, a) 64 | r, s_ = t.reward, t.state 65 | 66 | # update Q 67 | if s_ in world.goal_states: 68 | Q[a, s.y, s.x] = (1 - beta) * Q[a, s.y, s.x] + beta * r 69 | break 70 | else: 71 | a_ = np.argmax(Q[:, s_.y, s_.x]) 72 | Q[a, s.y, s.x] = (1-beta)*Q[a, s.y, s.x] + beta*(r + gamma * Q[a_, s_.y, s_.x]) 73 | 74 | s = s_ 75 | 76 | # update learning parameters 77 | # if iter > 0.3*max_episodes: 78 | # eps = 1/(iter + 1) 79 | # print("{}: eps={}".format(iter, eps)) 80 | 81 | 82 | iter += 1 83 | 84 | if iter > max_episodes: 85 | break 86 | 87 | return Q 88 | 89 | 90 | # ======================================== 91 | # policies 92 | # ======================================== 93 | 94 | 95 | # random policy: each action has the same probability 96 | def random_policy(s, Q): 97 | return [1.0 / len(GridWorld.ACTIONS) for a in GridWorld.ACTIONS] 98 | 99 | 100 | # greedy policy gives the best action (based on action value function Q) a probability of 1, others are given 0 101 | def greedy_policy(s, Q): 102 | probs = np.zeros_like(GridWorld.ACTIONS, dtype=float) 103 | probs[np.argmax(Q[:, s.y, s.x])] = 1.0 104 | return probs 105 | 106 | 107 | # epsilon-greedy policy gives random action with probability eps or greedy one otherwise 108 | def epsilon_greedy_policy(eps=0.0): 109 | def epsilon_greedy_policy_helper(s, Q): 110 | if np.random.uniform() < eps: 111 | return random_policy(s, Q) 112 | else: 113 | return greedy_policy(s, Q) 114 | 115 | return epsilon_greedy_policy_helper 116 | 117 | 118 | def policy_sample(policy, *args): 119 | p = policy(*args) 120 | return np.random.choice(GridWorld.ACTIONS, p=p) 121 | 122 | # ======================================== 123 | # other 124 | # ======================================== 125 | 126 | 127 | def value_update(world, Q, P): 128 | """ 129 | One value update step. 130 | :param Q: (A, M, N): current Q-values 131 | :param P: (M, N): indices of actions to be selected 132 | :return: (A, M, N): new Q-values 133 | """ 134 | Q_ = np.array(Q) 135 | for s in world.states(): 136 | for a, action_transitions in zip(world.ACTIONS, world.transitions(s)): 137 | t_p = np.array([t.prob for t in action_transitions]) 138 | 139 | t_q = [t.reward + gamma * Q[P[t.state.y, t.state.x], t.state.y, t.state.x] for t in action_transitions] 140 | Q_[a, s.y, s.x] = np.dot(t_p, t_q) 141 | 142 | return Q_ 143 | 144 | 145 | def converged(Q, Q_): 146 | print('error:', np.max(Q-Q_)) 147 | return np.max(Q-Q_) < 1e-5 148 | 149 | 150 | def eval_fixed_policy(world, P): 151 | Q = np.zeros((len(world.ACTIONS), world.height, world.width)) 152 | i = 0 153 | while True: 154 | Q_ = value_update(world, Q, P) 155 | if converged(Q, Q_) and i != 0: 156 | break 157 | Q = Q_ 158 | i += 1 159 | return Q 160 | 161 | 162 | def q_to_v_argmax(world, Q): 163 | """ Converts Q function to V by choosing the best action. """ 164 | Vnew = np.zeros((world.height, world.width)) 165 | for s in world.states(): 166 | a = np.argmax(Q[:, s.y, s.x]) 167 | Vnew[s.y, s.x] = Q[a, s.y, s.x] 168 | return Vnew 169 | 170 | 171 | if __name__ == '__main__': 172 | import pickle 173 | np.random.seed(2) 174 | world = GridWorld(10, 15, random_action_p=0.1) 175 | 176 | # Q = policy_iteration(world) 177 | # Q = value_iteration(world) 178 | Q = q_learning(world, max_episodes=10000) 179 | 180 | pickle.dump((world, Q), open('data/models/exp_10_15.pkl', mode='wb')) 181 | 182 | show_fixed(world, q_to_v_argmax(world, Q), np.argmax(Q, axis=0)) 183 | -------------------------------------------------------------------------------- /cvar/gridworld/interactive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Silvicek/cvar-algorithms/ec60696d269857f78213b8cbde506bb5e94f34a3/cvar/gridworld/interactive.png -------------------------------------------------------------------------------- /cvar/gridworld/plots/__init__.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from cycler import cycler 3 | 4 | 5 | # ======================================== 6 | # plt + latex settings 7 | # ======================================== 8 | # plt.rc('axes', prop_cycle=(cycler('color', ['#1f77b4', '#d62728']))) 9 | 10 | # plt.rc('text', usetex=True) 11 | # plt.rc('font', family='serif') 12 | -------------------------------------------------------------------------------- /cvar/gridworld/plots/grid.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from mpl_toolkits.axes_grid1 import make_axes_locatable 3 | import numpy as np 4 | from cvar.gridworld.cliffwalker import State 5 | from cvar.gridworld.core import cvar_computation 6 | 7 | # arrows 8 | offsets = {0: (0.4, 0), 1: (-0.4, 0), 2: (0, 0.4), 3: (0, -0.4)} 9 | dirs = {0: (-0.8, 0), 1: (0.8, 0), 2: (0, -0.8), 3: (0, 0.8)} 10 | 11 | 12 | class PlotMachine: 13 | 14 | def __init__(self, world, V=None): 15 | if V is None: 16 | self.V = -1 * np.ones((world.height, world.width)) 17 | else: 18 | self.V = V 19 | # darken cliff 20 | cool = np.min(self.V) * 1.1 21 | for s in world.cliff_states: 22 | self.V[s.y, s.x] = cool 23 | 24 | plt.ion() 25 | 26 | self.fig, self.ax = plt.subplots() 27 | 28 | im = self.ax.imshow(self.V, interpolation='nearest', origin='upper') 29 | plt.tick_params(axis='both', which='both', bottom='off', top='off', 30 | labelbottom='off', right='off', left='off', labelleft='off') 31 | divider = make_axes_locatable(self.ax) 32 | cax = divider.append_axes("right", size="5%", pad=0.05) 33 | plt.colorbar(im, cax=cax) 34 | 35 | self.ax.text(world.initial_state.x, world.initial_state.y, 'S', ha='center', va='center', fontsize=20) 36 | for s in world.goal_states: 37 | self.ax.text(s[1], s[0], 'G', ha='center', va='center', fontsize=20) 38 | for s in world.risky_goal_states: 39 | self.ax.text(s[1], s[0], 'R', ha='center', va='center', fontsize=20) 40 | 41 | self.arrow = self.ax.add_patch(plt.Arrow(0, 0, 1, 1, color='white')) 42 | 43 | def step(self, s, a): 44 | 45 | self.arrow.remove() 46 | arrow = plt.Arrow(s.x + offsets[a][0], s.y + offsets[a][1], dirs[a][0], dirs[a][1], color='white') 47 | self.arrow = self.ax.add_patch(arrow) 48 | 49 | self.fig.canvas.draw() 50 | self.fig.canvas.flush_events() 51 | 52 | 53 | # TODO: unify imshow 54 | class InteractivePlotMachine: 55 | 56 | def __init__(self, world, V, action_value=False, alpha=1): 57 | self.world = world 58 | self.V = V 59 | if action_value: 60 | img = np.max(np.array([V.Q[ix].yc_alpha(alpha)/alpha for ix in np.ndindex(V.Q.shape)]).reshape(V.Q.shape), axis=-1) 61 | print(img.shape) 62 | 63 | self.fig, self.ax = grid_plot(world, img) 64 | self.fig.canvas.mpl_connect('button_press_event', self.handle_click_q) 65 | else: 66 | img = np.array([V.V[ix].cvar_alpha(alpha) for ix in np.ndindex(V.V.shape)]).reshape(V.V.shape) 67 | print(img.shape) 68 | 69 | self.fig, self.ax = grid_plot(world, img) 70 | self.fig.canvas.mpl_connect('button_press_event', self.handle_click_v) 71 | 72 | self.ax.set_title("$\\alpha={:.2f}$".format(alpha)) 73 | # Optimal path 74 | path = self.V.optimal_path(alpha) 75 | print(path) 76 | self.ax.plot([s[1] for s in path], [s[0] for s in path], 'o-', color='white') 77 | 78 | # 79 | self.state_fig = None 80 | self.state_ax = None 81 | 82 | def handle_click_v(self, event): 83 | 84 | if event.xdata is None: 85 | return 86 | x, y = self._canvas_to_grid(event.xdata, event.ydata) 87 | 88 | if self.state_fig is None: 89 | self.state_fig, self.state_ax = plt.subplots(1, 2) 90 | 91 | # clear axes 92 | for ax in self.state_ax: 93 | ax.clear() 94 | 95 | for a in self.world.ACTIONS: 96 | self.V.plot(y, x, a, ax=self.state_ax, show=False) 97 | ax.legend([self.world.ACTION_NAMES[a] for a in self.world.ACTIONS]) 98 | 99 | # combination of all actions 100 | V_x = self.V.V[y, x].var 101 | yc_x = self.V.V[y, x].yc 102 | self.state_ax[1].step(self.V.V[y, x].atoms, list(V_x) + [V_x[-1]], '--', where='post') 103 | self.state_ax[0].plot(self.V.V[y, x].atoms, np.insert(yc_x, 0, 0), '--') 104 | 105 | # titles 106 | self.state_fig.suptitle("(y={}, x={})".format(y, x)) 107 | 108 | # show 109 | self.state_fig.show() 110 | 111 | def handle_click_q(self, event): 112 | if event.xdata is None: 113 | return 114 | x, y = self._canvas_to_grid(event.xdata, event.ydata) 115 | 116 | if self.state_fig is None: 117 | self.state_fig, self.state_ax = plt.subplots(1, 3) 118 | 119 | # clear axes 120 | for ax in self.state_ax: 121 | ax.clear() 122 | 123 | for a in self.world.ACTIONS: 124 | self.V.Q[y, x, a].plot(ax=self.state_ax, show=False) 125 | ax.legend([self.world.ACTION_NAMES[a] for a in self.world.ACTIONS]) 126 | 127 | # combination of all actions 128 | V_x = self.V.joint_action_dist(State(y, x)) 129 | yc_x = self.V.joint_action_dist(State(y, x), True) 130 | self.state_ax[1].step(self.V.atoms, list(V_x) + [V_x[-1]], '--', where='post') 131 | self.state_ax[0].plot(self.V.atoms, np.insert(yc_x, 0, 0), '--') 132 | 133 | # titles 134 | self.state_fig.suptitle("(y={}, x={})".format(y, x)) 135 | 136 | # show 137 | self.state_fig.show() 138 | 139 | 140 | def _canvas_to_grid(self, xd, yd): 141 | offset = -0.5 142 | cell_length = 1 143 | x = int((xd - offset) / cell_length) 144 | y = int((yd - offset) / cell_length) 145 | return x, y 146 | 147 | def show(self): 148 | plt.show() 149 | 150 | 151 | # visualizes the final value function with a fixed policy 152 | def show_fixed(world, V, P): 153 | 154 | ax = plt.gca() 155 | 156 | # darken cliff 157 | cool = np.min(V) * 1.1 158 | for s in world.cliff_states: 159 | V[s.y, s.x] = cool 160 | 161 | im = ax.imshow(V, interpolation='nearest', origin='upper') 162 | plt.tick_params(axis='both', which='both', bottom='off', top='off', 163 | labelbottom='off', right='off', left='off', labelleft='off') 164 | divider = make_axes_locatable(ax) 165 | cax = divider.append_axes("right", size="5%", pad=0.05) 166 | 167 | plt.colorbar(im, cax=cax) 168 | 169 | ax.text(world.initial_state[1], world.initial_state[0], 'S', ha='center', va='center', fontsize=20) 170 | for s in world.goal_states: 171 | ax.text(s[1], s[0], 'G', ha='center', va='center', fontsize=20) 172 | for s in world.risky_goal_states: 173 | ax.text(s[1], s[0], 'R', ha='center', va='center', fontsize=20) 174 | 175 | for s in world.states(): 176 | if s in world.cliff_states: 177 | continue 178 | if s in world.goal_states: 179 | continue 180 | if s in world.risky_goal_states: 181 | continue 182 | 183 | a = P[s.y, s.x] 184 | ax.add_patch(plt.Arrow(s.x + offsets[a][0], s.y + offsets[a][1], dirs[a][0], dirs[a][1], color='white')) 185 | 186 | plt.show() 187 | 188 | 189 | def grid_plot(world, img=None, figax=None, sg_size=20): 190 | 191 | if figax is None: 192 | fig, ax = plt.subplots() 193 | else: 194 | fig, ax = figax 195 | 196 | if img is None: 197 | img = -1 * np.ones((world.height, world.width)) 198 | # darken cliff 199 | cool = np.min(img) * 1.1 200 | for s in world.cliff_states: 201 | img[s.y, s.x] = cool 202 | 203 | im = ax.imshow(img, interpolation='nearest', origin='upper') 204 | plt.tick_params(axis='both', which='both', bottom='off', top='off', 205 | labelbottom='off', right='off', left='off', labelleft='off') 206 | divider = make_axes_locatable(ax) 207 | cax = divider.append_axes("right", size="5%", pad=0.05) 208 | plt.colorbar(im, cax=cax) 209 | 210 | ax.text(world.initial_state.x, world.initial_state.y, 'S', ha='center', va='center', fontsize=sg_size) 211 | for s in world.goal_states: 212 | ax.text(s[1], s[0], 'G', ha='center', va='center', fontsize=sg_size) 213 | for s in world.risky_goal_states: 214 | ax.text(s[1], s[0], 'R', ha='center', va='center', fontsize=sg_size) 215 | return fig, ax 216 | 217 | -------------------------------------------------------------------------------- /cvar/gridworld/plots/info_plots.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import cvar.gridworld.plots.grid as grid 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | from cvar.gridworld.algorithms.q_learning import ActionValueFunction, MarkovQState 7 | from cvar.gridworld.algorithms.value_iteration import ValueFunction, MarkovState 8 | from cycler import cycler 9 | from cvar.gridworld.core.runs import epoch 10 | 11 | model_path = '../data/models/' 12 | plots_path = '../data/plots/' 13 | 14 | # ============================= SETTINGS 15 | plt.rc('text', usetex=True) 16 | plt.rc('font', family='serif') 17 | matplotlib.rcParams.update({'font.size': 8}) 18 | # plt.rc('axes', prop_cycle=(cycler('color', ['#1f77b4', '#d62728']))) 19 | 20 | 21 | def optimal_paths_grids(file_name, save_name=None, vi=False): 22 | world, model = pickle.load(open(model_path+file_name, 'rb')) 23 | alphas = [0.1, 0.2, 0.3, 1.] 24 | fig, axs = plt.subplots(2, 2, figsize=(8.5, 5)) 25 | 26 | for ax, alpha in zip(axs.flatten(), alphas): 27 | if vi: 28 | img = np.array([model.V[ix].cvar_alpha(alpha) for ix in np.ndindex(model.V.shape)]).reshape(model.V.shape) 29 | else: 30 | img = np.max(np.array([model.Q[ix].yc_alpha(alpha)/alpha for ix in np.ndindex(model.Q.shape)]).reshape(model.Q.shape), axis=-1) 31 | grid.grid_plot(world, img=img, figax=(fig, ax), sg_size=10) 32 | 33 | path = model.optimal_path(alpha) 34 | print(path) 35 | ax.plot([s[1] for s in path], [s[0] for s in path], '--', color='white') 36 | 37 | ax.set_title("$\\alpha={}$".format(alpha)) 38 | ax.axis('off') 39 | if save_name is None: 40 | plt.show() 41 | else: 42 | plt.savefig(plots_path+save_name, bbox_inches='tight') 43 | 44 | 45 | # ============================= RUNS -> stats 46 | def generate_samples(world, policy, nb_episodes=1000): 47 | scores = [] 48 | for i in range(nb_episodes): 49 | S, A, R = epoch(world, policy) 50 | policy.reset() 51 | scores.append(np.sum(R)) 52 | if i % 10 == 0: 53 | print('e:', i) 54 | return scores 55 | 56 | 57 | def sample_histograms(alpha, suffix): 58 | from cvar.common.cvar_computation import var_cvar_from_samples 59 | from cvar.gridworld.core.policies import GreedyPolicy, VarXiQPolicy 60 | 61 | # exp VI 62 | world, Q = pickle.load(open(model_path+'exp_'+suffix+'.pkl', 'rb')) 63 | scores_exp = generate_samples(world, GreedyPolicy(Q), nb_episodes=1000) 64 | v_exp, c_exp = var_cvar_from_samples(scores_exp, alpha) 65 | print('CVaR_{}(exp)={}'.format(alpha, c_exp)) 66 | 67 | # CVaR VI 68 | # world, Q = pickle.load(open('../data/models/vi_10_15.pkl', 'rb')) 69 | # scores_vi = generate_samples(world, XiBasedPolicy(Q, alpha)) 70 | 71 | # Q-learned 72 | world, Q = pickle.load(open('../data/models/q_'+suffix+'.pkl', 'rb')) 73 | scores_q = generate_samples(world, VarXiQPolicy(Q, alpha), nb_episodes=1000) 74 | v_q, c_q = var_cvar_from_samples(scores_q, alpha) 75 | print('CVaR_{}(q)={}'.format(alpha, c_q)) 76 | fig = plt.figure(figsize=(5, 3)) 77 | plt.grid() 78 | plt.hist(scores_exp, density=True, bins=20, edgecolor='black') 79 | plt.hist(scores_q, density=True, bins=20, edgecolor='black') 80 | plt.legend(['Q-learning', 'CVaR Q-learning']) 81 | 82 | plt.savefig(plots_path + 'sample_hist.pdf', bbox_inches='tight') 83 | # plt.show() 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | sample_histograms(0.05, suffix='10_15') 89 | 90 | # optimal_paths_grids('vi_40_60.pkl', 'vi_optimal_paths.pdf', vi=True) 91 | # optimal_paths_grids('q_10_15.pkl', 'q_optimal_paths.pdf', vi=False) 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /cvar/gridworld/plots/other.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plot_cvars(): 7 | """ Old function to plot policy improvement comparisons. """ 8 | import pickle 9 | 10 | data = pickle.load(open('data/stats.pkl', 'rb')) 11 | 12 | cvars = data['cvars'] 13 | alphas = np.tile(data['alphas'], (len(cvars), 1)) 14 | ax = plt.gca() 15 | ax.plot(alphas.T, cvars.T, '-') 16 | ax.set_xscale('log') 17 | ax.set_xticks(alphas[0]) 18 | ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) 19 | ax.invert_xaxis() 20 | ax.set_xlabel('$\\alpha$') 21 | ax.set_ylabel('CVaR$_\\alpha$') 22 | # ax.set_ylim([-50, -10]) 23 | ax.legend(data['names']) 24 | ax.grid() 25 | plt.show() 26 | 27 | -------------------------------------------------------------------------------- /cvar/gridworld/plots/thesis_plots.py: -------------------------------------------------------------------------------- 1 | """ Used for generating various unrelated plots used in the thesis. """ 2 | import matplotlib.pyplot as plt 3 | import scipy.stats 4 | from cycler import cycler 5 | # from cvar.gridworld.core.util import softmax 6 | 7 | from cvar.gridworld.core.constants import * 8 | 9 | # plt.rc('axes', prop_cycle=(cycler('color', ['#1f77b4', '#d62728']))) 10 | 11 | # tex 12 | plt.rc('text', usetex=True) 13 | plt.rc('font', family='serif') 14 | # ==================== 15 | 16 | 17 | def pdf_to_cvar(prob, var, alpha): 18 | print(prob[:10]/alpha) 19 | print(np.sum(prob[:300])) 20 | print(sum(prob)) 21 | p = 0. 22 | cv = 0. 23 | for p_, v_ in zip(prob, var): 24 | if p + p_ >= alpha: 25 | cv += (alpha - p) * v_ 26 | break 27 | else: 28 | cv += p_ * v_ 29 | p += p_ 30 | return v_, cv / alpha 31 | 32 | 33 | def plot_cvar_pdf(prob, vars, alpha, discrete=False): 34 | var, cvar = pdf_to_cvar(prob, vars, alpha) 35 | 36 | if discrete: 37 | n, bins, patches = plt.hist(x, 50, normed=1, facecolor='green', alpha=0.75, edgecolor='black') 38 | print(n, bins, patches) 39 | plt.show() 40 | else: 41 | fig, ax = plt.subplots(1, figsize=(8,4)) 42 | 43 | ax.plot(vars, prob) 44 | ax.vlines(x=var, ymin=0., ymax=0.001, linestyles='--', colors='g', label='$VaR_{%.2f}$' % alpha) 45 | ax.vlines(x=cvar, ymin=0., ymax=0.001, linestyles='--', colors='r', label='$CVaR_{%.2f}$' % alpha) 46 | 47 | # plt.axis('off') 48 | ax.set_yticklabels([]) 49 | ax.set_xticklabels([]) 50 | plt.title('Probability Distribution Function') 51 | plt.grid() 52 | plt.legend() 53 | plt.show() 54 | 55 | 56 | def cvar_multinomial(): 57 | alpha = 0.05 58 | 59 | # student 60 | distribution = scipy.stats.t(1) 61 | vars = np.arange(-10, 10, 0.01) 62 | prob = distribution.pdf(vars) / 100 63 | 64 | # gaussian mixture 65 | d1 = scipy.stats.t(1, -2) 66 | d2 = scipy.stats.norm(5,1) 67 | d3 = scipy.stats.norm(-0,1) 68 | d4 = scipy.stats.norm(-7,0.5) 69 | vars = np.arange(-10, 10, 0.01) 70 | prob = (0.3 * d1.pdf(vars) + 0.3 * d2.pdf(vars) + 0.37 * d3.pdf(vars) + 0.03*d4.pdf(vars)) / 100 71 | # prob = ( d3.pdf(vars)) / 100 72 | 73 | 74 | # multinomial 75 | nb_atoms = 50 76 | atoms = softmax(np.random.random(nb_atoms)) 77 | var_values = np.random.random([nb_atoms])*10 - 5 78 | var_values.sort() 79 | 80 | x = np.random.lognormal(sigma=0.6, size=10000) + np.random.randn(10000) 81 | 82 | # the histogram of the data 83 | # n, bins, patches = plt.hist(x, 50, normed=1, facecolor='green', alpha=0.75, edgecolor='black') 84 | # plt.show() 85 | 86 | plot_cvar_pdf(prob, vars, alpha) 87 | # plot_cvar_pdf(atoms, var_values, alpha, discrete=True) 88 | 89 | 90 | def loss_functions(): 91 | # fig, ax = plt.subplots(1,4) 92 | 93 | x = np.linspace(-1, 1, 201) 94 | print(x) 95 | 96 | 97 | y_mse = x ** 2 98 | 99 | # y_10 = np.where(x >= 0, 0.1*x, (0.1-1)*(-x)) 100 | y_med = x * (0.5 - (x < 0)) 101 | y_30 = x * (0.3 - (x < 0)) 102 | y_70 = x * (0.7 - (x < 0)) 103 | 104 | 105 | fig = plt.figure(figsize=(5,4)) 106 | # y_90 = np.abs(x) 107 | plt.plot(x, y_mse) 108 | plt.plot(x, y_med) 109 | plt.plot(x, y_30) 110 | plt.plot(x, y_70) 111 | 112 | axes = plt.gca() 113 | axes.set_xlim([-1, 1]) 114 | axes.set_ylim([0, 1]) 115 | 116 | plt.grid() 117 | 118 | plt.tick_params(axis='x', which='both', bottom=False, labelbottom=False) 119 | plt.tick_params(axis='y', which='both', left=False, labelleft=False) 120 | 121 | plt.legend(['MSE', '$\\alpha=0.5$', '$\\alpha=0.3$', '$\\alpha=0.7$']) 122 | 123 | plt.savefig('../data/plots/losses.pdf', bbox_inches='tight') 124 | plt.show() 125 | 126 | if __name__ == '__main__': 127 | loss_functions() 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /cvar/gridworld/plots/vi_compare_plots.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plots comparisons between tamar, sort, wasserstein. 3 | """ 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | from pulp import * 7 | from cvar.gridworld.core import cvar_computation 8 | import numpy as np 9 | from cycler import cycler 10 | 11 | 12 | plt.rc('axes', prop_cycle=(cycler('color', ['#1f77b4', '#d62728']))) 13 | 14 | plt.rc('text', usetex=True) 15 | plt.rc('font', family='serif') 16 | matplotlib.rcParams.update({'font.size': 8}) 17 | 18 | 19 | # TODO: fix and move 20 | def wasserstein_lp(): 21 | # 0) weight by transition probs 22 | p = np.outer(transition_p, atom_p).flatten() 23 | 24 | # 1) create quantile function 25 | sortargs = var_values.flatten().argsort() 26 | var_sorted = var_values.flatten()[sortargs] 27 | p_sorted = p.flatten()[sortargs] 28 | p_sorted, var_sorted = further_split(p_sorted, var_sorted) 29 | 30 | 31 | # 2) create LP minimizing |y-var| 32 | Y = [LpVariable('y_'+str(i)) for i in range(nb_atoms)] 33 | U = [LpVariable('u_'+str(i)) for i in range(len(p_sorted))] # abs value slack 34 | 35 | prob = LpProblem(name='wasserstein') 36 | 37 | cp = 0. 38 | atom_ix = 1 39 | for u, p_, v_ in zip(U, p_sorted, var_sorted): 40 | cp += p_ 41 | 42 | prob.addConstraint(u >= Y[atom_ix-1] - v_) 43 | prob.addConstraint(u >= v_ - Y[atom_ix-1]) 44 | 45 | if cp == atoms[atom_ix]: 46 | atom_ix += 1 47 | 48 | # opt criterion 49 | prob.setObjective(sum([u*p for u, p in zip(U, p_sorted)])) 50 | 51 | prob.solve() 52 | 53 | print(value(prob.objective)) 54 | 55 | return [value(y_) for y_ in Y] 56 | 57 | 58 | def wasserstein_median(): 59 | # 0) weight by transition probs 60 | p = np.outer(transition_p, atom_p).flatten() 61 | 62 | # 1) create quantile function 63 | sortargs = var_values.flatten().argsort() 64 | var_sorted = var_values.flatten()[sortargs] 65 | p_sorted = p.flatten()[sortargs] 66 | p_sorted, var_sorted = further_split(p_sorted, var_sorted) 67 | 68 | # 2) median minimizes wasserstein 69 | cp = 0. 70 | var_solution = [] 71 | atom_ix = 0 72 | for ix, p_, v_ in zip(range(len(p_sorted)), p_sorted, var_sorted): 73 | 74 | median_p = atoms[atom_ix] + atom_p[atom_ix]/2 75 | 76 | if abs(cp + p_ - median_p) < atom_p[atom_ix]/100: # there is a step near the middle 77 | var_solution.append((v_ + var_sorted[ix+1])/2) 78 | atom_ix += 1 79 | elif cp + p_ > atoms[atom_ix] + atom_p[atom_ix]/2: 80 | atom_ix += 1 81 | var_solution.append(v_) 82 | 83 | cp += p_ 84 | 85 | if atom_ix == nb_atoms: 86 | break 87 | 88 | return var_solution 89 | 90 | 91 | def exact_pv(): 92 | p = np.outer(transition_p, atom_p).flatten() 93 | 94 | # 1) sort 95 | sortargs = var_values.flatten().argsort() 96 | var_sorted = var_values.flatten()[sortargs] 97 | p_sorted = p.flatten()[sortargs] 98 | return p_sorted, var_sorted 99 | 100 | 101 | def plot(*solutions, legend=True): 102 | # solution = (name, (prob, var)) 103 | 104 | fig, axs = plt.subplots(1, 2, figsize=(7, 3)) 105 | axs = np.array(axs) 106 | axs = axs.reshape(-1) 107 | 108 | # var 109 | ax = axs[0] 110 | for _, (p, sol) in solutions: 111 | sol = list(sol) 112 | ax.step(np.insert(np.cumsum(p), 0, 0), sol + [sol[-1]], where='post') 113 | ax.set_title('Quantile function') 114 | 115 | # yV 116 | ax = axs[1] 117 | for _, (p, sol) in solutions: 118 | ax.plot(np.insert(np.cumsum(p), 0, 0), np.insert(np.cumsum(p * sol), 0, 0), 'o-') 119 | ax.set_title('$y$CVaR$_y$') 120 | 121 | # # cvar 122 | # ax = axs[2] 123 | # for _, (p, sol) in solutions: 124 | # p, v = var_to_cvar_approx(p, sol) 125 | # ax.plot(p, v) 126 | # ax.set_title('CVaR') 127 | 128 | # cvar_s 129 | # ax = axs[3] 130 | # for _, (p, sol) in solutions: 131 | # a = [cvar_computation.s_to_alpha(s, p, sol) for s in s_range] 132 | # cv = [cvar_computation.single_cvar(p, sol, alpha) for alpha in a] 133 | # ax.plot(s_range, cv) 134 | # 135 | # var_at_atoms = cvar_computation.var_vector(atoms, ex_p, ex_v) 136 | # a = np.array([cvar_computation.single_var_to_alpha(atom_p, var_at_atoms, s) for s in s_range]) 137 | # cv = [cvar_computation.single_alpha_to_cvar(atom_p, ss, alpha) for alpha in a] 138 | # ax.plot(s_range, cv) 139 | # ax.set_title('CVaR(s)') 140 | 141 | # ===================================================== 142 | 143 | # legend 144 | if legend: 145 | for ax in axs: 146 | ax.legend([name for name, _ in solutions]) 147 | 148 | # hide last plot 149 | # ax[1][1].axis('off') 150 | 151 | # grid: on 152 | for ax in axs: 153 | # ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False) 154 | # ax.tick_params(axis='y', which='both', left=False, labelleft=False) 155 | ax.grid() 156 | 157 | # hide upper x axis 158 | # plt.setp(ax[0].get_xticklabels(), visible=False) 159 | 160 | # plt.show() 161 | plt.savefig('../data/plots/cvar_visualized.pdf', bbox_inches='tight') 162 | 163 | 164 | def var_to_cvar_approx(p, var, res=0.001): 165 | cvar = np.zeros(int(1/res)) 166 | 167 | cp = 0. 168 | ccp = 0. 169 | cv = 0. 170 | ix = 0 171 | for p_, v_ in zip(p, var): 172 | 173 | while ccp < min(1, cp+p_): 174 | 175 | ccp += res 176 | cv += res*v_ 177 | cvar[ix] = cv / ccp 178 | ix += 1 179 | cp = ccp 180 | return np.arange(res, 1+res, res), cvar 181 | 182 | 183 | def plot_process(): 184 | plt.rcParams['axes.grid'] = True 185 | 186 | fig, ax = plt.subplots(2, 4, figsize=(16, 8), sharey=True) 187 | 188 | 189 | # var 190 | p, v = atom_p, var_values[0] 191 | ax[0][0].step(np.insert(np.cumsum(p), 0, 0), np.insert(v, 0, v[0]), 'o-', where='pre') 192 | p, v = atom_p, var_values[1] 193 | ax[0][1].step(np.insert(np.cumsum(p), 0, 0), np.insert(v, 0, v[0]), 'o-', where='pre') 194 | p, v = exact_pv() 195 | ax[0][2].step(np.insert(np.cumsum(p), 0, 0), np.insert(v, 0, v[0]), 'o-', where='pre') 196 | p, v = atom_p, cvar_computation.var_from_transitions_lp(atoms, transition_p, var_values) 197 | ax[0][3].step(np.insert(np.cumsum(p), 0, 0), np.insert(v, 0, v[0]), 'o-', where='pre') 198 | 199 | plt.savefig('data/multivar.pdf') 200 | plt.show() 201 | 202 | fig, ax = plt.subplots(2, 4, figsize=(16, 8), sharey=True) 203 | # yCVaR 204 | p, v = atom_p, var_values[0] 205 | ax[1][0].plot(np.insert(np.cumsum(p), 0, 0), np.insert(np.cumsum(p * v), 0, 0), 'o-') 206 | p, v = atom_p, var_values[1] 207 | ax[1][1].plot(np.insert(np.cumsum(p), 0, 0), np.insert(np.cumsum(p * v), 0, 0), 'o-') 208 | p, v = exact_pv() 209 | ax[1][2].plot(np.insert(np.cumsum(p), 0, 0), np.insert(np.cumsum(p * v), 0, 0), 'o-') 210 | p, v = atom_p, cvar_computation.var_from_transitions_lp(atoms, transition_p, var_values) 211 | ax[1][3].plot(np.insert(np.cumsum(p), 0, 0), np.insert(np.cumsum(p * v), 0, 0), 'o-') 212 | 213 | plt.savefig('data/multiycvar.pdf') 214 | plt.show() 215 | 216 | 217 | if __name__ == '__main__': 218 | # nb_atoms = 3 219 | # nb_transitions = 2 220 | # 221 | # transition_p = np.array([0.25, 0.75]) 222 | # 223 | # atoms = np.array([0., 0.25, 0.5, 1.]) 224 | # atom_p = atoms[1:] - atoms[:-1] 225 | # 226 | # var_values = np.array([[-1, 0, 0.5], 227 | # [-3, -2, -1]]) 228 | 229 | nb_atoms = 4 230 | nb_transitions = 2 231 | 232 | transition_p = np.array([0.25, 0.75]) 233 | 234 | atoms = np.array([0., 0.25, 0.5, 0.75, 1.]) 235 | atom_p = atoms[1:] - atoms[:-1] 236 | 237 | t_atoms = np.tile(atoms, nb_transitions).reshape((nb_transitions, -1)) 238 | 239 | var_values = np.array([[-0.5, 0.25, 0.5, 1], 240 | [-3, -2, -1, 0]]) 241 | 242 | # ================================================ 243 | 244 | # nb_atoms = 4 245 | # nb_transitions = 2 246 | # var_values = np.random.randint(-10, 10, [nb_transitions, nb_atoms]) 247 | # var_values.sort() 248 | # 249 | # transition_p = softmax(np.random.random(nb_transitions)) 250 | # atoms = np.zeros(nb_atoms + 1) 251 | # atoms[1:] = np.cumsum(softmax(np.random.random(nb_atoms))) 252 | # atom_p = atoms[1:] - atoms[:-1] 253 | # 254 | # var_values = np.random.randint(-10, 10, [nb_transitions, nb_atoms]) 255 | # var_values.sort() 256 | 257 | print(atoms) 258 | print(atom_p) 259 | print(var_values) 260 | print('-----------------------') 261 | 262 | ss, _ = cvar_computation.v_yc_from_t(atoms, transition_p, var_values, t_atoms) 263 | # wm = wasserstein_median() 264 | # tam, _ = cvar_computation.v_yc_from_t_lp(atoms, transition_p, var_values, t_atoms) 265 | 266 | ex_p, ex_v = exact_pv() 267 | 268 | print('sort:', ss) 269 | # print('wasserstein med:', wm) 270 | # print('tamar:', tam) 271 | 272 | s_range = np.arange(ex_v[0], ex_v[-1]+0.05, 0.01) 273 | # plt.plot(s_range, [cvar_s(s, ss, atom_p) for s in s_range]) 274 | # plt.show() 275 | # quit() 276 | 277 | 278 | # plot(exact_pv(), ('sort', ss), ('wasserstein', wm), ('tamar', tam)) 279 | # plot(('Exact', (ex_p, ex_v)), ('CVaR VI', (atom_p, tam)), ('Wasserstein', (atom_p, wm))) 280 | plot(('Exact', (ex_p, ex_v)), ('CVaR VI', (atom_p, ss))) 281 | # plot(('Exact', (ex_p, ex_v)), ('CVaR VI', (atoms, tam)), ('Wasserstein', (atoms, wm))) 282 | # plot_process() 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /cvar/gridworld/run_q.py: -------------------------------------------------------------------------------- 1 | from cvar.gridworld.cliffwalker import * 2 | from cvar.gridworld.plots.grid import InteractivePlotMachine 3 | from cvar.gridworld.core.policies import VarBasedQPolicy, XiBasedPolicy 4 | from cvar.gridworld.algorithms.q_learning import q_learning 5 | 6 | if __name__ == '__main__': 7 | import pickle 8 | 9 | # ============================= new config 10 | run_alpha = 1. 11 | world = GridWorld(10, 15, random_action_p=0.1) 12 | 13 | Q = q_learning(world, run_alpha, max_episodes=10000) 14 | 15 | pickle.dump((world, Q), open('data/models/q_10_15.pkl', mode='wb')) 16 | 17 | # ============================= load 18 | world, Q = pickle.load(open('data/models/q_10_15.pkl', 'rb')) 19 | 20 | # ============================= RUN 21 | print('ATOMS:', Q.atoms) 22 | 23 | for alpha in np.arange(0.05, 1.05, 0.05): 24 | print(alpha) 25 | pm = InteractivePlotMachine(world, Q, alpha=alpha, action_value=True) 26 | pm.show() 27 | 28 | # =============== plot dynamic 29 | # V_visual = q_learning.q_to_v_exp(Q) 30 | # 31 | # # print(V_visual) 32 | # plot_machine = PlotMachine(world, V_visual) 33 | # # policy = var_policy 34 | # for i in range(100): 35 | # S, A, R = epoch(world, policy, plot_machine=plot_machine) 36 | # print('{}: {}'.format(i, np.sum(R))) 37 | # policy.reset() -------------------------------------------------------------------------------- /cvar/gridworld/run_vi.py: -------------------------------------------------------------------------------- 1 | from cvar.gridworld.core.constants import gamma 2 | from cvar.gridworld.cliffwalker import * 3 | from cvar.gridworld.core import cvar_computation 4 | from cvar.gridworld.core.constants import gamma 5 | from cvar.gridworld.core.runs import epoch 6 | from cvar.gridworld.algorithms.value_iteration import value_iteration 7 | 8 | 9 | def several_epochs(arg): 10 | np.random.seed() 11 | world, policy, nb_epochs = arg 12 | rewards = np.zeros(nb_epochs) 13 | 14 | for i in range(nb_epochs): 15 | S, A, R = epoch(world, policy) 16 | policy.reset() 17 | rewards[i] = np.sum(R) 18 | rewards[i] = np.dot(R, np.array([gamma ** i for i in range(len(R))])) 19 | 20 | return rewards 21 | 22 | 23 | def policy_stats(world, policy, alpha, nb_epochs, verbose=True): 24 | import copy 25 | import multiprocessing as mp 26 | threads = 4 27 | 28 | with mp.Pool(threads) as p: 29 | rewards = p.map(several_epochs, [(world, copy.deepcopy(policy), int(nb_epochs/threads)) for _ in range(threads)]) 30 | 31 | rewards = np.array(rewards).flatten() 32 | 33 | var, cvar = cvar_computation.var_cvar_from_samples(rewards, alpha) 34 | if verbose: 35 | print('----------------') 36 | print(policy.__name__) 37 | print('expected value=', np.mean(rewards)) 38 | print('cvar_{}={}'.format(alpha, cvar)) 39 | # print('var_{}={}'.format(alpha, var)) 40 | 41 | return cvar, rewards 42 | 43 | 44 | def exhaustive_stats(world, epochs, *args): 45 | V = value_iteration(world) 46 | 47 | alphas = np.array([1.0, 0.5, 0.25, 0.1, 0.05, 0.025, 0.01, 0.005, 0.001]) 48 | 49 | cvars = np.zeros((len(args), len(alphas))) 50 | names = [] 51 | 52 | for i, policy in enumerate(args): 53 | names.append(policy.__name__) 54 | for j, alpha in enumerate(alphas): 55 | pol = policy(V, alpha) 56 | 57 | cvars[i, j], _ = policy_stats(world, pol, alpha=alpha, nb_epochs=int(epochs), verbose=False) 58 | 59 | print('{}_{} done...'.format(pol.__name__, alpha)) 60 | 61 | import pickle 62 | pickle.dump({'cvars': cvars, 'alphas': alphas, 'names': names}, open('data/stats.pkl', 'wb')) 63 | print(cvars) 64 | 65 | from cvar.gridworld.plots.other import plot_cvars 66 | plot_cvars() 67 | 68 | 69 | if __name__ == '__main__': 70 | import pickle 71 | from cvar.gridworld.plots.grid import InteractivePlotMachine 72 | 73 | np.random.seed(2) 74 | # ============================= new config 75 | world = GridWorld(10, 15, random_action_p=0.1) 76 | V = value_iteration(world, max_iters=10000, eps_convergence=1e-5) 77 | pickle.dump((world, V), open('data/models/vi_test.pkl', mode='wb')) 78 | 79 | # ============================= load 80 | world, V = pickle.load(open('data/models/vi_test.pkl', 'rb')) 81 | 82 | # ============================= RUN 83 | for alpha in np.arange(0.05, 1.01, 0.05): 84 | print(alpha) 85 | pm = InteractivePlotMachine(world, V, alpha=alpha) 86 | pm.show() 87 | 88 | # =============== VI stats 89 | # nb_epochs = int(1e6) 90 | # rewards_sample = [] 91 | # for alpha in [0.1, 0.25, 0.5, 1.]: 92 | # _, rewards = policy_stats(world, TamarPolicy(V, alpha), alpha, nb_epochs=nb_epochs) 93 | # rewards_sample.append(rewards) 94 | # np.save('files/sample_rewards_tamar.npy', np.array(rewards_sample)) 95 | # policy_stats(world, var_policy, alpha, nb_epochs=nb_epochs) 96 | 97 | # =============== plot dynamic 98 | # V_visual = np.array([[V.V[i, j].cvar_alpha(alpha) for j in range(len(V.V[i]))] for i in range(len(V.V))]) 99 | # # print(V_visual) 100 | # plot_machine = PlotMachine(world, V_visual) 101 | # # policy = var_policy 102 | # for i in range(100): 103 | # S, A, R = epoch(world, policy, plot_machine=plot_machine) 104 | # print('{}: {}'.format(i, np.sum(R))) 105 | # policy.reset() 106 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import sys 3 | 4 | if sys.version_info.major != 3: 5 | print('This Python is only compatible with Python 3, but you are running ' 6 | 'Python {}. The installation will likely fail.'.format(sys.version_info.major)) 7 | 8 | print([package for package in find_packages() 9 | if package.startswith('cvar')]) 10 | setup(name='cvar-algorithms', 11 | packages=['cvar'], 12 | install_requires=[ 13 | 'matplotlib', 14 | 'opencv-python', 15 | 'pygame' 16 | ], 17 | description='Risk-Averse DistributionalReinforcement Learning', 18 | author='Silvestr Stanko', 19 | url='todo', 20 | author_email='silvicek@gmail.com', 21 | version='1.0.0') 22 | --------------------------------------------------------------------------------