├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── vcs.xml ├── modules.xml ├── misc.xml └── spinal_navigation_rl.iml ├── .gitmodules ├── utils ├── __pycache__ │ ├── utils.cpython-37.pyc │ ├── resnet.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ └── visualization.cpython-37.pyc ├── __init__.py ├── visualization.py ├── resnet.py └── utils.py ├── requirements.txt ├── LICENSE ├── hyperparams ├── dqn.yml └── ppo2.yml ├── polyaxonfile.yaml ├── README.md └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /runs/ 3 | 4 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "gym_sacrum_env"] 2 | path = gym_sacrum_env 3 | url = https://github.com/hhase/gym_sacrum_env 4 | -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhase/spinal-navigation-rl/HEAD/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhase/spinal-navigation-rl/HEAD/utils/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhase/spinal-navigation-rl/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhase/spinal-navigation-rl/HEAD/utils/__pycache__/visualization.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import make_env, ALGOS, linear_schedule, create_test_env,\ 2 | get_trained_models, CustomDQNPolicy, get_latest_run_id,\ 3 | get_saved_hyperparams, get_wrapper_class 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow==6.2.0 2 | numpy==1.16.4 3 | scipy==1.3.0 4 | matplotlib==3.1.0 5 | polyaxon-client==0.5.6 6 | gym==0.14.0 7 | optuna==0.18.1 8 | scikit-image==0.15.0 9 | opencv-python==4.1.1.26 10 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/spinal_navigation_rl.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hannes Hase 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /hyperparams/dqn.yml: -------------------------------------------------------------------------------- 1 | atari: 2 | policy: 'CnnPolicy' 3 | n_timesteps: !!float 1e7 4 | buffer_size: 10000 5 | learning_rate: !!float 1e-4 6 | learning_starts: 10000 7 | target_network_update_freq: 1000 8 | train_freq: 4 9 | exploration_final_eps: 0.01 10 | exploration_fraction: 0.1 11 | prioritized_replay_alpha: 0.6 12 | prioritized_replay: True 13 | 14 | CartPole-v1: 15 | n_timesteps: !!float 1e5 16 | policy: 'CustomDQNPolicy' 17 | learning_rate: !!float 1e-3 18 | buffer_size: 50000 19 | exploration_fraction: 0.1 20 | exploration_final_eps: 0.02 21 | prioritized_replay: True 22 | 23 | MountainCar-v0: 24 | n_timesteps: 100000 25 | policy: 'CustomDQNPolicy' 26 | learning_rate: !!float 1e-3 27 | buffer_size: 50000 28 | exploration_fraction: 0.1 29 | exploration_final_eps: 0.1 30 | param_noise: True 31 | 32 | LunarLander-v2: 33 | n_timesteps: !!float 2e5 34 | policy: 'CustomDQNPolicy' 35 | learning_rate: !!float 1e-3 36 | buffer_size: 100000 37 | exploration_fraction: 0.1 38 | exploration_final_eps: 0.05 39 | prioritized_replay: True 40 | 41 | Acrobot-v1: 42 | n_timesteps: !!float 1e5 43 | policy: 'CustomDQNPolicy' 44 | learning_rate: !!float 1e-3 45 | buffer_size: 50000 46 | exploration_fraction: 0.1 47 | exploration_final_eps: 0.02 48 | prioritized_replay: True 49 | 50 | # DQN PARAMS 51 | gym_sacrum_nav:sacrum_nav-v0: 52 | n_timesteps: !!float 1e5 53 | #policy: 'CnnPolicy' 54 | learning_rate: !!float 1e-3 55 | buffer_size: 50000 56 | exploration_fraction: 0.1 57 | exploration_final_eps: 0.02 58 | prioritized_replay: True 59 | 60 | # Previous action mem params 61 | gym_sacrum_nav:sacrum_nav-v2: 62 | n_timesteps: !!float 2e5 #4e3 63 | #policy: 'CnnPolicy' 64 | learning_rate: !!float 1e-3 65 | buffer_size: 5000 #50000 66 | exploration_fraction: 0.3 67 | exploration_final_eps: 0.02 68 | prioritized_replay: True 69 | -------------------------------------------------------------------------------- /polyaxonfile.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 1 3 | 4 | kind: experiment 5 | 6 | framework: tensorflow 7 | 8 | tags: [sacrum_navigation] 9 | 10 | build: 11 | image: tensorflow/tensorflow:1.15.0-gpu-py3 12 | build_steps: 13 | - pip install -r requirements.txt 14 | 15 | environment: 16 | resources: 17 | cpu: 18 | requests: 3 19 | limits: 4 20 | memory: 21 | requests: 8048 22 | limits: 32784 23 | gpu: 24 | requests: 1 25 | limits: 1 26 | 27 | declarations: 28 | run_name: "Gym_Sacrum_Env" 29 | env: "gym_sacrum_nav:sacrum_nav-v2" 30 | tensorboard_log: "./runs/" 31 | 32 | # Framework parameters 33 | trained_agent: "" 34 | algo: "dqn" 35 | log_interval: -1 36 | log_folder: "logs" 37 | data_folder: "./data/" 38 | seed: 0 39 | verbose: 1 40 | 41 | run: 42 | cmd: apt update && apt install -y git && 43 | apt install -y libopenmpi-dev && 44 | pip install --upgrade pip && 45 | pip install mpi4py && 46 | apt-get install -y libsm6 libxext6 libxrender-dev && 47 | pip install -e gym_sacrum_nav && 48 | pip install -e git+git://github.com/hhase/stable-baselines#egg=stable-baselines && 49 | python -u main_prev_actions.py --env={{ env }}\ 50 | -tb={{ tensorboard_log }}\ 51 | -i={{ trained_agent }}\ 52 | --algo={{ algo }}\ 53 | --log-interval={{ log_interval }}\ 54 | -f={{ log_folder }}\ 55 | --data-folder={{ data_folder }}\ 56 | --seed={{ seed }}\ 57 | --n-trials={{ n_trials }}\ 58 | --n-jobs={{ n_jobs }}\ 59 | --sampler={{ sampler }}\ 60 | --pruner={{ pruner }}\ 61 | --verbose={{ verbose }} 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ultrasound-Guided Robotic Navigation with Deep Reinforcement Learning 2 | 3 | 4 | Code for: 5 | ``` 6 | @misc{hase2020ultrasoundguided, 7 | title={Ultrasound-Guided Robotic Navigation with Deep Reinforcement Learning}, 8 | author={Hannes Hase and Mohammad Farid Azampour and Maria Tirindelli and Magdalini Paschali and Walter Simson and Emad Fatemizadeh and Nassir Navab}, 9 | year={2020}, 10 | eprint={2003.13321}, 11 | archivePrefix={arXiv}, 12 | primaryClass={cs.LG} 13 | } 14 | ``` 15 | # The project 16 | This project aims at learning a policy for autonomously navigating to the sacrum in simulated lower back environments from volunteers. As for the deep reinforcement learning agent, we use a double dueling DQN with a prioritized replay memory. 17 | 18 | For the implementation of this project, we used the [rl-zoo](https://github.com/araffin/rl-baselines-zoo) framework, a slightly adapted [stable-baselines](https://github.com/hhase/stable-baselines) library and an [environment](https://github.com/hhase/gym_sacrum_env) built using the [gym](https://gym.openai.com/) toolkit. 19 | 20 | # Setup 21 | To run the code, first, some parameters need to be set. 22 | 23 | - `DATA_PATH`: corresponds to the location of the [dataset](https://github.com/hhase/sacrum_data-set). 24 | - `OUTPUT_PATH`: corresponds to the path for the output. 25 | - `test_patients`: amount of test environments. 26 | - `val_patients`: amount of validation environments. 27 | - `prev_actions`: size of the action memory. 28 | - `prev_frames`: size of the previous frame memory. 29 | - `val_set`: if defined, sets the environments to be used for validation. 30 | - `test_set`: if defined, sets the environments to be used for testing. 31 | - `shuffles`: amount of random shuffles for train/val/test set creation. Only relevant if test and validation sets are not defined. 32 | - `chebishev`: boolean that enables diagonal movements. 33 | - `no_nop`: boolean that removes the stopping action from the action space. Used for MS-DQN architecture. 34 | - `max_time_steps`: boolean that enables resetting the agent when it takes too long to reach a goal state. 35 | - `time_step_limit`: the amount of time steps the agent has to reach a goal state. 36 | - `reward_[action]`: sets the rewards given to the agent depending on its actions on the environment. 37 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import io 2 | import numpy as np 3 | import tensorflow as tf 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.axes_grid1 import make_axes_locatable 6 | 7 | def reachability_plot(env, patients, reach_maps): 8 | single_env = isinstance(patients, (int, np.int64)) 9 | envs = 1 if single_env else len(patients) 10 | goals = env.goals 11 | goal_rows = [] 12 | goal_cols = [] 13 | numerator = np.zeros([envs, env.num_rows * 2 + 1, env.num_cols * 2 + 1]) 14 | denominator = np.zeros_like(numerator) 15 | 16 | plot_size = 5 17 | fig, ax = plt.subplots(1, 1, figsize=(plot_size, plot_size)) 18 | for i in range(envs): 19 | goal = goals[patients] if single_env else goals[patients[i]] 20 | goal_row, goal_col = env.val_to_coords(goal[-1]) if isinstance(goal, list) else env.val_to_coords(goal) 21 | goal_rows.append(goal_row) 22 | goal_cols.append(goal_col) 23 | reach_map = reach_maps[i, :, :] 24 | x_shift = env.num_cols - goal_col 25 | y_shift = env.num_rows - goal_row 26 | numerator[i, :, :] = np.pad(reach_map, [[y_shift, env.num_rows - y_shift + 1], [x_shift, env.num_cols - x_shift + 1]], mode="constant") 27 | denominator[i, :, :] = np.pad(np.ones_like(reach_map), [[y_shift, env.num_rows - y_shift + 1], [x_shift, env.num_cols - x_shift + 1]], mode="constant") 28 | 29 | denominator = np.sum(denominator, axis=0) 30 | rows = np.any(denominator, axis=1) 31 | cols = np.any(denominator, axis=0) 32 | first_row, last_row = np.where(rows)[0][[0, -1]] 33 | first_col, last_col = np.where(cols)[0][[0, -1]] 34 | 35 | denominator += (denominator == 0) * 1 36 | im = ax.matshow(np.sum(numerator, axis=0)[first_row:last_row + 1, first_col:last_col + 1] / denominator[first_row:last_row + 1, first_col:last_col + 1], cmap='Greens') 37 | ax.scatter(np.max(goal_cols), np.max(goal_rows), marker='s', c='red', s=100) 38 | 39 | divider = make_axes_locatable(ax) 40 | cax = divider.append_axes("right", size="5%", pad=0.05) 41 | fig.colorbar(im, cax=cax) 42 | ax.set_title("Average reachability: {}".format(np.sum(reach_maps)/np.prod(reach_maps.shape))) 43 | 44 | return fig 45 | 46 | def plot2fig(fig): 47 | """Create a pyplot plot and save to buffer.""" 48 | buf = io.BytesIO() 49 | plt.savefig(buf, format='png') 50 | plt.close(fig) 51 | buf.seek(0) 52 | image = tf.image.decode_png(buf.getvalue(), channels=4) 53 | image = tf.expand_dims(image, 0) 54 | buf.close() 55 | return image 56 | -------------------------------------------------------------------------------- /hyperparams/ppo2.yml: -------------------------------------------------------------------------------- 1 | atari: 2 | policy: 'CnnPolicy' 3 | n_envs: 8 4 | n_steps: 128 5 | noptepochs: 4 6 | nminibatches: 4 7 | n_timesteps: !!float 1e7 8 | learning_rate: lin_2.5e-4 9 | cliprange: lin_0.1 10 | vf_coef: 0.5 11 | ent_coef: 0.01 12 | cliprange_vf: -1 13 | 14 | Pendulum-v0: 15 | n_envs: 8 16 | n_timesteps: !!float 2e6 17 | policy: 'MlpPolicy' 18 | n_steps: 2048 19 | nminibatches: 32 20 | lam: 0.95 21 | gamma: 0.99 22 | noptepochs: 10 23 | ent_coef: 0.0 24 | learning_rate: !!float 3e-4 25 | cliprange: 0.2 26 | 27 | # Tuned 28 | CartPole-v1: 29 | n_envs: 8 30 | n_timesteps: !!float 1e5 31 | policy: 'MlpPolicy' 32 | n_steps: 32 33 | nminibatches: 1 34 | lam: 0.8 35 | gamma: 0.98 36 | noptepochs: 20 37 | ent_coef: 0.0 38 | learning_rate: lin_0.001 39 | cliprange: lin_0.2 40 | 41 | MountainCar-v0: 42 | normalize: true 43 | n_envs: 16 44 | n_timesteps: !!float 1e6 45 | policy: 'MlpPolicy' 46 | n_steps: 16 47 | nminibatches: 1 48 | lam: 0.98 49 | gamma: 0.99 50 | noptepochs: 4 51 | ent_coef: 0.0 52 | 53 | MountainCarContinuous-v0: 54 | normalize: true 55 | n_envs: 16 56 | n_timesteps: !!float 1e6 57 | policy: 'MlpPolicy' 58 | n_steps: 256 59 | nminibatches: 8 60 | lam: 0.94 61 | gamma: 0.99 62 | noptepochs: 4 63 | ent_coef: 0.0 64 | 65 | Acrobot-v1: 66 | normalize: true 67 | n_envs: 16 68 | n_timesteps: !!float 1e6 69 | policy: 'MlpPolicy' 70 | n_steps: 256 71 | nminibatches: 8 72 | lam: 0.94 73 | gamma: 0.99 74 | noptepochs: 4 75 | ent_coef: 0.0 76 | 77 | BipedalWalker-v2: 78 | normalize: true 79 | n_envs: 16 80 | n_timesteps: !!float 5e6 81 | policy: 'MlpPolicy' 82 | n_steps: 2048 83 | nminibatches: 32 84 | lam: 0.95 85 | gamma: 0.99 86 | noptepochs: 10 87 | ent_coef: 0.001 88 | learning_rate: !!float 2.5e-4 89 | cliprange: 0.2 90 | 91 | BipedalWalkerHardcore-v2: 92 | normalize: true 93 | n_envs: 16 94 | n_timesteps: !!float 10e7 95 | policy: 'MlpPolicy' 96 | n_steps: 2048 97 | nminibatches: 32 98 | lam: 0.95 99 | gamma: 0.99 100 | noptepochs: 10 101 | ent_coef: 0.001 102 | learning_rate: lin_2.5e-4 103 | cliprange: lin_0.2 104 | 105 | LunarLander-v2: 106 | n_envs: 16 107 | n_timesteps: !!float 1e6 108 | policy: 'MlpPolicy' 109 | n_steps: 1024 110 | nminibatches: 32 111 | lam: 0.98 112 | gamma: 0.999 113 | noptepochs: 4 114 | ent_coef: 0.01 115 | 116 | LunarLanderContinuous-v2: 117 | n_envs: 16 118 | n_timesteps: !!float 1e6 119 | policy: 'MlpPolicy' 120 | n_steps: 1024 121 | nminibatches: 32 122 | lam: 0.98 123 | gamma: 0.999 124 | noptepochs: 4 125 | ent_coef: 0.01 126 | 127 | Walker2DBulletEnv-v0: 128 | env_wrapper: utils.wrappers.TimeFeatureWrapper 129 | normalize: true 130 | n_envs: 4 131 | n_timesteps: !!float 2e6 132 | policy: 'MlpPolicy' 133 | n_steps: 1024 134 | nminibatches: 64 135 | lam: 0.95 136 | gamma: 0.99 137 | noptepochs: 10 138 | ent_coef: 0.0 139 | learning_rate: lin_2.5e-4 140 | cliprange: 0.1 141 | cliprange_vf: -1 142 | 143 | 144 | HalfCheetahBulletEnv-v0: 145 | env_wrapper: utils.wrappers.TimeFeatureWrapper 146 | normalize: true 147 | n_envs: 1 148 | n_timesteps: !!float 2e6 149 | policy: 'MlpPolicy' 150 | n_steps: 2048 151 | nminibatches: 32 152 | lam: 0.95 153 | gamma: 0.99 154 | noptepochs: 10 155 | ent_coef: 0.0 156 | learning_rate: !!float 3e-4 157 | cliprange: 0.2 158 | 159 | HalfCheetah-v2: 160 | normalize: true 161 | n_envs: 1 162 | n_timesteps: !!float 2e6 163 | policy: 'MlpPolicy' 164 | n_steps: 2048 165 | nminibatches: 32 166 | lam: 0.95 167 | gamma: 0.99 168 | noptepochs: 10 169 | ent_coef: 0.0 170 | learning_rate: lin_3e-4 171 | cliprange: 0.2 172 | cliprange_vf: -1 173 | 174 | AntBulletEnv-v0: 175 | normalize: true 176 | n_envs: 8 177 | n_timesteps: !!float 2e6 178 | policy: 'CustomMlpPolicy' 179 | n_steps: 256 180 | nminibatches: 32 181 | lam: 0.95 182 | gamma: 0.99 183 | noptepochs: 10 184 | ent_coef: 0.0 185 | learning_rate: 2.5e-4 186 | cliprange: 0.2 187 | 188 | HopperBulletEnv-v0: 189 | normalize: true 190 | n_envs: 8 191 | n_timesteps: !!float 2e6 192 | policy: 'MlpPolicy' 193 | n_steps: 2048 194 | nminibatches: 128 195 | lam: 0.95 196 | gamma: 0.99 197 | noptepochs: 10 198 | ent_coef: 0.0 199 | learning_rate: 2.5e-4 200 | cliprange: 0.2 201 | 202 | ReacherBulletEnv-v0: 203 | normalize: true 204 | n_envs: 8 205 | n_timesteps: !!float 2e6 206 | policy: 'MlpPolicy' 207 | n_steps: 2048 208 | nminibatches: 32 209 | lam: 0.95 210 | gamma: 0.99 211 | noptepochs: 10 212 | ent_coef: 0.0 213 | learning_rate: 2.5e-4 214 | cliprange: 0.2 215 | 216 | MinitaurBulletEnv-v0: 217 | normalize: true 218 | n_envs: 8 219 | n_timesteps: !!float 2e6 220 | policy: 'MlpPolicy' 221 | n_steps: 2048 222 | nminibatches: 32 223 | lam: 0.95 224 | gamma: 0.99 225 | noptepochs: 10 226 | ent_coef: 0.0 227 | learning_rate: 2.5e-4 228 | cliprange: 0.2 229 | 230 | MinitaurBulletDuckEnv-v0: 231 | normalize: true 232 | n_envs: 8 233 | n_timesteps: !!float 2e6 234 | policy: 'MlpPolicy' 235 | n_steps: 2048 236 | nminibatches: 32 237 | lam: 0.95 238 | gamma: 0.99 239 | noptepochs: 10 240 | ent_coef: 0.0 241 | learning_rate: 2.5e-4 242 | cliprange: 0.2 243 | 244 | # To be tuned 245 | HumanoidBulletEnv-v0: 246 | normalize: true 247 | n_envs: 8 248 | n_timesteps: !!float 1e7 249 | policy: 'MlpPolicy' 250 | n_steps: 2048 251 | nminibatches: 32 252 | lam: 0.95 253 | gamma: 0.99 254 | noptepochs: 10 255 | ent_coef: 0.0 256 | learning_rate: 2.5e-4 257 | cliprange: 0.2 258 | 259 | InvertedDoublePendulumBulletEnv-v0: 260 | normalize: true 261 | n_envs: 8 262 | n_timesteps: !!float 2e6 263 | policy: 'MlpPolicy' 264 | n_steps: 2048 265 | nminibatches: 32 266 | lam: 0.95 267 | gamma: 0.99 268 | noptepochs: 10 269 | ent_coef: 0.0 270 | learning_rate: 2.5e-4 271 | cliprange: 0.2 272 | 273 | InvertedPendulumSwingupBulletEnv-v0: 274 | normalize: true 275 | n_envs: 8 276 | n_timesteps: !!float 2e6 277 | policy: 'MlpPolicy' 278 | n_steps: 2048 279 | nminibatches: 32 280 | lam: 0.95 281 | gamma: 0.99 282 | noptepochs: 10 283 | ent_coef: 0.0 284 | learning_rate: 2.5e-4 285 | cliprange: 0.2 286 | 287 | # Following https://github.com/lcswillems/rl-starter-files 288 | MiniGrid-DoorKey-5x5-v0: 289 | env_wrapper: gym_minigrid.wrappers.FlatObsWrapper # requires --gym-packages gym_minigrid 290 | normalize: true 291 | n_envs: 8 # number of environment copies running in parallel 292 | n_timesteps: !!float 1e5 293 | policy: MlpPolicy 294 | n_steps: 128 # batch size is n_steps * n_env 295 | nminibatches: 32 # Number of training minibatches per update 296 | lam: 0.95 # Factor for trade-off of bias vs variance for Generalized Advantage Estimator 297 | gamma: 0.99 298 | noptepochs: 10 # Number of epoch when optimizing the surrogate 299 | ent_coef: 0.0 # Entropy coefficient for the loss caculation 300 | learning_rate: 2.5e-4 # The learning rate, it can be a function 301 | cliprange: 0.2 # Clipping parameter, it can be a function 302 | 303 | MiniGrid-FourRooms-v0: 304 | env_wrapper: gym_minigrid.wrappers.FlatObsWrapper # requires --gym-packages gym_minigrid 305 | normalize: true 306 | n_envs: 8 307 | n_timesteps: !!float 4e6 308 | policy: 'MlpPolicy' 309 | n_steps: 512 310 | nminibatches: 32 311 | lam: 0.95 312 | gamma: 0.99 313 | noptepochs: 10 314 | ent_coef: 0.0 315 | learning_rate: 2.5e-4 316 | cliprange: 0.2 317 | 318 | gym_sacrum_nav:sacrum_nav-v1: 319 | normalize: true 320 | n_envs: 8 321 | n_timesteps: !!float 4e6 322 | policy: 'CnnPolicy' 323 | n_steps: 512 324 | nminibatches: 32 325 | lam: 0.95 326 | gamma: 0.99 327 | noptepochs: 10 328 | ent_coef: 0.0 329 | learning_rate: 2.5e-4 330 | cliprange: 0.2 331 | -------------------------------------------------------------------------------- /utils/resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import time 4 | import pickle 5 | import pdb 6 | 7 | def timeit(f): 8 | """ Decorator to time Any Function """ 9 | 10 | def timed(*args, **kwargs): 11 | start_time = time.time() 12 | result = f(*args, **kwargs) 13 | end_time = time.time() 14 | seconds = end_time - start_time 15 | print(" [-] %s : %2.5f sec, which is %2.5f mins, which is %2.5f hours" % 16 | (f.__name__, seconds, seconds / 60, seconds / 3600)) 17 | return result 18 | 19 | return timed 20 | 21 | def _debug(operation): 22 | print("Layer_name: " + operation.op.name + " -Output_Shape: " + str(operation.shape.as_list())) 23 | 24 | # Summaries for variables 25 | def variable_summaries(var): 26 | """ 27 | Attach a lot of summaries to a Tensor (for TensorBoard visualization). 28 | :param var: variable to be summarized 29 | :return: None 30 | """ 31 | with tf.name_scope('summaries'): 32 | mean = tf.reduce_mean(var) 33 | tf.summary.scalar('mean', mean) 34 | with tf.name_scope('stddev'): 35 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 36 | tf.summary.scalar('stddev', stddev) 37 | tf.summary.scalar('max', tf.reduce_max(var)) 38 | tf.summary.scalar('min', tf.reduce_min(var)) 39 | tf.summary.histogram('histogram', var) 40 | 41 | def variable_with_weight_decay(kernel_shape, initializer, wd): 42 | """ 43 | Create a variable with L2 Regularization (Weight Decay) 44 | :param kernel_shape: the size of the convolving weight kernel. 45 | :param initializer: The initialization scheme, He et al. normal or Xavier normal are recommended. 46 | :param wd:(weight decay) L2 regularization parameter. 47 | :return: The weights of the kernel initialized. The L2 loss is added to the loss collection. 48 | """ 49 | w = tf.get_variable('weights', kernel_shape, tf.float32, initializer=initializer) 50 | 51 | collection_name = tf.GraphKeys.REGULARIZATION_LOSSES 52 | if wd and (not tf.get_variable_scope().reuse): 53 | weight_decay = tf.multiply(tf.nn.l2_loss(w), wd, name='w_loss') 54 | tf.add_to_collection(collection_name, weight_decay) 55 | #variable_summaries(w) 56 | return w 57 | 58 | 59 | def _residual_block(name, x, filters, pool_first=False, strides=1, dilation=1, bias=-1): 60 | print('Building residual unit: %s' % name) 61 | with tf.variable_scope(name): 62 | # get input channels 63 | in_channel = x.shape.as_list()[-1] 64 | 65 | # Shortcut connection 66 | shortcut = tf.identity(x) 67 | 68 | if pool_first: 69 | if in_channel == filters: 70 | if strides == 1: 71 | shortcut = tf.identity(x) 72 | else: 73 | shortcut= tf.pad(x, tf.constant([[0,0],[1,1],[1,1],[0,0]]), "CONSTANT") 74 | shortcut = tf.nn.max_pool(shortcut, [1, strides, strides, 1], [1, strides, strides, 1], 'VALID') 75 | else: 76 | shortcut = _conv('shortcut_conv', x, padding='VALID', 77 | num_filters=filters, kernel_size=(1, 1), stride=(strides, strides), 78 | bias=bias) 79 | else: 80 | if dilation != 1: 81 | shortcut = _conv('shortcut_conv', x, padding='VALID', 82 | num_filters=filters, kernel_size=(1, 1), dilation=dilation, bias=bias) 83 | 84 | # Residual 85 | x = _conv('conv_1', x, padding=[[0,0],[1,1],[1,1],[0,0]], 86 | num_filters=filters, kernel_size=(3, 3), stride=(strides, strides), bias=bias) 87 | #x = _bn('bn_1', x) 88 | x = _relu('relu_1', x) 89 | x = _conv('conv_2', x, padding=[[0,0],[1,1],[1,1],[0,0]], 90 | num_filters=filters, kernel_size=(3, 3), bias=bias) 91 | #x = _bn('bn_2', x) 92 | 93 | # Merge 94 | x = x + shortcut 95 | x = _relu('relu_2', x) 96 | 97 | print('residual-unit-%s-shape: ' % name + str(x.shape.as_list())) 98 | 99 | return x 100 | 101 | def _conv(name, x, num_filters=16, kernel_size=(3, 3), padding='SAME', stride=(1, 1), 102 | initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0, dilation=1.0, bias=-1): 103 | 104 | with tf.variable_scope(name): 105 | stride = [1, stride[0], stride[1], 1] 106 | kernel_shape = [kernel_size[0], kernel_size[1], x.shape[-1], num_filters] 107 | 108 | w = variable_with_weight_decay(kernel_shape, initializer, l2_strength) 109 | 110 | #variable_summaries(w) 111 | if dilation > 1: 112 | conv = tf.nn.atrous_conv2d(x, w, dilation, padding) 113 | else: 114 | if type(padding)==type(''): 115 | conv = tf.nn.conv2d(x, w, stride, padding) 116 | else: 117 | conv = tf.pad(x, padding, "CONSTANT") 118 | conv = tf.nn.conv2d(conv, w, stride, padding='VALID') 119 | 120 | if bias != -1: 121 | bias = tf.get_variable('biases', [num_filters], initializer=tf.constant_initializer(bias)) 122 | 123 | #variable_summaries(bias) 124 | conv = tf.nn.bias_add(conv, bias) 125 | 126 | tf.add_to_collection('debug_layers', conv) 127 | 128 | return conv 129 | 130 | def _relu(name, x): 131 | with tf.variable_scope(name): 132 | return tf.nn.relu(x) 133 | 134 | def _fc(name, x, output_dim=128, initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0, bias=-1): 135 | 136 | with tf.variable_scope(name): 137 | n_in = x.get_shape()[-1].value 138 | 139 | w = variable_with_weight_decay([n_in, output_dim], initializer, l2_strength) 140 | 141 | #variable_summaries(w) 142 | 143 | if bias != -1 and isinstance(bias, float): 144 | bias = tf.get_variable("biases", [output_dim], tf.float32, tf.constant_initializer(bias)) 145 | output = tf.nn.bias_add(tf.matmul(x, w), bias) 146 | else: 147 | output = tf.matmul(x, w) 148 | 149 | return output 150 | 151 | def _bn(name, x, train_flag): 152 | with tf.variable_scope(name): 153 | moving_average_decay = 0.9 154 | decay = moving_average_decay 155 | 156 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2]) 157 | 158 | mu = tf.get_variable('mu', batch_mean.shape, dtype=tf.float32, 159 | initializer=tf.zeros_initializer(), trainable=False) 160 | tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, mu) 161 | tf.add_to_collection('mu_sigma_bn', mu) 162 | sigma = tf.get_variable('sigma', batch_var.shape, dtype=tf.float32, 163 | initializer=tf.ones_initializer(), trainable=False) 164 | tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, sigma) 165 | tf.add_to_collection('mu_sigma_bn', sigma) 166 | beta = tf.get_variable('beta', batch_mean.shape, dtype=tf.float32, 167 | initializer=tf.zeros_initializer()) 168 | gamma = tf.get_variable('gamma', batch_var.shape, dtype=tf.float32, 169 | initializer=tf.ones_initializer()) 170 | 171 | # BN when training 172 | update = 1.0 - decay 173 | update_mu = mu.assign_sub(update * (mu - batch_mean)) 174 | update_sigma = sigma.assign_sub(update * (sigma - batch_var)) 175 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mu) 176 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_sigma) 177 | 178 | mean, var = tf.cond(train_flag, lambda: (batch_mean, batch_var), lambda: (mu, sigma)) 179 | bn = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5) 180 | 181 | tf.add_to_collection('debug_layers', bn) 182 | 183 | return bn 184 | 185 | 186 | def ResNet18(x_input=None, classes=5, bias=-1, weight_decay=5e-4, test_classification=False): 187 | 188 | with tf.variable_scope('conv1_x'): 189 | print('Building unit: conv1') 190 | conv1 = _conv('conv1', x_input, padding= [[0,0],[3,3],[3,3],[0,0]], 191 | num_filters=64, kernel_size=(7, 7), stride=(2, 2), l2_strength=weight_decay, 192 | bias=bias) 193 | 194 | #conv1 = _bn('bn1', conv1) 195 | 196 | conv1 = _relu('relu1', conv1) 197 | _debug(conv1) 198 | conv1= tf.pad(conv1, tf.constant([[0,0],[1,1],[1,1],[0,0]]), "CONSTANT") 199 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID', 200 | name='max_pool1') 201 | _debug(conv1) 202 | print('conv1-shape: ' + str(conv1.shape.as_list())) 203 | 204 | with tf.variable_scope('conv2_x'): 205 | conv2 = _residual_block('conv2_1', conv1, 64) 206 | _debug(conv2) 207 | conv2 = _residual_block('conv2_2', conv2, 64) 208 | _debug(conv2) 209 | 210 | with tf.variable_scope('conv3_x'): 211 | conv3 = _residual_block('conv3_1', conv2, 128, pool_first=True, strides=2) 212 | _debug(conv3) 213 | conv3 = _residual_block('conv3_2', conv3, 128) 214 | _debug(conv3) 215 | 216 | with tf.variable_scope('conv4_x'): 217 | conv4 = _residual_block('conv4_1', conv3, 256, pool_first=True, strides=2) 218 | _debug(conv4) 219 | conv4 = _residual_block('conv4_2', conv4, 256) 220 | _debug(conv4) 221 | 222 | with tf.variable_scope('conv5_x'): 223 | conv5 = _residual_block('conv5_1', conv4, 512, pool_first=True, strides=2) 224 | _debug(conv5) 225 | conv5 = _residual_block('conv5_2', conv5, 512) 226 | _debug(conv5) 227 | 228 | with tf.variable_scope('resnet_out'): 229 | print('Building unit: logits') 230 | #score = tf.reduce_mean(conv5, axis=[1, 2]) 231 | score = tf.compat.v1.layers.flatten(conv5) 232 | _debug(score) 233 | score = _fc('logits_dense', score, output_dim=classes, l2_strength=weight_decay) 234 | print('logits-shape: ' + str(score.shape.as_list())) 235 | 236 | return score -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import inspect 4 | import glob 5 | import yaml 6 | import importlib 7 | 8 | import gym 9 | try: 10 | import pybullet_envs 11 | except ImportError: 12 | pybullet_envs = None 13 | 14 | from gym.envs.registration import load 15 | 16 | from stable_baselines.deepq.policies import FeedForwardPolicy 17 | from stable_baselines.common.policies import FeedForwardPolicy as BasePolicy 18 | from stable_baselines.common.policies import register_policy 19 | from stable_baselines.sac.policies import FeedForwardPolicy as SACPolicy 20 | from stable_baselines.bench import Monitor 21 | from stable_baselines import logger 22 | from stable_baselines import PPO2, DQN 23 | from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize, \ 24 | VecFrameStack, SubprocVecEnv 25 | from stable_baselines.common import set_global_seeds 26 | 27 | ALGOS = { 28 | 'dqn': DQN, 29 | 'ppo2': PPO2 30 | } 31 | 32 | 33 | # ================== Custom Policies ================= 34 | 35 | class CustomDQNPolicy(FeedForwardPolicy): 36 | def __init__(self, *args, **kwargs): 37 | super(CustomDQNPolicy, self).__init__(*args, **kwargs, 38 | layers=[64], 39 | layer_norm=True, 40 | feature_extraction="mlp") 41 | 42 | 43 | class CustomMlpPolicy(BasePolicy): 44 | def __init__(self, *args, **kwargs): 45 | super(CustomMlpPolicy, self).__init__(*args, **kwargs, 46 | layers=[16], 47 | feature_extraction="mlp") 48 | 49 | 50 | class CustomSACPolicy(SACPolicy): 51 | def __init__(self, *args, **kwargs): 52 | super(CustomSACPolicy, self).__init__(*args, **kwargs, 53 | layers=[256, 256], 54 | feature_extraction="mlp") 55 | 56 | 57 | register_policy('CustomSACPolicy', CustomSACPolicy) 58 | register_policy('CustomDQNPolicy', CustomDQNPolicy) 59 | register_policy('CustomMlpPolicy', CustomMlpPolicy) 60 | 61 | 62 | def flatten_dict_observations(env): 63 | assert isinstance(env.observation_space, gym.spaces.Dict) 64 | keys = env.observation_space.spaces.keys() 65 | return gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys)) 66 | 67 | 68 | def get_wrapper_class(hyperparams): 69 | """ 70 | Get a Gym environment wrapper class specified as a hyper parameter 71 | "env_wrapper". 72 | e.g. 73 | env_wrapper: gym_minigrid.wrappers.FlatObsWrapper 74 | 75 | :param hyperparams: (dict) 76 | :return: a subclass of gym.Wrapper (class object) you can use to 77 | create another Gym env giving an original env. 78 | """ 79 | 80 | def get_module_name(fullname): 81 | return '.'.join(wrapper_name.split('.')[:-1]) 82 | 83 | def get_class_name(fullname): 84 | return wrapper_name.split('.')[-1] 85 | 86 | if 'env_wrapper' in hyperparams.keys(): 87 | wrapper_name = hyperparams.get('env_wrapper') 88 | wrapper_module = importlib.import_module(get_module_name(wrapper_name)) 89 | return getattr(wrapper_module, get_class_name(wrapper_name)) 90 | else: 91 | return None 92 | 93 | 94 | def make_env(env_id, rank=0, seed=0, log_dir=None, wrapper_class=None, **kwargs): 95 | """ 96 | Helper function to multiprocess training 97 | and log the progress. 98 | 99 | :param env_id: (str) 100 | :param rank: (int) 101 | :param seed: (int) 102 | :param log_dir: (str) 103 | :param wrapper: (type) a subclass of gym.Wrapper to wrap the original 104 | env with 105 | """ 106 | if log_dir is None and log_dir != '': 107 | log_dir = "/tmp/gym/{}/".format(int(time.time())) 108 | os.makedirs(log_dir, exist_ok=True) 109 | 110 | def _init(): 111 | set_global_seeds(seed + rank) 112 | env = gym.make(env_id, **kwargs) 113 | 114 | # Dict observation space is currently not supported. 115 | # https://github.com/hill-a/stable-baselines/issues/321 116 | # We allow a Gym env wrapper (a subclass of gym.Wrapper) 117 | if wrapper_class: 118 | env = wrapper_class(env) 119 | 120 | env.seed(seed + rank) 121 | env = Monitor(env, os.path.join(log_dir, str(rank)), allow_early_resets=True) 122 | return env 123 | 124 | return _init 125 | 126 | 127 | def create_test_env(env_id, n_envs=1, is_atari=False, 128 | stats_path=None, seed=0, 129 | log_dir='', should_render=True, hyperparams=None): 130 | """ 131 | Create environment for testing a trained agent 132 | 133 | :param env_id: (str) 134 | :param n_envs: (int) number of processes 135 | :param is_atari: (bool) 136 | :param stats_path: (str) path to folder containing saved running averaged 137 | :param seed: (int) Seed for random number generator 138 | :param log_dir: (str) Where to log rewards 139 | :param should_render: (bool) For Pybullet env, display the GUI 140 | :param env_wrapper: (type) A subclass of gym.Wrapper to wrap the original 141 | env with 142 | :param hyperparams: (dict) Additional hyperparams (ex: n_stack) 143 | :return: (gym.Env) 144 | """ 145 | # HACK to save logs 146 | if log_dir is not None: 147 | os.environ["OPENAI_LOG_FORMAT"] = 'csv' 148 | os.environ["OPENAI_LOGDIR"] = os.path.abspath(log_dir) 149 | os.makedirs(log_dir, exist_ok=True) 150 | logger.configure() 151 | 152 | # Create the environment and wrap it if necessary 153 | env_wrapper = get_wrapper_class(hyperparams) 154 | if 'env_wrapper' in hyperparams.keys(): 155 | del hyperparams['env_wrapper'] 156 | 157 | if is_atari: 158 | print("Using Atari wrapper") 159 | #env = make_atari_env(env_id, num_env=n_envs, seed=seed) 160 | ## Frame-stacking with 4 frames 161 | #env = VecFrameStack(env, n_stack=4) 162 | elif n_envs > 1: 163 | # start_method = 'spawn' for thread safe 164 | env = SubprocVecEnv([make_env(env_id, i, seed, log_dir, wrapper_class=env_wrapper) for i in range(n_envs)]) 165 | # Pybullet envs does not follow gym.render() interface 166 | elif "Bullet" in env_id: 167 | spec = gym.envs.registry.env_specs[env_id] 168 | try: 169 | class_ = load(spec.entry_point) 170 | except AttributeError: 171 | # Backward compatibility with gym 172 | class_ = load(spec._entry_point) 173 | # HACK: force SubprocVecEnv for Bullet env that does not 174 | # have a render argument 175 | render_name = None 176 | use_subproc = 'renders' not in inspect.getfullargspec(class_.__init__).args 177 | if not use_subproc: 178 | render_name = 'renders' 179 | # Dev branch of pybullet 180 | # use_subproc = use_subproc and 'render' not in inspect.getfullargspec(class_.__init__).args 181 | # if not use_subproc and render_name is None: 182 | # render_name = 'render' 183 | 184 | # Create the env, with the original kwargs, and the new ones overriding them if needed 185 | def _init(): 186 | # TODO: fix for pybullet locomotion envs 187 | env = class_(**{**spec._kwargs}, **{render_name: should_render}) 188 | env.seed(0) 189 | if log_dir is not None: 190 | env = Monitor(env, os.path.join(log_dir, "0"), allow_early_resets=True) 191 | return env 192 | 193 | if use_subproc: 194 | env = SubprocVecEnv([make_env(env_id, 0, seed, log_dir, wrapper_class=env_wrapper)]) 195 | else: 196 | env = DummyVecEnv([_init]) 197 | else: 198 | env = DummyVecEnv([make_env(env_id, 0, seed, log_dir, wrapper_class=env_wrapper)]) 199 | 200 | # Load saved stats for normalizing input and rewards 201 | # And optionally stack frames 202 | if stats_path is not None: 203 | if hyperparams['normalize']: 204 | print("Loading running average") 205 | print("with params: {}".format(hyperparams['normalize_kwargs'])) 206 | env = VecNormalize(env, training=False, **hyperparams['normalize_kwargs']) 207 | env.load_running_average(stats_path) 208 | 209 | n_stack = hyperparams.get('frame_stack', 0) 210 | if n_stack > 0: 211 | print("Stacking {} frames".format(n_stack)) 212 | env = VecFrameStack(env, n_stack) 213 | return env 214 | 215 | 216 | def linear_schedule(initial_value): 217 | """ 218 | Linear learning rate schedule. 219 | 220 | :param initial_value: (float or str) 221 | :return: (function) 222 | """ 223 | if isinstance(initial_value, str): 224 | initial_value = float(initial_value) 225 | 226 | def func(progress): 227 | """ 228 | Progress will decrease from 1 (beginning) to 0 229 | :param progress: (float) 230 | :return: (float) 231 | """ 232 | return progress * initial_value 233 | 234 | return func 235 | 236 | 237 | def get_trained_models(log_folder): 238 | """ 239 | :param log_folder: (str) Root log folder 240 | :return: (dict) Dict representing the trained agent 241 | """ 242 | algos = os.listdir(log_folder) 243 | trained_models = {} 244 | for algo in algos: 245 | for ext in ['zip', 'pkl']: 246 | for env_id in glob.glob('{}/{}/*.{}'.format(log_folder, algo, ext)): 247 | # Retrieve env name 248 | env_id = env_id.split('/')[-1].split('.{}'.format(ext))[0] 249 | trained_models['{}-{}'.format(algo, env_id)] = (algo, env_id) 250 | return trained_models 251 | 252 | 253 | def get_latest_run_id(log_path, env_id): 254 | """ 255 | Returns the latest run number for the given log name and log path, 256 | by finding the greatest number in the directories. 257 | 258 | :param log_path: (str) path to log folder 259 | :param env_id: (str) 260 | :return: (int) latest run number 261 | """ 262 | max_run_id = 0 263 | for path in glob.glob(log_path + "/{}_[0-9]*".format(env_id)): 264 | file_name = path.split("/")[-1] 265 | ext = file_name.split("_")[-1] 266 | if env_id == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: 267 | max_run_id = int(ext) 268 | return max_run_id 269 | 270 | 271 | def get_saved_hyperparams(stats_path, norm_reward=False, test_mode=False): 272 | """ 273 | :param stats_path: (str) 274 | :param norm_reward: (bool) 275 | :param test_mode: (bool) 276 | :return: (dict, str) 277 | """ 278 | hyperparams = {} 279 | if not os.path.isdir(stats_path): 280 | stats_path = None 281 | else: 282 | config_file = os.path.join(stats_path, 'config.yml') 283 | if os.path.isfile(config_file): 284 | # Load saved hyperparameters 285 | with open(os.path.join(stats_path, 'config.yml'), 'r') as f: 286 | hyperparams = yaml.load(f) 287 | hyperparams['normalize'] = hyperparams.get('normalize', False) 288 | else: 289 | obs_rms_path = os.path.join(stats_path, 'obs_rms.pkl') 290 | hyperparams['normalize'] = os.path.isfile(obs_rms_path) 291 | 292 | # Load normalization params 293 | if hyperparams['normalize']: 294 | if isinstance(hyperparams['normalize'], str): 295 | normalize_kwargs = eval(hyperparams['normalize']) 296 | if test_mode: 297 | normalize_kwargs['norm_reward'] = norm_reward 298 | else: 299 | normalize_kwargs = {'norm_obs': hyperparams['normalize'], 'norm_reward': norm_reward} 300 | hyperparams['normalize_kwargs'] = normalize_kwargs 301 | return hyperparams, stats_path 302 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import yaml 4 | import argparse 5 | import numpy as np 6 | import tensorflow as tf 7 | from pprint import pprint 8 | import matplotlib.pyplot as plt 9 | from utils.resnet import ResNet18 10 | from collections import OrderedDict 11 | from stable_baselines.ppo2.ppo2 import constfn 12 | from stable_baselines.common import set_global_seeds 13 | from stable_baselines.common.vec_env import VecNormalize 14 | from utils.visualization import reachability_plot, plot2fig 15 | from stable_baselines.deepq.policies import FeedForwardPolicy 16 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv 17 | from utils import ALGOS, get_wrapper_class, linear_schedule, make_env 18 | 19 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 20 | 21 | ######################################################################################################################## 22 | # SET PATHS 23 | # data path -> path to dataset | output path -> path to logging and model saving location 24 | DATA_PATH = "" 25 | OUTPUT_PATH = "" 26 | # 27 | ######################################################################################################################## 28 | # SET ENVIRONMENT PARAMETERS 29 | env_params = {'data_path': DATA_PATH, 30 | 'verbose': 1, 31 | 'test_patients': 5, 32 | 'val_patients': 4, 33 | 'prev_actions': 5, 34 | 'prev_frames': 3, 35 | 'val_set': np.array([25, 26, 27, 28]), 36 | 'test_set': np.array([29, 30, 31, 32, 33]), 37 | 'shuffles': 20, 38 | 'chebishev': False, 39 | 'no_nop': False, 40 | 'max_time_steps': True, 41 | 'time_step_limit': 50, 42 | 'reward_goal_correct': 1.0, 43 | 'reward_goal_incorrect': -0.25, 44 | 'reward_move_closer': 0.05, 45 | 'reward_move_further': -0.1, 46 | 'reward_border_collision': -0.1} 47 | # parameters to define the environment 48 | # 49 | ######################################################################################################################## 50 | 51 | best_mean_reward, n_steps = -np.inf, 0 52 | 53 | def custom_cnn(input, **kwargs): 54 | action_mem_size = 25 55 | action_history = input[:, 0, 0:action_mem_size, -1] 56 | input = input[..., :-1] 57 | 58 | action_values = ResNet18(x_input=input, classes=512) 59 | 60 | return action_values, action_history 61 | 62 | class CustomCnnPolicy(FeedForwardPolicy): 63 | 64 | def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, obs_phs=None, dueling=True, **_kwargs): 65 | super(CustomCnnPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse, cnn_extractor=custom_cnn, 66 | feature_extraction="cnn", obs_phs=obs_phs, dueling=dueling, layer_norm=False, **_kwargs) 67 | 68 | class Indicators(): 69 | def __init__(self, num_states): 70 | self.num_states = num_states 71 | self.reset() 72 | 73 | def reset(self): 74 | self.reach_goal = [] 75 | self.efficiency = [] 76 | self.correct_actions = 0 77 | self.total_actions = 0 78 | 79 | def __str__(self): 80 | print("Goal reached {}% of the times".format(0 if not self.reach_goal else np.average(self.reach_goal)*100)) 81 | print("Overall efficiency: {}%".format(np.average(0 if not self.efficiency else self.efficiency)*100)) 82 | return "" 83 | 84 | 85 | def callback(locals_, globals_): 86 | 87 | global n_steps, best_mean_reward 88 | self_ = locals_.get('self') 89 | env_ = self_.env.envs[0] 90 | info_ = locals_.get('info') 91 | writer_ = locals_.get('writer') 92 | episode_ = locals_.get('num_episodes') 93 | 94 | # LOG CORRECT DECISIONS/TOTAL DECISIONS 95 | correct_decision_rate = info_.get('correct_decision_rate') if info_ else None 96 | if correct_decision_rate: 97 | summary = tf.Summary(value=[tf.Summary.Value(tag='callback_logging/correct_decision_rate', simple_value=correct_decision_rate)]) 98 | writer_.add_summary(summary, episode_) 99 | 100 | episode_completed = locals_.get('done') 101 | if episode_completed and episode_ % 40 == 0: 102 | print("Done with episode {}".format(episode_)) 103 | actions = [] 104 | state_vals = [] 105 | for state in range(env_.num_states): 106 | 107 | frame = env_.frames[env_.patient_idx][state][0] 108 | frames = np.repeat(frame[:, :, np.newaxis], len(env_.prev_frames), axis=2) 109 | observation = np.dstack((frames, np.zeros_like(frame))) 110 | 111 | action, q_vals, state_val = self_.predict(observation) 112 | actions.append(action) 113 | state_vals.append(state_val) 114 | 115 | # TEST CORRECTNESS ON TRAINING PATIENT 116 | quiver_plot, policy_correctness = env_.quiver_plot(states=list(range(env_.num_states)), actions=actions, state_vals=state_vals) 117 | summary = tf.Summary(value=[tf.Summary.Value(tag='callback_logging/policy correctness', simple_value=policy_correctness)]) 118 | writer_.add_summary(summary, episode_) 119 | 120 | # LOG POLICY QUIVER PLOT 121 | if episode_ % 200 == 0: 122 | image = plot2fig(quiver_plot) 123 | summary = tf.Summary(value=[tf.Summary.Value(tag='Regular_training/Policy graph at episode {}'.format(episode_), image=image)]) 124 | writer_.add_summary(summary, episode_) 125 | plt.close(quiver_plot) 126 | 127 | if episode_completed and episode_ % 100 == 0: 128 | avg_val_reachability = 0.0 129 | avg_val_correctness = 0.0 130 | max_time_steps = 20 131 | indicators = Indicators(env_.num_states) 132 | val_patients = env_.val_patient_idxs 133 | val_reachabilities = [] 134 | 135 | for val_patient in val_patients: 136 | goals = env_.goals[val_patient] 137 | 138 | for state in range(env_.num_states): 139 | 140 | obs = env_.set(val_patient, state) 141 | prev_state = env_.state 142 | for step in range(max_time_steps): 143 | 144 | if env_.no_nop and env_.state in goals: 145 | done = True 146 | else: 147 | action, q_vals, state_val = model.predict(obs) 148 | obs, reward, done, info = env_.step(action) 149 | 150 | moving = not (prev_state == env_.state) 151 | if moving: 152 | indicators.total_actions += 1 153 | if reward > 0: 154 | indicators.correct_actions += 1 155 | 156 | if done or step == max_time_steps - 1: 157 | if (env_.state in goals and env_.no_nop) or (not env_.no_nop and reward == env_.reward_dict["goal_correct"]): 158 | indicators.reach_goal.append(1) 159 | else: 160 | indicators.reach_goal.append(0) 161 | _ = env_.reset() 162 | break 163 | 164 | prev_state = env_.state 165 | 166 | val_reachabilities.append(np.average(indicators.reach_goal)) 167 | # avg_val_reachability += np.average(indicators.reach_goal)/len(val_patients) if isinstance(val_patients, (list, np.ndarray)) else np.average(indicators.reach_goal) 168 | print("Correctness for test patient {}: {}".format(val_patient, indicators.correct_actions / indicators.total_actions)) 169 | print("Reachability for validation patient {}: {}".format(val_patient, np.average(indicators.reach_goal))) 170 | avg_val_correctness += indicators.correct_actions / indicators.total_actions / len(val_patients) if isinstance(val_patients, (list, np.ndarray)) \ 171 | else indicators.correct_actions / indicators.total_actions 172 | indicators.reset() 173 | 174 | val_median_reachability = np.median(val_reachabilities) 175 | summary = tf.Summary(value=[tf.Summary.Value(tag='callback_logging/val_median_reachability', simple_value=val_median_reachability)]) 176 | writer_.add_summary(summary, episode_) 177 | summary = tf.Summary(value=[tf.Summary.Value(tag='callback_logging/val_correctness', simple_value=avg_val_correctness)]) 178 | writer_.add_summary(summary, episode_) 179 | 180 | # if env_.val_reachability < avg_val_reachability: 181 | if env_.val_reachability < val_median_reachability: 182 | print("Improved the model at episode {}!".format(episode_)) 183 | self_.save(OUTPUT_PATH + "val_model_episode_{}".format(episode_), cloudpickle=True) 184 | env_.val_reachability = val_median_reachability 185 | 186 | n_steps += 1 187 | 188 | return True 189 | 190 | if __name__ == '__main__': 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument('--env', type=str, nargs='+', default=["gym_sacrum_nav:sacrum_nav-v2"], help='environment ID(s)') 193 | parser.add_argument('-tb', '--tensorboard-log', help='Tensorboard log dir', default='./runs/', type=str) 194 | parser.add_argument('-i', '--trained-agent', help='Path to a pretrained agent to continue training', default='', type=str) 195 | parser.add_argument('--algo', help='RL Algorithm', default='dqn', type=str, required=False, choices=list(ALGOS.keys())) 196 | parser.add_argument('-n', '--n-timesteps', help='Overwrite the number of timesteps', default=-1, type=int) 197 | parser.add_argument('--log-interval', help='Override log interval (default: -1, no change)', default=-1, type=int) 198 | parser.add_argument('-f', '--log-folder', help='Log folder', type=str, default='logs') 199 | parser.add_argument('--data-folder', help='Data folder', type=str, default='./data/') 200 | parser.add_argument('--seed', help='Random generator seed', type=int, default=0) 201 | parser.add_argument('--n-trials', help='Number of trials for optimizing hyperparameters', type=int, default=10) 202 | parser.add_argument('--verbose', help='Verbose mode (0: no output, 1: INFO, 2: debug)', default=1, type=int) 203 | parser.add_argument('--gym-packages', type=str, nargs='+', default=[], help='Additional external Gym environemnt package modules to import (e.g. gym_minigrid)') 204 | args = parser.parse_args() 205 | 206 | env_params['data_path'] = args.data_folder 207 | 208 | env_ids = args.env 209 | set_global_seeds(args.seed) 210 | 211 | for env_id in env_ids: 212 | tensorboard_log = None if args.tensorboard_log == '' else os.path.join(args.tensorboard_log, env_id) 213 | print("=" * 10, env_id, "=" * 10) 214 | 215 | # Load hyperparameters from yaml file 216 | with open('hyperparams/{}.yml'.format(args.algo), 'r') as f: 217 | hyperparams_dict = yaml.full_load(f) 218 | if env_id in list(hyperparams_dict.keys()): 219 | hyperparams = hyperparams_dict[env_id] 220 | else: 221 | raise ValueError("Hyperparameters not found for {}-{}".format(args.algo, env_id)) 222 | 223 | saved_hyperparams = OrderedDict([(key, hyperparams[key]) for key in sorted(hyperparams.keys())]) 224 | algo_ = args.algo 225 | 226 | if args.verbose > 0: 227 | pprint(saved_hyperparams) 228 | 229 | n_envs = hyperparams.get('n_envs', 1) 230 | 231 | if args.verbose > 0: 232 | print("Using {} environments".format(n_envs)) 233 | 234 | n_timesteps = int(hyperparams['n_timesteps']) 235 | 236 | # Delete keys so the dict can be pass to the model constructor 237 | if 'n_envs' in hyperparams.keys(): 238 | del hyperparams['n_envs'] 239 | del hyperparams['n_timesteps'] 240 | 241 | env_wrapper = get_wrapper_class(hyperparams) 242 | if 'env_wrapper' in hyperparams.keys(): 243 | del hyperparams['env_wrapper'] 244 | 245 | if algo_ in ["ppo2"]: 246 | for key in ['learning_rate', 'cliprange', 'cliprange_vf']: 247 | if key not in hyperparams: 248 | continue 249 | if isinstance(hyperparams[key], str): 250 | schedule, initial_value = hyperparams[key].split('_') 251 | initial_value = float(initial_value) 252 | hyperparams[key] = linear_schedule(initial_value) 253 | elif isinstance(hyperparams[key], (float, int)): 254 | # Negative value: ignore (ex: for clipping) 255 | if hyperparams[key] < 0: 256 | continue 257 | hyperparams[key] = constfn(float(hyperparams[key])) 258 | else: 259 | raise ValueError('Invalid value for {}: {}'.format(key, hyperparams[key])) 260 | normalize = False 261 | normalize_kwargs = {} 262 | if 'normalize' in hyperparams.keys(): 263 | normalize = hyperparams['normalize'] 264 | if isinstance(normalize, str): 265 | normalize_kwargs = eval(normalize) 266 | normalize = True 267 | del hyperparams['normalize'] 268 | 269 | def create_env(env_params): 270 | global hyperparams 271 | 272 | if algo_ in ['dqn']: 273 | env = gym.make(env_id, env_params=env_params) 274 | env.seed(args.seed) 275 | if env_wrapper is not None: 276 | env = env_wrapper(env) 277 | else: 278 | env = DummyVecEnv([make_env(env_id, 0, args.seed, wrapper_class=env_wrapper, env_params=env_params)]) 279 | if normalize: 280 | if args.verbose > 0: 281 | if len(normalize_kwargs) > 0: 282 | print("Normalization activated: {}".format(normalize_kwargs)) 283 | else: 284 | print("Normalizing input and reward") 285 | env = VecNormalize(env, **normalize_kwargs) 286 | return env 287 | 288 | env = create_env(env_params) 289 | 290 | env = DummyVecEnv([lambda: env]) 291 | 292 | print(hyperparams) 293 | model = ALGOS[args.algo](CustomCnnPolicy, 294 | env=env, 295 | tensorboard_log=tensorboard_log, 296 | verbose=args.verbose, 297 | batch_size=64, 298 | **hyperparams) 299 | print("Model loaded!") 300 | model.is_tb_set = False 301 | 302 | kwargs = {} 303 | if args.log_interval > -1: 304 | kwargs = {'log_interval': args.log_interval} 305 | 306 | model.learn(n_timesteps, callback=callback, **kwargs) 307 | 308 | model.save(OUTPUT_PATH + "final_model") 309 | --------------------------------------------------------------------------------