├── .gitignore
├── LICENSE
├── README.md
├── cfgs
├── dmlp.yaml
└── dt.yaml
├── collect_antmaze_data.py
├── collect_pointmaze_data.py
├── model.py
├── requirements.txt
├── train_dmlp.py
├── train_dt.py
└── utils
├── __init__.py
├── data.py
├── env.py
└── misc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 |
3 | #saved binaries
4 | saved/
5 |
6 | #virtual env
7 | env_*
8 |
9 | #local experiments
10 | local_exp/
11 | wandb/
12 |
13 | #dataset
14 | data
15 |
16 | #evaluation
17 | eval.py
18 |
19 | #binaries
20 | *.pkl
21 | *.zip
22 |
23 | #results
24 | dt_runs/
25 | dqn_runs/
26 | checkpoints/
27 |
28 | #jupyter notebooks
29 | notebooks/
30 |
31 | #vscode
32 | .vscode
33 |
34 | #mujoco
35 | MUJOCO_LOG.TXT
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Raj Ghugare
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Closing the Gap between TD Learning and Supervised Learning -- A Generalisation Point of View](https://arxiv.org/abs/2401.11237)
2 | [Raj Ghugare](https://rajghugare19.github.io/), $\quad$ [Matthieu Geist](https://scholar.google.com/citations?user=ectPLEUAAAAJ), $\quad$ [Glen Berseth](https://neo-x.github.io/)\*, $\quad$ [Benjamin Eysenbach](https://ben-eysenbach.github.io/)\*
3 |
4 | \* Equal advising.
5 |
6 | ## Installation
7 |
8 | Create virtual environment named `env_stuff` using command:
9 | ```sh
10 | python3 -m venv env_stuff
11 | ```
12 |
13 | Install all the packages used to run the code using the `requirements.txt` file:
14 | ```sh
15 | pip install -r requirements.txt
16 | ```
17 |
18 | ## Training
19 |
20 | To train an RvS (decision-mlp) agent on pointmaze-umaze using temporal data augmentation, with $\epsilon=0.5$ and $K=40$:
21 | ```sh
22 | python train_dmlp.py dataset_name=pointmaze-umaze-v0 augment_data=True nclusters=40
23 | ```
24 |
25 | To train a DT (decision-transformer) agent on pointmaze-umaze using temporal data augmentation, with $\epsilon=0.5$ and $K=40$:
26 | ```sh
27 | python train_dt.py dataset_name=pointmaze-umaze-v0 augment_data=True nclusters=40
28 | ```
29 |
30 | ## Datasets
31 |
32 | To download the pretrained datasets, visit [this google drive link](https://drive.google.com/drive/folders/1j8Ok2UMYSqfIQReuE6csf1nMoI1s25K-?usp=sharing).
33 |
34 | To collect the pointmaze-large dataset with $1e^6$ transitions and seed 1:
35 | ```sh
36 | python collect_pointmaze_data.py pointmaze-large-v0 1 1000000
37 | ```
38 |
39 | To collect the antmaze-large dataset with $1e^6$ transitions and seed 1:
40 | ```sh
41 | python collect_antmaze_data.py antmaze-umaze-v0 1 1000000
42 | ```
43 |
44 | ## Acknowledgment
45 | Our codebase has been build using/on top of the following codes. We thank the respective authors for their awesome contributions.
46 | - [NanoGPT](https://github.com/karpathy/nanoGPT)
47 | - [min-decision-transformer](https://github.com/nikhilbarhate99/min-decision-transformer)
48 |
49 | ## Correspondence
50 |
51 | If you have any questions or suggestions, please reach out to me via raj.ghugare@mila.quebec.
52 |
--------------------------------------------------------------------------------
/cfgs/dmlp.yaml:
--------------------------------------------------------------------------------
1 | seed: 1
2 | device: cuda
3 | agent_name: dmlp
4 |
5 | #eval
6 | num_eval_ep: 10
7 | num_eval_len:
8 |
9 | # env
10 | env_name:
11 | dataset_name: pointmaze-umaze-v0
12 | remote_data: False
13 | num_workers: 4
14 | render: False
15 |
16 | #train
17 | num_updates_per_iter: 40000
18 | max_train_iters: 25 #25 x 40000 -> 1000000 transition batches
19 |
20 | augment_data: False
21 | augment_prob: 0
22 | nclusters:
23 | batch_size: 256
24 | lr: 1e-3
25 |
26 | #logging
27 | log_dir: dt_runs
28 | wandb_log: False
29 | wandb_entity: raj19
30 | wandb_run_name:
31 | wandb_group_name:
32 | wandb_dir:
33 |
34 | #saving
35 | save_snapshot: True
36 | save_snapshot_interval: 25
37 |
38 | #hydra
39 | hydra:
40 | run:
41 | dir: ${log_dir}/${dataset_name}/${now:%Y.%m.%d}_${now:%H.%M.%S}
42 | job:
43 | chdir: False
--------------------------------------------------------------------------------
/cfgs/dt.yaml:
--------------------------------------------------------------------------------
1 | seed: 1
2 | device: cuda
3 | agent_name: dt
4 |
5 | #eval
6 | num_eval_ep: 10
7 | num_eval_len:
8 |
9 | # env
10 | env_name:
11 | dataset_name: pointmaze-umaze-v0
12 | remote_data: False
13 | num_workers: 4
14 | render: False
15 |
16 | #train
17 | num_updates_per_iter: 8000
18 | max_train_iters: 25 #25 x 8000 x 5-> 1000000 transition batches
19 |
20 | augment_data: False
21 | augment_prob: 0
22 | nclusters:
23 | warmup_steps: 1000
24 | drop_p: 0.1
25 | batch_size: 256
26 | lr: 1e-3
27 | wt_decay: 1e-4
28 |
29 | #model
30 | context_len: 5
31 | n_blocks: 3
32 | embed_dim: 128
33 | n_heads: 1
34 |
35 | #logging
36 | log_dir: dt_runs
37 | wandb_log: False
38 | wandb_entity: raj19
39 | wandb_run_name:
40 | wandb_group_name:
41 | wandb_dir:
42 |
43 | #saving
44 | save_snapshot: True
45 | save_snapshot_interval: 25
46 |
47 | #hydra
48 | hydra:
49 | run:
50 | dir: ${log_dir}/${dataset_name}/${now:%Y.%m.%d}_${now:%H.%M.%S}
51 | job:
52 | chdir: False
--------------------------------------------------------------------------------
/collect_antmaze_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import pickle
5 | import numpy as np
6 | import gymnasium as gym
7 | from stable_baselines3 import SAC
8 | from utils import get_maze_map, AntmazeWrapper
9 |
10 | '''
11 | Code taken from https://github.com/rodrigodelazcano/d4rl-minari-dataset-generation/blob/main/scripts/antmaze/create_antmaze_dataset.py
12 | '''
13 |
14 | UP = 0
15 | DOWN = 1
16 | LEFT = 2
17 | RIGHT = 3
18 |
19 | EXPLORATION_ACTIONS = {UP: (0, 1), DOWN: (0, -1), LEFT: (-1, 0), RIGHT: (1, 0)}
20 |
21 | class QIteration:
22 | """Solves for optimal policy with Q-Value Iteration.
23 |
24 | Inspired by https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/pointmaze/q_iteration.py
25 | """
26 | def __init__(self, maze):
27 | self.maze = maze
28 | self.num_states = maze.map_length * maze.map_width
29 | self.num_actions = len(EXPLORATION_ACTIONS.keys())
30 | self.rew_matrix = np.zeros((self.num_states, self.num_actions))
31 | self.compute_transition_matrix()
32 |
33 | def generate_path(self, current_cell, goal_cell):
34 | self.compute_reward_matrix(goal_cell)
35 | q_values = self.get_q_values()
36 | current_state = self.cell_to_state(current_cell)
37 | waypoints = {}
38 | while True:
39 | action_id = np.argmax(q_values[current_state])
40 | next_state, _ = self.get_next_state(current_state, EXPLORATION_ACTIONS[action_id])
41 | current_cell = self.state_to_cell(current_state)
42 | waypoints[current_cell] = self.state_to_cell(next_state)
43 | if waypoints[current_cell] == goal_cell:
44 | break
45 |
46 | current_state = next_state
47 |
48 | return waypoints
49 |
50 | def reward_function(self, desired_cell, current_cell):
51 | if desired_cell == current_cell:
52 | return 1.0
53 | else:
54 | return 0.0
55 |
56 | def state_to_cell(self, state):
57 | i = int(state/self.maze.map_width)
58 | j = state % self.maze.map_width
59 | return (i, j)
60 |
61 | def cell_to_state(self, cell):
62 | return cell[0] * self.maze.map_width + cell[1]
63 |
64 | def get_q_values(self, num_itrs=50, discount=0.99):
65 | q_fn = np.zeros((self.num_states, self.num_actions))
66 | for _ in range(num_itrs):
67 | v_fn = np.max(q_fn, axis=1)
68 | q_fn = self.rew_matrix + discount*self.transition_matrix.dot(v_fn)
69 | return q_fn
70 |
71 | def compute_reward_matrix(self, goal_cell):
72 | for state in range(self.num_states):
73 | for action in range(self.num_actions):
74 | next_state, _ = self.get_next_state(state, EXPLORATION_ACTIONS[action])
75 | next_cell = self.state_to_cell(next_state)
76 | self.rew_matrix[state, action] = self.reward_function(goal_cell, next_cell)
77 |
78 | def compute_transition_matrix(self):
79 | """Constructs this environment's transition matrix.
80 | Returns:
81 | A dS x dA x dS array where the entry transition_matrix[s, a, ns]
82 | corresponds to the probability of transitioning into state ns after taking
83 | action a from state s.
84 | """
85 | self.transition_matrix = np.zeros((self.num_states, self.num_actions, self.num_states))
86 | for state in range(self.num_states):
87 | for action_idx, action in EXPLORATION_ACTIONS.items():
88 | next_state, valid = self.get_next_state(state, action)
89 | if valid:
90 | self.transition_matrix[state, action_idx, next_state] = 1
91 |
92 | def get_next_state(self, state, action):
93 | cell = self.state_to_cell(state)
94 |
95 | next_cell = tuple(map(lambda i, j: int(i + j), cell, action))
96 | next_state = self.cell_to_state(next_cell)
97 |
98 | return next_state, self._check_valid_cell(next_cell)
99 |
100 | def _check_valid_cell(self, cell):
101 | # Out of map bounds
102 | if cell[0] >= self.maze.map_length:
103 | return False
104 | elif cell[1] >= self.maze.map_width:
105 | return False
106 | # Wall collision
107 | elif self.maze.maze_map[cell[0]][cell[1]] == 1:
108 | return False
109 | else:
110 | return True
111 |
112 | class WaypointController:
113 | """Generic agent controller to follow waypoints in the maze.
114 |
115 | Inspired by https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/pointmaze/waypoint_controller.py
116 | """
117 |
118 | def __init__(
119 | self, maze, model_callback, waypoint_threshold=0.45
120 | ):
121 | self.global_target_xy = np.empty(2)
122 | self.maze = maze
123 | self.maze_solver = QIteration(maze=self.maze)
124 |
125 | self.model_callback = model_callback
126 | self.waypoint_threshold = waypoint_threshold
127 | self.waypoint_targets = None
128 |
129 | def compute_action(self, obs):
130 | # Check if we need to generate new waypoint path due to change in global target
131 | if (
132 | np.linalg.norm(self.global_target_xy - obs["desired_goal"]) > 1e-3
133 | or self.waypoint_targets is None
134 | ):
135 | # Convert xy to cell id
136 | achived_goal_cell = tuple(self.maze.cell_xy_to_rowcol(obs["achieved_goal"]))
137 | self.global_target_id = tuple(
138 | self.maze.cell_xy_to_rowcol(obs["desired_goal"])
139 | )
140 |
141 | self.global_target_xy = obs["desired_goal"]
142 |
143 | self.waypoint_targets = self.maze_solver.generate_path(
144 | achived_goal_cell, self.global_target_id
145 | )
146 | # Check if the waypoint dictionary is empty
147 | # If empty then the ball is already in the target cell location
148 | if self.waypoint_targets:
149 | self.current_control_target_id = self.waypoint_targets[
150 | achived_goal_cell
151 | ]
152 | # If target is global goal go directly to goal position
153 | if self.current_control_target_id == self.global_target_id:
154 | self.current_control_target_xy = obs['desired_goal']
155 | else:
156 | self.current_control_target_xy = self.maze.cell_rowcol_to_xy(
157 | np.array(self.current_control_target_id)
158 | ) - np.random.uniform(size=(2,)) * 0.1
159 | else:
160 | self.waypoint_targets[
161 | self.current_control_target_id
162 | ] = self.current_control_target_id
163 | self.current_control_target_id = self.global_target_id
164 | self.current_control_target_xy = self.global_target_xy
165 |
166 | # Check if we need to go to the next waypoint
167 | dist = np.linalg.norm(self.current_control_target_xy - obs["achieved_goal"])
168 | if (
169 | dist <= self.waypoint_threshold
170 | and self.current_control_target_id != self.global_target_id
171 | ):
172 | self.current_control_target_id = self.waypoint_targets[
173 | self.current_control_target_id
174 | ]
175 | # If target is global goal go directly to goal position
176 | if self.current_control_target_id == self.global_target_id:
177 |
178 | self.current_control_target_xy = obs['desired_goal']
179 | else:
180 | self.current_control_target_xy = (
181 | self.maze.cell_rowcol_to_xy(
182 | np.array(self.current_control_target_id)
183 | )
184 | - np.random.uniform(size=(2,)) * 0.1
185 | )
186 | action = self.model_callback(obs, self.current_control_target_xy)
187 | return action
188 |
189 | def wrap_maze_obs(obs, waypoint_xy):
190 | """Wrap the maze obs into one suitable for GoalReachAnt."""
191 | goal_direction = waypoint_xy - obs["achieved_goal"]
192 | observation = np.concatenate([obs["observation"][2:], goal_direction])
193 | return observation
194 |
195 | def get_start_state_goal_pairs(dataset_name):
196 |
197 | if dataset_name == "antmaze-large-sl":
198 | train_start_state_goal = [
199 | {"goal_cells": np.array([ [5,4], [1,10], [7,10] ], dtype=np.int32),
200 | "reset_cells": np.array([ [7,1] ], dtype=np.int32),
201 | },
202 | ]
203 |
204 | elif dataset_name == "antmaze-large-v0":
205 | train_start_state_goal = [
206 | {"goal_cells": np.array([ [1,1], [1,2], [1,3], [1,4], [2,1], [2,4], [3,1], [3,2], [3,3], [3,4], [4,1], [5,1], [5,2], [6,2], [7,1], [7,2] ], dtype=np.int32),
207 | "reset_cells": np.array([ [1,1], [1,2], [1,3], [1,4], [2,1], [2,4], [3,1], [3,2], [3,3], [3,4], [4,1], [5,1], [5,2], [6,2], [7,1], [7,2] ], dtype=np.int32),
208 | },
209 |
210 | {"goal_cells": np.array([ [3,2], [3,3], [3,4], [3,5], [1,6], [2,6], [3,6], [4,6], [5,6], [6,6], [7,6], [7,5], [7,4], [6,4], [5,4], [1,6], [2,6], [1,7], [1,8], [1,9], [1,10], [2,10], [3,10], [3,9], [3,8], [2,8], [4,10], [5,10], [5,9], [5,8], [5,7], [4,6], [5,6], [6,8], [7,8], [7,9], [7,10] ], dtype=np.int32),
211 | "reset_cells": np.array([ [3,2], [3,3], [3,4], [3,5], [1,6], [2,6], [3,6], [4,6], [5,6], [6,6], [7,6], [7,5], [7,4], [6,4], [5,4], [1,7], [1,8], [1,9], [1,10], [2,10], [3,10], [3,9], [3,8], [2,8], [4,10], [5,10], [5,9], [5,8], [5,7], [6,8], [7,8], [7,9], [7,10] ], dtype=np.int32),
212 | },
213 | ]
214 |
215 | elif dataset_name == 'antmaze-medium-v0':
216 | train_start_state_goal = [
217 | {'goal_cells': np.array([4,6], dtype=np.int32),
218 | 'reset_cells': np.array([6,5], dtype=np.int32),
219 | 'name' : 'bottom_rightish_to_center',
220 | },
221 |
222 | {'goal_cells': np.array([6,5], dtype=np.int32),
223 | 'reset_cells': np.array([4,6], dtype=np.int32),
224 | 'name' : 'center_to_bottom_rightish',
225 | },
226 |
227 | {'goal_cells': np.array([4,4], dtype=np.int32),
228 | 'reset_cells': np.array([6,1], dtype=np.int32),
229 | 'name' : 'bottom_left_to_center',
230 | },
231 |
232 | {'goal_cells': np.array([6,1], dtype=np.int32),
233 | 'reset_cells': np.array([4,4], dtype=np.int32),
234 | 'name' : 'center_to_bottom_left',
235 | },
236 |
237 | {'goal_cells': np.array([4,2], dtype=np.int32),
238 | 'reset_cells': np.array([6,1], dtype=np.int32),
239 | 'name' : 'bottom_left_to_center_leftish',
240 | },
241 |
242 | {'goal_cells': np.array([6,1], dtype=np.int32),
243 | 'reset_cells': np.array([4,2], dtype=np.int32),
244 | 'name' : 'center_leftish_to_bottom_left',
245 | },
246 |
247 | {'goal_cells': np.array([4,6], dtype=np.int32),
248 | 'reset_cells': np.array([1,6], dtype=np.int32),
249 | 'name' : 'top_right_to_center',
250 | },
251 |
252 | {'goal_cells': np.array([1,6], dtype=np.int32),
253 | 'reset_cells': np.array([4,6], dtype=np.int32),
254 | 'name' : 'center_to_top_right',
255 | },
256 | ]
257 |
258 | elif dataset_name == "antmaze-umaze-v0":
259 | train_start_state_goal = [
260 | {"goal_cells": np.array([ [1,1], [1,2], [1,3] ], dtype=np.int32),
261 | "reset_cells": np.array([ [1,1], [1,2], [1,3] ], dtype=np.int32),
262 | },
263 |
264 | {"goal_cells": np.array([ [3,1], [3,2], [3,3] ], dtype=np.int32),
265 | "reset_cells": np.array([ [3,1], [3,2], [3,3] ], dtype=np.int32),
266 | },
267 |
268 | {"goal_cells": np.array([ [1,3], [2,3] ], dtype=np.int32),
269 | "reset_cells": np.array([ [1,3], [2,3] ], dtype=np.int32),
270 | },
271 |
272 | {"goal_cells": np.array([ [3,3], [2,3] ], dtype=np.int32),
273 | "reset_cells": np.array([ [3,3], [2,3] ], dtype=np.int32),
274 | },
275 | ]
276 |
277 | else:
278 | raise NotImplementedError
279 |
280 | return train_start_state_goal
281 |
282 | def load_policy(policy_file):
283 | data = torch.load(policy_file)
284 | policy = data['exploration/policy'].to('cpu')
285 | env = data['evaluation/env']
286 | print("Policy loaded")
287 | return policy, env
288 |
289 | def collect_dataset(dataset_name, num_data):
290 | if os.path.isfile("data/"+dataset_name+".pkl"):
291 | print("A dataset with the same name already exists. Doing nothing. Delete the file if you want any changes.")
292 | exit()
293 |
294 | if "antmaze-umaze" in dataset_name:
295 | env_name = "AntMaze_UMaze-v4"
296 | elif "antmaze-medium" in dataset_name:
297 | env_name = "AntMaze_Medium-v4"
298 | elif "antmaze-large" in dataset_name:
299 | env_name = "AntMaze_Large-v4"
300 | else:
301 | raise NotImplementedError
302 |
303 | # environment initialisation
304 | env = AntmazeWrapper(gym.make(env_name, continuing_task=False))#, render_mode="human"))
305 |
306 | # data placeholders
307 | observation_data = {
308 | "episode_id": np.zeros(shape=(int(num_data),), dtype=np.int32),
309 | "observation" : np.zeros(shape=(int(num_data), *env.observation_space["observation"].shape), dtype=np.float32),
310 | "achieved_goal" : np.zeros(shape=(int(num_data), *env.observation_space["achieved_goal"].shape), dtype=np.float32),
311 | }
312 | action_data = np.zeros(shape=(int(num_data), *env.action_space.shape), dtype=np.float32)
313 | termination_data = np.zeros(shape = (int(num_data),), dtype=bool)
314 |
315 | # get data collecting start state and goal pairs
316 | train_start_state_goal = get_start_state_goal_pairs(dataset_name)
317 |
318 | # controller
319 | model = SAC.load('GoalReachAnt_model.zip')
320 | def action_callback(obs, waypoint_xy):
321 | return model.predict(wrap_maze_obs(obs, waypoint_xy))[0]
322 | waypoint_controller = WaypointController(env.maze, action_callback)
323 |
324 | terminated = False
325 | truncated = False
326 | data_idx = 0
327 | episode_idx = 0
328 | success_count = 0
329 |
330 | if train_start_state_goal is not None:
331 | num_dc = len(train_start_state_goal)
332 | sg_dict = train_start_state_goal[ episode_idx % num_dc ]
333 | if len(sg_dict["goal_cells"].shape) == 2:
334 | x,_ = sg_dict["goal_cells"].shape
335 | id = np.random.choice(x, size=1)[0]
336 | sg_dict["goal_cell"] = sg_dict["goal_cells"][id]
337 | else:
338 | sg_dict["goal_cell"] = sg_dict["goal_cells"]
339 |
340 | if len(sg_dict["reset_cells"].shape) == 2:
341 | x,_ = sg_dict["reset_cells"].shape
342 | id = np.random.choice(x, size=1)[0]
343 | sg_dict["reset_cell"] = sg_dict["reset_cells"][id]
344 | else:
345 | sg_dict["reset_cell"] = sg_dict["reset_cells"]
346 | obs, _ = env.reset(options=sg_dict)
347 | else:
348 | obs, _ = env.reset()
349 |
350 | episode_start_idx = 0
351 | while data_idx < int(num_data)-1:
352 |
353 | action = waypoint_controller.compute_action(obs)
354 | action = np.clip(action, env.action_space.low, env.action_space.high, dtype=np.float32)
355 |
356 | observation_data["episode_id"][data_idx] = episode_idx
357 | observation_data["observation"][data_idx] = obs["observation"]
358 | observation_data["achieved_goal"][data_idx] = obs["achieved_goal"]
359 |
360 | action_data[data_idx] = action
361 | termination_data[data_idx] = terminated or truncated
362 | data_idx += 1
363 |
364 | obs, _, terminated, truncated, info = env.step(action)
365 |
366 | if terminated or truncated:
367 | if info['success']:
368 | success_count += 1
369 |
370 | observation_data["episode_id"][data_idx] = episode_idx
371 | observation_data["observation"][data_idx] = obs["observation"]
372 | observation_data["achieved_goal"][data_idx] = obs["achieved_goal"]
373 | termination_data[data_idx] = terminated or truncated
374 |
375 | data_idx += 1
376 | episode_idx += 1
377 | episode_start_idx = data_idx
378 |
379 | else:
380 | data_idx = episode_start_idx
381 |
382 | if train_start_state_goal is not None:
383 | sg_dict = train_start_state_goal[ episode_idx % num_dc ]
384 | if len(sg_dict["goal_cells"].shape) == 2:
385 | x,_ = sg_dict["goal_cells"].shape
386 | id = np.random.choice(x, size=1)[0]
387 | sg_dict["goal_cell"] = sg_dict["goal_cells"][id]
388 | else:
389 | sg_dict["goal_cell"] = sg_dict["goal_cells"]
390 |
391 | if len(sg_dict["reset_cells"].shape) == 2:
392 | x,_ = sg_dict["reset_cells"].shape
393 | id = np.random.choice(x, size=1)[0]
394 | sg_dict["reset_cell"] = sg_dict["reset_cells"][id]
395 | else:
396 | sg_dict["reset_cell"] = sg_dict["reset_cells"]
397 | obs, _ = env.reset(options=sg_dict)
398 | else:
399 | obs, _ = env.reset()
400 |
401 | terminated = False
402 | truncated = False
403 |
404 | if (data_idx + 1) % (num_data // 20) == 0:
405 | print("STEPS RECORDED:")
406 | print(data_idx)
407 |
408 | print("EPISODES RECORDED:")
409 | print(episode_idx)
410 |
411 | print("SUCCESS EPISODES RECORDED:")
412 | print(success_count)
413 |
414 |
415 | dataset = {"observations": observation_data,
416 | "actions": action_data,
417 | "terminations": termination_data,
418 | "success_count": success_count,
419 | "episode_count": episode_idx,
420 | }
421 |
422 | with open("data/"+dataset_name+".pkl", "wb") as fp:
423 | pickle.dump(dataset, fp)
424 |
425 | if __name__ == "__main__":
426 | import sys
427 |
428 | dataset_name = sys.argv[1]
429 |
430 | seed = sys.argv[2]
431 | num_data = int(1e6)
432 |
433 | np.random.seed(int(seed))
434 | collect_dataset(dataset_name, num_data)
--------------------------------------------------------------------------------
/collect_pointmaze_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import gymnasium as gym
5 | from utils import get_maze_map
6 |
7 | UP = 0
8 | DOWN = 1
9 | LEFT = 2
10 | RIGHT = 3
11 |
12 | EXPLORATION_ACTIONS = {UP: (0, 1), DOWN: (0, -1), LEFT: (-1, 0), RIGHT: (1, 0)}
13 |
14 | class QIteration:
15 | """Solves for optimal policy with Q-Value Iteration.
16 |
17 | Inspired by https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/pointmaze/q_iteration.py
18 | """
19 | def __init__(self, maze):
20 | self.maze = maze
21 | self.num_states = maze.map_length * maze.map_width
22 | self.num_actions = len(EXPLORATION_ACTIONS.keys())
23 | self.rew_matrix = np.zeros((self.num_states, self.num_actions))
24 | self.compute_transition_matrix()
25 |
26 | def generate_path(self, current_cell, goal_cell):
27 | self.compute_reward_matrix(goal_cell)
28 | q_values = self.get_q_values()
29 | current_state = self.cell_to_state(current_cell)
30 | waypoints = {}
31 | while True:
32 | action_id = np.argmax(q_values[current_state])
33 | next_state, _ = self.get_next_state(current_state, EXPLORATION_ACTIONS[action_id])
34 | current_cell = self.state_to_cell(current_state)
35 | waypoints[current_cell] = self.state_to_cell(next_state)
36 | if waypoints[current_cell] == goal_cell:
37 | break
38 |
39 | current_state = next_state
40 |
41 | return waypoints
42 |
43 | def reward_function(self, desired_cell, current_cell):
44 | if desired_cell == current_cell:
45 | return 1.0
46 | else:
47 | return 0.0
48 |
49 | def state_to_cell(self, state):
50 | i = int(state/self.maze.map_width)
51 | j = state % self.maze.map_width
52 | return (i, j)
53 |
54 | def cell_to_state(self, cell):
55 | return cell[0] * self.maze.map_width + cell[1]
56 |
57 | def get_q_values(self, num_itrs=50, discount=0.99):
58 | q_fn = np.zeros((self.num_states, self.num_actions))
59 | for _ in range(num_itrs):
60 | v_fn = np.max(q_fn, axis=1)
61 | q_fn = self.rew_matrix + discount*self.transition_matrix.dot(v_fn)
62 | return q_fn
63 |
64 | def compute_reward_matrix(self, goal_cell):
65 | for state in range(self.num_states):
66 | for action in range(self.num_actions):
67 | next_state, _ = self.get_next_state(state, EXPLORATION_ACTIONS[action])
68 | next_cell = self.state_to_cell(next_state)
69 | self.rew_matrix[state, action] = self.reward_function(goal_cell, next_cell)
70 |
71 | def compute_transition_matrix(self):
72 | """Constructs this environment's transition matrix.
73 | Returns:
74 | A dS x dA x dS array where the entry transition_matrix[s, a, ns]
75 | corresponds to the probability of transitioning into state ns after taking
76 | action a from state s.
77 | """
78 | self.transition_matrix = np.zeros((self.num_states, self.num_actions, self.num_states))
79 | for state in range(self.num_states):
80 | for action_idx, action in EXPLORATION_ACTIONS.items():
81 | next_state, valid = self.get_next_state(state, action)
82 | if valid:
83 | self.transition_matrix[state, action_idx, next_state] = 1
84 |
85 | def get_next_state(self, state, action):
86 | cell = self.state_to_cell(state)
87 |
88 | next_cell = tuple(map(lambda i, j: int(i + j), cell, action))
89 | next_state = self.cell_to_state(next_cell)
90 |
91 | return next_state, self._check_valid_cell(next_cell)
92 |
93 | def _check_valid_cell(self, cell):
94 | # Out of map bounds
95 | if cell[0] >= self.maze.map_length:
96 | return False
97 | elif cell[1] >= self.maze.map_width:
98 | return False
99 | # Wall collision
100 | elif self.maze.maze_map[cell[0]][cell[1]] == 1:
101 | return False
102 | else:
103 | return True
104 |
105 | class WaypointController:
106 | """Agent controller to follow waypoints in the maze.
107 |
108 | Inspired by https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/pointmaze/waypoint_controller.py
109 | """
110 | def __init__(self, maze, gains={"p": 10.0, "d": -1.0}, waypoint_threshold=0.1):
111 | self.global_target_xy = np.empty(2)
112 | self.maze = maze
113 |
114 | self.maze_solver = QIteration(maze=self.maze)
115 |
116 | self.gains = gains
117 | self.waypoint_threshold = waypoint_threshold
118 | self.waypoint_targets = None
119 |
120 | def compute_action(self, obs):
121 | # Check if we need to generate new waypoint path due to change in global target
122 | if np.linalg.norm(self.global_target_xy - obs["desired_goal"]) > 1e-3 or self.waypoint_targets is None:
123 | # Convert xy to cell id
124 | achieved_goal_cell = tuple(self.maze.cell_xy_to_rowcol(obs["achieved_goal"]))
125 | self.global_target_id = tuple(self.maze.cell_xy_to_rowcol(obs["desired_goal"]))
126 | self.global_target_xy = obs["desired_goal"]
127 |
128 | self.waypoint_targets = self.maze_solver.generate_path(achieved_goal_cell, self.global_target_id)
129 |
130 | # Check if the waypoint dictionary is empty
131 | # If empty then the ball is already in the target cell location
132 | if self.waypoint_targets:
133 | self.current_control_target_id = self.waypoint_targets[achieved_goal_cell]
134 | self.current_control_target_xy = self.maze.cell_rowcol_to_xy(np.array(self.current_control_target_id))
135 | else:
136 | self.waypoint_targets[self.current_control_target_id] = self.current_control_target_id
137 | self.current_control_target_id = self.global_target_id
138 | self.current_control_target_xy = self.global_target_xy
139 |
140 | # Check if we need to go to the next waypoint
141 | dist = np.linalg.norm(self.current_control_target_xy - obs["achieved_goal"])
142 | if dist <= self.waypoint_threshold and self.current_control_target_id != self.global_target_id:
143 | self.current_control_target_id = self.waypoint_targets[self.current_control_target_id]
144 | # If target is global goal go directly to goal position
145 | if self.current_control_target_id == self.global_target_id:
146 | self.current_control_target_xy = self.global_target_xy
147 | else:
148 | self.current_control_target_xy = self.maze.cell_rowcol_to_xy(np.array(self.current_control_target_id)) - np.random.uniform(size=(2,))*0.2
149 |
150 | action = self.gains["p"] * (self.current_control_target_xy - obs["achieved_goal"]) + self.gains["d"] * obs["observation"][2:]
151 | action = np.clip(action, -1, 1)
152 |
153 | return action
154 |
155 | def get_start_state_goal_pairs(dataset_name):
156 | if dataset_name == "pointmaze-large-sl":
157 | train_start_state_goal = [
158 | {"goal_cells": np.array([ [5,4], [1,10], [7,10] ], dtype=np.int32),
159 | "reset_cells": np.array([ [7,1] ], dtype=np.int32),
160 | },
161 | ]
162 |
163 | elif dataset_name == "pointmaze-large-v0":
164 | train_start_state_goal = [
165 | {"goal_cells": np.array([ [1,1], [1,2], [1,3], [1,4], [2,1], [2,4], [3,1], [3,2], [3,3], [3,4], [4,1], [5,1], [5,2], [6,2], [7,1], [7,2] ], dtype=np.int32),
166 | "reset_cells": np.array([ [1,1], [1,2], [1,3], [1,4], [2,1], [2,4], [3,1], [3,2], [3,3], [3,4], [4,1], [5,1], [5,2], [6,2], [7,1], [7,2] ], dtype=np.int32),
167 | },
168 |
169 | {"goal_cells": np.array([ [3,4], [3,5], [1,6], [2,6], [3,6], [4,6], [5,6], [6,6], [7,6], [7,5], [7,4], [6,4], [5,4] ], dtype=np.int32),
170 | "reset_cells": np.array([ [3,4], [3,5], [1,6], [2,6], [3,6], [4,6], [5,6], [6,6], [7,6], [7,5], [7,4], [6,4], [5,4] ], dtype=np.int32),
171 | },
172 |
173 | {"goal_cells": np.array([ [1,6], [1,7], [1,8], [1,9], [1,10], [2,10], [3,10], [3,9], [3,8], [2,8], [4,10], [5,10], [5,9], [5,8], [5,7], [5,6], [6,8], [7,8], [7,9], [7,10] ], dtype=np.int32),
174 | "reset_cells": np.array([ [1,7], [1,8], [1,9], [1,10], [2,10], [3,10], [3,9], [3,8], [2,8], [4,10], [5,10], [5,9], [5,8], [5,7], [6,8], [7,8], [7,9], [7,10] ], dtype=np.int32),
175 | },
176 | ]
177 |
178 | elif dataset_name == "pointmaze-medium-v0":
179 | train_start_state_goal = [
180 | {"goal_cells": np.array([ [6,5], [6,6], [5,6], [4,6] ], dtype=np.int32),
181 | "reset_cells": np.array([ [6,5], [6,6], [5,6], [4,6] ], dtype=np.int32),
182 | },
183 |
184 | {"goal_cells": np.array([ [4,1], [4,2], [5,1], [6,1], [6,2], [6,3], [5,3] ], dtype=np.int32),
185 | "reset_cells": np.array([ [4,1], [4,2], [5,1], [6,1], [6,2], [6,3], [5,3] ], dtype=np.int32),
186 | },
187 |
188 | {"goal_cells": np.array([ [1,5], [1,6], [2,4], [2,5], [2,6] ], dtype=np.int32),
189 | "reset_cells": np.array([ [1,5], [1,6], [2,4], [2,5], [2,6] ], dtype=np.int32),
190 | },
191 |
192 | {"goal_cells": np.array([ [1,1], [1,2], [2,1], [2,2] ], dtype=np.int32),
193 | "reset_cells": np.array([ [1,1], [1,2], [2,1], [2,2] ], dtype=np.int32),
194 | },
195 |
196 | {"goal_cells": np.array([ [5,3], [5,4], [4,2], [4,4], [4,5], [4,6], [3,2], [3,3], [3,4], [2,2], [2,4] ], dtype=np.int32),
197 | "reset_cells": np.array([ [5,3], [5,4], [4,2], [4,4], [4,5], [4,6], [3,2], [3,3], [3,4], [2,2], [2,4] ], dtype=np.int32),
198 | },
199 | ]
200 |
201 | elif dataset_name == "pointmaze-umaze-v0":
202 | train_start_state_goal = [
203 | {"goal_cells": np.array([ [1,1], [1,2], [1,3] ], dtype=np.int32),
204 | "reset_cells": np.array([ [1,1], [1,2], [1,3] ], dtype=np.int32),
205 | },
206 |
207 | {"goal_cells": np.array([ [3,1], [3,2], [3,3] ], dtype=np.int32),
208 | "reset_cells": np.array([ [3,1], [3,2], [3,3] ], dtype=np.int32),
209 | },
210 |
211 | {"goal_cells": np.array([ [1,3], [2,3] ], dtype=np.int32),
212 | "reset_cells": np.array([ [1,3], [2,3] ], dtype=np.int32),
213 | },
214 |
215 | {"goal_cells": np.array([ [3,3], [2,3] ], dtype=np.int32),
216 | "reset_cells": np.array([ [3,3], [2,3] ], dtype=np.int32),
217 | },
218 | ]
219 |
220 | else:
221 | raise NotImplementedError
222 |
223 | return train_start_state_goal
224 |
225 | def collect_dataset(dataset_name, num_data):
226 | if os.path.isfile("data/"+dataset_name+'-'+str(num_data)+".pkl"):
227 | print("A dataset with the same name already exists. Doing nothing. Delete the file if you want any changes.")
228 | exit()
229 |
230 | if "pointmaze-umaze" in dataset_name:
231 | env_name = "PointMaze_UMaze-v3"
232 | elif "pointmaze-medium" in dataset_name:
233 | env_name = "PointMaze_Medium-v3"
234 | elif "pointmaze-large" in dataset_name:
235 | env_name = "PointMaze_Large-v3"
236 | else:
237 | raise NotImplementedError
238 |
239 | # environment initialisation
240 | env = gym.make(env_name, continuing_task=False, max_episode_steps=10 * num_data)#, render_mode="human")
241 |
242 | # data placeholders
243 | observation_data = {
244 | "episode_id": np.zeros(shape=(int(num_data),), dtype=np.int32),
245 | "observation" : np.zeros(shape=(int(num_data), *env.observation_space["observation"].shape), dtype=np.float32),
246 | "achieved_goal" : np.zeros(shape=(int(num_data), *env.observation_space["achieved_goal"].shape), dtype=np.float32),
247 | }
248 | action_data = np.zeros(shape=(int(num_data), *env.action_space.shape), dtype=np.float32)
249 | termination_data = np.zeros(shape = (int(num_data),), dtype=bool)
250 |
251 | # controller initialisation
252 | waypoint_controller = WaypointController(maze=env.maze)
253 |
254 | # get data collecting start state and goal pairs
255 | train_start_state_goal = get_start_state_goal_pairs(dataset_name)
256 |
257 | terminated = False
258 | truncated = False
259 | data_idx = 0
260 | episode_idx = 0
261 | success_count = 0
262 |
263 | num_dc = len(train_start_state_goal)
264 |
265 | sg_dict = train_start_state_goal[ episode_idx % num_dc ]
266 | if len(sg_dict["goal_cells"].shape) == 2:
267 | x,_ = sg_dict["goal_cells"].shape
268 | id = np.random.choice(x, size=1)[0]
269 | sg_dict["goal_cell"] = sg_dict["goal_cells"][id]
270 | else:
271 | sg_dict["goal_cell"] = sg_dict["goal_cells"]
272 |
273 | if len(sg_dict["reset_cells"].shape) == 2:
274 | x,_ = sg_dict["reset_cells"].shape
275 | id = np.random.choice(x, size=1)[0]
276 | sg_dict["reset_cell"] = sg_dict["reset_cells"][id]
277 | else:
278 | sg_dict["reset_cell"] = sg_dict["reset_cells"]
279 | obs, _ = env.reset(options=sg_dict)
280 |
281 | while data_idx < int(num_data)-1:
282 | action = waypoint_controller.compute_action(obs)
283 | # Add some noise to each step action
284 | action += np.random.randn(*action.shape)*0.5
285 | action = np.clip(action, env.action_space.low, env.action_space.high, dtype=np.float32)
286 |
287 | observation_data["episode_id"][data_idx] = episode_idx
288 | observation_data["observation"][data_idx] = obs["observation"]
289 | observation_data["achieved_goal"][data_idx] = obs["achieved_goal"]
290 | action_data[data_idx] = action
291 | termination_data[data_idx] = terminated or truncated
292 | data_idx += 1
293 |
294 | obs, _, terminated, truncated, info = env.step(action)
295 |
296 | if terminated or truncated:
297 | if info['success']:
298 | success_count += 1
299 |
300 | observation_data["episode_id"][data_idx] = episode_idx
301 | observation_data["observation"][data_idx] = obs["observation"]
302 | observation_data["achieved_goal"][data_idx] = obs["achieved_goal"]
303 | termination_data[data_idx] = terminated or truncated
304 |
305 | data_idx += 1
306 | episode_idx += 1
307 | episode_start_idx = data_idx
308 |
309 | else:
310 | data_idx = episode_start_idx
311 |
312 | sg_dict = train_start_state_goal[ episode_idx % num_dc ]
313 | if len(sg_dict["goal_cells"].shape) == 2:
314 | x,_ = sg_dict["goal_cells"].shape
315 | id = np.random.choice(x, size=1)[0]
316 | sg_dict["goal_cell"] = sg_dict["goal_cells"][id]
317 | else:
318 | sg_dict["goal_cell"] = sg_dict["goal_cells"]
319 |
320 | if len(sg_dict["reset_cells"].shape) == 2:
321 | x,_ = sg_dict["reset_cells"].shape
322 | id = np.random.choice(x, size=1)[0]
323 | sg_dict["reset_cell"] = sg_dict["reset_cells"][id]
324 | else:
325 | sg_dict["reset_cell"] = sg_dict["reset_cells"]
326 |
327 | obs, _ = env.reset(options=sg_dict)
328 |
329 | terminated = False
330 | truncated = False
331 |
332 | if (data_idx + 1) % (num_data // 5) == 0:
333 | print("STEPS RECORDED:")
334 | print(data_idx)
335 |
336 | print("EPISODES RECORDED:")
337 | print(episode_idx)
338 |
339 | print("SUCCESS EPISODES RECORDED:")
340 | print(success_count)
341 |
342 | dataset = {"observations":observation_data,
343 | "actions":action_data,
344 | "terminations":termination_data,
345 | "success_count": success_count,
346 | "episode_count": episode_idx,
347 | }
348 |
349 | with open("data/"+dataset_name+'-'+str(num_data)+".pkl", "wb") as fp:
350 | pickle.dump(dataset, fp)
351 |
352 | if __name__ == "__main__":
353 | import sys
354 |
355 | dataset_name = sys.argv[1]
356 |
357 | seed = sys.argv[2]
358 | num_data = int(sys.argv[3])
359 |
360 | np.random.seed(int(seed))
361 | collect_dataset(dataset_name, num_data)
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.distributions as td
6 |
7 | def count_parameters(model):
8 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
9 |
10 | class DecisionMLP(nn.Module):
11 | def __init__(self, env_name, env, goal_dim=2, h_dim=1024):
12 | super().__init__()
13 |
14 | env_name = env_name
15 | state_dim = env.observation_space['observation'].shape[0]
16 | act_dim = env.action_space.shape[0]
17 |
18 | self.mlp = nn.Sequential(
19 | nn.Linear(state_dim + goal_dim, h_dim),
20 | nn.ReLU(),
21 | nn.Linear(h_dim, h_dim),
22 | nn.ReLU(),
23 | nn.Linear(h_dim, act_dim),
24 | nn.Tanh()
25 | )
26 |
27 | def forward(self, states, goals):
28 | h = torch.cat((states, goals), dim=-1)
29 | action_preds = self.mlp(h)
30 | return action_preds
31 |
32 | class MaskedCausalAttention(nn.Module):
33 | '''
34 | Thanks https://github.com/nikhilbarhate99/min-decision-transformer/tree/master
35 | '''
36 | def __init__(self, h_dim, max_T, n_heads, drop_p):
37 | super().__init__()
38 |
39 | self.n_heads = n_heads
40 | self.max_T = max_T
41 |
42 | self.q_net = nn.Linear(h_dim, h_dim)
43 | self.k_net = nn.Linear(h_dim, h_dim)
44 | self.v_net = nn.Linear(h_dim, h_dim)
45 |
46 | self.proj_net = nn.Linear(h_dim, h_dim)
47 |
48 | self.dropout = drop_p
49 | self.att_drop = nn.Dropout(drop_p)
50 | self.proj_drop = nn.Dropout(drop_p)
51 |
52 | def forward(self, x):
53 | B,T,C = x.shape # batch size, seq length, h_dim * n_heads
54 |
55 | N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim
56 |
57 | # rearrange q, k, v as (B, N, T, D)
58 | q = self.q_net(x).view(B, T, N, D).transpose(1,2)
59 | k = self.k_net(x).view(B, T, N, D).transpose(1,2)
60 | v = self.v_net(x).view(B, T, N, D).transpose(1,2)
61 |
62 | attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
63 | attention = attention.transpose(1, 2).contiguous().view(B,T,N*D)
64 |
65 | out = self.proj_drop(self.proj_net(attention))
66 | return out
67 |
68 | class Block(nn.Module):
69 | def __init__(self, h_dim, max_T, n_heads, drop_p):
70 | super().__init__()
71 | self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
72 | self.mlp = nn.Sequential(
73 | nn.Linear(h_dim, 4*h_dim),
74 | nn.GELU(),
75 | nn.Linear(4*h_dim, h_dim),
76 | nn.Dropout(drop_p),
77 | )
78 | self.ln1 = nn.LayerNorm(h_dim)
79 | self.ln2 = nn.LayerNorm(h_dim)
80 |
81 | def forward(self, x):
82 | x = x + self.attention(self.ln1(x)) # residual
83 | x = x + self.mlp(self.ln2(x)) # residual
84 | return x
85 |
86 | class DecisionTransformer(nn.Module):
87 | def __init__(self, env_name, env, n_blocks, h_dim, context_len,
88 | n_heads, drop_p, goal_dim=2, max_timestep=4096):
89 | super().__init__()
90 |
91 | self.env_name = env_name
92 | self.state_dim = env.observation_space['observation'].shape[0]
93 | self.act_dim = env.action_space.shape[0]
94 | self.goal_dim = goal_dim
95 | self.n_heads = n_heads
96 | self.h_dim = h_dim
97 |
98 | ### transformer blocks
99 | input_seq_len = 3 * context_len
100 | blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
101 | self.transformer = nn.Sequential(*blocks)
102 |
103 | ### projection heads (project to embedding)
104 | self.embed_timestep = nn.Embedding(max_timestep, h_dim)
105 | self.embed_goal = torch.nn.Linear(goal_dim, h_dim)
106 | self.embed_state = torch.nn.Linear(self.state_dim, h_dim)
107 | self.embed_action = torch.nn.Linear(self.act_dim, h_dim)
108 |
109 | ### prediction heads
110 | self.final_ln = nn.LayerNorm(h_dim)
111 | self.predict_action = nn.Sequential(
112 | *([nn.Linear(h_dim, self.act_dim)] + ([nn.Tanh()]))
113 | )
114 |
115 | def forward(self, states, actions, goals):
116 | B, T, _ = states.shape
117 |
118 | timesteps = torch.arange(0, T, dtype=torch.long, device=states.device)
119 | time_embeddings = self.embed_timestep(timesteps)
120 | state_embeddings = self.embed_state(states) + time_embeddings #B, T, h_dim
121 | action_embeddings = self.embed_action(actions) + time_embeddings #B, T, h_dim
122 | goal_embeddings = self.embed_goal(goals) + time_embeddings #B, T, h_dim
123 |
124 | h = torch.stack(
125 | (goal_embeddings, state_embeddings, action_embeddings), dim=1
126 | ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)
127 |
128 | # transformer and prediction
129 | h = self.transformer(h)
130 |
131 | h = self.final_ln(h)
132 |
133 | # get h reshaped such that its size = (B , 3 , T , h_dim) and
134 | # h[:, 0, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t
135 | # h[:, 1, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t
136 | # h[:, 2, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t, a_t
137 | # that is, for each timestep (t) we have 3 output embeddings from the transformer,
138 | # each conditioned on all previous timesteps plus
139 | # the 3 input variables at that timestep (g_t, s_t, a_t) in sequence.
140 | h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) # B, 3, T, h_dim
141 | action_preds = self.predict_action(h[:,1])
142 | return action_preds
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | antlr4-python3-runtime==4.9.3
2 | appdirs==1.4.4
3 | cachetools==5.3.2
4 | certifi==2023.11.17
5 | charset-normalizer==3.3.2
6 | click==8.1.7
7 | cloudpickle==3.0.0
8 | colorama==0.4.6
9 | contourpy==1.2.0
10 | cycler==0.12.1
11 | docker-pycreds==0.4.0
12 | Farama-Notifications==0.0.4
13 | filelock==3.13.1
14 | fonttools==4.47.2
15 | fsspec==2023.12.2
16 | gitdb==4.0.11
17 | GitPython==3.1.41
18 | google-api-core==2.15.0
19 | google-auth==2.26.2
20 | google-cloud-core==2.4.1
21 | google-cloud-storage==2.5.0
22 | google-crc32c==1.5.0
23 | google-resumable-media==2.7.0
24 | googleapis-common-protos==1.62.0
25 | gymnasium==0.29.1
26 | h5py==3.10.0
27 | hydra-core==1.3.2
28 | idna==3.6
29 | Jinja2==3.1.3
30 | joblib==1.3.2
31 | kiwisolver==1.4.5
32 | markdown-it-py==3.0.0
33 | MarkupSafe==2.1.3
34 | matplotlib==3.8.2
35 | mdurl==0.1.2
36 | minari==0.4.2
37 | mpmath==1.3.0
38 | networkx==3.2.1
39 | numpy==1.26.3
40 | nvidia-cublas-cu12==12.1.3.1
41 | nvidia-cuda-cupti-cu12==12.1.105
42 | nvidia-cuda-nvrtc-cu12==12.1.105
43 | nvidia-cuda-runtime-cu12==12.1.105
44 | nvidia-cudnn-cu12==8.9.2.26
45 | nvidia-cufft-cu12==11.0.2.54
46 | nvidia-curand-cu12==10.3.2.106
47 | nvidia-cusolver-cu12==11.4.5.107
48 | nvidia-cusparse-cu12==12.1.0.106
49 | nvidia-nccl-cu12==2.18.1
50 | nvidia-nvjitlink-cu12==12.3.101
51 | nvidia-nvtx-cu12==12.1.105
52 | omegaconf==2.3.0
53 | packaging==23.1
54 | pandas==2.1.4
55 | pillow==10.2.0
56 | portion==2.4.0
57 | protobuf==4.25.2
58 | psutil==5.9.7
59 | pyasn1==0.5.1
60 | pyasn1-modules==0.3.0
61 | Pygments==2.17.2
62 | pyparsing==3.1.1
63 | python-dateutil==2.8.2
64 | pytz==2023.3.post1
65 | PyYAML==6.0.1
66 | requests==2.31.0
67 | rich==13.7.0
68 | rsa==4.9
69 | scikit-learn==1.3.2
70 | scipy==1.11.4
71 | sentry-sdk==1.39.2
72 | setproctitle==1.3.3
73 | shellingham==1.5.4
74 | six==1.16.0
75 | smmap==5.0.1
76 | sortedcontainers==2.4.0
77 | stable-baselines3==2.2.1
78 | sympy==1.12
79 | threadpoolctl==3.2.0
80 | torch==2.1.2
81 | tqdm==4.66.1
82 | triton==2.1.0
83 | typer==0.9.0
84 | typing_extensions==4.9.0
85 | tzdata==2023.4
86 | urllib3==2.1.0
87 | wandb==0.16.2
88 |
--------------------------------------------------------------------------------
/train_dmlp.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import wandb
3 | import random
4 | import minari
5 | import numpy as np
6 | import gymnasium as gym
7 | from pathlib import Path
8 | from datetime import datetime
9 | from omegaconf import DictConfig
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch.utils.data import DataLoader
14 |
15 | from model import DecisionMLP
16 | from utils import MinariEpisodicDataset, convert_remote_to_local, get_test_start_state_goals, get_lr, AntmazeWrapper
17 |
18 | def set_seed(seed):
19 | random.seed(seed)
20 | np.random.seed(seed)
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 |
24 | def eval_env(cfg, model, device, render=False):
25 | if render:
26 | render_mode = 'human'
27 | else:
28 | render_mode = None
29 |
30 | if "pointmaze" in cfg.dataset_name:
31 | env = env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode)
32 | elif "antmaze" in cfg.dataset_name:
33 | env = AntmazeWrapper(env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode))
34 | else:
35 | raise NotImplementedError
36 |
37 | test_start_state_goal = get_test_start_state_goals(cfg)
38 |
39 | model.eval()
40 | results = dict()
41 | with torch.no_grad():
42 | cum_reward = 0
43 | for ss_g in test_start_state_goal:
44 | total_reward = 0
45 | total_timesteps = 0
46 |
47 | print(ss_g['name'] + ':')
48 | for _ in range(cfg.num_eval_ep):
49 | obs, _ = env.reset(options=ss_g)
50 | done = False
51 | for t in range(env.spec.max_episode_steps):
52 | total_timesteps += 1
53 |
54 | running_state = torch.tensor(obs['observation'], dtype=torch.float32, device=device).view(1, -1)
55 | target_goal = torch.tensor(obs['desired_goal'], dtype=torch.float32, device=device).view(1, -1)
56 |
57 | act_preds = model.forward(
58 | running_state,
59 | target_goal,
60 | )
61 | act = act_preds[0].detach()
62 |
63 | obs, running_reward, done, _, _ = env.step(act.cpu().numpy())
64 |
65 | total_reward += running_reward
66 |
67 | if done:
68 | break
69 |
70 | print('Achievied goal: ', tuple(obs['achieved_goal'].tolist()))
71 | print('Desired goal: ', tuple(obs['desired_goal'].tolist()))
72 |
73 | print("=" * 60)
74 | cum_reward += total_reward
75 | results['eval/' + str(ss_g['name']) + '_avg_reward'] = total_reward / cfg.num_eval_ep
76 | results['eval/' + str(ss_g['name']) + '_avg_ep_len'] = total_timesteps / cfg.num_eval_ep
77 |
78 | results['eval/avg_reward'] = cum_reward / (cfg.num_eval_ep * len(test_start_state_goal))
79 | env.close()
80 | return results
81 |
82 | def train(cfg, hydra_cfg):
83 |
84 | #set seed
85 | set_seed(cfg.seed)
86 |
87 | #set device
88 | device = torch.device(cfg.device)
89 |
90 | if cfg.save_snapshot:
91 | checkpoint_path = Path(hydra_cfg['runtime']['output_dir']) / Path('checkpoint')
92 | checkpoint_path.mkdir(exist_ok=True)
93 | best_eval_returns = 0
94 |
95 | start_time = datetime.now().replace(microsecond=0)
96 | time_elapsed = start_time - start_time
97 | start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")
98 |
99 | if "pointmaze" in cfg.dataset_name:
100 | if "umaze" in cfg.dataset_name:
101 | cfg.env_name = 'PointMaze_UMaze-v3'
102 | cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
103 | elif "medium" in cfg.dataset_name:
104 | cfg.env_name = 'PointMaze_Medium-v3'
105 | cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
106 | elif "large" in cfg.dataset_name:
107 | cfg.env_name = 'PointMaze_Large-v3'
108 | cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
109 | env = gym.make(cfg.env_name, continuing_task=False)
110 |
111 | elif "antmaze" in cfg.dataset_name:
112 | if "umaze" in cfg.dataset_name:
113 | cfg.env_name = 'AntMaze_UMaze-v4'
114 | cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
115 | elif "medium" in cfg.dataset_name:
116 | cfg.env_name = 'AntMaze_Medium-v4'
117 | cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
118 | elif "large" in cfg.dataset_name:
119 | cfg.env_name = 'AntMaze_Large-v4'
120 | cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
121 | else:
122 | raise NotImplementedError
123 | env = AntmazeWrapper(gym.make(cfg.env_name, continuing_task=False))
124 |
125 | else:
126 | raise NotImplementedError
127 |
128 | env.action_space.seed(cfg.seed)
129 | env.observation_space.seed(cfg.seed)
130 |
131 | #create dataset
132 | if cfg.remote_data:
133 | convert_remote_to_local(cfg.dataset_name, env)
134 |
135 | train_dataset = MinariEpisodicDataset(cfg.dataset_name, cfg.remote_data, cfg.augment_data, cfg.augment_prob, cfg.nclusters)
136 |
137 | train_data_loader = DataLoader(
138 | train_dataset,
139 | batch_size=cfg.batch_size,
140 | shuffle=True,
141 | num_workers=cfg.num_workers
142 | )
143 | train_data_iter = iter(train_data_loader)
144 |
145 | #create model
146 | model = DecisionMLP(cfg.env_name, env, goal_dim=train_dataset.goal_dim).to(device)
147 |
148 | optimizer = torch.optim.Adam(
149 | model.parameters(),
150 | lr=cfg.lr,
151 | )
152 |
153 | total_updates = 0
154 | for i_train_iter in range(cfg.max_train_iters):
155 |
156 | log_action_losses = []
157 | model.train()
158 |
159 | for i in range(cfg.num_updates_per_iter):
160 | try:
161 | states, goals, actions = next(train_data_iter)
162 |
163 | except StopIteration:
164 | train_data_iter = iter(train_data_loader)
165 | states, goals, actions = next(train_data_iter)
166 |
167 | states = states.to(device)
168 | actions = actions.to(device)
169 | goals = goals.to(device)
170 |
171 | action_preds = model.forward(
172 | states=states,
173 | goals=goals,
174 | )
175 | action_loss = F.mse_loss(action_preds, actions, reduction='mean')
176 |
177 | optimizer.zero_grad()
178 | action_loss.backward()
179 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
180 | optimizer.step()
181 |
182 | log_action_losses.append(action_loss.detach().cpu().item())
183 |
184 | time = datetime.now().replace(microsecond=0) - start_time - time_elapsed
185 | time_elapsed = datetime.now().replace(microsecond=0) - start_time
186 |
187 | total_updates += cfg.num_updates_per_iter
188 |
189 | mean_action_loss = np.mean(log_action_losses)
190 |
191 | results = eval_env(cfg, model, device, render=cfg.render)
192 |
193 | log_str = ("=" * 60 + '\n' +
194 | "time elapsed: " + str(time_elapsed) + '\n' +
195 | "num of updates: " + str(total_updates) + '\n' +
196 | "train action loss: " + format(mean_action_loss, ".5f") #+ '\n' +
197 | )
198 |
199 | print(results)
200 | print(log_str)
201 |
202 | if cfg.wandb_log:
203 | log_data = dict()
204 | log_data['time'] = time.total_seconds()
205 | log_data['time_elapsed'] = time_elapsed.total_seconds()
206 | log_data['total_updates'] = total_updates
207 | log_data['mean_action_loss'] = mean_action_loss
208 | log_data['lr'] = get_lr(optimizer)
209 | log_data['training_iter'] = i_train_iter
210 | log_data.update(results)
211 | wandb.log(log_data)
212 |
213 | if cfg.save_snapshot and (1+i_train_iter)%cfg.save_snapshot_interval == 0:
214 | snapshot = Path(checkpoint_path) / Path(str(i_train_iter)+'.pt')
215 | torch.save(model.state_dict(), snapshot)
216 |
217 | if cfg.save_snapshot and results['eval/avg_reward'] >= best_eval_returns:
218 | print("=" * 60)
219 | print("saving best model!")
220 | print("=" * 60)
221 | best_eval_returns = results['eval/avg_reward']
222 | snapshot = Path(checkpoint_path) / 'best.pt'
223 | torch.save(model.state_dict(), snapshot)
224 |
225 | print("=" * 60)
226 | print("finished training!")
227 | print("=" * 60)
228 | end_time = datetime.now().replace(microsecond=0)
229 | time_elapsed = str(end_time - start_time)
230 | end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S")
231 | print("started training at: " + start_time_str)
232 | print("finished training at: " + end_time_str)
233 | print("total training time: " + time_elapsed)
234 | print("=" * 60)
235 |
236 | @hydra.main(config_path='cfgs', config_name='dmlp', version_base=None)
237 | def main(cfg: DictConfig):
238 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
239 |
240 | if cfg.wandb_log:
241 | if cfg.wandb_dir is None:
242 | cfg.wandb_dir = hydra_cfg['runtime']['output_dir']
243 |
244 | project_name = cfg.dataset_name
245 | wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=cfg.wandb_dir, group=cfg.wandb_group_name)
246 | wandb.run.name = cfg.wandb_run_name
247 |
248 | train(cfg, hydra_cfg)
249 |
250 | if __name__ == "__main__":
251 | main()
--------------------------------------------------------------------------------
/train_dt.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import wandb
3 | import random
4 | import minari
5 | import numpy as np
6 | import gymnasium as gym
7 | from pathlib import Path
8 | from datetime import datetime
9 | from omegaconf import DictConfig
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch.utils.data import DataLoader
14 |
15 | from model import DecisionTransformer
16 | from utils import MinariEpisodicTrajectoryDataset, convert_remote_to_local, get_test_start_state_goals, get_lr, AntmazeWrapper
17 |
18 | def set_seed(seed):
19 | random.seed(seed)
20 | np.random.seed(seed)
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 |
24 | def eval_env(cfg, model, device, render=False):
25 | if render:
26 | render_mode = 'human'
27 | else:
28 | render_mode = None
29 |
30 | if "pointmaze" in cfg.dataset_name:
31 | env = env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode)
32 | elif "antmaze" in cfg.dataset_name:
33 | env = AntmazeWrapper(env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode))
34 | else:
35 | raise NotImplementedError
36 |
37 | test_start_state_goal = get_test_start_state_goals(cfg)
38 |
39 | model.eval()
40 | results = dict()
41 | eval_batch_size = 1
42 |
43 | with torch.no_grad():
44 | cum_reward = 0
45 | for ss_g in test_start_state_goal:
46 | total_reward = 0
47 | total_timesteps = 0
48 | print(ss_g['name'] + ':')
49 | for _ in range(cfg.num_eval_ep):
50 | # zeros place holders
51 | m_actions = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.act_dim),
52 | dtype=torch.float32, device=device)
53 | m_states = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.state_dim),
54 | dtype=torch.float32, device=device)
55 | m_goals = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.goal_dim),
56 | dtype=torch.float32, device=device)
57 |
58 | obs, _ = env.reset(options=ss_g)
59 | done = False
60 |
61 | for t in range(env.spec.max_episode_steps):
62 | total_timesteps += 1
63 |
64 | m_states[0, t] = torch.tensor(obs['observation'], dtype=torch.float32, device=device)
65 | m_goals[0, t] = torch.tensor(obs['desired_goal'], dtype=torch.float32, device=device)
66 |
67 |
68 | if t < cfg.context_len:
69 | act_preds = model.forward(m_states[:,:t+1],
70 | m_actions[:,:t+1],
71 | m_goals[:,:t+1])
72 |
73 | else:
74 | act_preds = model.forward(m_states[:, t-cfg.context_len+1:t+1],
75 | m_actions[:, t-cfg.context_len+1:t+1],
76 | m_goals[:, t-cfg.context_len+1:t+1])
77 |
78 |
79 | act = act_preds[0, -1].detach()
80 |
81 | obs, running_reward, done, _, _ = env.step(act.cpu().numpy())
82 |
83 | # add action in placeholder
84 | m_actions[0, t] = act
85 |
86 | total_reward += running_reward
87 |
88 | if done:
89 | break
90 |
91 | print('Achievied goal: ', tuple(obs['achieved_goal'].tolist()))
92 | print('Desired goal: ', tuple(obs['desired_goal'].tolist()))
93 |
94 | print("=" * 60)
95 | cum_reward += total_reward
96 | results['eval/' + str(ss_g['name']) + '_avg_reward'] = total_reward / cfg.num_eval_ep
97 | results['eval/' + str(ss_g['name']) + '_avg_ep_len'] = total_timesteps / cfg.num_eval_ep
98 |
99 | results['eval/avg_reward'] = cum_reward / (cfg.num_eval_ep * len(test_start_state_goal))
100 | env.close()
101 | return results
102 |
103 | def train(cfg, hydra_cfg):
104 |
105 | #set seed
106 | set_seed(cfg.seed)
107 |
108 | #set device
109 | device = torch.device(cfg.device)
110 |
111 | if cfg.save_snapshot:
112 | checkpoint_path = Path(hydra_cfg['runtime']['output_dir']) / Path('checkpoint')
113 | checkpoint_path.mkdir(exist_ok=True)
114 | best_eval_returns = 0
115 |
116 | start_time = datetime.now().replace(microsecond=0)
117 | time_elapsed = start_time - start_time
118 | start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")
119 |
120 | #create env
121 | if "pointmaze" in cfg.dataset_name:
122 | if "umaze" in cfg.dataset_name:
123 | cfg.env_name = 'PointMaze_UMaze-v3'
124 | cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
125 | elif "medium" in cfg.dataset_name:
126 | cfg.env_name = 'PointMaze_Medium-v3'
127 | cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
128 | elif "large" in cfg.dataset_name:
129 | cfg.env_name = 'PointMaze_Large-v3'
130 | cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
131 | env = gym.make(cfg.env_name, continuing_task=False)
132 |
133 | elif "antmaze" in cfg.dataset_name:
134 | if "umaze" in cfg.dataset_name:
135 | cfg.env_name = 'AntMaze_UMaze-v4'
136 | cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
137 | elif "medium" in cfg.dataset_name:
138 | cfg.env_name = 'AntMaze_Medium-v4'
139 | cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
140 | elif "large" in cfg.dataset_name:
141 | cfg.env_name = 'AntMaze_Large-v4'
142 | cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
143 | else:
144 | raise NotImplementedError
145 | env = AntmazeWrapper(gym.make(cfg.env_name, continuing_task=False))
146 |
147 | else:
148 | raise NotImplementedError
149 |
150 | env.action_space.seed(cfg.seed)
151 | env.observation_space.seed(cfg.seed)
152 |
153 | #create dataset
154 | if cfg.remote_data:
155 | convert_remote_to_local(cfg.dataset_name, env)
156 |
157 | print(cfg.nclusters)
158 |
159 | train_dataset = MinariEpisodicTrajectoryDataset(cfg.dataset_name, cfg.remote_data, cfg.context_len, cfg.augment_data, cfg.augment_prob, cfg.nclusters)
160 |
161 | train_data_loader = DataLoader(
162 | train_dataset,
163 | batch_size=cfg.batch_size,
164 | shuffle=True,
165 | num_workers=cfg.num_workers
166 | )
167 | train_data_iter = iter(train_data_loader)
168 |
169 | #create model
170 | model = DecisionTransformer(cfg.env_name, env, cfg.n_blocks, cfg.embed_dim, cfg.context_len, cfg.n_heads, cfg.drop_p, goal_dim=train_dataset.goal_dim).to(device)
171 |
172 | optimizer = torch.optim.AdamW(
173 | model.parameters(),
174 | lr=cfg.lr,
175 | weight_decay=cfg.wt_decay
176 | )
177 |
178 | scheduler = torch.optim.lr_scheduler.LambdaLR(
179 | optimizer,
180 | lambda steps: min((steps+1)/cfg.warmup_steps, 1)
181 | )
182 |
183 | total_updates = 0
184 | for i_train_iter in range(cfg.max_train_iters):
185 |
186 | log_action_losses = []
187 | model.train()
188 |
189 | for i in range(cfg.num_updates_per_iter):
190 | try:
191 | states, goals, actions = next(train_data_iter)
192 |
193 | except StopIteration:
194 | train_data_iter = iter(train_data_loader)
195 | states, goals, actions = next(train_data_iter)
196 |
197 | states = states.to(device) # B x T x state_dim
198 | goals = goals.to(device).repeat(1, cfg.context_len, 1) # B x T x goal_dim
199 | actions = actions.to(device) # B x T
200 |
201 | action_preds = model.forward(
202 | states=states,
203 | actions=actions,
204 | goals=goals,
205 | )
206 |
207 | action_loss = F.mse_loss(action_preds, actions)
208 |
209 | optimizer.zero_grad()
210 | action_loss.backward()
211 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
212 | optimizer.step()
213 | scheduler.step()
214 |
215 | log_action_losses.append(action_loss.detach().cpu().item())
216 |
217 | time = datetime.now().replace(microsecond=0) - start_time - time_elapsed
218 | time_elapsed = datetime.now().replace(microsecond=0) - start_time
219 |
220 | total_updates += cfg.num_updates_per_iter
221 |
222 | mean_action_loss = np.mean(log_action_losses)
223 |
224 | results = eval_env(cfg, model, device, render=cfg.render)
225 |
226 | log_str = ("=" * 60 + '\n' +
227 | "time elapsed: " + str(time_elapsed) + '\n' +
228 | "num of updates: " + str(total_updates) + '\n' +
229 | "train action loss: " + format(mean_action_loss, ".5f") #+ '\n' +
230 | )
231 |
232 | print(results)
233 | print(log_str)
234 |
235 | if cfg.wandb_log:
236 | log_data = dict()
237 | log_data['time'] = time.total_seconds()
238 | log_data['time_elapsed'] = time_elapsed.total_seconds()
239 | log_data['total_updates'] = total_updates
240 | log_data['mean_action_loss'] = mean_action_loss
241 | log_data['lr'] = get_lr(optimizer)
242 | log_data['training_iter'] = i_train_iter
243 | log_data.update(results)
244 | wandb.log(log_data)
245 |
246 | if cfg.save_snapshot and (1+i_train_iter)%cfg.save_snapshot_interval == 0:
247 | snapshot = Path(checkpoint_path) / Path(str(i_train_iter)+'.pt')
248 | torch.save(model.state_dict(), snapshot)
249 |
250 | if cfg.save_snapshot and results['eval/avg_reward'] >= best_eval_returns:
251 | print("=" * 60)
252 | print("saving best model!")
253 | print("=" * 60)
254 | best_eval_returns = results['eval/avg_reward']
255 | snapshot = Path(checkpoint_path) / 'best.pt'
256 | torch.save(model.state_dict(), snapshot)
257 |
258 | print("=" * 60)
259 | print("finished training!")
260 | print("=" * 60)
261 | end_time = datetime.now().replace(microsecond=0)
262 | time_elapsed = str(end_time - start_time)
263 | end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S")
264 | print("started training at: " + start_time_str)
265 | print("finished training at: " + end_time_str)
266 | print("total training time: " + time_elapsed)
267 | print("=" * 60)
268 |
269 | @hydra.main(config_path='cfgs', config_name='dt', version_base=None)
270 | def main(cfg: DictConfig):
271 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
272 |
273 | if cfg.wandb_log:
274 | if cfg.wandb_dir is None:
275 | cfg.wandb_dir = hydra_cfg['runtime']['output_dir']
276 |
277 | project_name = cfg.dataset_name
278 | wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=cfg.wandb_dir, group=cfg.wandb_group_name)
279 | wandb.run.name = cfg.wandb_run_name
280 |
281 | train(cfg, hydra_cfg)
282 |
283 | if __name__ == "__main__":
284 | main()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import *
2 | from .env import *
3 | from .misc import *
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | import minari
5 | import numpy as np
6 | import gymnasium as gym
7 | from datetime import datetime
8 | from sklearn.cluster import KMeans
9 |
10 | from collections import defaultdict
11 | from torch.utils.data import Dataset
12 |
13 | def convert_remote_to_local(dataset_name, env):
14 |
15 | if os.path.isfile('data/'+dataset_name+'-remote.pkl'):
16 | print("A dataset with the same name already exists. Using that dataset.")
17 | return
18 |
19 | if "pointmaze" in dataset_name:
20 | if "pointmaze-umaze" in dataset_name:
21 | env_name = 'PointMaze_UMaze-v3'
22 | elif "pointmaze-medium" in dataset_name:
23 | env_name = 'PointMaze_Medium-v3'
24 | elif "pointmaze-large" in dataset_name:
25 | env_name = 'PointMaze_Large-v3'
26 |
27 | env = gym.make(env_name, continuing_task=False)
28 |
29 | else:
30 | raise NotImplementedError
31 |
32 | minari_dataset = minari.load_dataset(dataset_name)
33 |
34 | # data placeholders
35 | observation_data = {
36 | 'episode_id': np.zeros(shape=(int(1e6),), dtype=np.int32),
37 | 'observation' : np.zeros(shape=(int(1e6), *env.observation_space['observation'].shape), dtype=np.float32),
38 | 'achieved_goal' : np.zeros(shape=(int(1e6), *env.observation_space['achieved_goal'].shape), dtype=np.float32),
39 | }
40 | action_data = np.zeros(shape=(int(1e6), *env.action_space.shape), dtype=np.float32)
41 | termination_data = np.zeros(shape = (int(1e6),), dtype=bool)
42 |
43 | data_idx = 0
44 | episode_idx = 0
45 | for episode in minari_dataset:
46 | if data_idx + episode.total_timesteps + 1 > int(1e6):
47 |
48 | observation_data['episode_id'][data_idx: ] = episode_idx
49 | observation_data['observation'][data_idx: ] = episode.observations['observation'][:int(1e6)-data_idx]
50 | observation_data['achieved_goal'][data_idx: ] = episode.observations['achieved_goal'][:int(1e6)-data_idx]
51 |
52 | action_data[data_idx: ] = episode.actions[:int(1e6)-data_idx]
53 | termination_data[data_idx+1: ] = episode.truncations[:int(1e6)-data_idx-1]
54 |
55 | break
56 |
57 | else:
58 | try:
59 | observation_data['episode_id'][data_idx: data_idx+episode.total_timesteps+1] = episode_idx
60 | observation_data['observation'][data_idx: data_idx+episode.total_timesteps+1] = episode.observations['observation']
61 | observation_data['achieved_goal'][data_idx: data_idx+episode.total_timesteps+1] = episode.observations['achieved_goal']
62 |
63 | action_data[data_idx: data_idx+episode.total_timesteps] = episode.actions
64 | termination_data[data_idx+1: data_idx+episode.total_timesteps+1] = episode.truncations
65 |
66 | data_idx = data_idx + episode.total_timesteps + 1
67 | episode_idx = episode_idx + 1
68 |
69 | except:
70 | # some episodes are wierd; timesteps is equal to num obervations
71 | continue
72 |
73 | if (episode_idx + 1) % 1000 == 0:
74 | print('EPISODES RECORDED = ', episode_idx)
75 |
76 | if data_idx >= int(1e6):
77 | break
78 |
79 | print('Total transitions recorded = ', data_idx)
80 |
81 | dataset = {'observations':observation_data,
82 | 'actions':action_data,
83 | 'terminations':termination_data,
84 | }
85 |
86 | with open('data/'+dataset_name+'-remote.pkl', 'wb') as fp:
87 | pickle.dump(dataset, fp)
88 |
89 | def extract_discrete_id_to_data_id_map(discrete_goals, dones, last_valid_traj):
90 | discrete_goal_to_data_idx = defaultdict(list)
91 | gm = 0
92 | for i, d_g in enumerate(discrete_goals):
93 |
94 | discrete_goal_to_data_idx[d_g].append(i)
95 | gm += 1
96 |
97 | if (i + 1) % 200000 == 0:
98 | print('Goals mapped:', gm)
99 |
100 | if i == last_valid_traj:
101 | break
102 |
103 | for dg, data_idxes in discrete_goal_to_data_idx.items():
104 | discrete_goal_to_data_idx[dg] = np.array(data_idxes)
105 |
106 | print('Total goals mapped:', gm)
107 | return discrete_goal_to_data_idx
108 |
109 | def extract_done_markers(dones, episode_ids):
110 | """Given a per-timestep dones vector, return starts, ends, and lengths of trajs."""
111 |
112 | (ends,) = np.where(dones)
113 | return ends[ episode_ids[ : ends[-1] + 1 ] ], np.where(1 - dones[: ends[-1] + 1])[0]
114 |
115 | class MinariEpisodicTrajectoryDataset(Dataset):
116 | def __init__(self, dataset_name, remote_data, context_len, augment_data, augment_prob, nclusters=40):
117 | super().__init__()
118 | if remote_data:
119 | path = 'data/'+dataset_name+'-remote.pkl'
120 | else:
121 | path = 'data/'+dataset_name+'.pkl'
122 |
123 | with open(path, 'rb') as fp:
124 | self.dataset = pickle.load(fp)
125 |
126 | self.observations = self.dataset['observations']['observation']
127 | self.achieved_goals = self.dataset['observations']['achieved_goal']
128 | self.actions = self.dataset['actions']
129 |
130 | (ends,) = np.where(self.dataset['terminations'])
131 | self.starts = np.concatenate(([0], ends[:-1] + 1))
132 | self.lengths = ends - self.starts + 1 #length = number of actions taken in an episode + 1
133 | self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]
134 |
135 | good_idxes = self.lengths > context_len
136 | print('Throwing away ', np.sum(self.lengths[~good_idxes] - 1), 'number of transitions')
137 | self.starts = self.starts[good_idxes] #starts will only contain indices of episodes where number of states > context_len
138 | self.lengths = self.lengths[good_idxes]
139 |
140 | self.num_trajectories = len(self.starts)
141 |
142 | if augment_data:
143 | start_time = datetime.now().replace(microsecond=0)
144 | print('starting kmeans ... ')
145 | kmeans = KMeans(n_clusters=nclusters, n_init="auto").fit(self.achieved_goals)
146 | time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)
147 | print('kmeans done! time taken :', time_elapsed)
148 |
149 | self.discrete_goal_to_data_idx = extract_discrete_id_to_data_id_map(kmeans.labels_, self.dataset['terminations'], self.ends[-1])
150 | self.achieved_discrete_goals = kmeans.labels_
151 | kmeans = None
152 |
153 | self.state_dim = self.observations.shape[-1]
154 | self.state_dtype = self.observations.dtype
155 | self.act_dim = self.actions.shape[-1]
156 | self.act_dtype = self.actions.dtype
157 | self.goal_dim = self.achieved_goals.shape[-1]
158 | self.context_len = context_len
159 | self.augment_data = augment_data
160 | self.augment_prob = augment_prob
161 | self.dataset = None
162 |
163 | def __len__(self):
164 | return self.num_trajectories * 100
165 |
166 | def __getitem__(self, idx):
167 | '''
168 | Reminder: np.random.randint samples from the set [low, high)
169 | '''
170 | idx = idx % self.num_trajectories
171 | traj_len = self.lengths[idx] - 1 #traj_len = T, traj_len is the number of actions taken in the trajectory
172 | traj_start_i = self.starts[idx]
173 | assert self.ends[traj_start_i] == traj_start_i + traj_len
174 |
175 |
176 | if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
177 | correct = False
178 | while not correct:
179 | si = traj_start_i + np.random.randint(0, traj_len) #si can be traj_start_i + [0, T - 1]
180 | gi = np.random.randint(si, traj_start_i + traj_len) + 1 #gi can be traj_start_i + 1 + [si + 1, T]
181 | dummy_discrete_goal = self.achieved_discrete_goals[ gi ]
182 | nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal])
183 | nearby_goal_idx_ends = self.ends[nearby_goal_idx]
184 | if (gi-si) + (nearby_goal_idx_ends - nearby_goal_idx) + 1 > self.context_len:
185 | correct = True
186 |
187 | if gi - si < self.context_len:
188 | goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx + self.context_len - (gi - si), nearby_goal_idx_ends + 1) ]).view(1, -1)
189 | state = torch.tensor( np.concatenate( [ self.observations[si: gi], self.observations[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) )
190 | action = torch.tensor( np.concatenate( [ self.actions[si: gi], self.actions[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) )
191 | else:
192 | goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, nearby_goal_idx_ends + 1) ]).view(1, -1)
193 | state = torch.tensor(self.observations[si: si + self.context_len])
194 | action = torch.tensor(self.actions[si: si + self.context_len])
195 |
196 | else:
197 | si = traj_start_i + np.random.randint(0, traj_len - self.context_len + 1) #si can be traj_start_i + [0, T-C]
198 | gi = np.random.randint(si + self.context_len, traj_start_i + traj_len + 1) #gi can be [si+C, traj_start_i+T]
199 |
200 | goal = torch.tensor(self.achieved_goals[ gi ]).view(1, -1)
201 | state = torch.tensor(self.observations[si: si + self.context_len])
202 | action = torch.tensor(self.actions[si: si + self.context_len])
203 |
204 | return state, goal, action
205 |
206 | class MinariEpisodicDataset(Dataset):
207 | def __init__(self, dataset_name, remote_data, augment_data, augment_prob, nclusters=40):
208 | super().__init__()
209 | if remote_data:
210 | path = 'data/'+dataset_name+'-remote.pkl'
211 | else:
212 | path = 'data/'+dataset_name+'.pkl'
213 |
214 | with open(path, 'rb') as fp:
215 | self.dataset = pickle.load(fp)
216 |
217 | self.episode_ids = self.dataset['observations']['episode_id']
218 | self.observations = self.dataset['observations']['observation']
219 | self.achieved_goals = self.dataset['observations']['achieved_goal']
220 | self.actions = self.dataset['actions']
221 |
222 | (ends,) = np.where(self.dataset['terminations'])
223 | self.starts = np.concatenate(([0], ends[:-1] + 1))
224 | self.lengths = ends - self.starts + 1
225 | self.num_trajectories = len(self.starts)
226 |
227 | self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]
228 |
229 | if augment_data:
230 | start_time = datetime.now().replace(microsecond=0)
231 | print('starting kmeans ... ')
232 | kmeans = KMeans(n_clusters=nclusters, n_init="auto").fit(self.observations)
233 | time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)
234 | print('kmeans done! time taken :', time_elapsed)
235 |
236 | self.discrete_goal_to_data_idx = extract_discrete_id_to_data_id_map(kmeans.labels_, self.dataset['terminations'], self.ends[-1])
237 | self.achieved_discrete_goals = kmeans.labels_
238 | kmeans = None
239 |
240 | self.goal_dim = self.achieved_goals.shape[-1]
241 | self.augment_data = augment_data
242 | self.augment_prob = augment_prob
243 | self.dataset = None
244 |
245 | def __len__(self):
246 | return self.num_trajectories * 100
247 |
248 | def __getitem__(self, idx):
249 | '''
250 | Reminder: np.random.randint samples from the set [low, high)
251 | '''
252 | idx = idx % self.num_trajectories
253 | traj_len = self.lengths[idx] - 1 #traj_len = T, traj_len is the number of actions taken in the trajectory
254 | traj_start_i = self.starts[idx]
255 | assert self.ends[traj_start_i] == traj_start_i + traj_len
256 |
257 | si = np.random.randint(0, traj_len) #si can be [0, T-1]
258 |
259 | state = torch.tensor(self.observations[traj_start_i + si])
260 | action = torch.tensor(self.actions[traj_start_i + si])
261 |
262 | if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
263 | dummy_discrete_goal = self.achieved_discrete_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ]
264 | nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal])
265 | goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, self.ends[nearby_goal_idx] + 1) ])
266 | else:
267 | goal = torch.tensor(self.achieved_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ])
268 |
269 | return state, goal, action
--------------------------------------------------------------------------------
/utils/env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import gymnasium as gym
4 | from typing import Iterable, Any
5 | from gymnasium.spaces import Box
6 | from gymnasium.error import DependencyNotInstalled
7 |
8 | def get_lr(optimizer):
9 | for param_group in optimizer.param_groups:
10 | return param_group['lr']
11 |
12 | def get_parameters(modules: Iterable[nn.Module]):
13 | model_parameters = []
14 | for module in modules:
15 | model_parameters += list(module.parameters())
16 | return model_parameters
17 |
18 |
19 | class PreprocessObservationWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
20 |
21 | def __init__(self, env, shape, grayscale=False) -> None:
22 | """Resizes image observations to shape given by :attr:`shape`.
23 |
24 | Args:
25 | env: The environment to apply the wrapper
26 | shape: The shape of the resized observations
27 | """
28 | gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
29 | gym.ObservationWrapper.__init__(self, env)
30 |
31 | if isinstance(shape, int):
32 | shape = (shape, shape)
33 | assert len(shape) == 2 and all(
34 | x > 0 for x in shape
35 | ), f"Expected shape to be a 2-tuple of positive integers, got: {shape}"
36 |
37 | self.grayscale = grayscale
38 | if grayscale:
39 | obs_shape = env.observation_space['pixels'].shape
40 | env.observation_space['pixels'] = Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8)
41 |
42 | self.shape = tuple(shape)
43 | obs_shape = self.shape + env.observation_space['pixels'].shape[2:]
44 | self.observation_space['pixels'] = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
45 |
46 | def observation(self, observation):
47 | try:
48 | import cv2
49 | except ImportError as e:
50 | raise DependencyNotInstalled(
51 | "opencv (cv2) is not installed, run `pip install gymnasium[other]`"
52 | ) from e
53 |
54 | if self.grayscale:
55 | observation['pixels'] = cv2.cvtColor(observation['pixels'], cv2.COLOR_RGB2GRAY)
56 | observation['pixels'] = cv2.resize(
57 | observation['pixels'], self.shape[::-1], interpolation=cv2.INTER_AREA
58 | ).reshape(self.observation_space['pixels'].shape)
59 |
60 | return observation
61 |
62 | class AntmazeWrapper(gym.ObservationWrapper):
63 | def __init__(self, env):
64 | """Constructor for the observation wrapper."""
65 | gym.ObservationWrapper.__init__(self, env)
66 |
67 | self.observation_space['observation'] = gym.spaces.Box(
68 | low=-np.inf, high=np.inf, shape=(self.observation_space['observation'].shape[0] + 2,), dtype=np.float64
69 | )
70 |
71 | def reset(
72 | self, *, seed=None, options=None,
73 | ):
74 | obs, info = self.env.reset(seed=seed, options=options)
75 | return self.observation(obs), info
76 |
77 | def step(
78 | self, action
79 | ):
80 | """Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations."""
81 | observation, reward, terminated, truncated, info = self.env.step(action)
82 | return self.observation(observation), reward, terminated, truncated, info
83 |
84 | def observation(self, observation):
85 | """Returns a modified observation.
86 |
87 | Args:
88 | observation: The :attr:`env` observation
89 |
90 | Returns:
91 | The modified observation
92 | """
93 | observation['observation'] = np.concatenate((observation['achieved_goal'], observation['observation']), axis=0)
94 | return observation
95 |
96 | def get_test_start_state_goals(cfg):
97 | if cfg.dataset_name in ["pointmaze-umaze-v0", "antmaze-umaze-v0"]:
98 | test_start_state_goal = [
99 | {'goal_cell': np.array([1,1], dtype=np.int32),
100 | 'reset_cell': np.array([3,1], dtype=np.int32),
101 | 'name' : 'bottom_to_top',
102 | },
103 |
104 | {'goal_cell': np.array([3,1], dtype=np.int32),
105 | 'reset_cell': np.array([1,1], dtype=np.int32),
106 | 'name' : 'top_to_bottom',
107 | },
108 | ]
109 |
110 | elif cfg.dataset_name in ["pointmaze-medium-v0"]:
111 | test_start_state_goal = [
112 | {'goal_cell': np.array([6,3], dtype=np.int32),
113 | 'reset_cell': np.array([6,6], dtype=np.int32),
114 | 'name' : 'bottom_right_to_bottom_center',
115 | },
116 |
117 | {'goal_cell': np.array([6,1], dtype=np.int32),
118 | 'reset_cell': np.array([6,6], dtype=np.int32),
119 | 'name' : 'bottom_right_to_bottom_left',
120 | },
121 |
122 | {'goal_cell': np.array([6,3], dtype=np.int32),
123 | 'reset_cell': np.array([6,5], dtype=np.int32),
124 | 'name' : 'bottom_rightish_to_bottom_center',
125 | },
126 |
127 | {'goal_cell': np.array([6,1], dtype=np.int32),
128 | 'reset_cell': np.array([6,5], dtype=np.int32),
129 | 'name' : 'bottom_rightish_to_bottom_left',
130 | },
131 |
132 | {'goal_cell': np.array([6,5], dtype=np.int32),
133 | 'reset_cell': np.array([1,1], dtype=np.int32),
134 | 'name' : 'top_left_to_bottom_rightish',
135 | },
136 |
137 | {'goal_cell': np.array([6,6], dtype=np.int32),
138 | 'reset_cell': np.array([1,6], dtype=np.int32),
139 | 'name' : 'top_right_to_bottom_right',
140 | },
141 | ]
142 |
143 | elif cfg.dataset_name in ["antmaze-medium-v0"]:
144 | test_start_state_goal = [
145 | {'goal_cell': np.array([6,5], dtype=np.int32),
146 | 'reset_cell': np.array([6,1], dtype=np.int32),
147 | 'name' : 'bottom_left_to_bottom_rightish',
148 | },
149 |
150 | {'goal_cell': np.array([1,6], dtype=np.int32),
151 | 'reset_cell': np.array([6,1], dtype=np.int32),
152 | 'name' : 'bottom_left_to_top_right',
153 | },
154 |
155 | {'goal_cell': np.array([6,1], dtype=np.int32),
156 | 'reset_cell': np.array([6,5], dtype=np.int32),
157 | 'name' : 'bottom_rightish_to_bottom_left',
158 | },
159 |
160 | {'goal_cell': np.array([1,6], dtype=np.int32),
161 | 'reset_cell': np.array([6,5], dtype=np.int32),
162 | 'name' : 'bottom_rightish_to_top_right',
163 | },
164 |
165 | {'goal_cell': np.array([6,1], dtype=np.int32),
166 | 'reset_cell': np.array([1,6], dtype=np.int32),
167 | 'name' : 'top_right_to_bottom_left',
168 | },
169 |
170 | {'goal_cell': np.array([6,5], dtype=np.int32),
171 | 'reset_cell': np.array([1,6], dtype=np.int32),
172 | 'name' : 'top_right_to_bottom_rightish',
173 | },
174 |
175 | ]
176 |
177 | elif cfg.dataset_name in ["pointmaze-large-v0"]:
178 | test_start_state_goal = [
179 | {'goal_cell': np.array([7,4], dtype=np.int32),
180 | 'reset_cell': np.array([7,1], dtype=np.int32),
181 | 'name' : 'bottom_left_to_bottom_center',
182 | },
183 |
184 | {'goal_cell': np.array([7,10], dtype=np.int32),
185 | 'reset_cell': np.array([7,1], dtype=np.int32),
186 | 'name' : 'bottom_left_to_bottom_right',
187 | },
188 |
189 | {'goal_cell': np.array([1,10], dtype=np.int32),
190 | 'reset_cell': np.array([7,1], dtype=np.int32),
191 | 'name' : 'bottom_left_to_top_right',
192 | },
193 |
194 | {'goal_cell': np.array([7,1], dtype=np.int32),
195 | 'reset_cell': np.array([7,4], dtype=np.int32),
196 | 'name' : 'bottom_center_to_bottom_left',
197 | },
198 |
199 | {'goal_cell': np.array([7,10], dtype=np.int32),
200 | 'reset_cell': np.array([7,4], dtype=np.int32),
201 | 'name' : 'bottom_center_to_bottom_right',
202 | },
203 |
204 | {'goal_cell': np.array([1,1], dtype=np.int32),
205 | 'reset_cell': np.array([7,4], dtype=np.int32),
206 | 'name' : 'bottom_center_to_top_left',
207 | },
208 |
209 | {'goal_cell': np.array([7,1], dtype=np.int32),
210 | 'reset_cell': np.array([7,10], dtype=np.int32),
211 | 'name' : 'bottom_right_to_bottom_left',
212 | }
213 | ]
214 |
215 | elif cfg.dataset_name in ["antmaze-large-v0"]:
216 | test_start_state_goal = [
217 | {'goal_cell': np.array([7,4], dtype=np.int32),
218 | 'reset_cell': np.array([7,1], dtype=np.int32),
219 | 'name' : 'bottom_left_to_bottom_center',
220 | },
221 |
222 | {'goal_cell': np.array([7,10], dtype=np.int32),
223 | 'reset_cell': np.array([7,1], dtype=np.int32),
224 | 'name' : 'bottom_left_to_bottom_right',
225 | },
226 |
227 | {'goal_cell': np.array([1,10], dtype=np.int32),
228 | 'reset_cell': np.array([7,1], dtype=np.int32),
229 | 'name' : 'bottom_left_to_top_right',
230 | },
231 |
232 | {'goal_cell': np.array([7,1], dtype=np.int32),
233 | 'reset_cell': np.array([7,4], dtype=np.int32),
234 | 'name' : 'bottom_center_to_bottom_left',
235 | },
236 |
237 | {'goal_cell': np.array([1,1], dtype=np.int32),
238 | 'reset_cell': np.array([7,4], dtype=np.int32),
239 | 'name' : 'bottom_center_to_top_left',
240 | },
241 |
242 | {'goal_cell': np.array([7,1], dtype=np.int32),
243 | 'reset_cell': np.array([7,10], dtype=np.int32),
244 | 'name' : 'bottom_right_to_bottom_left',
245 | }
246 | ]
247 |
248 | else:
249 | raise NotImplementedError
250 |
251 | return test_start_state_goal
252 |
253 | def get_maze_map(dataset_name):
254 | return None
255 |
256 | def cell_to_state(cell, maze):
257 | return cell[:, 0] * maze.map_width + cell[:, 1]
258 |
259 |
260 | def cell_xy_to_rowcol(maze, xy_pos: np.ndarray) -> np.ndarray:
261 | """Converts a cell x and y coordinates to `(i,j)`"""
262 |
263 | i = np.reshape(np.floor((maze.y_map_center - xy_pos[:, 1]) / maze.maze_size_scaling), (-1, 1))
264 | j = np.reshape(np.floor((xy_pos[:, 0] + maze.x_map_center) / maze.maze_size_scaling), (-1, 1))
265 |
266 | return np.concatenate([i,j], axis=-1)
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import Iterable
3 |
4 | def weight_init(m):
5 | if isinstance(m, nn.Linear):
6 | nn.init.orthogonal_(m.weight.data)
7 | if hasattr(m.bias, 'data'):
8 | m.bias.data.fill_(0.0)
9 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
10 | gain = nn.init.calculate_gain('relu')
11 | nn.init.orthogonal_(m.weight.data, gain)
12 | if hasattr(m.bias, 'data'):
13 | m.bias.data.fill_(0.0)
14 |
15 | def soft_update(target, source, tau):
16 | for target_param, param in zip(target.parameters(), source.parameters()):
17 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
18 |
19 | def hard_update(target, source):
20 | for target_param, param in zip(target.parameters(), source.parameters()):
21 | target_param.data.copy_(param.data)
22 |
23 | def get_lr(optimizer):
24 | for param_group in optimizer.param_groups:
25 | return param_group['lr']
26 |
27 | def get_parameters(modules: Iterable[nn.Module]):
28 | model_parameters = []
29 | for module in modules:
30 | model_parameters += list(module.parameters())
31 | return model_parameters
--------------------------------------------------------------------------------