├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── vcs.xml
├── modules.xml
├── misc.xml
└── spinal_navigation_rl.iml
├── .gitmodules
├── utils
├── __pycache__
│ ├── utils.cpython-37.pyc
│ ├── resnet.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ └── visualization.cpython-37.pyc
├── __init__.py
├── visualization.py
├── resnet.py
└── utils.py
├── requirements.txt
├── LICENSE
├── hyperparams
├── dqn.yml
└── ppo2.yml
├── polyaxonfile.yaml
├── README.md
└── main.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/
2 | /runs/
3 |
4 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
3 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "gym_sacrum_env"]
2 | path = gym_sacrum_env
3 | url = https://github.com/hhase/gym_sacrum_env
4 |
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hhase/spinal-navigation-rl/HEAD/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hhase/spinal-navigation-rl/HEAD/utils/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hhase/spinal-navigation-rl/HEAD/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/visualization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hhase/spinal-navigation-rl/HEAD/utils/__pycache__/visualization.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import make_env, ALGOS, linear_schedule, create_test_env,\
2 | get_trained_models, CustomDQNPolicy, get_latest_run_id,\
3 | get_saved_hyperparams, get_wrapper_class
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pillow==6.2.0
2 | numpy==1.16.4
3 | scipy==1.3.0
4 | matplotlib==3.1.0
5 | polyaxon-client==0.5.6
6 | gym==0.14.0
7 | optuna==0.18.1
8 | scikit-image==0.15.0
9 | opencv-python==4.1.1.26
10 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/spinal_navigation_rl.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Hannes Hase
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/hyperparams/dqn.yml:
--------------------------------------------------------------------------------
1 | atari:
2 | policy: 'CnnPolicy'
3 | n_timesteps: !!float 1e7
4 | buffer_size: 10000
5 | learning_rate: !!float 1e-4
6 | learning_starts: 10000
7 | target_network_update_freq: 1000
8 | train_freq: 4
9 | exploration_final_eps: 0.01
10 | exploration_fraction: 0.1
11 | prioritized_replay_alpha: 0.6
12 | prioritized_replay: True
13 |
14 | CartPole-v1:
15 | n_timesteps: !!float 1e5
16 | policy: 'CustomDQNPolicy'
17 | learning_rate: !!float 1e-3
18 | buffer_size: 50000
19 | exploration_fraction: 0.1
20 | exploration_final_eps: 0.02
21 | prioritized_replay: True
22 |
23 | MountainCar-v0:
24 | n_timesteps: 100000
25 | policy: 'CustomDQNPolicy'
26 | learning_rate: !!float 1e-3
27 | buffer_size: 50000
28 | exploration_fraction: 0.1
29 | exploration_final_eps: 0.1
30 | param_noise: True
31 |
32 | LunarLander-v2:
33 | n_timesteps: !!float 2e5
34 | policy: 'CustomDQNPolicy'
35 | learning_rate: !!float 1e-3
36 | buffer_size: 100000
37 | exploration_fraction: 0.1
38 | exploration_final_eps: 0.05
39 | prioritized_replay: True
40 |
41 | Acrobot-v1:
42 | n_timesteps: !!float 1e5
43 | policy: 'CustomDQNPolicy'
44 | learning_rate: !!float 1e-3
45 | buffer_size: 50000
46 | exploration_fraction: 0.1
47 | exploration_final_eps: 0.02
48 | prioritized_replay: True
49 |
50 | # DQN PARAMS
51 | gym_sacrum_nav:sacrum_nav-v0:
52 | n_timesteps: !!float 1e5
53 | #policy: 'CnnPolicy'
54 | learning_rate: !!float 1e-3
55 | buffer_size: 50000
56 | exploration_fraction: 0.1
57 | exploration_final_eps: 0.02
58 | prioritized_replay: True
59 |
60 | # Previous action mem params
61 | gym_sacrum_nav:sacrum_nav-v2:
62 | n_timesteps: !!float 2e5 #4e3
63 | #policy: 'CnnPolicy'
64 | learning_rate: !!float 1e-3
65 | buffer_size: 5000 #50000
66 | exploration_fraction: 0.3
67 | exploration_final_eps: 0.02
68 | prioritized_replay: True
69 |
--------------------------------------------------------------------------------
/polyaxonfile.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | version: 1
3 |
4 | kind: experiment
5 |
6 | framework: tensorflow
7 |
8 | tags: [sacrum_navigation]
9 |
10 | build:
11 | image: tensorflow/tensorflow:1.15.0-gpu-py3
12 | build_steps:
13 | - pip install -r requirements.txt
14 |
15 | environment:
16 | resources:
17 | cpu:
18 | requests: 3
19 | limits: 4
20 | memory:
21 | requests: 8048
22 | limits: 32784
23 | gpu:
24 | requests: 1
25 | limits: 1
26 |
27 | declarations:
28 | run_name: "Gym_Sacrum_Env"
29 | env: "gym_sacrum_nav:sacrum_nav-v2"
30 | tensorboard_log: "./runs/"
31 |
32 | # Framework parameters
33 | trained_agent: ""
34 | algo: "dqn"
35 | log_interval: -1
36 | log_folder: "logs"
37 | data_folder: "./data/"
38 | seed: 0
39 | verbose: 1
40 |
41 | run:
42 | cmd: apt update && apt install -y git &&
43 | apt install -y libopenmpi-dev &&
44 | pip install --upgrade pip &&
45 | pip install mpi4py &&
46 | apt-get install -y libsm6 libxext6 libxrender-dev &&
47 | pip install -e gym_sacrum_nav &&
48 | pip install -e git+git://github.com/hhase/stable-baselines#egg=stable-baselines &&
49 | python -u main_prev_actions.py --env={{ env }}\
50 | -tb={{ tensorboard_log }}\
51 | -i={{ trained_agent }}\
52 | --algo={{ algo }}\
53 | --log-interval={{ log_interval }}\
54 | -f={{ log_folder }}\
55 | --data-folder={{ data_folder }}\
56 | --seed={{ seed }}\
57 | --n-trials={{ n_trials }}\
58 | --n-jobs={{ n_jobs }}\
59 | --sampler={{ sampler }}\
60 | --pruner={{ pruner }}\
61 | --verbose={{ verbose }}
62 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ultrasound-Guided Robotic Navigation with Deep Reinforcement Learning
2 |
3 |
4 | Code for:
5 | ```
6 | @misc{hase2020ultrasoundguided,
7 | title={Ultrasound-Guided Robotic Navigation with Deep Reinforcement Learning},
8 | author={Hannes Hase and Mohammad Farid Azampour and Maria Tirindelli and Magdalini Paschali and Walter Simson and Emad Fatemizadeh and Nassir Navab},
9 | year={2020},
10 | eprint={2003.13321},
11 | archivePrefix={arXiv},
12 | primaryClass={cs.LG}
13 | }
14 | ```
15 | # The project
16 | This project aims at learning a policy for autonomously navigating to the sacrum in simulated lower back environments from volunteers. As for the deep reinforcement learning agent, we use a double dueling DQN with a prioritized replay memory.
17 |
18 | For the implementation of this project, we used the [rl-zoo](https://github.com/araffin/rl-baselines-zoo) framework, a slightly adapted [stable-baselines](https://github.com/hhase/stable-baselines) library and an [environment](https://github.com/hhase/gym_sacrum_env) built using the [gym](https://gym.openai.com/) toolkit.
19 |
20 | # Setup
21 | To run the code, first, some parameters need to be set.
22 |
23 | - `DATA_PATH`: corresponds to the location of the [dataset](https://github.com/hhase/sacrum_data-set).
24 | - `OUTPUT_PATH`: corresponds to the path for the output.
25 | - `test_patients`: amount of test environments.
26 | - `val_patients`: amount of validation environments.
27 | - `prev_actions`: size of the action memory.
28 | - `prev_frames`: size of the previous frame memory.
29 | - `val_set`: if defined, sets the environments to be used for validation.
30 | - `test_set`: if defined, sets the environments to be used for testing.
31 | - `shuffles`: amount of random shuffles for train/val/test set creation. Only relevant if test and validation sets are not defined.
32 | - `chebishev`: boolean that enables diagonal movements.
33 | - `no_nop`: boolean that removes the stopping action from the action space. Used for MS-DQN architecture.
34 | - `max_time_steps`: boolean that enables resetting the agent when it takes too long to reach a goal state.
35 | - `time_step_limit`: the amount of time steps the agent has to reach a goal state.
36 | - `reward_[action]`: sets the rewards given to the agent depending on its actions on the environment.
37 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import io
2 | import numpy as np
3 | import tensorflow as tf
4 | import matplotlib.pyplot as plt
5 | from mpl_toolkits.axes_grid1 import make_axes_locatable
6 |
7 | def reachability_plot(env, patients, reach_maps):
8 | single_env = isinstance(patients, (int, np.int64))
9 | envs = 1 if single_env else len(patients)
10 | goals = env.goals
11 | goal_rows = []
12 | goal_cols = []
13 | numerator = np.zeros([envs, env.num_rows * 2 + 1, env.num_cols * 2 + 1])
14 | denominator = np.zeros_like(numerator)
15 |
16 | plot_size = 5
17 | fig, ax = plt.subplots(1, 1, figsize=(plot_size, plot_size))
18 | for i in range(envs):
19 | goal = goals[patients] if single_env else goals[patients[i]]
20 | goal_row, goal_col = env.val_to_coords(goal[-1]) if isinstance(goal, list) else env.val_to_coords(goal)
21 | goal_rows.append(goal_row)
22 | goal_cols.append(goal_col)
23 | reach_map = reach_maps[i, :, :]
24 | x_shift = env.num_cols - goal_col
25 | y_shift = env.num_rows - goal_row
26 | numerator[i, :, :] = np.pad(reach_map, [[y_shift, env.num_rows - y_shift + 1], [x_shift, env.num_cols - x_shift + 1]], mode="constant")
27 | denominator[i, :, :] = np.pad(np.ones_like(reach_map), [[y_shift, env.num_rows - y_shift + 1], [x_shift, env.num_cols - x_shift + 1]], mode="constant")
28 |
29 | denominator = np.sum(denominator, axis=0)
30 | rows = np.any(denominator, axis=1)
31 | cols = np.any(denominator, axis=0)
32 | first_row, last_row = np.where(rows)[0][[0, -1]]
33 | first_col, last_col = np.where(cols)[0][[0, -1]]
34 |
35 | denominator += (denominator == 0) * 1
36 | im = ax.matshow(np.sum(numerator, axis=0)[first_row:last_row + 1, first_col:last_col + 1] / denominator[first_row:last_row + 1, first_col:last_col + 1], cmap='Greens')
37 | ax.scatter(np.max(goal_cols), np.max(goal_rows), marker='s', c='red', s=100)
38 |
39 | divider = make_axes_locatable(ax)
40 | cax = divider.append_axes("right", size="5%", pad=0.05)
41 | fig.colorbar(im, cax=cax)
42 | ax.set_title("Average reachability: {}".format(np.sum(reach_maps)/np.prod(reach_maps.shape)))
43 |
44 | return fig
45 |
46 | def plot2fig(fig):
47 | """Create a pyplot plot and save to buffer."""
48 | buf = io.BytesIO()
49 | plt.savefig(buf, format='png')
50 | plt.close(fig)
51 | buf.seek(0)
52 | image = tf.image.decode_png(buf.getvalue(), channels=4)
53 | image = tf.expand_dims(image, 0)
54 | buf.close()
55 | return image
56 |
--------------------------------------------------------------------------------
/hyperparams/ppo2.yml:
--------------------------------------------------------------------------------
1 | atari:
2 | policy: 'CnnPolicy'
3 | n_envs: 8
4 | n_steps: 128
5 | noptepochs: 4
6 | nminibatches: 4
7 | n_timesteps: !!float 1e7
8 | learning_rate: lin_2.5e-4
9 | cliprange: lin_0.1
10 | vf_coef: 0.5
11 | ent_coef: 0.01
12 | cliprange_vf: -1
13 |
14 | Pendulum-v0:
15 | n_envs: 8
16 | n_timesteps: !!float 2e6
17 | policy: 'MlpPolicy'
18 | n_steps: 2048
19 | nminibatches: 32
20 | lam: 0.95
21 | gamma: 0.99
22 | noptepochs: 10
23 | ent_coef: 0.0
24 | learning_rate: !!float 3e-4
25 | cliprange: 0.2
26 |
27 | # Tuned
28 | CartPole-v1:
29 | n_envs: 8
30 | n_timesteps: !!float 1e5
31 | policy: 'MlpPolicy'
32 | n_steps: 32
33 | nminibatches: 1
34 | lam: 0.8
35 | gamma: 0.98
36 | noptepochs: 20
37 | ent_coef: 0.0
38 | learning_rate: lin_0.001
39 | cliprange: lin_0.2
40 |
41 | MountainCar-v0:
42 | normalize: true
43 | n_envs: 16
44 | n_timesteps: !!float 1e6
45 | policy: 'MlpPolicy'
46 | n_steps: 16
47 | nminibatches: 1
48 | lam: 0.98
49 | gamma: 0.99
50 | noptepochs: 4
51 | ent_coef: 0.0
52 |
53 | MountainCarContinuous-v0:
54 | normalize: true
55 | n_envs: 16
56 | n_timesteps: !!float 1e6
57 | policy: 'MlpPolicy'
58 | n_steps: 256
59 | nminibatches: 8
60 | lam: 0.94
61 | gamma: 0.99
62 | noptepochs: 4
63 | ent_coef: 0.0
64 |
65 | Acrobot-v1:
66 | normalize: true
67 | n_envs: 16
68 | n_timesteps: !!float 1e6
69 | policy: 'MlpPolicy'
70 | n_steps: 256
71 | nminibatches: 8
72 | lam: 0.94
73 | gamma: 0.99
74 | noptepochs: 4
75 | ent_coef: 0.0
76 |
77 | BipedalWalker-v2:
78 | normalize: true
79 | n_envs: 16
80 | n_timesteps: !!float 5e6
81 | policy: 'MlpPolicy'
82 | n_steps: 2048
83 | nminibatches: 32
84 | lam: 0.95
85 | gamma: 0.99
86 | noptepochs: 10
87 | ent_coef: 0.001
88 | learning_rate: !!float 2.5e-4
89 | cliprange: 0.2
90 |
91 | BipedalWalkerHardcore-v2:
92 | normalize: true
93 | n_envs: 16
94 | n_timesteps: !!float 10e7
95 | policy: 'MlpPolicy'
96 | n_steps: 2048
97 | nminibatches: 32
98 | lam: 0.95
99 | gamma: 0.99
100 | noptepochs: 10
101 | ent_coef: 0.001
102 | learning_rate: lin_2.5e-4
103 | cliprange: lin_0.2
104 |
105 | LunarLander-v2:
106 | n_envs: 16
107 | n_timesteps: !!float 1e6
108 | policy: 'MlpPolicy'
109 | n_steps: 1024
110 | nminibatches: 32
111 | lam: 0.98
112 | gamma: 0.999
113 | noptepochs: 4
114 | ent_coef: 0.01
115 |
116 | LunarLanderContinuous-v2:
117 | n_envs: 16
118 | n_timesteps: !!float 1e6
119 | policy: 'MlpPolicy'
120 | n_steps: 1024
121 | nminibatches: 32
122 | lam: 0.98
123 | gamma: 0.999
124 | noptepochs: 4
125 | ent_coef: 0.01
126 |
127 | Walker2DBulletEnv-v0:
128 | env_wrapper: utils.wrappers.TimeFeatureWrapper
129 | normalize: true
130 | n_envs: 4
131 | n_timesteps: !!float 2e6
132 | policy: 'MlpPolicy'
133 | n_steps: 1024
134 | nminibatches: 64
135 | lam: 0.95
136 | gamma: 0.99
137 | noptepochs: 10
138 | ent_coef: 0.0
139 | learning_rate: lin_2.5e-4
140 | cliprange: 0.1
141 | cliprange_vf: -1
142 |
143 |
144 | HalfCheetahBulletEnv-v0:
145 | env_wrapper: utils.wrappers.TimeFeatureWrapper
146 | normalize: true
147 | n_envs: 1
148 | n_timesteps: !!float 2e6
149 | policy: 'MlpPolicy'
150 | n_steps: 2048
151 | nminibatches: 32
152 | lam: 0.95
153 | gamma: 0.99
154 | noptepochs: 10
155 | ent_coef: 0.0
156 | learning_rate: !!float 3e-4
157 | cliprange: 0.2
158 |
159 | HalfCheetah-v2:
160 | normalize: true
161 | n_envs: 1
162 | n_timesteps: !!float 2e6
163 | policy: 'MlpPolicy'
164 | n_steps: 2048
165 | nminibatches: 32
166 | lam: 0.95
167 | gamma: 0.99
168 | noptepochs: 10
169 | ent_coef: 0.0
170 | learning_rate: lin_3e-4
171 | cliprange: 0.2
172 | cliprange_vf: -1
173 |
174 | AntBulletEnv-v0:
175 | normalize: true
176 | n_envs: 8
177 | n_timesteps: !!float 2e6
178 | policy: 'CustomMlpPolicy'
179 | n_steps: 256
180 | nminibatches: 32
181 | lam: 0.95
182 | gamma: 0.99
183 | noptepochs: 10
184 | ent_coef: 0.0
185 | learning_rate: 2.5e-4
186 | cliprange: 0.2
187 |
188 | HopperBulletEnv-v0:
189 | normalize: true
190 | n_envs: 8
191 | n_timesteps: !!float 2e6
192 | policy: 'MlpPolicy'
193 | n_steps: 2048
194 | nminibatches: 128
195 | lam: 0.95
196 | gamma: 0.99
197 | noptepochs: 10
198 | ent_coef: 0.0
199 | learning_rate: 2.5e-4
200 | cliprange: 0.2
201 |
202 | ReacherBulletEnv-v0:
203 | normalize: true
204 | n_envs: 8
205 | n_timesteps: !!float 2e6
206 | policy: 'MlpPolicy'
207 | n_steps: 2048
208 | nminibatches: 32
209 | lam: 0.95
210 | gamma: 0.99
211 | noptepochs: 10
212 | ent_coef: 0.0
213 | learning_rate: 2.5e-4
214 | cliprange: 0.2
215 |
216 | MinitaurBulletEnv-v0:
217 | normalize: true
218 | n_envs: 8
219 | n_timesteps: !!float 2e6
220 | policy: 'MlpPolicy'
221 | n_steps: 2048
222 | nminibatches: 32
223 | lam: 0.95
224 | gamma: 0.99
225 | noptepochs: 10
226 | ent_coef: 0.0
227 | learning_rate: 2.5e-4
228 | cliprange: 0.2
229 |
230 | MinitaurBulletDuckEnv-v0:
231 | normalize: true
232 | n_envs: 8
233 | n_timesteps: !!float 2e6
234 | policy: 'MlpPolicy'
235 | n_steps: 2048
236 | nminibatches: 32
237 | lam: 0.95
238 | gamma: 0.99
239 | noptepochs: 10
240 | ent_coef: 0.0
241 | learning_rate: 2.5e-4
242 | cliprange: 0.2
243 |
244 | # To be tuned
245 | HumanoidBulletEnv-v0:
246 | normalize: true
247 | n_envs: 8
248 | n_timesteps: !!float 1e7
249 | policy: 'MlpPolicy'
250 | n_steps: 2048
251 | nminibatches: 32
252 | lam: 0.95
253 | gamma: 0.99
254 | noptepochs: 10
255 | ent_coef: 0.0
256 | learning_rate: 2.5e-4
257 | cliprange: 0.2
258 |
259 | InvertedDoublePendulumBulletEnv-v0:
260 | normalize: true
261 | n_envs: 8
262 | n_timesteps: !!float 2e6
263 | policy: 'MlpPolicy'
264 | n_steps: 2048
265 | nminibatches: 32
266 | lam: 0.95
267 | gamma: 0.99
268 | noptepochs: 10
269 | ent_coef: 0.0
270 | learning_rate: 2.5e-4
271 | cliprange: 0.2
272 |
273 | InvertedPendulumSwingupBulletEnv-v0:
274 | normalize: true
275 | n_envs: 8
276 | n_timesteps: !!float 2e6
277 | policy: 'MlpPolicy'
278 | n_steps: 2048
279 | nminibatches: 32
280 | lam: 0.95
281 | gamma: 0.99
282 | noptepochs: 10
283 | ent_coef: 0.0
284 | learning_rate: 2.5e-4
285 | cliprange: 0.2
286 |
287 | # Following https://github.com/lcswillems/rl-starter-files
288 | MiniGrid-DoorKey-5x5-v0:
289 | env_wrapper: gym_minigrid.wrappers.FlatObsWrapper # requires --gym-packages gym_minigrid
290 | normalize: true
291 | n_envs: 8 # number of environment copies running in parallel
292 | n_timesteps: !!float 1e5
293 | policy: MlpPolicy
294 | n_steps: 128 # batch size is n_steps * n_env
295 | nminibatches: 32 # Number of training minibatches per update
296 | lam: 0.95 # Factor for trade-off of bias vs variance for Generalized Advantage Estimator
297 | gamma: 0.99
298 | noptepochs: 10 # Number of epoch when optimizing the surrogate
299 | ent_coef: 0.0 # Entropy coefficient for the loss caculation
300 | learning_rate: 2.5e-4 # The learning rate, it can be a function
301 | cliprange: 0.2 # Clipping parameter, it can be a function
302 |
303 | MiniGrid-FourRooms-v0:
304 | env_wrapper: gym_minigrid.wrappers.FlatObsWrapper # requires --gym-packages gym_minigrid
305 | normalize: true
306 | n_envs: 8
307 | n_timesteps: !!float 4e6
308 | policy: 'MlpPolicy'
309 | n_steps: 512
310 | nminibatches: 32
311 | lam: 0.95
312 | gamma: 0.99
313 | noptepochs: 10
314 | ent_coef: 0.0
315 | learning_rate: 2.5e-4
316 | cliprange: 0.2
317 |
318 | gym_sacrum_nav:sacrum_nav-v1:
319 | normalize: true
320 | n_envs: 8
321 | n_timesteps: !!float 4e6
322 | policy: 'CnnPolicy'
323 | n_steps: 512
324 | nminibatches: 32
325 | lam: 0.95
326 | gamma: 0.99
327 | noptepochs: 10
328 | ent_coef: 0.0
329 | learning_rate: 2.5e-4
330 | cliprange: 0.2
331 |
--------------------------------------------------------------------------------
/utils/resnet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import time
4 | import pickle
5 | import pdb
6 |
7 | def timeit(f):
8 | """ Decorator to time Any Function """
9 |
10 | def timed(*args, **kwargs):
11 | start_time = time.time()
12 | result = f(*args, **kwargs)
13 | end_time = time.time()
14 | seconds = end_time - start_time
15 | print(" [-] %s : %2.5f sec, which is %2.5f mins, which is %2.5f hours" %
16 | (f.__name__, seconds, seconds / 60, seconds / 3600))
17 | return result
18 |
19 | return timed
20 |
21 | def _debug(operation):
22 | print("Layer_name: " + operation.op.name + " -Output_Shape: " + str(operation.shape.as_list()))
23 |
24 | # Summaries for variables
25 | def variable_summaries(var):
26 | """
27 | Attach a lot of summaries to a Tensor (for TensorBoard visualization).
28 | :param var: variable to be summarized
29 | :return: None
30 | """
31 | with tf.name_scope('summaries'):
32 | mean = tf.reduce_mean(var)
33 | tf.summary.scalar('mean', mean)
34 | with tf.name_scope('stddev'):
35 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
36 | tf.summary.scalar('stddev', stddev)
37 | tf.summary.scalar('max', tf.reduce_max(var))
38 | tf.summary.scalar('min', tf.reduce_min(var))
39 | tf.summary.histogram('histogram', var)
40 |
41 | def variable_with_weight_decay(kernel_shape, initializer, wd):
42 | """
43 | Create a variable with L2 Regularization (Weight Decay)
44 | :param kernel_shape: the size of the convolving weight kernel.
45 | :param initializer: The initialization scheme, He et al. normal or Xavier normal are recommended.
46 | :param wd:(weight decay) L2 regularization parameter.
47 | :return: The weights of the kernel initialized. The L2 loss is added to the loss collection.
48 | """
49 | w = tf.get_variable('weights', kernel_shape, tf.float32, initializer=initializer)
50 |
51 | collection_name = tf.GraphKeys.REGULARIZATION_LOSSES
52 | if wd and (not tf.get_variable_scope().reuse):
53 | weight_decay = tf.multiply(tf.nn.l2_loss(w), wd, name='w_loss')
54 | tf.add_to_collection(collection_name, weight_decay)
55 | #variable_summaries(w)
56 | return w
57 |
58 |
59 | def _residual_block(name, x, filters, pool_first=False, strides=1, dilation=1, bias=-1):
60 | print('Building residual unit: %s' % name)
61 | with tf.variable_scope(name):
62 | # get input channels
63 | in_channel = x.shape.as_list()[-1]
64 |
65 | # Shortcut connection
66 | shortcut = tf.identity(x)
67 |
68 | if pool_first:
69 | if in_channel == filters:
70 | if strides == 1:
71 | shortcut = tf.identity(x)
72 | else:
73 | shortcut= tf.pad(x, tf.constant([[0,0],[1,1],[1,1],[0,0]]), "CONSTANT")
74 | shortcut = tf.nn.max_pool(shortcut, [1, strides, strides, 1], [1, strides, strides, 1], 'VALID')
75 | else:
76 | shortcut = _conv('shortcut_conv', x, padding='VALID',
77 | num_filters=filters, kernel_size=(1, 1), stride=(strides, strides),
78 | bias=bias)
79 | else:
80 | if dilation != 1:
81 | shortcut = _conv('shortcut_conv', x, padding='VALID',
82 | num_filters=filters, kernel_size=(1, 1), dilation=dilation, bias=bias)
83 |
84 | # Residual
85 | x = _conv('conv_1', x, padding=[[0,0],[1,1],[1,1],[0,0]],
86 | num_filters=filters, kernel_size=(3, 3), stride=(strides, strides), bias=bias)
87 | #x = _bn('bn_1', x)
88 | x = _relu('relu_1', x)
89 | x = _conv('conv_2', x, padding=[[0,0],[1,1],[1,1],[0,0]],
90 | num_filters=filters, kernel_size=(3, 3), bias=bias)
91 | #x = _bn('bn_2', x)
92 |
93 | # Merge
94 | x = x + shortcut
95 | x = _relu('relu_2', x)
96 |
97 | print('residual-unit-%s-shape: ' % name + str(x.shape.as_list()))
98 |
99 | return x
100 |
101 | def _conv(name, x, num_filters=16, kernel_size=(3, 3), padding='SAME', stride=(1, 1),
102 | initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0, dilation=1.0, bias=-1):
103 |
104 | with tf.variable_scope(name):
105 | stride = [1, stride[0], stride[1], 1]
106 | kernel_shape = [kernel_size[0], kernel_size[1], x.shape[-1], num_filters]
107 |
108 | w = variable_with_weight_decay(kernel_shape, initializer, l2_strength)
109 |
110 | #variable_summaries(w)
111 | if dilation > 1:
112 | conv = tf.nn.atrous_conv2d(x, w, dilation, padding)
113 | else:
114 | if type(padding)==type(''):
115 | conv = tf.nn.conv2d(x, w, stride, padding)
116 | else:
117 | conv = tf.pad(x, padding, "CONSTANT")
118 | conv = tf.nn.conv2d(conv, w, stride, padding='VALID')
119 |
120 | if bias != -1:
121 | bias = tf.get_variable('biases', [num_filters], initializer=tf.constant_initializer(bias))
122 |
123 | #variable_summaries(bias)
124 | conv = tf.nn.bias_add(conv, bias)
125 |
126 | tf.add_to_collection('debug_layers', conv)
127 |
128 | return conv
129 |
130 | def _relu(name, x):
131 | with tf.variable_scope(name):
132 | return tf.nn.relu(x)
133 |
134 | def _fc(name, x, output_dim=128, initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0, bias=-1):
135 |
136 | with tf.variable_scope(name):
137 | n_in = x.get_shape()[-1].value
138 |
139 | w = variable_with_weight_decay([n_in, output_dim], initializer, l2_strength)
140 |
141 | #variable_summaries(w)
142 |
143 | if bias != -1 and isinstance(bias, float):
144 | bias = tf.get_variable("biases", [output_dim], tf.float32, tf.constant_initializer(bias))
145 | output = tf.nn.bias_add(tf.matmul(x, w), bias)
146 | else:
147 | output = tf.matmul(x, w)
148 |
149 | return output
150 |
151 | def _bn(name, x, train_flag):
152 | with tf.variable_scope(name):
153 | moving_average_decay = 0.9
154 | decay = moving_average_decay
155 |
156 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
157 |
158 | mu = tf.get_variable('mu', batch_mean.shape, dtype=tf.float32,
159 | initializer=tf.zeros_initializer(), trainable=False)
160 | tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, mu)
161 | tf.add_to_collection('mu_sigma_bn', mu)
162 | sigma = tf.get_variable('sigma', batch_var.shape, dtype=tf.float32,
163 | initializer=tf.ones_initializer(), trainable=False)
164 | tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, sigma)
165 | tf.add_to_collection('mu_sigma_bn', sigma)
166 | beta = tf.get_variable('beta', batch_mean.shape, dtype=tf.float32,
167 | initializer=tf.zeros_initializer())
168 | gamma = tf.get_variable('gamma', batch_var.shape, dtype=tf.float32,
169 | initializer=tf.ones_initializer())
170 |
171 | # BN when training
172 | update = 1.0 - decay
173 | update_mu = mu.assign_sub(update * (mu - batch_mean))
174 | update_sigma = sigma.assign_sub(update * (sigma - batch_var))
175 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mu)
176 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_sigma)
177 |
178 | mean, var = tf.cond(train_flag, lambda: (batch_mean, batch_var), lambda: (mu, sigma))
179 | bn = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5)
180 |
181 | tf.add_to_collection('debug_layers', bn)
182 |
183 | return bn
184 |
185 |
186 | def ResNet18(x_input=None, classes=5, bias=-1, weight_decay=5e-4, test_classification=False):
187 |
188 | with tf.variable_scope('conv1_x'):
189 | print('Building unit: conv1')
190 | conv1 = _conv('conv1', x_input, padding= [[0,0],[3,3],[3,3],[0,0]],
191 | num_filters=64, kernel_size=(7, 7), stride=(2, 2), l2_strength=weight_decay,
192 | bias=bias)
193 |
194 | #conv1 = _bn('bn1', conv1)
195 |
196 | conv1 = _relu('relu1', conv1)
197 | _debug(conv1)
198 | conv1= tf.pad(conv1, tf.constant([[0,0],[1,1],[1,1],[0,0]]), "CONSTANT")
199 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID',
200 | name='max_pool1')
201 | _debug(conv1)
202 | print('conv1-shape: ' + str(conv1.shape.as_list()))
203 |
204 | with tf.variable_scope('conv2_x'):
205 | conv2 = _residual_block('conv2_1', conv1, 64)
206 | _debug(conv2)
207 | conv2 = _residual_block('conv2_2', conv2, 64)
208 | _debug(conv2)
209 |
210 | with tf.variable_scope('conv3_x'):
211 | conv3 = _residual_block('conv3_1', conv2, 128, pool_first=True, strides=2)
212 | _debug(conv3)
213 | conv3 = _residual_block('conv3_2', conv3, 128)
214 | _debug(conv3)
215 |
216 | with tf.variable_scope('conv4_x'):
217 | conv4 = _residual_block('conv4_1', conv3, 256, pool_first=True, strides=2)
218 | _debug(conv4)
219 | conv4 = _residual_block('conv4_2', conv4, 256)
220 | _debug(conv4)
221 |
222 | with tf.variable_scope('conv5_x'):
223 | conv5 = _residual_block('conv5_1', conv4, 512, pool_first=True, strides=2)
224 | _debug(conv5)
225 | conv5 = _residual_block('conv5_2', conv5, 512)
226 | _debug(conv5)
227 |
228 | with tf.variable_scope('resnet_out'):
229 | print('Building unit: logits')
230 | #score = tf.reduce_mean(conv5, axis=[1, 2])
231 | score = tf.compat.v1.layers.flatten(conv5)
232 | _debug(score)
233 | score = _fc('logits_dense', score, output_dim=classes, l2_strength=weight_decay)
234 | print('logits-shape: ' + str(score.shape.as_list()))
235 |
236 | return score
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import inspect
4 | import glob
5 | import yaml
6 | import importlib
7 |
8 | import gym
9 | try:
10 | import pybullet_envs
11 | except ImportError:
12 | pybullet_envs = None
13 |
14 | from gym.envs.registration import load
15 |
16 | from stable_baselines.deepq.policies import FeedForwardPolicy
17 | from stable_baselines.common.policies import FeedForwardPolicy as BasePolicy
18 | from stable_baselines.common.policies import register_policy
19 | from stable_baselines.sac.policies import FeedForwardPolicy as SACPolicy
20 | from stable_baselines.bench import Monitor
21 | from stable_baselines import logger
22 | from stable_baselines import PPO2, DQN
23 | from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize, \
24 | VecFrameStack, SubprocVecEnv
25 | from stable_baselines.common import set_global_seeds
26 |
27 | ALGOS = {
28 | 'dqn': DQN,
29 | 'ppo2': PPO2
30 | }
31 |
32 |
33 | # ================== Custom Policies =================
34 |
35 | class CustomDQNPolicy(FeedForwardPolicy):
36 | def __init__(self, *args, **kwargs):
37 | super(CustomDQNPolicy, self).__init__(*args, **kwargs,
38 | layers=[64],
39 | layer_norm=True,
40 | feature_extraction="mlp")
41 |
42 |
43 | class CustomMlpPolicy(BasePolicy):
44 | def __init__(self, *args, **kwargs):
45 | super(CustomMlpPolicy, self).__init__(*args, **kwargs,
46 | layers=[16],
47 | feature_extraction="mlp")
48 |
49 |
50 | class CustomSACPolicy(SACPolicy):
51 | def __init__(self, *args, **kwargs):
52 | super(CustomSACPolicy, self).__init__(*args, **kwargs,
53 | layers=[256, 256],
54 | feature_extraction="mlp")
55 |
56 |
57 | register_policy('CustomSACPolicy', CustomSACPolicy)
58 | register_policy('CustomDQNPolicy', CustomDQNPolicy)
59 | register_policy('CustomMlpPolicy', CustomMlpPolicy)
60 |
61 |
62 | def flatten_dict_observations(env):
63 | assert isinstance(env.observation_space, gym.spaces.Dict)
64 | keys = env.observation_space.spaces.keys()
65 | return gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
66 |
67 |
68 | def get_wrapper_class(hyperparams):
69 | """
70 | Get a Gym environment wrapper class specified as a hyper parameter
71 | "env_wrapper".
72 | e.g.
73 | env_wrapper: gym_minigrid.wrappers.FlatObsWrapper
74 |
75 | :param hyperparams: (dict)
76 | :return: a subclass of gym.Wrapper (class object) you can use to
77 | create another Gym env giving an original env.
78 | """
79 |
80 | def get_module_name(fullname):
81 | return '.'.join(wrapper_name.split('.')[:-1])
82 |
83 | def get_class_name(fullname):
84 | return wrapper_name.split('.')[-1]
85 |
86 | if 'env_wrapper' in hyperparams.keys():
87 | wrapper_name = hyperparams.get('env_wrapper')
88 | wrapper_module = importlib.import_module(get_module_name(wrapper_name))
89 | return getattr(wrapper_module, get_class_name(wrapper_name))
90 | else:
91 | return None
92 |
93 |
94 | def make_env(env_id, rank=0, seed=0, log_dir=None, wrapper_class=None, **kwargs):
95 | """
96 | Helper function to multiprocess training
97 | and log the progress.
98 |
99 | :param env_id: (str)
100 | :param rank: (int)
101 | :param seed: (int)
102 | :param log_dir: (str)
103 | :param wrapper: (type) a subclass of gym.Wrapper to wrap the original
104 | env with
105 | """
106 | if log_dir is None and log_dir != '':
107 | log_dir = "/tmp/gym/{}/".format(int(time.time()))
108 | os.makedirs(log_dir, exist_ok=True)
109 |
110 | def _init():
111 | set_global_seeds(seed + rank)
112 | env = gym.make(env_id, **kwargs)
113 |
114 | # Dict observation space is currently not supported.
115 | # https://github.com/hill-a/stable-baselines/issues/321
116 | # We allow a Gym env wrapper (a subclass of gym.Wrapper)
117 | if wrapper_class:
118 | env = wrapper_class(env)
119 |
120 | env.seed(seed + rank)
121 | env = Monitor(env, os.path.join(log_dir, str(rank)), allow_early_resets=True)
122 | return env
123 |
124 | return _init
125 |
126 |
127 | def create_test_env(env_id, n_envs=1, is_atari=False,
128 | stats_path=None, seed=0,
129 | log_dir='', should_render=True, hyperparams=None):
130 | """
131 | Create environment for testing a trained agent
132 |
133 | :param env_id: (str)
134 | :param n_envs: (int) number of processes
135 | :param is_atari: (bool)
136 | :param stats_path: (str) path to folder containing saved running averaged
137 | :param seed: (int) Seed for random number generator
138 | :param log_dir: (str) Where to log rewards
139 | :param should_render: (bool) For Pybullet env, display the GUI
140 | :param env_wrapper: (type) A subclass of gym.Wrapper to wrap the original
141 | env with
142 | :param hyperparams: (dict) Additional hyperparams (ex: n_stack)
143 | :return: (gym.Env)
144 | """
145 | # HACK to save logs
146 | if log_dir is not None:
147 | os.environ["OPENAI_LOG_FORMAT"] = 'csv'
148 | os.environ["OPENAI_LOGDIR"] = os.path.abspath(log_dir)
149 | os.makedirs(log_dir, exist_ok=True)
150 | logger.configure()
151 |
152 | # Create the environment and wrap it if necessary
153 | env_wrapper = get_wrapper_class(hyperparams)
154 | if 'env_wrapper' in hyperparams.keys():
155 | del hyperparams['env_wrapper']
156 |
157 | if is_atari:
158 | print("Using Atari wrapper")
159 | #env = make_atari_env(env_id, num_env=n_envs, seed=seed)
160 | ## Frame-stacking with 4 frames
161 | #env = VecFrameStack(env, n_stack=4)
162 | elif n_envs > 1:
163 | # start_method = 'spawn' for thread safe
164 | env = SubprocVecEnv([make_env(env_id, i, seed, log_dir, wrapper_class=env_wrapper) for i in range(n_envs)])
165 | # Pybullet envs does not follow gym.render() interface
166 | elif "Bullet" in env_id:
167 | spec = gym.envs.registry.env_specs[env_id]
168 | try:
169 | class_ = load(spec.entry_point)
170 | except AttributeError:
171 | # Backward compatibility with gym
172 | class_ = load(spec._entry_point)
173 | # HACK: force SubprocVecEnv for Bullet env that does not
174 | # have a render argument
175 | render_name = None
176 | use_subproc = 'renders' not in inspect.getfullargspec(class_.__init__).args
177 | if not use_subproc:
178 | render_name = 'renders'
179 | # Dev branch of pybullet
180 | # use_subproc = use_subproc and 'render' not in inspect.getfullargspec(class_.__init__).args
181 | # if not use_subproc and render_name is None:
182 | # render_name = 'render'
183 |
184 | # Create the env, with the original kwargs, and the new ones overriding them if needed
185 | def _init():
186 | # TODO: fix for pybullet locomotion envs
187 | env = class_(**{**spec._kwargs}, **{render_name: should_render})
188 | env.seed(0)
189 | if log_dir is not None:
190 | env = Monitor(env, os.path.join(log_dir, "0"), allow_early_resets=True)
191 | return env
192 |
193 | if use_subproc:
194 | env = SubprocVecEnv([make_env(env_id, 0, seed, log_dir, wrapper_class=env_wrapper)])
195 | else:
196 | env = DummyVecEnv([_init])
197 | else:
198 | env = DummyVecEnv([make_env(env_id, 0, seed, log_dir, wrapper_class=env_wrapper)])
199 |
200 | # Load saved stats for normalizing input and rewards
201 | # And optionally stack frames
202 | if stats_path is not None:
203 | if hyperparams['normalize']:
204 | print("Loading running average")
205 | print("with params: {}".format(hyperparams['normalize_kwargs']))
206 | env = VecNormalize(env, training=False, **hyperparams['normalize_kwargs'])
207 | env.load_running_average(stats_path)
208 |
209 | n_stack = hyperparams.get('frame_stack', 0)
210 | if n_stack > 0:
211 | print("Stacking {} frames".format(n_stack))
212 | env = VecFrameStack(env, n_stack)
213 | return env
214 |
215 |
216 | def linear_schedule(initial_value):
217 | """
218 | Linear learning rate schedule.
219 |
220 | :param initial_value: (float or str)
221 | :return: (function)
222 | """
223 | if isinstance(initial_value, str):
224 | initial_value = float(initial_value)
225 |
226 | def func(progress):
227 | """
228 | Progress will decrease from 1 (beginning) to 0
229 | :param progress: (float)
230 | :return: (float)
231 | """
232 | return progress * initial_value
233 |
234 | return func
235 |
236 |
237 | def get_trained_models(log_folder):
238 | """
239 | :param log_folder: (str) Root log folder
240 | :return: (dict) Dict representing the trained agent
241 | """
242 | algos = os.listdir(log_folder)
243 | trained_models = {}
244 | for algo in algos:
245 | for ext in ['zip', 'pkl']:
246 | for env_id in glob.glob('{}/{}/*.{}'.format(log_folder, algo, ext)):
247 | # Retrieve env name
248 | env_id = env_id.split('/')[-1].split('.{}'.format(ext))[0]
249 | trained_models['{}-{}'.format(algo, env_id)] = (algo, env_id)
250 | return trained_models
251 |
252 |
253 | def get_latest_run_id(log_path, env_id):
254 | """
255 | Returns the latest run number for the given log name and log path,
256 | by finding the greatest number in the directories.
257 |
258 | :param log_path: (str) path to log folder
259 | :param env_id: (str)
260 | :return: (int) latest run number
261 | """
262 | max_run_id = 0
263 | for path in glob.glob(log_path + "/{}_[0-9]*".format(env_id)):
264 | file_name = path.split("/")[-1]
265 | ext = file_name.split("_")[-1]
266 | if env_id == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
267 | max_run_id = int(ext)
268 | return max_run_id
269 |
270 |
271 | def get_saved_hyperparams(stats_path, norm_reward=False, test_mode=False):
272 | """
273 | :param stats_path: (str)
274 | :param norm_reward: (bool)
275 | :param test_mode: (bool)
276 | :return: (dict, str)
277 | """
278 | hyperparams = {}
279 | if not os.path.isdir(stats_path):
280 | stats_path = None
281 | else:
282 | config_file = os.path.join(stats_path, 'config.yml')
283 | if os.path.isfile(config_file):
284 | # Load saved hyperparameters
285 | with open(os.path.join(stats_path, 'config.yml'), 'r') as f:
286 | hyperparams = yaml.load(f)
287 | hyperparams['normalize'] = hyperparams.get('normalize', False)
288 | else:
289 | obs_rms_path = os.path.join(stats_path, 'obs_rms.pkl')
290 | hyperparams['normalize'] = os.path.isfile(obs_rms_path)
291 |
292 | # Load normalization params
293 | if hyperparams['normalize']:
294 | if isinstance(hyperparams['normalize'], str):
295 | normalize_kwargs = eval(hyperparams['normalize'])
296 | if test_mode:
297 | normalize_kwargs['norm_reward'] = norm_reward
298 | else:
299 | normalize_kwargs = {'norm_obs': hyperparams['normalize'], 'norm_reward': norm_reward}
300 | hyperparams['normalize_kwargs'] = normalize_kwargs
301 | return hyperparams, stats_path
302 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gym
3 | import yaml
4 | import argparse
5 | import numpy as np
6 | import tensorflow as tf
7 | from pprint import pprint
8 | import matplotlib.pyplot as plt
9 | from utils.resnet import ResNet18
10 | from collections import OrderedDict
11 | from stable_baselines.ppo2.ppo2 import constfn
12 | from stable_baselines.common import set_global_seeds
13 | from stable_baselines.common.vec_env import VecNormalize
14 | from utils.visualization import reachability_plot, plot2fig
15 | from stable_baselines.deepq.policies import FeedForwardPolicy
16 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
17 | from utils import ALGOS, get_wrapper_class, linear_schedule, make_env
18 |
19 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
20 |
21 | ########################################################################################################################
22 | # SET PATHS
23 | # data path -> path to dataset | output path -> path to logging and model saving location
24 | DATA_PATH = ""
25 | OUTPUT_PATH = ""
26 | #
27 | ########################################################################################################################
28 | # SET ENVIRONMENT PARAMETERS
29 | env_params = {'data_path': DATA_PATH,
30 | 'verbose': 1,
31 | 'test_patients': 5,
32 | 'val_patients': 4,
33 | 'prev_actions': 5,
34 | 'prev_frames': 3,
35 | 'val_set': np.array([25, 26, 27, 28]),
36 | 'test_set': np.array([29, 30, 31, 32, 33]),
37 | 'shuffles': 20,
38 | 'chebishev': False,
39 | 'no_nop': False,
40 | 'max_time_steps': True,
41 | 'time_step_limit': 50,
42 | 'reward_goal_correct': 1.0,
43 | 'reward_goal_incorrect': -0.25,
44 | 'reward_move_closer': 0.05,
45 | 'reward_move_further': -0.1,
46 | 'reward_border_collision': -0.1}
47 | # parameters to define the environment
48 | #
49 | ########################################################################################################################
50 |
51 | best_mean_reward, n_steps = -np.inf, 0
52 |
53 | def custom_cnn(input, **kwargs):
54 | action_mem_size = 25
55 | action_history = input[:, 0, 0:action_mem_size, -1]
56 | input = input[..., :-1]
57 |
58 | action_values = ResNet18(x_input=input, classes=512)
59 |
60 | return action_values, action_history
61 |
62 | class CustomCnnPolicy(FeedForwardPolicy):
63 |
64 | def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, obs_phs=None, dueling=True, **_kwargs):
65 | super(CustomCnnPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse, cnn_extractor=custom_cnn,
66 | feature_extraction="cnn", obs_phs=obs_phs, dueling=dueling, layer_norm=False, **_kwargs)
67 |
68 | class Indicators():
69 | def __init__(self, num_states):
70 | self.num_states = num_states
71 | self.reset()
72 |
73 | def reset(self):
74 | self.reach_goal = []
75 | self.efficiency = []
76 | self.correct_actions = 0
77 | self.total_actions = 0
78 |
79 | def __str__(self):
80 | print("Goal reached {}% of the times".format(0 if not self.reach_goal else np.average(self.reach_goal)*100))
81 | print("Overall efficiency: {}%".format(np.average(0 if not self.efficiency else self.efficiency)*100))
82 | return ""
83 |
84 |
85 | def callback(locals_, globals_):
86 |
87 | global n_steps, best_mean_reward
88 | self_ = locals_.get('self')
89 | env_ = self_.env.envs[0]
90 | info_ = locals_.get('info')
91 | writer_ = locals_.get('writer')
92 | episode_ = locals_.get('num_episodes')
93 |
94 | # LOG CORRECT DECISIONS/TOTAL DECISIONS
95 | correct_decision_rate = info_.get('correct_decision_rate') if info_ else None
96 | if correct_decision_rate:
97 | summary = tf.Summary(value=[tf.Summary.Value(tag='callback_logging/correct_decision_rate', simple_value=correct_decision_rate)])
98 | writer_.add_summary(summary, episode_)
99 |
100 | episode_completed = locals_.get('done')
101 | if episode_completed and episode_ % 40 == 0:
102 | print("Done with episode {}".format(episode_))
103 | actions = []
104 | state_vals = []
105 | for state in range(env_.num_states):
106 |
107 | frame = env_.frames[env_.patient_idx][state][0]
108 | frames = np.repeat(frame[:, :, np.newaxis], len(env_.prev_frames), axis=2)
109 | observation = np.dstack((frames, np.zeros_like(frame)))
110 |
111 | action, q_vals, state_val = self_.predict(observation)
112 | actions.append(action)
113 | state_vals.append(state_val)
114 |
115 | # TEST CORRECTNESS ON TRAINING PATIENT
116 | quiver_plot, policy_correctness = env_.quiver_plot(states=list(range(env_.num_states)), actions=actions, state_vals=state_vals)
117 | summary = tf.Summary(value=[tf.Summary.Value(tag='callback_logging/policy correctness', simple_value=policy_correctness)])
118 | writer_.add_summary(summary, episode_)
119 |
120 | # LOG POLICY QUIVER PLOT
121 | if episode_ % 200 == 0:
122 | image = plot2fig(quiver_plot)
123 | summary = tf.Summary(value=[tf.Summary.Value(tag='Regular_training/Policy graph at episode {}'.format(episode_), image=image)])
124 | writer_.add_summary(summary, episode_)
125 | plt.close(quiver_plot)
126 |
127 | if episode_completed and episode_ % 100 == 0:
128 | avg_val_reachability = 0.0
129 | avg_val_correctness = 0.0
130 | max_time_steps = 20
131 | indicators = Indicators(env_.num_states)
132 | val_patients = env_.val_patient_idxs
133 | val_reachabilities = []
134 |
135 | for val_patient in val_patients:
136 | goals = env_.goals[val_patient]
137 |
138 | for state in range(env_.num_states):
139 |
140 | obs = env_.set(val_patient, state)
141 | prev_state = env_.state
142 | for step in range(max_time_steps):
143 |
144 | if env_.no_nop and env_.state in goals:
145 | done = True
146 | else:
147 | action, q_vals, state_val = model.predict(obs)
148 | obs, reward, done, info = env_.step(action)
149 |
150 | moving = not (prev_state == env_.state)
151 | if moving:
152 | indicators.total_actions += 1
153 | if reward > 0:
154 | indicators.correct_actions += 1
155 |
156 | if done or step == max_time_steps - 1:
157 | if (env_.state in goals and env_.no_nop) or (not env_.no_nop and reward == env_.reward_dict["goal_correct"]):
158 | indicators.reach_goal.append(1)
159 | else:
160 | indicators.reach_goal.append(0)
161 | _ = env_.reset()
162 | break
163 |
164 | prev_state = env_.state
165 |
166 | val_reachabilities.append(np.average(indicators.reach_goal))
167 | # avg_val_reachability += np.average(indicators.reach_goal)/len(val_patients) if isinstance(val_patients, (list, np.ndarray)) else np.average(indicators.reach_goal)
168 | print("Correctness for test patient {}: {}".format(val_patient, indicators.correct_actions / indicators.total_actions))
169 | print("Reachability for validation patient {}: {}".format(val_patient, np.average(indicators.reach_goal)))
170 | avg_val_correctness += indicators.correct_actions / indicators.total_actions / len(val_patients) if isinstance(val_patients, (list, np.ndarray)) \
171 | else indicators.correct_actions / indicators.total_actions
172 | indicators.reset()
173 |
174 | val_median_reachability = np.median(val_reachabilities)
175 | summary = tf.Summary(value=[tf.Summary.Value(tag='callback_logging/val_median_reachability', simple_value=val_median_reachability)])
176 | writer_.add_summary(summary, episode_)
177 | summary = tf.Summary(value=[tf.Summary.Value(tag='callback_logging/val_correctness', simple_value=avg_val_correctness)])
178 | writer_.add_summary(summary, episode_)
179 |
180 | # if env_.val_reachability < avg_val_reachability:
181 | if env_.val_reachability < val_median_reachability:
182 | print("Improved the model at episode {}!".format(episode_))
183 | self_.save(OUTPUT_PATH + "val_model_episode_{}".format(episode_), cloudpickle=True)
184 | env_.val_reachability = val_median_reachability
185 |
186 | n_steps += 1
187 |
188 | return True
189 |
190 | if __name__ == '__main__':
191 | parser = argparse.ArgumentParser()
192 | parser.add_argument('--env', type=str, nargs='+', default=["gym_sacrum_nav:sacrum_nav-v2"], help='environment ID(s)')
193 | parser.add_argument('-tb', '--tensorboard-log', help='Tensorboard log dir', default='./runs/', type=str)
194 | parser.add_argument('-i', '--trained-agent', help='Path to a pretrained agent to continue training', default='', type=str)
195 | parser.add_argument('--algo', help='RL Algorithm', default='dqn', type=str, required=False, choices=list(ALGOS.keys()))
196 | parser.add_argument('-n', '--n-timesteps', help='Overwrite the number of timesteps', default=-1, type=int)
197 | parser.add_argument('--log-interval', help='Override log interval (default: -1, no change)', default=-1, type=int)
198 | parser.add_argument('-f', '--log-folder', help='Log folder', type=str, default='logs')
199 | parser.add_argument('--data-folder', help='Data folder', type=str, default='./data/')
200 | parser.add_argument('--seed', help='Random generator seed', type=int, default=0)
201 | parser.add_argument('--n-trials', help='Number of trials for optimizing hyperparameters', type=int, default=10)
202 | parser.add_argument('--verbose', help='Verbose mode (0: no output, 1: INFO, 2: debug)', default=1, type=int)
203 | parser.add_argument('--gym-packages', type=str, nargs='+', default=[], help='Additional external Gym environemnt package modules to import (e.g. gym_minigrid)')
204 | args = parser.parse_args()
205 |
206 | env_params['data_path'] = args.data_folder
207 |
208 | env_ids = args.env
209 | set_global_seeds(args.seed)
210 |
211 | for env_id in env_ids:
212 | tensorboard_log = None if args.tensorboard_log == '' else os.path.join(args.tensorboard_log, env_id)
213 | print("=" * 10, env_id, "=" * 10)
214 |
215 | # Load hyperparameters from yaml file
216 | with open('hyperparams/{}.yml'.format(args.algo), 'r') as f:
217 | hyperparams_dict = yaml.full_load(f)
218 | if env_id in list(hyperparams_dict.keys()):
219 | hyperparams = hyperparams_dict[env_id]
220 | else:
221 | raise ValueError("Hyperparameters not found for {}-{}".format(args.algo, env_id))
222 |
223 | saved_hyperparams = OrderedDict([(key, hyperparams[key]) for key in sorted(hyperparams.keys())])
224 | algo_ = args.algo
225 |
226 | if args.verbose > 0:
227 | pprint(saved_hyperparams)
228 |
229 | n_envs = hyperparams.get('n_envs', 1)
230 |
231 | if args.verbose > 0:
232 | print("Using {} environments".format(n_envs))
233 |
234 | n_timesteps = int(hyperparams['n_timesteps'])
235 |
236 | # Delete keys so the dict can be pass to the model constructor
237 | if 'n_envs' in hyperparams.keys():
238 | del hyperparams['n_envs']
239 | del hyperparams['n_timesteps']
240 |
241 | env_wrapper = get_wrapper_class(hyperparams)
242 | if 'env_wrapper' in hyperparams.keys():
243 | del hyperparams['env_wrapper']
244 |
245 | if algo_ in ["ppo2"]:
246 | for key in ['learning_rate', 'cliprange', 'cliprange_vf']:
247 | if key not in hyperparams:
248 | continue
249 | if isinstance(hyperparams[key], str):
250 | schedule, initial_value = hyperparams[key].split('_')
251 | initial_value = float(initial_value)
252 | hyperparams[key] = linear_schedule(initial_value)
253 | elif isinstance(hyperparams[key], (float, int)):
254 | # Negative value: ignore (ex: for clipping)
255 | if hyperparams[key] < 0:
256 | continue
257 | hyperparams[key] = constfn(float(hyperparams[key]))
258 | else:
259 | raise ValueError('Invalid value for {}: {}'.format(key, hyperparams[key]))
260 | normalize = False
261 | normalize_kwargs = {}
262 | if 'normalize' in hyperparams.keys():
263 | normalize = hyperparams['normalize']
264 | if isinstance(normalize, str):
265 | normalize_kwargs = eval(normalize)
266 | normalize = True
267 | del hyperparams['normalize']
268 |
269 | def create_env(env_params):
270 | global hyperparams
271 |
272 | if algo_ in ['dqn']:
273 | env = gym.make(env_id, env_params=env_params)
274 | env.seed(args.seed)
275 | if env_wrapper is not None:
276 | env = env_wrapper(env)
277 | else:
278 | env = DummyVecEnv([make_env(env_id, 0, args.seed, wrapper_class=env_wrapper, env_params=env_params)])
279 | if normalize:
280 | if args.verbose > 0:
281 | if len(normalize_kwargs) > 0:
282 | print("Normalization activated: {}".format(normalize_kwargs))
283 | else:
284 | print("Normalizing input and reward")
285 | env = VecNormalize(env, **normalize_kwargs)
286 | return env
287 |
288 | env = create_env(env_params)
289 |
290 | env = DummyVecEnv([lambda: env])
291 |
292 | print(hyperparams)
293 | model = ALGOS[args.algo](CustomCnnPolicy,
294 | env=env,
295 | tensorboard_log=tensorboard_log,
296 | verbose=args.verbose,
297 | batch_size=64,
298 | **hyperparams)
299 | print("Model loaded!")
300 | model.is_tb_set = False
301 |
302 | kwargs = {}
303 | if args.log_interval > -1:
304 | kwargs = {'log_interval': args.log_interval}
305 |
306 | model.learn(n_timesteps, callback=callback, **kwargs)
307 |
308 | model.save(OUTPUT_PATH + "final_model")
309 |
--------------------------------------------------------------------------------