├── .gitignore ├── LICENSE ├── README.md ├── cfgs ├── dmlp.yaml └── dt.yaml ├── collect_antmaze_data.py ├── collect_pointmaze_data.py ├── model.py ├── requirements.txt ├── train_dmlp.py ├── train_dt.py └── utils ├── __init__.py ├── data.py ├── env.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | #saved binaries 4 | saved/ 5 | 6 | #virtual env 7 | env_* 8 | 9 | #local experiments 10 | local_exp/ 11 | wandb/ 12 | 13 | #dataset 14 | data 15 | 16 | #evaluation 17 | eval.py 18 | 19 | #binaries 20 | *.pkl 21 | *.zip 22 | 23 | #results 24 | dt_runs/ 25 | dqn_runs/ 26 | checkpoints/ 27 | 28 | #jupyter notebooks 29 | notebooks/ 30 | 31 | #vscode 32 | .vscode 33 | 34 | #mujoco 35 | MUJOCO_LOG.TXT -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Raj Ghugare 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Closing the Gap between TD Learning and Supervised Learning -- A Generalisation Point of View](https://arxiv.org/abs/2401.11237) 2 | [Raj Ghugare](https://rajghugare19.github.io/), $\quad$ [Matthieu Geist](https://scholar.google.com/citations?user=ectPLEUAAAAJ), $\quad$ [Glen Berseth](https://neo-x.github.io/)\*, $\quad$ [Benjamin Eysenbach](https://ben-eysenbach.github.io/)\* 3 | 4 | \* Equal advising. 5 | 6 | ## Installation 7 | 8 | Create virtual environment named `env_stuff` using command:
9 | ```sh 10 | python3 -m venv env_stuff 11 | ``` 12 | 13 | Install all the packages used to run the code using the `requirements.txt` file:
14 | ```sh 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Training 19 | 20 | To train an RvS (decision-mlp) agent on pointmaze-umaze using temporal data augmentation, with $\epsilon=0.5$ and $K=40$:
21 | ```sh 22 | python train_dmlp.py dataset_name=pointmaze-umaze-v0 augment_data=True nclusters=40 23 | ``` 24 | 25 | To train a DT (decision-transformer) agent on pointmaze-umaze using temporal data augmentation, with $\epsilon=0.5$ and $K=40$:
26 | ```sh 27 | python train_dt.py dataset_name=pointmaze-umaze-v0 augment_data=True nclusters=40 28 | ``` 29 | 30 | ## Datasets 31 | 32 | To download the pretrained datasets, visit [this google drive link](https://drive.google.com/drive/folders/1j8Ok2UMYSqfIQReuE6csf1nMoI1s25K-?usp=sharing). 33 | 34 | To collect the pointmaze-large dataset with $1e^6$ transitions and seed 1:
35 | ```sh 36 | python collect_pointmaze_data.py pointmaze-large-v0 1 1000000 37 | ``` 38 | 39 | To collect the antmaze-large dataset with $1e^6$ transitions and seed 1:
40 | ```sh 41 | python collect_antmaze_data.py antmaze-umaze-v0 1 1000000 42 | ``` 43 | 44 | ## Acknowledgment 45 | Our codebase has been build using/on top of the following codes. We thank the respective authors for their awesome contributions. 46 | - [NanoGPT](https://github.com/karpathy/nanoGPT)
47 | - [min-decision-transformer](https://github.com/nikhilbarhate99/min-decision-transformer)
48 | 49 | ## Correspondence 50 | 51 | If you have any questions or suggestions, please reach out to me via raj.ghugare@mila.quebec. 52 | -------------------------------------------------------------------------------- /cfgs/dmlp.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | device: cuda 3 | agent_name: dmlp 4 | 5 | #eval 6 | num_eval_ep: 10 7 | num_eval_len: 8 | 9 | # env 10 | env_name: 11 | dataset_name: pointmaze-umaze-v0 12 | remote_data: False 13 | num_workers: 4 14 | render: False 15 | 16 | #train 17 | num_updates_per_iter: 40000 18 | max_train_iters: 25 #25 x 40000 -> 1000000 transition batches 19 | 20 | augment_data: False 21 | augment_prob: 0 22 | nclusters: 23 | batch_size: 256 24 | lr: 1e-3 25 | 26 | #logging 27 | log_dir: dt_runs 28 | wandb_log: False 29 | wandb_entity: raj19 30 | wandb_run_name: 31 | wandb_group_name: 32 | wandb_dir: 33 | 34 | #saving 35 | save_snapshot: True 36 | save_snapshot_interval: 25 37 | 38 | #hydra 39 | hydra: 40 | run: 41 | dir: ${log_dir}/${dataset_name}/${now:%Y.%m.%d}_${now:%H.%M.%S} 42 | job: 43 | chdir: False -------------------------------------------------------------------------------- /cfgs/dt.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | device: cuda 3 | agent_name: dt 4 | 5 | #eval 6 | num_eval_ep: 10 7 | num_eval_len: 8 | 9 | # env 10 | env_name: 11 | dataset_name: pointmaze-umaze-v0 12 | remote_data: False 13 | num_workers: 4 14 | render: False 15 | 16 | #train 17 | num_updates_per_iter: 8000 18 | max_train_iters: 25 #25 x 8000 x 5-> 1000000 transition batches 19 | 20 | augment_data: False 21 | augment_prob: 0 22 | nclusters: 23 | warmup_steps: 1000 24 | drop_p: 0.1 25 | batch_size: 256 26 | lr: 1e-3 27 | wt_decay: 1e-4 28 | 29 | #model 30 | context_len: 5 31 | n_blocks: 3 32 | embed_dim: 128 33 | n_heads: 1 34 | 35 | #logging 36 | log_dir: dt_runs 37 | wandb_log: False 38 | wandb_entity: raj19 39 | wandb_run_name: 40 | wandb_group_name: 41 | wandb_dir: 42 | 43 | #saving 44 | save_snapshot: True 45 | save_snapshot_interval: 25 46 | 47 | #hydra 48 | hydra: 49 | run: 50 | dir: ${log_dir}/${dataset_name}/${now:%Y.%m.%d}_${now:%H.%M.%S} 51 | job: 52 | chdir: False -------------------------------------------------------------------------------- /collect_antmaze_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import pickle 5 | import numpy as np 6 | import gymnasium as gym 7 | from stable_baselines3 import SAC 8 | from utils import get_maze_map, AntmazeWrapper 9 | 10 | ''' 11 | Code taken from https://github.com/rodrigodelazcano/d4rl-minari-dataset-generation/blob/main/scripts/antmaze/create_antmaze_dataset.py 12 | ''' 13 | 14 | UP = 0 15 | DOWN = 1 16 | LEFT = 2 17 | RIGHT = 3 18 | 19 | EXPLORATION_ACTIONS = {UP: (0, 1), DOWN: (0, -1), LEFT: (-1, 0), RIGHT: (1, 0)} 20 | 21 | class QIteration: 22 | """Solves for optimal policy with Q-Value Iteration. 23 | 24 | Inspired by https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/pointmaze/q_iteration.py 25 | """ 26 | def __init__(self, maze): 27 | self.maze = maze 28 | self.num_states = maze.map_length * maze.map_width 29 | self.num_actions = len(EXPLORATION_ACTIONS.keys()) 30 | self.rew_matrix = np.zeros((self.num_states, self.num_actions)) 31 | self.compute_transition_matrix() 32 | 33 | def generate_path(self, current_cell, goal_cell): 34 | self.compute_reward_matrix(goal_cell) 35 | q_values = self.get_q_values() 36 | current_state = self.cell_to_state(current_cell) 37 | waypoints = {} 38 | while True: 39 | action_id = np.argmax(q_values[current_state]) 40 | next_state, _ = self.get_next_state(current_state, EXPLORATION_ACTIONS[action_id]) 41 | current_cell = self.state_to_cell(current_state) 42 | waypoints[current_cell] = self.state_to_cell(next_state) 43 | if waypoints[current_cell] == goal_cell: 44 | break 45 | 46 | current_state = next_state 47 | 48 | return waypoints 49 | 50 | def reward_function(self, desired_cell, current_cell): 51 | if desired_cell == current_cell: 52 | return 1.0 53 | else: 54 | return 0.0 55 | 56 | def state_to_cell(self, state): 57 | i = int(state/self.maze.map_width) 58 | j = state % self.maze.map_width 59 | return (i, j) 60 | 61 | def cell_to_state(self, cell): 62 | return cell[0] * self.maze.map_width + cell[1] 63 | 64 | def get_q_values(self, num_itrs=50, discount=0.99): 65 | q_fn = np.zeros((self.num_states, self.num_actions)) 66 | for _ in range(num_itrs): 67 | v_fn = np.max(q_fn, axis=1) 68 | q_fn = self.rew_matrix + discount*self.transition_matrix.dot(v_fn) 69 | return q_fn 70 | 71 | def compute_reward_matrix(self, goal_cell): 72 | for state in range(self.num_states): 73 | for action in range(self.num_actions): 74 | next_state, _ = self.get_next_state(state, EXPLORATION_ACTIONS[action]) 75 | next_cell = self.state_to_cell(next_state) 76 | self.rew_matrix[state, action] = self.reward_function(goal_cell, next_cell) 77 | 78 | def compute_transition_matrix(self): 79 | """Constructs this environment's transition matrix. 80 | Returns: 81 | A dS x dA x dS array where the entry transition_matrix[s, a, ns] 82 | corresponds to the probability of transitioning into state ns after taking 83 | action a from state s. 84 | """ 85 | self.transition_matrix = np.zeros((self.num_states, self.num_actions, self.num_states)) 86 | for state in range(self.num_states): 87 | for action_idx, action in EXPLORATION_ACTIONS.items(): 88 | next_state, valid = self.get_next_state(state, action) 89 | if valid: 90 | self.transition_matrix[state, action_idx, next_state] = 1 91 | 92 | def get_next_state(self, state, action): 93 | cell = self.state_to_cell(state) 94 | 95 | next_cell = tuple(map(lambda i, j: int(i + j), cell, action)) 96 | next_state = self.cell_to_state(next_cell) 97 | 98 | return next_state, self._check_valid_cell(next_cell) 99 | 100 | def _check_valid_cell(self, cell): 101 | # Out of map bounds 102 | if cell[0] >= self.maze.map_length: 103 | return False 104 | elif cell[1] >= self.maze.map_width: 105 | return False 106 | # Wall collision 107 | elif self.maze.maze_map[cell[0]][cell[1]] == 1: 108 | return False 109 | else: 110 | return True 111 | 112 | class WaypointController: 113 | """Generic agent controller to follow waypoints in the maze. 114 | 115 | Inspired by https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/pointmaze/waypoint_controller.py 116 | """ 117 | 118 | def __init__( 119 | self, maze, model_callback, waypoint_threshold=0.45 120 | ): 121 | self.global_target_xy = np.empty(2) 122 | self.maze = maze 123 | self.maze_solver = QIteration(maze=self.maze) 124 | 125 | self.model_callback = model_callback 126 | self.waypoint_threshold = waypoint_threshold 127 | self.waypoint_targets = None 128 | 129 | def compute_action(self, obs): 130 | # Check if we need to generate new waypoint path due to change in global target 131 | if ( 132 | np.linalg.norm(self.global_target_xy - obs["desired_goal"]) > 1e-3 133 | or self.waypoint_targets is None 134 | ): 135 | # Convert xy to cell id 136 | achived_goal_cell = tuple(self.maze.cell_xy_to_rowcol(obs["achieved_goal"])) 137 | self.global_target_id = tuple( 138 | self.maze.cell_xy_to_rowcol(obs["desired_goal"]) 139 | ) 140 | 141 | self.global_target_xy = obs["desired_goal"] 142 | 143 | self.waypoint_targets = self.maze_solver.generate_path( 144 | achived_goal_cell, self.global_target_id 145 | ) 146 | # Check if the waypoint dictionary is empty 147 | # If empty then the ball is already in the target cell location 148 | if self.waypoint_targets: 149 | self.current_control_target_id = self.waypoint_targets[ 150 | achived_goal_cell 151 | ] 152 | # If target is global goal go directly to goal position 153 | if self.current_control_target_id == self.global_target_id: 154 | self.current_control_target_xy = obs['desired_goal'] 155 | else: 156 | self.current_control_target_xy = self.maze.cell_rowcol_to_xy( 157 | np.array(self.current_control_target_id) 158 | ) - np.random.uniform(size=(2,)) * 0.1 159 | else: 160 | self.waypoint_targets[ 161 | self.current_control_target_id 162 | ] = self.current_control_target_id 163 | self.current_control_target_id = self.global_target_id 164 | self.current_control_target_xy = self.global_target_xy 165 | 166 | # Check if we need to go to the next waypoint 167 | dist = np.linalg.norm(self.current_control_target_xy - obs["achieved_goal"]) 168 | if ( 169 | dist <= self.waypoint_threshold 170 | and self.current_control_target_id != self.global_target_id 171 | ): 172 | self.current_control_target_id = self.waypoint_targets[ 173 | self.current_control_target_id 174 | ] 175 | # If target is global goal go directly to goal position 176 | if self.current_control_target_id == self.global_target_id: 177 | 178 | self.current_control_target_xy = obs['desired_goal'] 179 | else: 180 | self.current_control_target_xy = ( 181 | self.maze.cell_rowcol_to_xy( 182 | np.array(self.current_control_target_id) 183 | ) 184 | - np.random.uniform(size=(2,)) * 0.1 185 | ) 186 | action = self.model_callback(obs, self.current_control_target_xy) 187 | return action 188 | 189 | def wrap_maze_obs(obs, waypoint_xy): 190 | """Wrap the maze obs into one suitable for GoalReachAnt.""" 191 | goal_direction = waypoint_xy - obs["achieved_goal"] 192 | observation = np.concatenate([obs["observation"][2:], goal_direction]) 193 | return observation 194 | 195 | def get_start_state_goal_pairs(dataset_name): 196 | 197 | if dataset_name == "antmaze-large-sl": 198 | train_start_state_goal = [ 199 | {"goal_cells": np.array([ [5,4], [1,10], [7,10] ], dtype=np.int32), 200 | "reset_cells": np.array([ [7,1] ], dtype=np.int32), 201 | }, 202 | ] 203 | 204 | elif dataset_name == "antmaze-large-v0": 205 | train_start_state_goal = [ 206 | {"goal_cells": np.array([ [1,1], [1,2], [1,3], [1,4], [2,1], [2,4], [3,1], [3,2], [3,3], [3,4], [4,1], [5,1], [5,2], [6,2], [7,1], [7,2] ], dtype=np.int32), 207 | "reset_cells": np.array([ [1,1], [1,2], [1,3], [1,4], [2,1], [2,4], [3,1], [3,2], [3,3], [3,4], [4,1], [5,1], [5,2], [6,2], [7,1], [7,2] ], dtype=np.int32), 208 | }, 209 | 210 | {"goal_cells": np.array([ [3,2], [3,3], [3,4], [3,5], [1,6], [2,6], [3,6], [4,6], [5,6], [6,6], [7,6], [7,5], [7,4], [6,4], [5,4], [1,6], [2,6], [1,7], [1,8], [1,9], [1,10], [2,10], [3,10], [3,9], [3,8], [2,8], [4,10], [5,10], [5,9], [5,8], [5,7], [4,6], [5,6], [6,8], [7,8], [7,9], [7,10] ], dtype=np.int32), 211 | "reset_cells": np.array([ [3,2], [3,3], [3,4], [3,5], [1,6], [2,6], [3,6], [4,6], [5,6], [6,6], [7,6], [7,5], [7,4], [6,4], [5,4], [1,7], [1,8], [1,9], [1,10], [2,10], [3,10], [3,9], [3,8], [2,8], [4,10], [5,10], [5,9], [5,8], [5,7], [6,8], [7,8], [7,9], [7,10] ], dtype=np.int32), 212 | }, 213 | ] 214 | 215 | elif dataset_name == 'antmaze-medium-v0': 216 | train_start_state_goal = [ 217 | {'goal_cells': np.array([4,6], dtype=np.int32), 218 | 'reset_cells': np.array([6,5], dtype=np.int32), 219 | 'name' : 'bottom_rightish_to_center', 220 | }, 221 | 222 | {'goal_cells': np.array([6,5], dtype=np.int32), 223 | 'reset_cells': np.array([4,6], dtype=np.int32), 224 | 'name' : 'center_to_bottom_rightish', 225 | }, 226 | 227 | {'goal_cells': np.array([4,4], dtype=np.int32), 228 | 'reset_cells': np.array([6,1], dtype=np.int32), 229 | 'name' : 'bottom_left_to_center', 230 | }, 231 | 232 | {'goal_cells': np.array([6,1], dtype=np.int32), 233 | 'reset_cells': np.array([4,4], dtype=np.int32), 234 | 'name' : 'center_to_bottom_left', 235 | }, 236 | 237 | {'goal_cells': np.array([4,2], dtype=np.int32), 238 | 'reset_cells': np.array([6,1], dtype=np.int32), 239 | 'name' : 'bottom_left_to_center_leftish', 240 | }, 241 | 242 | {'goal_cells': np.array([6,1], dtype=np.int32), 243 | 'reset_cells': np.array([4,2], dtype=np.int32), 244 | 'name' : 'center_leftish_to_bottom_left', 245 | }, 246 | 247 | {'goal_cells': np.array([4,6], dtype=np.int32), 248 | 'reset_cells': np.array([1,6], dtype=np.int32), 249 | 'name' : 'top_right_to_center', 250 | }, 251 | 252 | {'goal_cells': np.array([1,6], dtype=np.int32), 253 | 'reset_cells': np.array([4,6], dtype=np.int32), 254 | 'name' : 'center_to_top_right', 255 | }, 256 | ] 257 | 258 | elif dataset_name == "antmaze-umaze-v0": 259 | train_start_state_goal = [ 260 | {"goal_cells": np.array([ [1,1], [1,2], [1,3] ], dtype=np.int32), 261 | "reset_cells": np.array([ [1,1], [1,2], [1,3] ], dtype=np.int32), 262 | }, 263 | 264 | {"goal_cells": np.array([ [3,1], [3,2], [3,3] ], dtype=np.int32), 265 | "reset_cells": np.array([ [3,1], [3,2], [3,3] ], dtype=np.int32), 266 | }, 267 | 268 | {"goal_cells": np.array([ [1,3], [2,3] ], dtype=np.int32), 269 | "reset_cells": np.array([ [1,3], [2,3] ], dtype=np.int32), 270 | }, 271 | 272 | {"goal_cells": np.array([ [3,3], [2,3] ], dtype=np.int32), 273 | "reset_cells": np.array([ [3,3], [2,3] ], dtype=np.int32), 274 | }, 275 | ] 276 | 277 | else: 278 | raise NotImplementedError 279 | 280 | return train_start_state_goal 281 | 282 | def load_policy(policy_file): 283 | data = torch.load(policy_file) 284 | policy = data['exploration/policy'].to('cpu') 285 | env = data['evaluation/env'] 286 | print("Policy loaded") 287 | return policy, env 288 | 289 | def collect_dataset(dataset_name, num_data): 290 | if os.path.isfile("data/"+dataset_name+".pkl"): 291 | print("A dataset with the same name already exists. Doing nothing. Delete the file if you want any changes.") 292 | exit() 293 | 294 | if "antmaze-umaze" in dataset_name: 295 | env_name = "AntMaze_UMaze-v4" 296 | elif "antmaze-medium" in dataset_name: 297 | env_name = "AntMaze_Medium-v4" 298 | elif "antmaze-large" in dataset_name: 299 | env_name = "AntMaze_Large-v4" 300 | else: 301 | raise NotImplementedError 302 | 303 | # environment initialisation 304 | env = AntmazeWrapper(gym.make(env_name, continuing_task=False))#, render_mode="human")) 305 | 306 | # data placeholders 307 | observation_data = { 308 | "episode_id": np.zeros(shape=(int(num_data),), dtype=np.int32), 309 | "observation" : np.zeros(shape=(int(num_data), *env.observation_space["observation"].shape), dtype=np.float32), 310 | "achieved_goal" : np.zeros(shape=(int(num_data), *env.observation_space["achieved_goal"].shape), dtype=np.float32), 311 | } 312 | action_data = np.zeros(shape=(int(num_data), *env.action_space.shape), dtype=np.float32) 313 | termination_data = np.zeros(shape = (int(num_data),), dtype=bool) 314 | 315 | # get data collecting start state and goal pairs 316 | train_start_state_goal = get_start_state_goal_pairs(dataset_name) 317 | 318 | # controller 319 | model = SAC.load('GoalReachAnt_model.zip') 320 | def action_callback(obs, waypoint_xy): 321 | return model.predict(wrap_maze_obs(obs, waypoint_xy))[0] 322 | waypoint_controller = WaypointController(env.maze, action_callback) 323 | 324 | terminated = False 325 | truncated = False 326 | data_idx = 0 327 | episode_idx = 0 328 | success_count = 0 329 | 330 | if train_start_state_goal is not None: 331 | num_dc = len(train_start_state_goal) 332 | sg_dict = train_start_state_goal[ episode_idx % num_dc ] 333 | if len(sg_dict["goal_cells"].shape) == 2: 334 | x,_ = sg_dict["goal_cells"].shape 335 | id = np.random.choice(x, size=1)[0] 336 | sg_dict["goal_cell"] = sg_dict["goal_cells"][id] 337 | else: 338 | sg_dict["goal_cell"] = sg_dict["goal_cells"] 339 | 340 | if len(sg_dict["reset_cells"].shape) == 2: 341 | x,_ = sg_dict["reset_cells"].shape 342 | id = np.random.choice(x, size=1)[0] 343 | sg_dict["reset_cell"] = sg_dict["reset_cells"][id] 344 | else: 345 | sg_dict["reset_cell"] = sg_dict["reset_cells"] 346 | obs, _ = env.reset(options=sg_dict) 347 | else: 348 | obs, _ = env.reset() 349 | 350 | episode_start_idx = 0 351 | while data_idx < int(num_data)-1: 352 | 353 | action = waypoint_controller.compute_action(obs) 354 | action = np.clip(action, env.action_space.low, env.action_space.high, dtype=np.float32) 355 | 356 | observation_data["episode_id"][data_idx] = episode_idx 357 | observation_data["observation"][data_idx] = obs["observation"] 358 | observation_data["achieved_goal"][data_idx] = obs["achieved_goal"] 359 | 360 | action_data[data_idx] = action 361 | termination_data[data_idx] = terminated or truncated 362 | data_idx += 1 363 | 364 | obs, _, terminated, truncated, info = env.step(action) 365 | 366 | if terminated or truncated: 367 | if info['success']: 368 | success_count += 1 369 | 370 | observation_data["episode_id"][data_idx] = episode_idx 371 | observation_data["observation"][data_idx] = obs["observation"] 372 | observation_data["achieved_goal"][data_idx] = obs["achieved_goal"] 373 | termination_data[data_idx] = terminated or truncated 374 | 375 | data_idx += 1 376 | episode_idx += 1 377 | episode_start_idx = data_idx 378 | 379 | else: 380 | data_idx = episode_start_idx 381 | 382 | if train_start_state_goal is not None: 383 | sg_dict = train_start_state_goal[ episode_idx % num_dc ] 384 | if len(sg_dict["goal_cells"].shape) == 2: 385 | x,_ = sg_dict["goal_cells"].shape 386 | id = np.random.choice(x, size=1)[0] 387 | sg_dict["goal_cell"] = sg_dict["goal_cells"][id] 388 | else: 389 | sg_dict["goal_cell"] = sg_dict["goal_cells"] 390 | 391 | if len(sg_dict["reset_cells"].shape) == 2: 392 | x,_ = sg_dict["reset_cells"].shape 393 | id = np.random.choice(x, size=1)[0] 394 | sg_dict["reset_cell"] = sg_dict["reset_cells"][id] 395 | else: 396 | sg_dict["reset_cell"] = sg_dict["reset_cells"] 397 | obs, _ = env.reset(options=sg_dict) 398 | else: 399 | obs, _ = env.reset() 400 | 401 | terminated = False 402 | truncated = False 403 | 404 | if (data_idx + 1) % (num_data // 20) == 0: 405 | print("STEPS RECORDED:") 406 | print(data_idx) 407 | 408 | print("EPISODES RECORDED:") 409 | print(episode_idx) 410 | 411 | print("SUCCESS EPISODES RECORDED:") 412 | print(success_count) 413 | 414 | 415 | dataset = {"observations": observation_data, 416 | "actions": action_data, 417 | "terminations": termination_data, 418 | "success_count": success_count, 419 | "episode_count": episode_idx, 420 | } 421 | 422 | with open("data/"+dataset_name+".pkl", "wb") as fp: 423 | pickle.dump(dataset, fp) 424 | 425 | if __name__ == "__main__": 426 | import sys 427 | 428 | dataset_name = sys.argv[1] 429 | 430 | seed = sys.argv[2] 431 | num_data = int(1e6) 432 | 433 | np.random.seed(int(seed)) 434 | collect_dataset(dataset_name, num_data) -------------------------------------------------------------------------------- /collect_pointmaze_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import gymnasium as gym 5 | from utils import get_maze_map 6 | 7 | UP = 0 8 | DOWN = 1 9 | LEFT = 2 10 | RIGHT = 3 11 | 12 | EXPLORATION_ACTIONS = {UP: (0, 1), DOWN: (0, -1), LEFT: (-1, 0), RIGHT: (1, 0)} 13 | 14 | class QIteration: 15 | """Solves for optimal policy with Q-Value Iteration. 16 | 17 | Inspired by https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/pointmaze/q_iteration.py 18 | """ 19 | def __init__(self, maze): 20 | self.maze = maze 21 | self.num_states = maze.map_length * maze.map_width 22 | self.num_actions = len(EXPLORATION_ACTIONS.keys()) 23 | self.rew_matrix = np.zeros((self.num_states, self.num_actions)) 24 | self.compute_transition_matrix() 25 | 26 | def generate_path(self, current_cell, goal_cell): 27 | self.compute_reward_matrix(goal_cell) 28 | q_values = self.get_q_values() 29 | current_state = self.cell_to_state(current_cell) 30 | waypoints = {} 31 | while True: 32 | action_id = np.argmax(q_values[current_state]) 33 | next_state, _ = self.get_next_state(current_state, EXPLORATION_ACTIONS[action_id]) 34 | current_cell = self.state_to_cell(current_state) 35 | waypoints[current_cell] = self.state_to_cell(next_state) 36 | if waypoints[current_cell] == goal_cell: 37 | break 38 | 39 | current_state = next_state 40 | 41 | return waypoints 42 | 43 | def reward_function(self, desired_cell, current_cell): 44 | if desired_cell == current_cell: 45 | return 1.0 46 | else: 47 | return 0.0 48 | 49 | def state_to_cell(self, state): 50 | i = int(state/self.maze.map_width) 51 | j = state % self.maze.map_width 52 | return (i, j) 53 | 54 | def cell_to_state(self, cell): 55 | return cell[0] * self.maze.map_width + cell[1] 56 | 57 | def get_q_values(self, num_itrs=50, discount=0.99): 58 | q_fn = np.zeros((self.num_states, self.num_actions)) 59 | for _ in range(num_itrs): 60 | v_fn = np.max(q_fn, axis=1) 61 | q_fn = self.rew_matrix + discount*self.transition_matrix.dot(v_fn) 62 | return q_fn 63 | 64 | def compute_reward_matrix(self, goal_cell): 65 | for state in range(self.num_states): 66 | for action in range(self.num_actions): 67 | next_state, _ = self.get_next_state(state, EXPLORATION_ACTIONS[action]) 68 | next_cell = self.state_to_cell(next_state) 69 | self.rew_matrix[state, action] = self.reward_function(goal_cell, next_cell) 70 | 71 | def compute_transition_matrix(self): 72 | """Constructs this environment's transition matrix. 73 | Returns: 74 | A dS x dA x dS array where the entry transition_matrix[s, a, ns] 75 | corresponds to the probability of transitioning into state ns after taking 76 | action a from state s. 77 | """ 78 | self.transition_matrix = np.zeros((self.num_states, self.num_actions, self.num_states)) 79 | for state in range(self.num_states): 80 | for action_idx, action in EXPLORATION_ACTIONS.items(): 81 | next_state, valid = self.get_next_state(state, action) 82 | if valid: 83 | self.transition_matrix[state, action_idx, next_state] = 1 84 | 85 | def get_next_state(self, state, action): 86 | cell = self.state_to_cell(state) 87 | 88 | next_cell = tuple(map(lambda i, j: int(i + j), cell, action)) 89 | next_state = self.cell_to_state(next_cell) 90 | 91 | return next_state, self._check_valid_cell(next_cell) 92 | 93 | def _check_valid_cell(self, cell): 94 | # Out of map bounds 95 | if cell[0] >= self.maze.map_length: 96 | return False 97 | elif cell[1] >= self.maze.map_width: 98 | return False 99 | # Wall collision 100 | elif self.maze.maze_map[cell[0]][cell[1]] == 1: 101 | return False 102 | else: 103 | return True 104 | 105 | class WaypointController: 106 | """Agent controller to follow waypoints in the maze. 107 | 108 | Inspired by https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/pointmaze/waypoint_controller.py 109 | """ 110 | def __init__(self, maze, gains={"p": 10.0, "d": -1.0}, waypoint_threshold=0.1): 111 | self.global_target_xy = np.empty(2) 112 | self.maze = maze 113 | 114 | self.maze_solver = QIteration(maze=self.maze) 115 | 116 | self.gains = gains 117 | self.waypoint_threshold = waypoint_threshold 118 | self.waypoint_targets = None 119 | 120 | def compute_action(self, obs): 121 | # Check if we need to generate new waypoint path due to change in global target 122 | if np.linalg.norm(self.global_target_xy - obs["desired_goal"]) > 1e-3 or self.waypoint_targets is None: 123 | # Convert xy to cell id 124 | achieved_goal_cell = tuple(self.maze.cell_xy_to_rowcol(obs["achieved_goal"])) 125 | self.global_target_id = tuple(self.maze.cell_xy_to_rowcol(obs["desired_goal"])) 126 | self.global_target_xy = obs["desired_goal"] 127 | 128 | self.waypoint_targets = self.maze_solver.generate_path(achieved_goal_cell, self.global_target_id) 129 | 130 | # Check if the waypoint dictionary is empty 131 | # If empty then the ball is already in the target cell location 132 | if self.waypoint_targets: 133 | self.current_control_target_id = self.waypoint_targets[achieved_goal_cell] 134 | self.current_control_target_xy = self.maze.cell_rowcol_to_xy(np.array(self.current_control_target_id)) 135 | else: 136 | self.waypoint_targets[self.current_control_target_id] = self.current_control_target_id 137 | self.current_control_target_id = self.global_target_id 138 | self.current_control_target_xy = self.global_target_xy 139 | 140 | # Check if we need to go to the next waypoint 141 | dist = np.linalg.norm(self.current_control_target_xy - obs["achieved_goal"]) 142 | if dist <= self.waypoint_threshold and self.current_control_target_id != self.global_target_id: 143 | self.current_control_target_id = self.waypoint_targets[self.current_control_target_id] 144 | # If target is global goal go directly to goal position 145 | if self.current_control_target_id == self.global_target_id: 146 | self.current_control_target_xy = self.global_target_xy 147 | else: 148 | self.current_control_target_xy = self.maze.cell_rowcol_to_xy(np.array(self.current_control_target_id)) - np.random.uniform(size=(2,))*0.2 149 | 150 | action = self.gains["p"] * (self.current_control_target_xy - obs["achieved_goal"]) + self.gains["d"] * obs["observation"][2:] 151 | action = np.clip(action, -1, 1) 152 | 153 | return action 154 | 155 | def get_start_state_goal_pairs(dataset_name): 156 | if dataset_name == "pointmaze-large-sl": 157 | train_start_state_goal = [ 158 | {"goal_cells": np.array([ [5,4], [1,10], [7,10] ], dtype=np.int32), 159 | "reset_cells": np.array([ [7,1] ], dtype=np.int32), 160 | }, 161 | ] 162 | 163 | elif dataset_name == "pointmaze-large-v0": 164 | train_start_state_goal = [ 165 | {"goal_cells": np.array([ [1,1], [1,2], [1,3], [1,4], [2,1], [2,4], [3,1], [3,2], [3,3], [3,4], [4,1], [5,1], [5,2], [6,2], [7,1], [7,2] ], dtype=np.int32), 166 | "reset_cells": np.array([ [1,1], [1,2], [1,3], [1,4], [2,1], [2,4], [3,1], [3,2], [3,3], [3,4], [4,1], [5,1], [5,2], [6,2], [7,1], [7,2] ], dtype=np.int32), 167 | }, 168 | 169 | {"goal_cells": np.array([ [3,4], [3,5], [1,6], [2,6], [3,6], [4,6], [5,6], [6,6], [7,6], [7,5], [7,4], [6,4], [5,4] ], dtype=np.int32), 170 | "reset_cells": np.array([ [3,4], [3,5], [1,6], [2,6], [3,6], [4,6], [5,6], [6,6], [7,6], [7,5], [7,4], [6,4], [5,4] ], dtype=np.int32), 171 | }, 172 | 173 | {"goal_cells": np.array([ [1,6], [1,7], [1,8], [1,9], [1,10], [2,10], [3,10], [3,9], [3,8], [2,8], [4,10], [5,10], [5,9], [5,8], [5,7], [5,6], [6,8], [7,8], [7,9], [7,10] ], dtype=np.int32), 174 | "reset_cells": np.array([ [1,7], [1,8], [1,9], [1,10], [2,10], [3,10], [3,9], [3,8], [2,8], [4,10], [5,10], [5,9], [5,8], [5,7], [6,8], [7,8], [7,9], [7,10] ], dtype=np.int32), 175 | }, 176 | ] 177 | 178 | elif dataset_name == "pointmaze-medium-v0": 179 | train_start_state_goal = [ 180 | {"goal_cells": np.array([ [6,5], [6,6], [5,6], [4,6] ], dtype=np.int32), 181 | "reset_cells": np.array([ [6,5], [6,6], [5,6], [4,6] ], dtype=np.int32), 182 | }, 183 | 184 | {"goal_cells": np.array([ [4,1], [4,2], [5,1], [6,1], [6,2], [6,3], [5,3] ], dtype=np.int32), 185 | "reset_cells": np.array([ [4,1], [4,2], [5,1], [6,1], [6,2], [6,3], [5,3] ], dtype=np.int32), 186 | }, 187 | 188 | {"goal_cells": np.array([ [1,5], [1,6], [2,4], [2,5], [2,6] ], dtype=np.int32), 189 | "reset_cells": np.array([ [1,5], [1,6], [2,4], [2,5], [2,6] ], dtype=np.int32), 190 | }, 191 | 192 | {"goal_cells": np.array([ [1,1], [1,2], [2,1], [2,2] ], dtype=np.int32), 193 | "reset_cells": np.array([ [1,1], [1,2], [2,1], [2,2] ], dtype=np.int32), 194 | }, 195 | 196 | {"goal_cells": np.array([ [5,3], [5,4], [4,2], [4,4], [4,5], [4,6], [3,2], [3,3], [3,4], [2,2], [2,4] ], dtype=np.int32), 197 | "reset_cells": np.array([ [5,3], [5,4], [4,2], [4,4], [4,5], [4,6], [3,2], [3,3], [3,4], [2,2], [2,4] ], dtype=np.int32), 198 | }, 199 | ] 200 | 201 | elif dataset_name == "pointmaze-umaze-v0": 202 | train_start_state_goal = [ 203 | {"goal_cells": np.array([ [1,1], [1,2], [1,3] ], dtype=np.int32), 204 | "reset_cells": np.array([ [1,1], [1,2], [1,3] ], dtype=np.int32), 205 | }, 206 | 207 | {"goal_cells": np.array([ [3,1], [3,2], [3,3] ], dtype=np.int32), 208 | "reset_cells": np.array([ [3,1], [3,2], [3,3] ], dtype=np.int32), 209 | }, 210 | 211 | {"goal_cells": np.array([ [1,3], [2,3] ], dtype=np.int32), 212 | "reset_cells": np.array([ [1,3], [2,3] ], dtype=np.int32), 213 | }, 214 | 215 | {"goal_cells": np.array([ [3,3], [2,3] ], dtype=np.int32), 216 | "reset_cells": np.array([ [3,3], [2,3] ], dtype=np.int32), 217 | }, 218 | ] 219 | 220 | else: 221 | raise NotImplementedError 222 | 223 | return train_start_state_goal 224 | 225 | def collect_dataset(dataset_name, num_data): 226 | if os.path.isfile("data/"+dataset_name+'-'+str(num_data)+".pkl"): 227 | print("A dataset with the same name already exists. Doing nothing. Delete the file if you want any changes.") 228 | exit() 229 | 230 | if "pointmaze-umaze" in dataset_name: 231 | env_name = "PointMaze_UMaze-v3" 232 | elif "pointmaze-medium" in dataset_name: 233 | env_name = "PointMaze_Medium-v3" 234 | elif "pointmaze-large" in dataset_name: 235 | env_name = "PointMaze_Large-v3" 236 | else: 237 | raise NotImplementedError 238 | 239 | # environment initialisation 240 | env = gym.make(env_name, continuing_task=False, max_episode_steps=10 * num_data)#, render_mode="human") 241 | 242 | # data placeholders 243 | observation_data = { 244 | "episode_id": np.zeros(shape=(int(num_data),), dtype=np.int32), 245 | "observation" : np.zeros(shape=(int(num_data), *env.observation_space["observation"].shape), dtype=np.float32), 246 | "achieved_goal" : np.zeros(shape=(int(num_data), *env.observation_space["achieved_goal"].shape), dtype=np.float32), 247 | } 248 | action_data = np.zeros(shape=(int(num_data), *env.action_space.shape), dtype=np.float32) 249 | termination_data = np.zeros(shape = (int(num_data),), dtype=bool) 250 | 251 | # controller initialisation 252 | waypoint_controller = WaypointController(maze=env.maze) 253 | 254 | # get data collecting start state and goal pairs 255 | train_start_state_goal = get_start_state_goal_pairs(dataset_name) 256 | 257 | terminated = False 258 | truncated = False 259 | data_idx = 0 260 | episode_idx = 0 261 | success_count = 0 262 | 263 | num_dc = len(train_start_state_goal) 264 | 265 | sg_dict = train_start_state_goal[ episode_idx % num_dc ] 266 | if len(sg_dict["goal_cells"].shape) == 2: 267 | x,_ = sg_dict["goal_cells"].shape 268 | id = np.random.choice(x, size=1)[0] 269 | sg_dict["goal_cell"] = sg_dict["goal_cells"][id] 270 | else: 271 | sg_dict["goal_cell"] = sg_dict["goal_cells"] 272 | 273 | if len(sg_dict["reset_cells"].shape) == 2: 274 | x,_ = sg_dict["reset_cells"].shape 275 | id = np.random.choice(x, size=1)[0] 276 | sg_dict["reset_cell"] = sg_dict["reset_cells"][id] 277 | else: 278 | sg_dict["reset_cell"] = sg_dict["reset_cells"] 279 | obs, _ = env.reset(options=sg_dict) 280 | 281 | while data_idx < int(num_data)-1: 282 | action = waypoint_controller.compute_action(obs) 283 | # Add some noise to each step action 284 | action += np.random.randn(*action.shape)*0.5 285 | action = np.clip(action, env.action_space.low, env.action_space.high, dtype=np.float32) 286 | 287 | observation_data["episode_id"][data_idx] = episode_idx 288 | observation_data["observation"][data_idx] = obs["observation"] 289 | observation_data["achieved_goal"][data_idx] = obs["achieved_goal"] 290 | action_data[data_idx] = action 291 | termination_data[data_idx] = terminated or truncated 292 | data_idx += 1 293 | 294 | obs, _, terminated, truncated, info = env.step(action) 295 | 296 | if terminated or truncated: 297 | if info['success']: 298 | success_count += 1 299 | 300 | observation_data["episode_id"][data_idx] = episode_idx 301 | observation_data["observation"][data_idx] = obs["observation"] 302 | observation_data["achieved_goal"][data_idx] = obs["achieved_goal"] 303 | termination_data[data_idx] = terminated or truncated 304 | 305 | data_idx += 1 306 | episode_idx += 1 307 | episode_start_idx = data_idx 308 | 309 | else: 310 | data_idx = episode_start_idx 311 | 312 | sg_dict = train_start_state_goal[ episode_idx % num_dc ] 313 | if len(sg_dict["goal_cells"].shape) == 2: 314 | x,_ = sg_dict["goal_cells"].shape 315 | id = np.random.choice(x, size=1)[0] 316 | sg_dict["goal_cell"] = sg_dict["goal_cells"][id] 317 | else: 318 | sg_dict["goal_cell"] = sg_dict["goal_cells"] 319 | 320 | if len(sg_dict["reset_cells"].shape) == 2: 321 | x,_ = sg_dict["reset_cells"].shape 322 | id = np.random.choice(x, size=1)[0] 323 | sg_dict["reset_cell"] = sg_dict["reset_cells"][id] 324 | else: 325 | sg_dict["reset_cell"] = sg_dict["reset_cells"] 326 | 327 | obs, _ = env.reset(options=sg_dict) 328 | 329 | terminated = False 330 | truncated = False 331 | 332 | if (data_idx + 1) % (num_data // 5) == 0: 333 | print("STEPS RECORDED:") 334 | print(data_idx) 335 | 336 | print("EPISODES RECORDED:") 337 | print(episode_idx) 338 | 339 | print("SUCCESS EPISODES RECORDED:") 340 | print(success_count) 341 | 342 | dataset = {"observations":observation_data, 343 | "actions":action_data, 344 | "terminations":termination_data, 345 | "success_count": success_count, 346 | "episode_count": episode_idx, 347 | } 348 | 349 | with open("data/"+dataset_name+'-'+str(num_data)+".pkl", "wb") as fp: 350 | pickle.dump(dataset, fp) 351 | 352 | if __name__ == "__main__": 353 | import sys 354 | 355 | dataset_name = sys.argv[1] 356 | 357 | seed = sys.argv[2] 358 | num_data = int(sys.argv[3]) 359 | 360 | np.random.seed(int(seed)) 361 | collect_dataset(dataset_name, num_data) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.distributions as td 6 | 7 | def count_parameters(model): 8 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | 10 | class DecisionMLP(nn.Module): 11 | def __init__(self, env_name, env, goal_dim=2, h_dim=1024): 12 | super().__init__() 13 | 14 | env_name = env_name 15 | state_dim = env.observation_space['observation'].shape[0] 16 | act_dim = env.action_space.shape[0] 17 | 18 | self.mlp = nn.Sequential( 19 | nn.Linear(state_dim + goal_dim, h_dim), 20 | nn.ReLU(), 21 | nn.Linear(h_dim, h_dim), 22 | nn.ReLU(), 23 | nn.Linear(h_dim, act_dim), 24 | nn.Tanh() 25 | ) 26 | 27 | def forward(self, states, goals): 28 | h = torch.cat((states, goals), dim=-1) 29 | action_preds = self.mlp(h) 30 | return action_preds 31 | 32 | class MaskedCausalAttention(nn.Module): 33 | ''' 34 | Thanks https://github.com/nikhilbarhate99/min-decision-transformer/tree/master 35 | ''' 36 | def __init__(self, h_dim, max_T, n_heads, drop_p): 37 | super().__init__() 38 | 39 | self.n_heads = n_heads 40 | self.max_T = max_T 41 | 42 | self.q_net = nn.Linear(h_dim, h_dim) 43 | self.k_net = nn.Linear(h_dim, h_dim) 44 | self.v_net = nn.Linear(h_dim, h_dim) 45 | 46 | self.proj_net = nn.Linear(h_dim, h_dim) 47 | 48 | self.dropout = drop_p 49 | self.att_drop = nn.Dropout(drop_p) 50 | self.proj_drop = nn.Dropout(drop_p) 51 | 52 | def forward(self, x): 53 | B,T,C = x.shape # batch size, seq length, h_dim * n_heads 54 | 55 | N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim 56 | 57 | # rearrange q, k, v as (B, N, T, D) 58 | q = self.q_net(x).view(B, T, N, D).transpose(1,2) 59 | k = self.k_net(x).view(B, T, N, D).transpose(1,2) 60 | v = self.v_net(x).view(B, T, N, D).transpose(1,2) 61 | 62 | attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 63 | attention = attention.transpose(1, 2).contiguous().view(B,T,N*D) 64 | 65 | out = self.proj_drop(self.proj_net(attention)) 66 | return out 67 | 68 | class Block(nn.Module): 69 | def __init__(self, h_dim, max_T, n_heads, drop_p): 70 | super().__init__() 71 | self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) 72 | self.mlp = nn.Sequential( 73 | nn.Linear(h_dim, 4*h_dim), 74 | nn.GELU(), 75 | nn.Linear(4*h_dim, h_dim), 76 | nn.Dropout(drop_p), 77 | ) 78 | self.ln1 = nn.LayerNorm(h_dim) 79 | self.ln2 = nn.LayerNorm(h_dim) 80 | 81 | def forward(self, x): 82 | x = x + self.attention(self.ln1(x)) # residual 83 | x = x + self.mlp(self.ln2(x)) # residual 84 | return x 85 | 86 | class DecisionTransformer(nn.Module): 87 | def __init__(self, env_name, env, n_blocks, h_dim, context_len, 88 | n_heads, drop_p, goal_dim=2, max_timestep=4096): 89 | super().__init__() 90 | 91 | self.env_name = env_name 92 | self.state_dim = env.observation_space['observation'].shape[0] 93 | self.act_dim = env.action_space.shape[0] 94 | self.goal_dim = goal_dim 95 | self.n_heads = n_heads 96 | self.h_dim = h_dim 97 | 98 | ### transformer blocks 99 | input_seq_len = 3 * context_len 100 | blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] 101 | self.transformer = nn.Sequential(*blocks) 102 | 103 | ### projection heads (project to embedding) 104 | self.embed_timestep = nn.Embedding(max_timestep, h_dim) 105 | self.embed_goal = torch.nn.Linear(goal_dim, h_dim) 106 | self.embed_state = torch.nn.Linear(self.state_dim, h_dim) 107 | self.embed_action = torch.nn.Linear(self.act_dim, h_dim) 108 | 109 | ### prediction heads 110 | self.final_ln = nn.LayerNorm(h_dim) 111 | self.predict_action = nn.Sequential( 112 | *([nn.Linear(h_dim, self.act_dim)] + ([nn.Tanh()])) 113 | ) 114 | 115 | def forward(self, states, actions, goals): 116 | B, T, _ = states.shape 117 | 118 | timesteps = torch.arange(0, T, dtype=torch.long, device=states.device) 119 | time_embeddings = self.embed_timestep(timesteps) 120 | state_embeddings = self.embed_state(states) + time_embeddings #B, T, h_dim 121 | action_embeddings = self.embed_action(actions) + time_embeddings #B, T, h_dim 122 | goal_embeddings = self.embed_goal(goals) + time_embeddings #B, T, h_dim 123 | 124 | h = torch.stack( 125 | (goal_embeddings, state_embeddings, action_embeddings), dim=1 126 | ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) 127 | 128 | # transformer and prediction 129 | h = self.transformer(h) 130 | 131 | h = self.final_ln(h) 132 | 133 | # get h reshaped such that its size = (B , 3 , T , h_dim) and 134 | # h[:, 0, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t 135 | # h[:, 1, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t 136 | # h[:, 2, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t, a_t 137 | # that is, for each timestep (t) we have 3 output embeddings from the transformer, 138 | # each conditioned on all previous timesteps plus 139 | # the 3 input variables at that timestep (g_t, s_t, a_t) in sequence. 140 | h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) # B, 3, T, h_dim 141 | action_preds = self.predict_action(h[:,1]) 142 | return action_preds -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | antlr4-python3-runtime==4.9.3 2 | appdirs==1.4.4 3 | cachetools==5.3.2 4 | certifi==2023.11.17 5 | charset-normalizer==3.3.2 6 | click==8.1.7 7 | cloudpickle==3.0.0 8 | colorama==0.4.6 9 | contourpy==1.2.0 10 | cycler==0.12.1 11 | docker-pycreds==0.4.0 12 | Farama-Notifications==0.0.4 13 | filelock==3.13.1 14 | fonttools==4.47.2 15 | fsspec==2023.12.2 16 | gitdb==4.0.11 17 | GitPython==3.1.41 18 | google-api-core==2.15.0 19 | google-auth==2.26.2 20 | google-cloud-core==2.4.1 21 | google-cloud-storage==2.5.0 22 | google-crc32c==1.5.0 23 | google-resumable-media==2.7.0 24 | googleapis-common-protos==1.62.0 25 | gymnasium==0.29.1 26 | h5py==3.10.0 27 | hydra-core==1.3.2 28 | idna==3.6 29 | Jinja2==3.1.3 30 | joblib==1.3.2 31 | kiwisolver==1.4.5 32 | markdown-it-py==3.0.0 33 | MarkupSafe==2.1.3 34 | matplotlib==3.8.2 35 | mdurl==0.1.2 36 | minari==0.4.2 37 | mpmath==1.3.0 38 | networkx==3.2.1 39 | numpy==1.26.3 40 | nvidia-cublas-cu12==12.1.3.1 41 | nvidia-cuda-cupti-cu12==12.1.105 42 | nvidia-cuda-nvrtc-cu12==12.1.105 43 | nvidia-cuda-runtime-cu12==12.1.105 44 | nvidia-cudnn-cu12==8.9.2.26 45 | nvidia-cufft-cu12==11.0.2.54 46 | nvidia-curand-cu12==10.3.2.106 47 | nvidia-cusolver-cu12==11.4.5.107 48 | nvidia-cusparse-cu12==12.1.0.106 49 | nvidia-nccl-cu12==2.18.1 50 | nvidia-nvjitlink-cu12==12.3.101 51 | nvidia-nvtx-cu12==12.1.105 52 | omegaconf==2.3.0 53 | packaging==23.1 54 | pandas==2.1.4 55 | pillow==10.2.0 56 | portion==2.4.0 57 | protobuf==4.25.2 58 | psutil==5.9.7 59 | pyasn1==0.5.1 60 | pyasn1-modules==0.3.0 61 | Pygments==2.17.2 62 | pyparsing==3.1.1 63 | python-dateutil==2.8.2 64 | pytz==2023.3.post1 65 | PyYAML==6.0.1 66 | requests==2.31.0 67 | rich==13.7.0 68 | rsa==4.9 69 | scikit-learn==1.3.2 70 | scipy==1.11.4 71 | sentry-sdk==1.39.2 72 | setproctitle==1.3.3 73 | shellingham==1.5.4 74 | six==1.16.0 75 | smmap==5.0.1 76 | sortedcontainers==2.4.0 77 | stable-baselines3==2.2.1 78 | sympy==1.12 79 | threadpoolctl==3.2.0 80 | torch==2.1.2 81 | tqdm==4.66.1 82 | triton==2.1.0 83 | typer==0.9.0 84 | typing_extensions==4.9.0 85 | tzdata==2023.4 86 | urllib3==2.1.0 87 | wandb==0.16.2 88 | -------------------------------------------------------------------------------- /train_dmlp.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import wandb 3 | import random 4 | import minari 5 | import numpy as np 6 | import gymnasium as gym 7 | from pathlib import Path 8 | from datetime import datetime 9 | from omegaconf import DictConfig 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | from model import DecisionMLP 16 | from utils import MinariEpisodicDataset, convert_remote_to_local, get_test_start_state_goals, get_lr, AntmazeWrapper 17 | 18 | def set_seed(seed): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | 24 | def eval_env(cfg, model, device, render=False): 25 | if render: 26 | render_mode = 'human' 27 | else: 28 | render_mode = None 29 | 30 | if "pointmaze" in cfg.dataset_name: 31 | env = env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode) 32 | elif "antmaze" in cfg.dataset_name: 33 | env = AntmazeWrapper(env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode)) 34 | else: 35 | raise NotImplementedError 36 | 37 | test_start_state_goal = get_test_start_state_goals(cfg) 38 | 39 | model.eval() 40 | results = dict() 41 | with torch.no_grad(): 42 | cum_reward = 0 43 | for ss_g in test_start_state_goal: 44 | total_reward = 0 45 | total_timesteps = 0 46 | 47 | print(ss_g['name'] + ':') 48 | for _ in range(cfg.num_eval_ep): 49 | obs, _ = env.reset(options=ss_g) 50 | done = False 51 | for t in range(env.spec.max_episode_steps): 52 | total_timesteps += 1 53 | 54 | running_state = torch.tensor(obs['observation'], dtype=torch.float32, device=device).view(1, -1) 55 | target_goal = torch.tensor(obs['desired_goal'], dtype=torch.float32, device=device).view(1, -1) 56 | 57 | act_preds = model.forward( 58 | running_state, 59 | target_goal, 60 | ) 61 | act = act_preds[0].detach() 62 | 63 | obs, running_reward, done, _, _ = env.step(act.cpu().numpy()) 64 | 65 | total_reward += running_reward 66 | 67 | if done: 68 | break 69 | 70 | print('Achievied goal: ', tuple(obs['achieved_goal'].tolist())) 71 | print('Desired goal: ', tuple(obs['desired_goal'].tolist())) 72 | 73 | print("=" * 60) 74 | cum_reward += total_reward 75 | results['eval/' + str(ss_g['name']) + '_avg_reward'] = total_reward / cfg.num_eval_ep 76 | results['eval/' + str(ss_g['name']) + '_avg_ep_len'] = total_timesteps / cfg.num_eval_ep 77 | 78 | results['eval/avg_reward'] = cum_reward / (cfg.num_eval_ep * len(test_start_state_goal)) 79 | env.close() 80 | return results 81 | 82 | def train(cfg, hydra_cfg): 83 | 84 | #set seed 85 | set_seed(cfg.seed) 86 | 87 | #set device 88 | device = torch.device(cfg.device) 89 | 90 | if cfg.save_snapshot: 91 | checkpoint_path = Path(hydra_cfg['runtime']['output_dir']) / Path('checkpoint') 92 | checkpoint_path.mkdir(exist_ok=True) 93 | best_eval_returns = 0 94 | 95 | start_time = datetime.now().replace(microsecond=0) 96 | time_elapsed = start_time - start_time 97 | start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S") 98 | 99 | if "pointmaze" in cfg.dataset_name: 100 | if "umaze" in cfg.dataset_name: 101 | cfg.env_name = 'PointMaze_UMaze-v3' 102 | cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters 103 | elif "medium" in cfg.dataset_name: 104 | cfg.env_name = 'PointMaze_Medium-v3' 105 | cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters 106 | elif "large" in cfg.dataset_name: 107 | cfg.env_name = 'PointMaze_Large-v3' 108 | cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters 109 | env = gym.make(cfg.env_name, continuing_task=False) 110 | 111 | elif "antmaze" in cfg.dataset_name: 112 | if "umaze" in cfg.dataset_name: 113 | cfg.env_name = 'AntMaze_UMaze-v4' 114 | cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters 115 | elif "medium" in cfg.dataset_name: 116 | cfg.env_name = 'AntMaze_Medium-v4' 117 | cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters 118 | elif "large" in cfg.dataset_name: 119 | cfg.env_name = 'AntMaze_Large-v4' 120 | cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters 121 | else: 122 | raise NotImplementedError 123 | env = AntmazeWrapper(gym.make(cfg.env_name, continuing_task=False)) 124 | 125 | else: 126 | raise NotImplementedError 127 | 128 | env.action_space.seed(cfg.seed) 129 | env.observation_space.seed(cfg.seed) 130 | 131 | #create dataset 132 | if cfg.remote_data: 133 | convert_remote_to_local(cfg.dataset_name, env) 134 | 135 | train_dataset = MinariEpisodicDataset(cfg.dataset_name, cfg.remote_data, cfg.augment_data, cfg.augment_prob, cfg.nclusters) 136 | 137 | train_data_loader = DataLoader( 138 | train_dataset, 139 | batch_size=cfg.batch_size, 140 | shuffle=True, 141 | num_workers=cfg.num_workers 142 | ) 143 | train_data_iter = iter(train_data_loader) 144 | 145 | #create model 146 | model = DecisionMLP(cfg.env_name, env, goal_dim=train_dataset.goal_dim).to(device) 147 | 148 | optimizer = torch.optim.Adam( 149 | model.parameters(), 150 | lr=cfg.lr, 151 | ) 152 | 153 | total_updates = 0 154 | for i_train_iter in range(cfg.max_train_iters): 155 | 156 | log_action_losses = [] 157 | model.train() 158 | 159 | for i in range(cfg.num_updates_per_iter): 160 | try: 161 | states, goals, actions = next(train_data_iter) 162 | 163 | except StopIteration: 164 | train_data_iter = iter(train_data_loader) 165 | states, goals, actions = next(train_data_iter) 166 | 167 | states = states.to(device) 168 | actions = actions.to(device) 169 | goals = goals.to(device) 170 | 171 | action_preds = model.forward( 172 | states=states, 173 | goals=goals, 174 | ) 175 | action_loss = F.mse_loss(action_preds, actions, reduction='mean') 176 | 177 | optimizer.zero_grad() 178 | action_loss.backward() 179 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25) 180 | optimizer.step() 181 | 182 | log_action_losses.append(action_loss.detach().cpu().item()) 183 | 184 | time = datetime.now().replace(microsecond=0) - start_time - time_elapsed 185 | time_elapsed = datetime.now().replace(microsecond=0) - start_time 186 | 187 | total_updates += cfg.num_updates_per_iter 188 | 189 | mean_action_loss = np.mean(log_action_losses) 190 | 191 | results = eval_env(cfg, model, device, render=cfg.render) 192 | 193 | log_str = ("=" * 60 + '\n' + 194 | "time elapsed: " + str(time_elapsed) + '\n' + 195 | "num of updates: " + str(total_updates) + '\n' + 196 | "train action loss: " + format(mean_action_loss, ".5f") #+ '\n' + 197 | ) 198 | 199 | print(results) 200 | print(log_str) 201 | 202 | if cfg.wandb_log: 203 | log_data = dict() 204 | log_data['time'] = time.total_seconds() 205 | log_data['time_elapsed'] = time_elapsed.total_seconds() 206 | log_data['total_updates'] = total_updates 207 | log_data['mean_action_loss'] = mean_action_loss 208 | log_data['lr'] = get_lr(optimizer) 209 | log_data['training_iter'] = i_train_iter 210 | log_data.update(results) 211 | wandb.log(log_data) 212 | 213 | if cfg.save_snapshot and (1+i_train_iter)%cfg.save_snapshot_interval == 0: 214 | snapshot = Path(checkpoint_path) / Path(str(i_train_iter)+'.pt') 215 | torch.save(model.state_dict(), snapshot) 216 | 217 | if cfg.save_snapshot and results['eval/avg_reward'] >= best_eval_returns: 218 | print("=" * 60) 219 | print("saving best model!") 220 | print("=" * 60) 221 | best_eval_returns = results['eval/avg_reward'] 222 | snapshot = Path(checkpoint_path) / 'best.pt' 223 | torch.save(model.state_dict(), snapshot) 224 | 225 | print("=" * 60) 226 | print("finished training!") 227 | print("=" * 60) 228 | end_time = datetime.now().replace(microsecond=0) 229 | time_elapsed = str(end_time - start_time) 230 | end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S") 231 | print("started training at: " + start_time_str) 232 | print("finished training at: " + end_time_str) 233 | print("total training time: " + time_elapsed) 234 | print("=" * 60) 235 | 236 | @hydra.main(config_path='cfgs', config_name='dmlp', version_base=None) 237 | def main(cfg: DictConfig): 238 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 239 | 240 | if cfg.wandb_log: 241 | if cfg.wandb_dir is None: 242 | cfg.wandb_dir = hydra_cfg['runtime']['output_dir'] 243 | 244 | project_name = cfg.dataset_name 245 | wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=cfg.wandb_dir, group=cfg.wandb_group_name) 246 | wandb.run.name = cfg.wandb_run_name 247 | 248 | train(cfg, hydra_cfg) 249 | 250 | if __name__ == "__main__": 251 | main() -------------------------------------------------------------------------------- /train_dt.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import wandb 3 | import random 4 | import minari 5 | import numpy as np 6 | import gymnasium as gym 7 | from pathlib import Path 8 | from datetime import datetime 9 | from omegaconf import DictConfig 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | from model import DecisionTransformer 16 | from utils import MinariEpisodicTrajectoryDataset, convert_remote_to_local, get_test_start_state_goals, get_lr, AntmazeWrapper 17 | 18 | def set_seed(seed): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | 24 | def eval_env(cfg, model, device, render=False): 25 | if render: 26 | render_mode = 'human' 27 | else: 28 | render_mode = None 29 | 30 | if "pointmaze" in cfg.dataset_name: 31 | env = env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode) 32 | elif "antmaze" in cfg.dataset_name: 33 | env = AntmazeWrapper(env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode)) 34 | else: 35 | raise NotImplementedError 36 | 37 | test_start_state_goal = get_test_start_state_goals(cfg) 38 | 39 | model.eval() 40 | results = dict() 41 | eval_batch_size = 1 42 | 43 | with torch.no_grad(): 44 | cum_reward = 0 45 | for ss_g in test_start_state_goal: 46 | total_reward = 0 47 | total_timesteps = 0 48 | print(ss_g['name'] + ':') 49 | for _ in range(cfg.num_eval_ep): 50 | # zeros place holders 51 | m_actions = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.act_dim), 52 | dtype=torch.float32, device=device) 53 | m_states = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.state_dim), 54 | dtype=torch.float32, device=device) 55 | m_goals = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.goal_dim), 56 | dtype=torch.float32, device=device) 57 | 58 | obs, _ = env.reset(options=ss_g) 59 | done = False 60 | 61 | for t in range(env.spec.max_episode_steps): 62 | total_timesteps += 1 63 | 64 | m_states[0, t] = torch.tensor(obs['observation'], dtype=torch.float32, device=device) 65 | m_goals[0, t] = torch.tensor(obs['desired_goal'], dtype=torch.float32, device=device) 66 | 67 | 68 | if t < cfg.context_len: 69 | act_preds = model.forward(m_states[:,:t+1], 70 | m_actions[:,:t+1], 71 | m_goals[:,:t+1]) 72 | 73 | else: 74 | act_preds = model.forward(m_states[:, t-cfg.context_len+1:t+1], 75 | m_actions[:, t-cfg.context_len+1:t+1], 76 | m_goals[:, t-cfg.context_len+1:t+1]) 77 | 78 | 79 | act = act_preds[0, -1].detach() 80 | 81 | obs, running_reward, done, _, _ = env.step(act.cpu().numpy()) 82 | 83 | # add action in placeholder 84 | m_actions[0, t] = act 85 | 86 | total_reward += running_reward 87 | 88 | if done: 89 | break 90 | 91 | print('Achievied goal: ', tuple(obs['achieved_goal'].tolist())) 92 | print('Desired goal: ', tuple(obs['desired_goal'].tolist())) 93 | 94 | print("=" * 60) 95 | cum_reward += total_reward 96 | results['eval/' + str(ss_g['name']) + '_avg_reward'] = total_reward / cfg.num_eval_ep 97 | results['eval/' + str(ss_g['name']) + '_avg_ep_len'] = total_timesteps / cfg.num_eval_ep 98 | 99 | results['eval/avg_reward'] = cum_reward / (cfg.num_eval_ep * len(test_start_state_goal)) 100 | env.close() 101 | return results 102 | 103 | def train(cfg, hydra_cfg): 104 | 105 | #set seed 106 | set_seed(cfg.seed) 107 | 108 | #set device 109 | device = torch.device(cfg.device) 110 | 111 | if cfg.save_snapshot: 112 | checkpoint_path = Path(hydra_cfg['runtime']['output_dir']) / Path('checkpoint') 113 | checkpoint_path.mkdir(exist_ok=True) 114 | best_eval_returns = 0 115 | 116 | start_time = datetime.now().replace(microsecond=0) 117 | time_elapsed = start_time - start_time 118 | start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S") 119 | 120 | #create env 121 | if "pointmaze" in cfg.dataset_name: 122 | if "umaze" in cfg.dataset_name: 123 | cfg.env_name = 'PointMaze_UMaze-v3' 124 | cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters 125 | elif "medium" in cfg.dataset_name: 126 | cfg.env_name = 'PointMaze_Medium-v3' 127 | cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters 128 | elif "large" in cfg.dataset_name: 129 | cfg.env_name = 'PointMaze_Large-v3' 130 | cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters 131 | env = gym.make(cfg.env_name, continuing_task=False) 132 | 133 | elif "antmaze" in cfg.dataset_name: 134 | if "umaze" in cfg.dataset_name: 135 | cfg.env_name = 'AntMaze_UMaze-v4' 136 | cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters 137 | elif "medium" in cfg.dataset_name: 138 | cfg.env_name = 'AntMaze_Medium-v4' 139 | cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters 140 | elif "large" in cfg.dataset_name: 141 | cfg.env_name = 'AntMaze_Large-v4' 142 | cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters 143 | else: 144 | raise NotImplementedError 145 | env = AntmazeWrapper(gym.make(cfg.env_name, continuing_task=False)) 146 | 147 | else: 148 | raise NotImplementedError 149 | 150 | env.action_space.seed(cfg.seed) 151 | env.observation_space.seed(cfg.seed) 152 | 153 | #create dataset 154 | if cfg.remote_data: 155 | convert_remote_to_local(cfg.dataset_name, env) 156 | 157 | print(cfg.nclusters) 158 | 159 | train_dataset = MinariEpisodicTrajectoryDataset(cfg.dataset_name, cfg.remote_data, cfg.context_len, cfg.augment_data, cfg.augment_prob, cfg.nclusters) 160 | 161 | train_data_loader = DataLoader( 162 | train_dataset, 163 | batch_size=cfg.batch_size, 164 | shuffle=True, 165 | num_workers=cfg.num_workers 166 | ) 167 | train_data_iter = iter(train_data_loader) 168 | 169 | #create model 170 | model = DecisionTransformer(cfg.env_name, env, cfg.n_blocks, cfg.embed_dim, cfg.context_len, cfg.n_heads, cfg.drop_p, goal_dim=train_dataset.goal_dim).to(device) 171 | 172 | optimizer = torch.optim.AdamW( 173 | model.parameters(), 174 | lr=cfg.lr, 175 | weight_decay=cfg.wt_decay 176 | ) 177 | 178 | scheduler = torch.optim.lr_scheduler.LambdaLR( 179 | optimizer, 180 | lambda steps: min((steps+1)/cfg.warmup_steps, 1) 181 | ) 182 | 183 | total_updates = 0 184 | for i_train_iter in range(cfg.max_train_iters): 185 | 186 | log_action_losses = [] 187 | model.train() 188 | 189 | for i in range(cfg.num_updates_per_iter): 190 | try: 191 | states, goals, actions = next(train_data_iter) 192 | 193 | except StopIteration: 194 | train_data_iter = iter(train_data_loader) 195 | states, goals, actions = next(train_data_iter) 196 | 197 | states = states.to(device) # B x T x state_dim 198 | goals = goals.to(device).repeat(1, cfg.context_len, 1) # B x T x goal_dim 199 | actions = actions.to(device) # B x T 200 | 201 | action_preds = model.forward( 202 | states=states, 203 | actions=actions, 204 | goals=goals, 205 | ) 206 | 207 | action_loss = F.mse_loss(action_preds, actions) 208 | 209 | optimizer.zero_grad() 210 | action_loss.backward() 211 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25) 212 | optimizer.step() 213 | scheduler.step() 214 | 215 | log_action_losses.append(action_loss.detach().cpu().item()) 216 | 217 | time = datetime.now().replace(microsecond=0) - start_time - time_elapsed 218 | time_elapsed = datetime.now().replace(microsecond=0) - start_time 219 | 220 | total_updates += cfg.num_updates_per_iter 221 | 222 | mean_action_loss = np.mean(log_action_losses) 223 | 224 | results = eval_env(cfg, model, device, render=cfg.render) 225 | 226 | log_str = ("=" * 60 + '\n' + 227 | "time elapsed: " + str(time_elapsed) + '\n' + 228 | "num of updates: " + str(total_updates) + '\n' + 229 | "train action loss: " + format(mean_action_loss, ".5f") #+ '\n' + 230 | ) 231 | 232 | print(results) 233 | print(log_str) 234 | 235 | if cfg.wandb_log: 236 | log_data = dict() 237 | log_data['time'] = time.total_seconds() 238 | log_data['time_elapsed'] = time_elapsed.total_seconds() 239 | log_data['total_updates'] = total_updates 240 | log_data['mean_action_loss'] = mean_action_loss 241 | log_data['lr'] = get_lr(optimizer) 242 | log_data['training_iter'] = i_train_iter 243 | log_data.update(results) 244 | wandb.log(log_data) 245 | 246 | if cfg.save_snapshot and (1+i_train_iter)%cfg.save_snapshot_interval == 0: 247 | snapshot = Path(checkpoint_path) / Path(str(i_train_iter)+'.pt') 248 | torch.save(model.state_dict(), snapshot) 249 | 250 | if cfg.save_snapshot and results['eval/avg_reward'] >= best_eval_returns: 251 | print("=" * 60) 252 | print("saving best model!") 253 | print("=" * 60) 254 | best_eval_returns = results['eval/avg_reward'] 255 | snapshot = Path(checkpoint_path) / 'best.pt' 256 | torch.save(model.state_dict(), snapshot) 257 | 258 | print("=" * 60) 259 | print("finished training!") 260 | print("=" * 60) 261 | end_time = datetime.now().replace(microsecond=0) 262 | time_elapsed = str(end_time - start_time) 263 | end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S") 264 | print("started training at: " + start_time_str) 265 | print("finished training at: " + end_time_str) 266 | print("total training time: " + time_elapsed) 267 | print("=" * 60) 268 | 269 | @hydra.main(config_path='cfgs', config_name='dt', version_base=None) 270 | def main(cfg: DictConfig): 271 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 272 | 273 | if cfg.wandb_log: 274 | if cfg.wandb_dir is None: 275 | cfg.wandb_dir = hydra_cfg['runtime']['output_dir'] 276 | 277 | project_name = cfg.dataset_name 278 | wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=cfg.wandb_dir, group=cfg.wandb_group_name) 279 | wandb.run.name = cfg.wandb_run_name 280 | 281 | train(cfg, hydra_cfg) 282 | 283 | if __name__ == "__main__": 284 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .env import * 3 | from .misc import * -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import minari 5 | import numpy as np 6 | import gymnasium as gym 7 | from datetime import datetime 8 | from sklearn.cluster import KMeans 9 | 10 | from collections import defaultdict 11 | from torch.utils.data import Dataset 12 | 13 | def convert_remote_to_local(dataset_name, env): 14 | 15 | if os.path.isfile('data/'+dataset_name+'-remote.pkl'): 16 | print("A dataset with the same name already exists. Using that dataset.") 17 | return 18 | 19 | if "pointmaze" in dataset_name: 20 | if "pointmaze-umaze" in dataset_name: 21 | env_name = 'PointMaze_UMaze-v3' 22 | elif "pointmaze-medium" in dataset_name: 23 | env_name = 'PointMaze_Medium-v3' 24 | elif "pointmaze-large" in dataset_name: 25 | env_name = 'PointMaze_Large-v3' 26 | 27 | env = gym.make(env_name, continuing_task=False) 28 | 29 | else: 30 | raise NotImplementedError 31 | 32 | minari_dataset = minari.load_dataset(dataset_name) 33 | 34 | # data placeholders 35 | observation_data = { 36 | 'episode_id': np.zeros(shape=(int(1e6),), dtype=np.int32), 37 | 'observation' : np.zeros(shape=(int(1e6), *env.observation_space['observation'].shape), dtype=np.float32), 38 | 'achieved_goal' : np.zeros(shape=(int(1e6), *env.observation_space['achieved_goal'].shape), dtype=np.float32), 39 | } 40 | action_data = np.zeros(shape=(int(1e6), *env.action_space.shape), dtype=np.float32) 41 | termination_data = np.zeros(shape = (int(1e6),), dtype=bool) 42 | 43 | data_idx = 0 44 | episode_idx = 0 45 | for episode in minari_dataset: 46 | if data_idx + episode.total_timesteps + 1 > int(1e6): 47 | 48 | observation_data['episode_id'][data_idx: ] = episode_idx 49 | observation_data['observation'][data_idx: ] = episode.observations['observation'][:int(1e6)-data_idx] 50 | observation_data['achieved_goal'][data_idx: ] = episode.observations['achieved_goal'][:int(1e6)-data_idx] 51 | 52 | action_data[data_idx: ] = episode.actions[:int(1e6)-data_idx] 53 | termination_data[data_idx+1: ] = episode.truncations[:int(1e6)-data_idx-1] 54 | 55 | break 56 | 57 | else: 58 | try: 59 | observation_data['episode_id'][data_idx: data_idx+episode.total_timesteps+1] = episode_idx 60 | observation_data['observation'][data_idx: data_idx+episode.total_timesteps+1] = episode.observations['observation'] 61 | observation_data['achieved_goal'][data_idx: data_idx+episode.total_timesteps+1] = episode.observations['achieved_goal'] 62 | 63 | action_data[data_idx: data_idx+episode.total_timesteps] = episode.actions 64 | termination_data[data_idx+1: data_idx+episode.total_timesteps+1] = episode.truncations 65 | 66 | data_idx = data_idx + episode.total_timesteps + 1 67 | episode_idx = episode_idx + 1 68 | 69 | except: 70 | # some episodes are wierd; timesteps is equal to num obervations 71 | continue 72 | 73 | if (episode_idx + 1) % 1000 == 0: 74 | print('EPISODES RECORDED = ', episode_idx) 75 | 76 | if data_idx >= int(1e6): 77 | break 78 | 79 | print('Total transitions recorded = ', data_idx) 80 | 81 | dataset = {'observations':observation_data, 82 | 'actions':action_data, 83 | 'terminations':termination_data, 84 | } 85 | 86 | with open('data/'+dataset_name+'-remote.pkl', 'wb') as fp: 87 | pickle.dump(dataset, fp) 88 | 89 | def extract_discrete_id_to_data_id_map(discrete_goals, dones, last_valid_traj): 90 | discrete_goal_to_data_idx = defaultdict(list) 91 | gm = 0 92 | for i, d_g in enumerate(discrete_goals): 93 | 94 | discrete_goal_to_data_idx[d_g].append(i) 95 | gm += 1 96 | 97 | if (i + 1) % 200000 == 0: 98 | print('Goals mapped:', gm) 99 | 100 | if i == last_valid_traj: 101 | break 102 | 103 | for dg, data_idxes in discrete_goal_to_data_idx.items(): 104 | discrete_goal_to_data_idx[dg] = np.array(data_idxes) 105 | 106 | print('Total goals mapped:', gm) 107 | return discrete_goal_to_data_idx 108 | 109 | def extract_done_markers(dones, episode_ids): 110 | """Given a per-timestep dones vector, return starts, ends, and lengths of trajs.""" 111 | 112 | (ends,) = np.where(dones) 113 | return ends[ episode_ids[ : ends[-1] + 1 ] ], np.where(1 - dones[: ends[-1] + 1])[0] 114 | 115 | class MinariEpisodicTrajectoryDataset(Dataset): 116 | def __init__(self, dataset_name, remote_data, context_len, augment_data, augment_prob, nclusters=40): 117 | super().__init__() 118 | if remote_data: 119 | path = 'data/'+dataset_name+'-remote.pkl' 120 | else: 121 | path = 'data/'+dataset_name+'.pkl' 122 | 123 | with open(path, 'rb') as fp: 124 | self.dataset = pickle.load(fp) 125 | 126 | self.observations = self.dataset['observations']['observation'] 127 | self.achieved_goals = self.dataset['observations']['achieved_goal'] 128 | self.actions = self.dataset['actions'] 129 | 130 | (ends,) = np.where(self.dataset['terminations']) 131 | self.starts = np.concatenate(([0], ends[:-1] + 1)) 132 | self.lengths = ends - self.starts + 1 #length = number of actions taken in an episode + 1 133 | self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ] 134 | 135 | good_idxes = self.lengths > context_len 136 | print('Throwing away ', np.sum(self.lengths[~good_idxes] - 1), 'number of transitions') 137 | self.starts = self.starts[good_idxes] #starts will only contain indices of episodes where number of states > context_len 138 | self.lengths = self.lengths[good_idxes] 139 | 140 | self.num_trajectories = len(self.starts) 141 | 142 | if augment_data: 143 | start_time = datetime.now().replace(microsecond=0) 144 | print('starting kmeans ... ') 145 | kmeans = KMeans(n_clusters=nclusters, n_init="auto").fit(self.achieved_goals) 146 | time_elapsed = str(datetime.now().replace(microsecond=0) - start_time) 147 | print('kmeans done! time taken :', time_elapsed) 148 | 149 | self.discrete_goal_to_data_idx = extract_discrete_id_to_data_id_map(kmeans.labels_, self.dataset['terminations'], self.ends[-1]) 150 | self.achieved_discrete_goals = kmeans.labels_ 151 | kmeans = None 152 | 153 | self.state_dim = self.observations.shape[-1] 154 | self.state_dtype = self.observations.dtype 155 | self.act_dim = self.actions.shape[-1] 156 | self.act_dtype = self.actions.dtype 157 | self.goal_dim = self.achieved_goals.shape[-1] 158 | self.context_len = context_len 159 | self.augment_data = augment_data 160 | self.augment_prob = augment_prob 161 | self.dataset = None 162 | 163 | def __len__(self): 164 | return self.num_trajectories * 100 165 | 166 | def __getitem__(self, idx): 167 | ''' 168 | Reminder: np.random.randint samples from the set [low, high) 169 | ''' 170 | idx = idx % self.num_trajectories 171 | traj_len = self.lengths[idx] - 1 #traj_len = T, traj_len is the number of actions taken in the trajectory 172 | traj_start_i = self.starts[idx] 173 | assert self.ends[traj_start_i] == traj_start_i + traj_len 174 | 175 | 176 | if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob: 177 | correct = False 178 | while not correct: 179 | si = traj_start_i + np.random.randint(0, traj_len) #si can be traj_start_i + [0, T - 1] 180 | gi = np.random.randint(si, traj_start_i + traj_len) + 1 #gi can be traj_start_i + 1 + [si + 1, T] 181 | dummy_discrete_goal = self.achieved_discrete_goals[ gi ] 182 | nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal]) 183 | nearby_goal_idx_ends = self.ends[nearby_goal_idx] 184 | if (gi-si) + (nearby_goal_idx_ends - nearby_goal_idx) + 1 > self.context_len: 185 | correct = True 186 | 187 | if gi - si < self.context_len: 188 | goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx + self.context_len - (gi - si), nearby_goal_idx_ends + 1) ]).view(1, -1) 189 | state = torch.tensor( np.concatenate( [ self.observations[si: gi], self.observations[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) ) 190 | action = torch.tensor( np.concatenate( [ self.actions[si: gi], self.actions[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) ) 191 | else: 192 | goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, nearby_goal_idx_ends + 1) ]).view(1, -1) 193 | state = torch.tensor(self.observations[si: si + self.context_len]) 194 | action = torch.tensor(self.actions[si: si + self.context_len]) 195 | 196 | else: 197 | si = traj_start_i + np.random.randint(0, traj_len - self.context_len + 1) #si can be traj_start_i + [0, T-C] 198 | gi = np.random.randint(si + self.context_len, traj_start_i + traj_len + 1) #gi can be [si+C, traj_start_i+T] 199 | 200 | goal = torch.tensor(self.achieved_goals[ gi ]).view(1, -1) 201 | state = torch.tensor(self.observations[si: si + self.context_len]) 202 | action = torch.tensor(self.actions[si: si + self.context_len]) 203 | 204 | return state, goal, action 205 | 206 | class MinariEpisodicDataset(Dataset): 207 | def __init__(self, dataset_name, remote_data, augment_data, augment_prob, nclusters=40): 208 | super().__init__() 209 | if remote_data: 210 | path = 'data/'+dataset_name+'-remote.pkl' 211 | else: 212 | path = 'data/'+dataset_name+'.pkl' 213 | 214 | with open(path, 'rb') as fp: 215 | self.dataset = pickle.load(fp) 216 | 217 | self.episode_ids = self.dataset['observations']['episode_id'] 218 | self.observations = self.dataset['observations']['observation'] 219 | self.achieved_goals = self.dataset['observations']['achieved_goal'] 220 | self.actions = self.dataset['actions'] 221 | 222 | (ends,) = np.where(self.dataset['terminations']) 223 | self.starts = np.concatenate(([0], ends[:-1] + 1)) 224 | self.lengths = ends - self.starts + 1 225 | self.num_trajectories = len(self.starts) 226 | 227 | self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ] 228 | 229 | if augment_data: 230 | start_time = datetime.now().replace(microsecond=0) 231 | print('starting kmeans ... ') 232 | kmeans = KMeans(n_clusters=nclusters, n_init="auto").fit(self.observations) 233 | time_elapsed = str(datetime.now().replace(microsecond=0) - start_time) 234 | print('kmeans done! time taken :', time_elapsed) 235 | 236 | self.discrete_goal_to_data_idx = extract_discrete_id_to_data_id_map(kmeans.labels_, self.dataset['terminations'], self.ends[-1]) 237 | self.achieved_discrete_goals = kmeans.labels_ 238 | kmeans = None 239 | 240 | self.goal_dim = self.achieved_goals.shape[-1] 241 | self.augment_data = augment_data 242 | self.augment_prob = augment_prob 243 | self.dataset = None 244 | 245 | def __len__(self): 246 | return self.num_trajectories * 100 247 | 248 | def __getitem__(self, idx): 249 | ''' 250 | Reminder: np.random.randint samples from the set [low, high) 251 | ''' 252 | idx = idx % self.num_trajectories 253 | traj_len = self.lengths[idx] - 1 #traj_len = T, traj_len is the number of actions taken in the trajectory 254 | traj_start_i = self.starts[idx] 255 | assert self.ends[traj_start_i] == traj_start_i + traj_len 256 | 257 | si = np.random.randint(0, traj_len) #si can be [0, T-1] 258 | 259 | state = torch.tensor(self.observations[traj_start_i + si]) 260 | action = torch.tensor(self.actions[traj_start_i + si]) 261 | 262 | if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob: 263 | dummy_discrete_goal = self.achieved_discrete_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ] 264 | nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal]) 265 | goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, self.ends[nearby_goal_idx] + 1) ]) 266 | else: 267 | goal = torch.tensor(self.achieved_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ]) 268 | 269 | return state, goal, action -------------------------------------------------------------------------------- /utils/env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import gymnasium as gym 4 | from typing import Iterable, Any 5 | from gymnasium.spaces import Box 6 | from gymnasium.error import DependencyNotInstalled 7 | 8 | def get_lr(optimizer): 9 | for param_group in optimizer.param_groups: 10 | return param_group['lr'] 11 | 12 | def get_parameters(modules: Iterable[nn.Module]): 13 | model_parameters = [] 14 | for module in modules: 15 | model_parameters += list(module.parameters()) 16 | return model_parameters 17 | 18 | 19 | class PreprocessObservationWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): 20 | 21 | def __init__(self, env, shape, grayscale=False) -> None: 22 | """Resizes image observations to shape given by :attr:`shape`. 23 | 24 | Args: 25 | env: The environment to apply the wrapper 26 | shape: The shape of the resized observations 27 | """ 28 | gym.utils.RecordConstructorArgs.__init__(self, shape=shape) 29 | gym.ObservationWrapper.__init__(self, env) 30 | 31 | if isinstance(shape, int): 32 | shape = (shape, shape) 33 | assert len(shape) == 2 and all( 34 | x > 0 for x in shape 35 | ), f"Expected shape to be a 2-tuple of positive integers, got: {shape}" 36 | 37 | self.grayscale = grayscale 38 | if grayscale: 39 | obs_shape = env.observation_space['pixels'].shape 40 | env.observation_space['pixels'] = Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8) 41 | 42 | self.shape = tuple(shape) 43 | obs_shape = self.shape + env.observation_space['pixels'].shape[2:] 44 | self.observation_space['pixels'] = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) 45 | 46 | def observation(self, observation): 47 | try: 48 | import cv2 49 | except ImportError as e: 50 | raise DependencyNotInstalled( 51 | "opencv (cv2) is not installed, run `pip install gymnasium[other]`" 52 | ) from e 53 | 54 | if self.grayscale: 55 | observation['pixels'] = cv2.cvtColor(observation['pixels'], cv2.COLOR_RGB2GRAY) 56 | observation['pixels'] = cv2.resize( 57 | observation['pixels'], self.shape[::-1], interpolation=cv2.INTER_AREA 58 | ).reshape(self.observation_space['pixels'].shape) 59 | 60 | return observation 61 | 62 | class AntmazeWrapper(gym.ObservationWrapper): 63 | def __init__(self, env): 64 | """Constructor for the observation wrapper.""" 65 | gym.ObservationWrapper.__init__(self, env) 66 | 67 | self.observation_space['observation'] = gym.spaces.Box( 68 | low=-np.inf, high=np.inf, shape=(self.observation_space['observation'].shape[0] + 2,), dtype=np.float64 69 | ) 70 | 71 | def reset( 72 | self, *, seed=None, options=None, 73 | ): 74 | obs, info = self.env.reset(seed=seed, options=options) 75 | return self.observation(obs), info 76 | 77 | def step( 78 | self, action 79 | ): 80 | """Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations.""" 81 | observation, reward, terminated, truncated, info = self.env.step(action) 82 | return self.observation(observation), reward, terminated, truncated, info 83 | 84 | def observation(self, observation): 85 | """Returns a modified observation. 86 | 87 | Args: 88 | observation: The :attr:`env` observation 89 | 90 | Returns: 91 | The modified observation 92 | """ 93 | observation['observation'] = np.concatenate((observation['achieved_goal'], observation['observation']), axis=0) 94 | return observation 95 | 96 | def get_test_start_state_goals(cfg): 97 | if cfg.dataset_name in ["pointmaze-umaze-v0", "antmaze-umaze-v0"]: 98 | test_start_state_goal = [ 99 | {'goal_cell': np.array([1,1], dtype=np.int32), 100 | 'reset_cell': np.array([3,1], dtype=np.int32), 101 | 'name' : 'bottom_to_top', 102 | }, 103 | 104 | {'goal_cell': np.array([3,1], dtype=np.int32), 105 | 'reset_cell': np.array([1,1], dtype=np.int32), 106 | 'name' : 'top_to_bottom', 107 | }, 108 | ] 109 | 110 | elif cfg.dataset_name in ["pointmaze-medium-v0"]: 111 | test_start_state_goal = [ 112 | {'goal_cell': np.array([6,3], dtype=np.int32), 113 | 'reset_cell': np.array([6,6], dtype=np.int32), 114 | 'name' : 'bottom_right_to_bottom_center', 115 | }, 116 | 117 | {'goal_cell': np.array([6,1], dtype=np.int32), 118 | 'reset_cell': np.array([6,6], dtype=np.int32), 119 | 'name' : 'bottom_right_to_bottom_left', 120 | }, 121 | 122 | {'goal_cell': np.array([6,3], dtype=np.int32), 123 | 'reset_cell': np.array([6,5], dtype=np.int32), 124 | 'name' : 'bottom_rightish_to_bottom_center', 125 | }, 126 | 127 | {'goal_cell': np.array([6,1], dtype=np.int32), 128 | 'reset_cell': np.array([6,5], dtype=np.int32), 129 | 'name' : 'bottom_rightish_to_bottom_left', 130 | }, 131 | 132 | {'goal_cell': np.array([6,5], dtype=np.int32), 133 | 'reset_cell': np.array([1,1], dtype=np.int32), 134 | 'name' : 'top_left_to_bottom_rightish', 135 | }, 136 | 137 | {'goal_cell': np.array([6,6], dtype=np.int32), 138 | 'reset_cell': np.array([1,6], dtype=np.int32), 139 | 'name' : 'top_right_to_bottom_right', 140 | }, 141 | ] 142 | 143 | elif cfg.dataset_name in ["antmaze-medium-v0"]: 144 | test_start_state_goal = [ 145 | {'goal_cell': np.array([6,5], dtype=np.int32), 146 | 'reset_cell': np.array([6,1], dtype=np.int32), 147 | 'name' : 'bottom_left_to_bottom_rightish', 148 | }, 149 | 150 | {'goal_cell': np.array([1,6], dtype=np.int32), 151 | 'reset_cell': np.array([6,1], dtype=np.int32), 152 | 'name' : 'bottom_left_to_top_right', 153 | }, 154 | 155 | {'goal_cell': np.array([6,1], dtype=np.int32), 156 | 'reset_cell': np.array([6,5], dtype=np.int32), 157 | 'name' : 'bottom_rightish_to_bottom_left', 158 | }, 159 | 160 | {'goal_cell': np.array([1,6], dtype=np.int32), 161 | 'reset_cell': np.array([6,5], dtype=np.int32), 162 | 'name' : 'bottom_rightish_to_top_right', 163 | }, 164 | 165 | {'goal_cell': np.array([6,1], dtype=np.int32), 166 | 'reset_cell': np.array([1,6], dtype=np.int32), 167 | 'name' : 'top_right_to_bottom_left', 168 | }, 169 | 170 | {'goal_cell': np.array([6,5], dtype=np.int32), 171 | 'reset_cell': np.array([1,6], dtype=np.int32), 172 | 'name' : 'top_right_to_bottom_rightish', 173 | }, 174 | 175 | ] 176 | 177 | elif cfg.dataset_name in ["pointmaze-large-v0"]: 178 | test_start_state_goal = [ 179 | {'goal_cell': np.array([7,4], dtype=np.int32), 180 | 'reset_cell': np.array([7,1], dtype=np.int32), 181 | 'name' : 'bottom_left_to_bottom_center', 182 | }, 183 | 184 | {'goal_cell': np.array([7,10], dtype=np.int32), 185 | 'reset_cell': np.array([7,1], dtype=np.int32), 186 | 'name' : 'bottom_left_to_bottom_right', 187 | }, 188 | 189 | {'goal_cell': np.array([1,10], dtype=np.int32), 190 | 'reset_cell': np.array([7,1], dtype=np.int32), 191 | 'name' : 'bottom_left_to_top_right', 192 | }, 193 | 194 | {'goal_cell': np.array([7,1], dtype=np.int32), 195 | 'reset_cell': np.array([7,4], dtype=np.int32), 196 | 'name' : 'bottom_center_to_bottom_left', 197 | }, 198 | 199 | {'goal_cell': np.array([7,10], dtype=np.int32), 200 | 'reset_cell': np.array([7,4], dtype=np.int32), 201 | 'name' : 'bottom_center_to_bottom_right', 202 | }, 203 | 204 | {'goal_cell': np.array([1,1], dtype=np.int32), 205 | 'reset_cell': np.array([7,4], dtype=np.int32), 206 | 'name' : 'bottom_center_to_top_left', 207 | }, 208 | 209 | {'goal_cell': np.array([7,1], dtype=np.int32), 210 | 'reset_cell': np.array([7,10], dtype=np.int32), 211 | 'name' : 'bottom_right_to_bottom_left', 212 | } 213 | ] 214 | 215 | elif cfg.dataset_name in ["antmaze-large-v0"]: 216 | test_start_state_goal = [ 217 | {'goal_cell': np.array([7,4], dtype=np.int32), 218 | 'reset_cell': np.array([7,1], dtype=np.int32), 219 | 'name' : 'bottom_left_to_bottom_center', 220 | }, 221 | 222 | {'goal_cell': np.array([7,10], dtype=np.int32), 223 | 'reset_cell': np.array([7,1], dtype=np.int32), 224 | 'name' : 'bottom_left_to_bottom_right', 225 | }, 226 | 227 | {'goal_cell': np.array([1,10], dtype=np.int32), 228 | 'reset_cell': np.array([7,1], dtype=np.int32), 229 | 'name' : 'bottom_left_to_top_right', 230 | }, 231 | 232 | {'goal_cell': np.array([7,1], dtype=np.int32), 233 | 'reset_cell': np.array([7,4], dtype=np.int32), 234 | 'name' : 'bottom_center_to_bottom_left', 235 | }, 236 | 237 | {'goal_cell': np.array([1,1], dtype=np.int32), 238 | 'reset_cell': np.array([7,4], dtype=np.int32), 239 | 'name' : 'bottom_center_to_top_left', 240 | }, 241 | 242 | {'goal_cell': np.array([7,1], dtype=np.int32), 243 | 'reset_cell': np.array([7,10], dtype=np.int32), 244 | 'name' : 'bottom_right_to_bottom_left', 245 | } 246 | ] 247 | 248 | else: 249 | raise NotImplementedError 250 | 251 | return test_start_state_goal 252 | 253 | def get_maze_map(dataset_name): 254 | return None 255 | 256 | def cell_to_state(cell, maze): 257 | return cell[:, 0] * maze.map_width + cell[:, 1] 258 | 259 | 260 | def cell_xy_to_rowcol(maze, xy_pos: np.ndarray) -> np.ndarray: 261 | """Converts a cell x and y coordinates to `(i,j)`""" 262 | 263 | i = np.reshape(np.floor((maze.y_map_center - xy_pos[:, 1]) / maze.maze_size_scaling), (-1, 1)) 264 | j = np.reshape(np.floor((xy_pos[:, 0] + maze.x_map_center) / maze.maze_size_scaling), (-1, 1)) 265 | 266 | return np.concatenate([i,j], axis=-1) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Iterable 3 | 4 | def weight_init(m): 5 | if isinstance(m, nn.Linear): 6 | nn.init.orthogonal_(m.weight.data) 7 | if hasattr(m.bias, 'data'): 8 | m.bias.data.fill_(0.0) 9 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 10 | gain = nn.init.calculate_gain('relu') 11 | nn.init.orthogonal_(m.weight.data, gain) 12 | if hasattr(m.bias, 'data'): 13 | m.bias.data.fill_(0.0) 14 | 15 | def soft_update(target, source, tau): 16 | for target_param, param in zip(target.parameters(), source.parameters()): 17 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 18 | 19 | def hard_update(target, source): 20 | for target_param, param in zip(target.parameters(), source.parameters()): 21 | target_param.data.copy_(param.data) 22 | 23 | def get_lr(optimizer): 24 | for param_group in optimizer.param_groups: 25 | return param_group['lr'] 26 | 27 | def get_parameters(modules: Iterable[nn.Module]): 28 | model_parameters = [] 29 | for module in modules: 30 | model_parameters += list(module.parameters()) 31 | return model_parameters --------------------------------------------------------------------------------