├── .gitignore ├── github_teaser.png ├── scripts ├── eval.sh ├── train.sh └── replay_turn_faucet_trajectories.sh ├── maniskill2_patches ├── pick_cube.py ├── peg_insertion_side.py ├── turn_faucet.py └── record.py ├── src ├── train_utils.py ├── vec_env.py ├── train.py ├── data.py ├── eval.py └── model.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | src/__pycache__ 2 | models 3 | wandb 4 | path.py 5 | -------------------------------------------------------------------------------- /github_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanJia/CoTPC/HEAD/github_teaser.png -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../src && 4 | 5 | python eval.py --num_traj=500 --eval_max_steps=200 \ 6 | --key_states=abc --key_state_loss=0 \ 7 | --from_ckpt=1_800_000 --task=StackCube-v0 \ 8 | --model_name=some_model_name 9 | -------------------------------------------------------------------------------- /maniskill2_patches/pick_cube.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch the pick_cube env in ManiSkill2 so that it allows additional metrics for eval 3 | and flags for obtaining key states in training. Please simply replace the `evaluate` 4 | function in (with the correct level of indentation): 5 | 6 | https://github.com/haosulab/ManiSkill2/blob/main/mani_skill2/envs/pick_and_place/pick_cube.py 7 | """ 8 | 9 | def evaluate(self, **kwargs): 10 | is_obj_placed = self.check_obj_placed() 11 | is_robot_static = self.check_robot_static() 12 | is_grasped = self.agent.check_grasp(self.obj, max_angle=30) 13 | return dict( 14 | is_obj_placed=is_obj_placed, 15 | is_robot_static=is_robot_static, 16 | is_grasped=is_grasped, 17 | success=is_obj_placed and is_robot_static, 18 | ) -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../src && 4 | 5 | # Example script for PickCube training (with a good set of hyper-parameters). 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --model_name=some_model_name \ 8 | --num_traj=500 --n_iters=1_600_000 \ 9 | --context_length=60 --model_type=s+a+cot \ 10 | --task=PickCube-v0 --key_state_coeff=0.1 \ 11 | --key_state_loss=0 --key_states=ab \ 12 | --init_lr=5e-4 --num_workers=20 13 | 14 | # CUDA_VISIBLE_DEVICES=0 python train.py \ 15 | # --model_name=some_model_name \ 16 | # --num_traj=500 --n_iters=1_600_000 \ 17 | # --context_length=60 --model_type=s+a+cot \ 18 | # --task=TurnFaucet-v0 --key_state_coeff=0.1 \ 19 | # --key_state_loss=0 --key_states=ab \ 20 | # --init_lr=5e-4 --num_workers=20 21 | -------------------------------------------------------------------------------- /src/train_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for the cosine decay learning rate with linear warmup. 3 | """ 4 | 5 | import math 6 | import functools 7 | import torch 8 | 9 | 10 | def _cosine_decay_warmup(iteration, warmup_iterations, total_iterations): 11 | """ 12 | Linear warmup from 0 --> 1.0, then decay using cosine decay to 0.1 13 | """ 14 | if iteration <= warmup_iterations: 15 | multiplier = iteration / warmup_iterations 16 | else: 17 | multiplier = (iteration - warmup_iterations) / (total_iterations - warmup_iterations) 18 | multiplier = max(0.1, 0.5 * (1 + math.cos(math.pi * multiplier))) 19 | return multiplier 20 | 21 | 22 | def CosineAnnealingLRWarmup(optimizer, T_max, T_warmup): 23 | _decay_func = functools.partial( 24 | _cosine_decay_warmup, 25 | warmup_iterations=T_warmup, total_iterations=T_max 26 | ) 27 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, _decay_func) 28 | return scheduler 29 | -------------------------------------------------------------------------------- /scripts/replay_turn_faucet_trajectories.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Assume that the *.h5 and *.json are in `../data/rigid_body_envs/TurnFaucet-v0/raw`, 4 | # replay the trajectories with a subset of a total of 10 faucet models. 5 | for s in 5002 5021 5023 5028 5029 5045 5047 5051 5056 5063 6 | do 7 | python -m mani_skill2.trajectory.replay_trajectory \ 8 | --traj-path ../data/rigid_body_envs/TurnFaucet-v0/raw/$s.h5 \ 9 | --save-traj --target-control-mode pd_joint_delta_pos \ 10 | --obs-mode state --num-procs 20 11 | done 12 | 13 | mv ../data/rigid_body_envs/TurnFaucet-v0/raw/*.state.pd_joint_delta_pos.h5 \ 14 | ../data/rigid_body_envs/TurnFaucet-v0/merged/ 15 | mv ../data/rigid_body_envs/TurnFaucet-v0/raw/*.state.pd_joint_delta_pos.json \ 16 | ../data/rigid_body_envs/TurnFaucet-v0/merged/ 17 | 18 | python -m mani_skill2.trajectory.merge_trajectory \ 19 | -i ../data/rigid_body_envs/TurnFaucet-v0/merged -p *.h5 \ 20 | -o ../data/rigid_body_envs/TurnFaucet-v0/trajectory.state.pd_joint_delta_pos.h5 -------------------------------------------------------------------------------- /maniskill2_patches/peg_insertion_side.py: -------------------------------------------------------------------------------- 1 | # I updated part of the environment file in ManiSkill2 so that it allows metrics 2 | # computation of either reaching an intermediate key state during testing. 3 | # Simply replace the `evaluate` function. 4 | 5 | """ 6 | Patch the peg_insertion_side env in ManiSkill2 so that it allows additional metrics for 7 | eval and flags for obtaining key states in training. Please simply replace the `evaluate` 8 | function in (with the correct level of indentation): 9 | 10 | https://github.com/haosulab/ManiSkill2/blob/main/mani_skill2/envs/assembly/peg_insertion_side.py 11 | """ 12 | 13 | def evaluate(self, **kwargs): 14 | is_grasped = self.agent.check_grasp(self.peg, max_angle=20) 15 | 16 | pre_inserted = False 17 | if is_grasped: 18 | peg_head_wrt_goal = self.goal_pose.inv() * self.peg_head_pose 19 | peg_head_wrt_goal_yz_dist = np.linalg.norm(peg_head_wrt_goal.p[1:]) 20 | peg_wrt_goal = self.goal_pose.inv() * self.peg.pose 21 | peg_wrt_goal_yz_dist = np.linalg.norm(peg_wrt_goal.p[1:]) 22 | if peg_head_wrt_goal_yz_dist < 0.01 and peg_wrt_goal_yz_dist < 0.01: 23 | pre_inserted = True 24 | 25 | success, peg_head_pos_at_hole = self.has_peg_inserted() 26 | return dict( 27 | success=success, 28 | pre_inserted=pre_inserted, 29 | peg_head_pos_at_hole=peg_head_pos_at_hole, 30 | is_grasped=is_grasped, 31 | ) -------------------------------------------------------------------------------- /maniskill2_patches/turn_faucet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch the turn_faucet env in ManiSkill2 so that it allows additional metrics for eval 3 | and flags for obtaining key states in training. Please replace the `evaluate` function 4 | in (with the correct level of indentation): 5 | 6 | https://github.com/haosulab/ManiSkill2/blob/main/mani_skill2/envs/misc/turn_faucet.py 7 | 8 | Moroever, replace the `_get_obs_extra` function to customize the state representation 9 | so that it is easier to distinguish among different faucet models. 10 | 11 | Note: 12 | The 10 faucet models (a simpler subset of all faucets in ManiSkill2) we use in the 13 | CoTPC paper have the ids: 5002,5021,5023,5028,5029,5045,5047,5051,5056,5063. 14 | """ 15 | 16 | def _get_curr_target_link_pos(self): 17 | """ 18 | Access the current pose of the target link (i.e., the handle of the faucet). 19 | """ 20 | cmass_pose = self.target_link.pose * self.target_link.cmass_local_pose 21 | return cmass_pose.p 22 | 23 | def _get_obs_extra(self): 24 | obs = OrderedDict( 25 | tcp_pose=vectorize_pose(self.tcp.pose), 26 | target_angle_diff=self.target_angle_diff, 27 | target_joint_axis=self.target_joint_axis, 28 | target_link_pos=self.target_link_pos, 29 | curr_target_link_pos=self._get_curr_target_link_pos(), # Added code. 30 | ) 31 | if self._obs_mode in ["state", "state_dict"]: 32 | angle_dist = self.target_angle - self.current_angle 33 | obs["angle_dist"] = angle_dist 34 | return obs 35 | 36 | def evaluate(self, **kwargs): 37 | is_contacted = any(self.agent.check_contact_fingers(self.target_link)) 38 | angle_dist = self.target_angle - self.current_angle 39 | return dict( 40 | success=angle_dist < 0, 41 | angle_dist=angle_dist, 42 | is_contacted=is_contacted) -------------------------------------------------------------------------------- /maniskill2_patches/record.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch the record utility in ManiSkill2 so that it records additional metrics for eval 3 | and flags for obtaining key states in training. Please patch the `flush_trajectory` 4 | function in (with the correct level of indentation): 5 | 6 | https://github.com/haosulab/ManiSkill2/blob/main/mani_skill2/utils/wrappers/record.py 7 | """ 8 | 9 | def flush_trajectory(self, **args): 10 | 11 | # some code here ... 12 | 13 | ########################### ADDED CODE ############################# 14 | # Append info (boolean flags) to the recorded trajectories. 15 | # This tells you what info to store in the trajs. 16 | info_bool_keys = [] 17 | for k, v in self._episode_data[-1]['info'].items(): 18 | if type(v).__module__ == np.__name__ and v.dtype == 'bool': 19 | info_bool_keys.append(k) 20 | elif isinstance(v, bool): 21 | info_bool_keys.append(k) 22 | 23 | # This info only appears in some trajs. 24 | if 'TimeLimit.truncated' in info_bool_keys: 25 | info_bool_keys.remove('TimeLimit.truncated') 26 | #################################################################### 27 | 28 | if len(self._episode_data) == 1: 29 | # some code here ... 30 | 31 | ########################### ADDED CODE ############################# 32 | infos_bool = {k: np.empty(shape=(0,), dtype=bool) for k in info_bool_keys} 33 | #################################################################### 34 | else: 35 | # some code here ... 36 | 37 | ########################### ADDED CODE ############################# 38 | infos_bool = {k: [] for k in info_bool_keys} 39 | for x in self._episode_data[1:]: 40 | for k in info_bool_keys: 41 | infos_bool[k].append(x['info'][k]) 42 | for k in infos_bool: 43 | infos_bool[k] = np.stack(infos_bool[k]) 44 | #################################################################### 45 | 46 | # some code here ... 47 | 48 | ########################### ADDED CODE ############################# 49 | # Dump the additional entries to the demo trajectories. 50 | rewards = np.array([x['r'] for x in self._episode_data[1:]], dtype=np.float32) 51 | group.create_dataset("rewards", data=rewards, dtype=np.float32) 52 | for k, v in infos_bool.items(): 53 | group.create_dataset(f"infos/{k}", data=v, dtype=bool) 54 | #################################################################### 55 | 56 | # Handle JSON 57 | # some code here ... -------------------------------------------------------------------------------- /src/vec_env.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import cloudpickle 3 | import numpy as np 4 | from multiprocessing import Pipe, Process 5 | import gym 6 | 7 | from transforms3d.euler import euler2quat 8 | from transforms3d.quaternions import qmult 9 | 10 | import sapien.core as sapien 11 | 12 | 13 | def disturb(env, kwargs): 14 | if 'peg' in kwargs: 15 | dx, dy, dr = kwargs['peg'] 16 | pose = env.peg.get_pose() 17 | quat = euler2quat(0, 0, dr) 18 | env.peg.set_pose(sapien.Pose( 19 | p=pose.p+[dx,dy,0], q=qmult(quat, pose.q))) 20 | if 'box' in kwargs: 21 | dx, dy, dr = kwargs['box'] 22 | pose = env.box.get_pose() 23 | quat = euler2quat(0, 0, dr) 24 | env.box.set_pose(sapien.Pose( 25 | p=pose.p+[dx,dy,0], q=qmult(quat, pose.q))) 26 | 27 | def get_mp_envs(env_id, n_env, start_idx=0, **env_kwargs): 28 | def env_fn(rank): 29 | def fn(): 30 | env = gym.make(env_id, **env_kwargs) 31 | return env 32 | return fn 33 | return VecEnv([env_fn(i + start_idx) for i in range(n_env)]) 34 | 35 | 36 | class CloudpickleWrapper(object): 37 | def __init__(self, x): 38 | self.x = x 39 | 40 | def __getstate__(self): 41 | return cloudpickle.dumps(self.x) 42 | 43 | def __setstate__(self, ob): 44 | self.x = pickle.loads(ob) 45 | 46 | def __call__(self): 47 | return self.x() 48 | 49 | def worker(remote, parent_remote, env_fn): 50 | parent_remote.close() 51 | env = env_fn() 52 | while True: 53 | cmd, data = remote.recv() 54 | 55 | if cmd == 'step': 56 | ob, reward, done, info = env.step(data) 57 | # if done: ob = env.reset() # We ignore the done signal here. 58 | remote.send((ob, reward, done, info)) 59 | elif cmd == 'reset': 60 | ob = env.reset(**data) 61 | remote.send(ob) 62 | elif cmd == 'render': 63 | remote.send(env.render()) 64 | elif cmd == 'close': 65 | remote.close() 66 | break 67 | elif cmd == 'disturb': 68 | disturb(env, data) 69 | else: 70 | raise NameError('NotImplentedError') 71 | 72 | class VecEnv(): 73 | def __init__(self, env_fns): 74 | self.waiting = False 75 | self.closed = False 76 | no_envs = len(env_fns) 77 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(no_envs)]) 78 | self.ps = [] 79 | 80 | for wrk, rem, fn in zip(self.work_remotes, self.remotes, env_fns): 81 | p = Process(target=worker, args=(wrk, rem, CloudpickleWrapper(fn))) 82 | self.ps.append(p) 83 | 84 | for p in self.ps: 85 | p.daemon = True 86 | p.start() 87 | 88 | for remote in self.work_remotes: 89 | remote.close() 90 | 91 | def step_async(self, actions): 92 | if self.waiting: 93 | raise NameError('AlreadySteppingError') 94 | self.waiting = True 95 | for remote, action in zip(self.remotes, actions): 96 | remote.send(('step', action)) 97 | 98 | def step_wait(self): 99 | if not self.waiting: 100 | raise NameError('NotSteppingError') 101 | self.waiting = False 102 | results = [remote.recv() for remote in self.remotes] 103 | obs, rews, dones, infos = zip(*results) 104 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 105 | 106 | def step(self, actions): 107 | self.step_async(actions) 108 | return self.step_wait() 109 | 110 | def reset(self, kwargs_list): 111 | for remote, kwargs in zip(self.remotes, kwargs_list): 112 | remote.send(('reset', kwargs)) 113 | return np.stack([remote.recv() for remote in self.remotes]) 114 | 115 | def disturb(self, kwargs_list): 116 | for remote, kwargs in zip(self.remotes, kwargs_list): 117 | remote.send(('disturb', kwargs)) 118 | 119 | def close(self): 120 | if self.closed: 121 | return 122 | if self.waiting: 123 | for remote in self.remotes: 124 | remote.recv() 125 | for remote in self.remotes: 126 | remote.send(('close', None)) 127 | for p in self.ps: 128 | p.join() 129 | self.closed = True 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chain-of-Thought Predictive Control (CoTPC) 2 | This is the official repository for CoTPC, a powerful hierarchical imitation learning model presented in the following paper: 3 | 4 | ### **[Chain-of-Thought Predictive Control](https://zjia.eng.ucsd.edu/cotpc)**
5 | Zhiwei Jia, Fangchen Liu, Vineet Thumuluri, Linghao Chen, Zhiao Huang, Hao Su
6 | UC San Diego, UC Berkeley, Zhejiang University
7 | 8 |

9 |
10 | [arXiv][website] 11 |

12 | 13 | ### Tasks 14 | Currently the code supports five tasks from the [ManiSkill2](https://github.com/haosulab/ManiSkill2) benchmark: 15 | `PickCube-v0`, `StackCube-v0`, `PegInsertionSide-v0`, `TurnFaucet-v0`, and `PushChair-v1`. 16 | 17 | ### Demonstration Data 18 | The state-based demo trajectories used in the paper are stored in this Googld Drive [folder](https://drive.google.com/drive/folders/1VdunXUlzqAvy-D8MniQ4anhV5LLBfNbJ). 19 | Each folder has a `*.h5` file (for actual trajectories) and a `*.json` file (for metadata regarding the trajectories). 20 | Each task has over 1000 demo trajectories. 21 | Each trajectory comes with a different env configurations (i.e., env seed, which influences object poses, object geometries, etc.). 22 | These demos are generated by replaying the official ManiSkill2 [demos](https://github.com/haosulab/ManiSkill2#demonstrations) or the ones adapted from [this](https://github.com/caiqi/Silver-Bullet-3D/tree/master/No_Restriction) repo with several patches to the ManiSkill2 code (see `CoTPC/maniskill2_patches`). 23 | Specifically, we add additional flags to the tasks so that the key states (the Chain-of-Thought) can be obtained with priviledged information from the simulator. 24 | For the task `TurnFaucet-v0`, we use a subset of 10 faucet models for the demos (see `CoTPC/scripts/replay_turn_faucet_trajectories.sh`). 25 | For the task `PushChair-v1`, we use a subset of 5 chair models for the demos. 26 | If you want to generate visual-based demos, please refer to the official ManiSkill2 guidance [here](https://github.com/haosulab/ManiSkill2#demonstrations). 27 | 28 | ### Data Loader 29 | The data loader samples (contiguous) subarrays of demo trajectories by specifying the min and max sample lengths. 30 | In CoTPC, we simply use a fixed value for both min and max (e.g., set the context size as 60 for all of the 5 tasks from ManiSkill2). 31 | With a fixed random seed `seed` and `num_traj`, each time the data loader samples a fixed subset of all trajectories (by default, we use `seed=0`). 32 | For TurnFaucet and PushChair, due to the variations of the different object models, the loader performs sampling such that the number of trajs 33 | per model is the same (hopefully this data balance eases model training). 34 | The key states (a.k.a the Chain-of-Thought) can be obtained with the function `get_key_states(data_idx)` that accesses privileged info available during training. 35 | 36 | ### Evaluation 37 | We patch the environment code from ManiSkill2 (see `CoTPC/maniskill2_patches`) to provide additional evaluation metrics (intermediate success rate) each task execution. 38 | We report evaluation results for models using both the seen env configurations and the unseen env configuration. 39 | For tasks involving geometric variations we also evaluate using unseen objects (i.e., zero-shot transfer). 40 | Please refer to `CoTPC/src/eval.py` for details. 41 | 42 | 43 | 44 | 45 | ### Sampled Training Scripts 46 | Please see `CoTPC/scripts/train.sh` as examples. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import torch.nn.functional as F 7 | 8 | from collections import deque 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from data import MS2Demos, get_padding_fn 13 | from model import GPTConfig, GPTWithCoT 14 | from train_utils import CosineAnnealingLRWarmup 15 | 16 | try: 17 | # Use might need this for wandb to work due to protobuf issues. 18 | os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' 19 | 20 | import wandb 21 | USE_WANDB = True 22 | PROJECT_NAME = 'CoTPC' # Please specify the project name. 23 | except ImportError: 24 | print('Do not use wandb since it is not found.') 25 | USE_WANDB = False 26 | 27 | # Please specify MODEL_PATH (the base folder for storing models) in `path.py`. 28 | from path import MODEL_PATH 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser() 32 | 33 | # Training hyper-parameters. 34 | parser.add_argument("--n_iters", default=1_600_000, type=int, help="Number of training iterations.") 35 | parser.add_argument("--batch_size", default=256, type=int, help="Batch size.") 36 | parser.add_argument("--init_lr", default='5e-4', type=str, help="The initial learning rate.") 37 | parser.add_argument("--weight_decay", default='0', type=str, help="Weight decay coefficient.") 38 | parser.add_argument("--beta1", default='0.9', type=str, help="Beta1 in the Adam optimizer.") 39 | parser.add_argument("--beta2", default='0.95', type=str, help="Beta2 in the Adam optimizer.") 40 | parser.add_argument("--dropout", default='0.0', type=str, help="Dropout probability.") 41 | parser.add_argument("--lr_schedule", default='cos_decay_with_warmup', type=str, 42 | help="The learning rate schedule.") 43 | 44 | # Hyper-parameters regarding CoTPC. 45 | parser.add_argument("--key_state_coeff", default=0.0, type=float, 46 | help="Coefficient for the key state prediction loss.") 47 | parser.add_argument('--model_type', type=str, default='s+a+cot', 48 | help="Model type for the CoTPC model (see GPTConfig).") 49 | parser.add_argument('--key_states', type=str, default='a', 50 | help="Which key states to use (see GPTConfig for the spec. format).") 51 | parser.add_argument("--key_state_loss", default='', type=str, 52 | help="Features out of what attention layers to use for key state prediction " + 53 | "losses (see GPTConfig for the spec. format).") 54 | parser.add_argument('--cot_decoder', type=str, default='256', help="Specs of the CoT decoder.") 55 | 56 | # General hyper-parameters regarding model loading and saving 57 | parser.add_argument("--model_name", default='', type=str, help="Model name (for storing ckpts).") 58 | parser.add_argument("--from_model_name", default='', type=str, help="Name of the pretrained model.") 59 | parser.add_argument("--from_ckpt", default=-1, type=int, help="Ckpt of pretrained model.") 60 | 61 | # Hyper-parameters regarding the demo dataset 62 | parser.add_argument('--task', type=str, default='PickCube-v0', help="Task (env-id) in ManiSkill2.") 63 | parser.add_argument('--control_mode', type=str, default='pd_joint_delta_pos', 64 | help="Control mode used in envs from ManiSkill2.") 65 | parser.add_argument('--obs_mode', type=str, default='state', 66 | help="State mode used in envs from ManiSkill2.") 67 | parser.add_argument("--seed", default=0, type=int,help="Random seed for data spliting.") 68 | parser.add_argument("--num_traj", default=-1, type=int, help="Number of training trajectories.") 69 | parser.add_argument('--context_length', type=int, default=60, 70 | help="Context size of CoTPC (the maximium length of sequences " + 71 | "sampled from demo trajectories in training).") 72 | parser.add_argument('--min_seq_length', type=int, default=60, 73 | help="Mininum length of sequences sampled from demo trajectories in training.") 74 | 75 | # Save and log frequencies. 76 | parser.add_argument("--save_every", default=40000, type=int, help="Save model every # iters.") 77 | parser.add_argument("--log_every", default=2000, type=int, help="log metrics every # iters.") 78 | 79 | # General hyper-parameters for the GPT architecture. 80 | parser.add_argument("--n_layer", default=4, type=int, help="Number of attention layers.") 81 | parser.add_argument("--n_head", default=8, type=int, help="Number of attention heads.") 82 | parser.add_argument("--n_embd", default=128, type=int, help="Hidden feature dimension.") 83 | 84 | # For faster data loader. 85 | parser.add_argument("--num_workers", default=2, type=int, 86 | help="A positive number for fast async data loading.") 87 | parser.add_argument('--multiplier', type=int, default=20, 88 | help="Duplicate the dataset to reduce data loader overhead.") 89 | 90 | return parser.parse_args() 91 | 92 | 93 | def mse_loss_with_weights(preds, targets, weights=None): 94 | losses = torch.mean((preds - targets) ** 2, -1) 95 | if weights is None: 96 | return torch.mean(losses) 97 | else: 98 | assert losses.shape == weights.shape, losses.shape 99 | return torch.mean(losses * weights) 100 | 101 | 102 | def get_loss(preds, targets, lengths): 103 | # If we have sequences of varied lengths, use masks so we do not compute loss 104 | # over padded values. If we set max_seq_length=min_seq_length, then it should 105 | # not matter since all sequences have the same length. 106 | B = preds.shape[0] 107 | max_len = torch.max(lengths) # Max length of the current mini-batch. 108 | lengths = lengths[:, None] # B x 1 109 | temp = torch.arange(0, max_len)[None].expand(B, -1).cuda() # B x max_len 110 | masks = (temp < lengths.expand(B, max_len)).float() # B x max_len 111 | 112 | loss = mse_loss_with_weights( 113 | preds.reshape(-1, preds.size(-1)), 114 | targets.reshape(-1, targets.size(-1)), 115 | masks.reshape(-1)) 116 | return loss 117 | 118 | 119 | if __name__ == "__main__": 120 | 121 | args = parse_args() 122 | assert args.model_name != '', 'Should specify --model_name' 123 | print('Model name:', args.model_name) 124 | 125 | if 'cot' in args.model_type: 126 | assert args.key_states, 'Should specify --key_states.' 127 | 128 | train_dataset = MS2Demos( 129 | control_mode=args.control_mode, 130 | obs_mode=args.obs_mode, 131 | length=args.num_traj, seed=args.seed, 132 | min_seq_length=args.min_seq_length, 133 | max_seq_length=args.context_length, 134 | with_key_states='cot' in args.model_type, 135 | task=args.task, multiplier=args.multiplier) 136 | print('Training data size:', len(train_dataset)) 137 | print('Max steps:', train_dataset.max_steps) 138 | 139 | input_dict = ['s', 'a', 't'] 140 | input_dict += ['k'] if 'cot' in args.model_type else [] 141 | collate_fn = get_padding_fn(input_dict) 142 | train_data = DataLoader( 143 | dataset=train_dataset, 144 | batch_size=args.batch_size, 145 | shuffle=True, 146 | pin_memory=True, # Faster data loading if using GPU. 147 | num_workers=args.num_workers, 148 | persistent_workers=True, # Faster data loader resets. 149 | collate_fn=collate_fn, 150 | drop_last=True, 151 | ) 152 | data_iter = iter(train_data) 153 | 154 | state_dim, action_dim = train_dataset.info() 155 | conf = GPTConfig( 156 | args.context_length, 157 | n_layer=args.n_layer, 158 | n_head=args.n_head, 159 | n_embd=args.n_embd, 160 | model_type=args.model_type, 161 | key_states=args.key_states, 162 | key_state_loss=args.key_state_loss, 163 | max_timestep=train_dataset.max_steps, 164 | embd_pdrop=float(args.dropout), 165 | resid_pdrop=float(args.dropout), 166 | attn_pdrop=float(args.dropout), 167 | cot_decoder=args.cot_decoder, 168 | ) 169 | model = GPTWithCoT(conf, state_dim=state_dim, action_dim=action_dim).cuda() 170 | optimizer = model.configure_adamw_optimizers({ 171 | 'init_lr': float(args.init_lr), 172 | 'weight_decay': float(args.weight_decay), 173 | 'beta1': float(args.beta1), 174 | 'beta2': float(args.beta2), 175 | }) 176 | 177 | # Learning rate schedules (which might require more tuning). 178 | if args.lr_schedule == 'cos_decay_with_warmup': 179 | lr_scheduler = CosineAnnealingLRWarmup( 180 | optimizer, T_max=args.n_iters, T_warmup=1000) 181 | else: 182 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 183 | optimizer, milestones=[78000], gamma=0.1) 184 | 185 | model_path = os.path.join(MODEL_PATH, args.model_name) 186 | os.makedirs(model_path, exist_ok=True) 187 | 188 | # If loaded from pretrained model first. 189 | if args.from_ckpt > 0: 190 | if args.from_model_name: 191 | path = os.path.join( 192 | MODEL_PATH, args.from_model_name, f'{args.from_ckpt}.pth') 193 | else: 194 | path = os.path.join(model_path, f'{args.from_ckpt}.pth') 195 | model.load_state_dict(torch.load(path), strict=True) 196 | print(f'Pretrained model loaded from {path}.') 197 | 198 | log_path = os.path.join(model_path, 'log.txt') 199 | if USE_WANDB: 200 | wandb.init( 201 | project=PROJECT_NAME, name=args.model_name, config=args, 202 | config_exclude_keys=['model_name', 'save_every', 'log_every'], 203 | ) 204 | wandb.run.log_code(".") # Need to first enable it on wandb web UI. 205 | 206 | losses_act_pred = deque(maxlen=1000) 207 | losses_key_states = deque(maxlen=1000) 208 | 209 | # Convert key states to integers. 210 | key_states = [ord(c) - ord('a') for c in args.key_states] 211 | 212 | # Main training loop. 213 | for idx in tqdm(range(args.n_iters + 1)): 214 | 215 | # Adjust lr schedule when loaded from pretrained models. 216 | if args.from_ckpt > 0 and idx <= args.from_ckpt: 217 | lr_scheduler.step() 218 | continue 219 | 220 | # Obtain the current mini-batch (infinite loop). 221 | try: 222 | batch = next(data_iter) 223 | except StopIteration: 224 | data_iter = iter(train_data) 225 | batch = next(data_iter) 226 | batch = {k: v.cuda() for k, v in batch.items()} 227 | 228 | # Forward pass. 229 | act_pred, key_states_pred = model(batch['s'], batch['t'], batch['a']) 230 | 231 | # Obtain training losses. 232 | loss_act_pred = get_loss(act_pred, batch['a'], batch['lengths']) 233 | total_loss = loss_act_pred 234 | 235 | loss_key_states = torch.tensor(-1) # -1 means N/A. 236 | if 'cot' in args.model_type: 237 | ks_gt = torch.stack( 238 | [batch['k'][:, k_idx] for k_idx in key_states], 1) 239 | loss_key_states = torch.mean(torch.stack( 240 | [F.mse_loss(ks_pred, ks_gt) for ks_pred in key_states_pred])) 241 | if args.key_state_coeff > 0: 242 | total_loss += args.key_state_coeff * loss_key_states 243 | 244 | losses_act_pred.append(loss_act_pred.item()) 245 | losses_key_states.append(loss_key_states.item()) 246 | optimizer.zero_grad() 247 | total_loss.backward() 248 | optimizer.step() 249 | 250 | if idx % args.log_every == 0: 251 | with open(log_path, 'a' if os.path.exists(log_path) else 'w') as f: 252 | avg_loss_act_pred = np.mean(losses_act_pred) 253 | avg_loss_key_states = np.mean(losses_key_states) 254 | print(f'Iteration {idx}: {avg_loss_act_pred}, {avg_loss_key_states}') 255 | f.write(f'{idx},{avg_loss_act_pred},{avg_loss_key_states}\n') 256 | if USE_WANDB: 257 | log_dict = { 258 | 'n_iter': idx, 259 | 'loss_actions': avg_loss_act_pred, 260 | 'loss_sum': avg_loss_act_pred, 261 | } 262 | if 'cot' in args.model_type: 263 | log_dict['loss_key_states'] = avg_loss_key_states 264 | log_dict['loss_sum'] = avg_loss_act_pred + avg_loss_key_states 265 | wandb.log(log_dict) 266 | 267 | if idx > 0 and idx % args.save_every == 0: 268 | save_path = os.path.join(model_path, f'{idx}.pth') 269 | torch.save({ 270 | 'model': model.state_dict(), 271 | 'metadata': vars(args) 272 | }, save_path) 273 | 274 | # Update learning rate. 275 | lr_scheduler.step() 276 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | 5 | from torch.utils.data.dataset import Dataset 6 | from torch.nn.utils.rnn import pad_sequence 7 | import torch 8 | 9 | # Please specify the DATA_PATH (the base folder for storing data) in `path.py`. 10 | from path import DATA_PATH 11 | 12 | 13 | class MS2Demos(Dataset): 14 | def __init__(self, 15 | data_split='train', 16 | task='PickCube-v0', 17 | obs_mode='state', 18 | control_mode='pd_joint_delta_pos', 19 | length=-1, 20 | min_seq_length=None, 21 | max_seq_length=None, 22 | with_key_states=False, 23 | multiplier=20, # Used for faster data loading. 24 | seed=None): # seed for train/test spliting. 25 | super().__init__() 26 | self.task = task 27 | self.data_split = data_split 28 | self.seed = seed 29 | self.min_seq_length = min_seq_length # For sampling trajectories. 30 | self.max_seq_length = max_seq_length # For sampling trajectories. 31 | self.with_key_states = with_key_states # Whether output key states. 32 | self.multiplier = multiplier 33 | 34 | # Usually set min and max traj length to be the same value. 35 | self.max_steps = -1 # Maximum timesteps across all trajectories. 36 | traj_path = os.path.join(DATA_PATH, 37 | f'{task}/trajectory.{obs_mode}.{control_mode}.h5') 38 | print('Traj path:', traj_path) 39 | self.data = self.load_demo_dataset(traj_path, length) 40 | 41 | # Cache key states for faster data loading. 42 | if self.with_key_states: 43 | self.idx_to_key_states = dict() 44 | 45 | def __len__(self): 46 | return len(self.data['env_states']) 47 | 48 | def __getitem__(self, index): 49 | # Offset by one since the last obs does not have a corresponding action. 50 | l = len(self.data['obs'][index]) - 1 51 | 52 | # Sample starting and ending index given the min and max traj length. 53 | if self.min_seq_length is None and self.max_seq_length is None: 54 | s_idx, e_idx = 0, l 55 | else: 56 | min_length = 0 if self.min_seq_length is None else self.min_seq_length 57 | max_length = l if self.max_seq_length is None else self.max_seq_length 58 | assert min_length <= max_length 59 | if min_length == max_length: 60 | length = min_length 61 | else: 62 | length = np.random.randint(min_length, max_length, 1)[0] 63 | if length <= l: 64 | s_idx = np.random.randint(0, l - length + 1, 1)[0] 65 | e_idx = s_idx + length 66 | else: 67 | s_idx, e_idx = 0, l 68 | assert e_idx <= l, f'{e_idx}, {l}' 69 | 70 | # Call get_key_states() if you want to use the key states. 71 | # Here `s` is the state observation, `a` is the action, 72 | # `env_states` not used during training (can be used to reconstruct env for debugging). 73 | # `t` is used for positional embedding as in Decision Transformer. 74 | data_dict = { 75 | 's': self.data['obs'][index][s_idx:e_idx].astype(np.float32), 76 | 'a': self.data['actions'][index][s_idx:e_idx].astype(np.float32), 77 | 't': np.array([s_idx]).astype(np.float32), 78 | # 'env_states': self.data['env_states'][index][s_idx:e_idx].astype(np.float32), 79 | } 80 | if self.with_key_states: 81 | if f'key_states_{index}' not in self.idx_to_key_states: 82 | self.idx_to_key_states[f'key_states_{index}'] = self.get_key_states(index) 83 | data_dict['k'] = self.idx_to_key_states[f'key_states_{index}'] 84 | return data_dict 85 | 86 | def info(self): # Get observation and action shapes. 87 | return self.data['obs'][0].shape[-1], self.data['actions'][0].shape[-1] 88 | 89 | def load_demo_dataset(self, path, length): 90 | dataset = {} 91 | traj_all = h5py.File(path) 92 | if length == -1: 93 | length = len(traj_all) 94 | np.random.seed(self.seed) # Fix the random seed for train/test data split. 95 | 96 | # Since TurnFaucet uses 10 different faucet models, we shuffle the data 97 | # such that the resulting sampled data are evenly sampled across faucet models. 98 | if self.task == 'TurnFaucet-v0': 99 | ids = [] 100 | for i in range(10): # Hard-code the 10 data splits for permutation. 101 | t_ids = np.random.permutation(len(traj_all)//10)[:length//10] 102 | t_ids += i*len(traj_all)//10 103 | ids.append(t_ids) 104 | ids = np.concatenate(ids) 105 | # Since PushChair uses 5 different faucet models, we shuffle the data 106 | # such that the resulting sampled data are evenly sampled across chair models. 107 | elif self.task == 'PushChair-v1': 108 | ids = [] 109 | for i in range(5): # Hard-code the 5 data splits for permutation. 110 | t_ids = np.random.permutation(len(traj_all)//5)[:length//5] 111 | t_ids += i*len(traj_all)//5 112 | ids.append(t_ids) 113 | ids = np.concatenate(ids) 114 | else: 115 | ids = np.random.permutation(len(traj_all))[:length] 116 | 117 | ids = ids.tolist() * self.multiplier # Duplicate the data for faster loading. 118 | 119 | # Note that the size of `env_states` and `obs` is that of the others + 1. 120 | # And most `infos` is for the next obs rather than the current obs. 121 | 122 | # `env_states` is used for reseting the env (might be helpful for eval) 123 | dataset['env_states'] = [np.array( 124 | traj_all[f"traj_{i}"]['env_states']) for i in ids] 125 | # `obs` is the observation of each step. 126 | dataset['obs'] = [np.array(traj_all[f"traj_{i}"]["obs"]) for i in ids] 127 | dataset['actions'] = [np.array(traj_all[f"traj_{i}"]["actions"]) for i in ids] 128 | 129 | # actions = np.concatenate(dataset['actions']) 130 | # actions_std = np.std(actions, 0) 131 | # dataset['actions'] = [ 132 | # np.array(traj_all[f"traj_{i}"]["actions"]) / (actions_std + 1e-7) for i in ids] 133 | 134 | # `rewards` is not currently used in CoTPC training. 135 | dataset['rewards'] = [np.array(traj_all[f"traj_{i}"]["rewards"]) for i in ids] 136 | for k in traj_all['traj_0']['infos'].keys(): 137 | dataset[f'infos/{k}'] = [np.array( 138 | traj_all[f"traj_{i}"]["infos"][k]) for i in ids] 139 | if k == 'info': # For PushChair. 140 | for kk in traj_all['traj_0']['infos'][k].keys(): 141 | dataset[f'infos/demo_{kk}'] = [np.array( 142 | traj_all[f"traj_{i}"]["infos"][k][kk]) for i in ids] 143 | 144 | self.max_steps = np.max([len(s) for s in dataset['env_states']]) 145 | 146 | return dataset 147 | 148 | def get_key_states(self, idx): 149 | # Note that `infos` is for the next obs rather than the current obs. 150 | # Thus, we need to offset the `step_idx`` by one. 151 | key_states = [] 152 | 153 | # If TurnFaucet (two key states) 154 | # key state I: is_contacted -> true 155 | # key state II: end of the trajectory 156 | if self.task == 'TurnFaucet-v0': 157 | for step_idx, key in enumerate(self.data['infos/is_contacted'][idx]): 158 | if key: break 159 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32)) 160 | 161 | # If PegInsertion (three key states) 162 | # key state I: is_grasped -> true 163 | # key state II: pre_inserted -> true 164 | # key state III: end of the trajectory 165 | if self.task == 'PegInsertionSide-v0': 166 | for step_idx, key in enumerate(self.data['infos/is_grasped'][idx]): 167 | if key: break 168 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32)) 169 | for step_idx, key in enumerate(self.data['infos/pre_inserted'][idx]): 170 | if key: break 171 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32)) 172 | 173 | # If PickCube (two key states) 174 | # key state I: is_grasped -> true 175 | # key state II: end of the trajectory 176 | if self.task == 'PickCube-v0': 177 | for step_idx, key in enumerate(self.data['infos/is_grasped'][idx]): 178 | if key: break 179 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32)) 180 | 181 | # If StackCube (three key states) 182 | # key state I: is_cubaA_grasped -> true 183 | # key state II: the last state of is_cubeA_on_cubeB -> true 184 | # right before is_cubaA_grasped -> false 185 | # key state III: end of the trajectory 186 | if self.task == 'StackCube-v0': 187 | for step_idx, key in enumerate(self.data['infos/is_cubaA_grasped'][idx]): 188 | if key: break 189 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32)) 190 | for step_idx, k1 in enumerate(self.data['infos/is_cubeA_on_cubeB'][idx]): 191 | k2 = self.data['infos/is_cubaA_grasped'][idx][step_idx] 192 | if k1 and not k2: break 193 | # Right before such a state and so we do not use step_idx+1. 194 | key_states.append(self.data['obs'][idx][step_idx].astype(np.float32)) 195 | 196 | # If PushChair (four key states): 197 | # key state I: right before demo_rotate -> true 198 | # key state II: right before demo_move -> true 199 | # key state III: when chair_close_to_target & chair_standing -> true 200 | # key state IV: end of the trajectory 201 | lengths = [] 202 | # In PushChair, demo_* indicate the current state (not the next). 203 | if self.task == 'PushChair-v1': 204 | for step_idx, key in enumerate(self.data['infos/demo_rotate'][idx]): 205 | if key: break 206 | lengths.append(step_idx) 207 | key_states.append(self.data['obs'][idx][step_idx].astype(np.float32)) 208 | for step_idx, key in enumerate(self.data['infos/demo_move'][idx]): 209 | if key: break 210 | lengths.append(step_idx - np.sum(lengths)) 211 | key_states.append(self.data['obs'][idx][step_idx].astype(np.float32)) 212 | for step_idx, key in enumerate(np.bitwise_and( 213 | self.data['infos/chair_close_to_target'][idx], 214 | self.data['infos/chair_standing'][idx])): 215 | if key: break 216 | lengths.append(step_idx + 1 - np.sum(lengths)) 217 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32)) 218 | lengths.append(len(self.data['infos/success'][idx]) - np.sum(lengths)) 219 | 220 | # Always append the last state in the trajectory as the last key state. 221 | key_states.append(self.data['obs'][idx][-1].astype(np.float32)) 222 | 223 | key_states = np.stack(key_states, 0).astype(np.float32) 224 | assert len(key_states) > 0, self.task 225 | return key_states 226 | 227 | 228 | # To obtain the padding function for sequences. 229 | def get_padding_fn(data_names): 230 | assert 's' in data_names, 'Should at least include `s` in data_names.' 231 | 232 | def pad_collate(*args): 233 | assert len(args) == 1 234 | output = {k: [] for k in data_names} 235 | for b in args[0]: # Batches 236 | for k in data_names: 237 | output[k].append(torch.from_numpy(b[k])) 238 | 239 | # Include the actual length of each sequence sampled from a trajectory. 240 | # If we set max_seq_length=min_seq_length, this is a constant across samples. 241 | output['lengths'] = torch.tensor([len(s) for s in output['s']]) 242 | 243 | # Padding all the sequences. 244 | for k in data_names: 245 | output[k] = pad_sequence(output[k], batch_first=True, padding_value=0) 246 | 247 | return output 248 | 249 | return pad_collate 250 | 251 | 252 | # Sample code for the data loader. 253 | if __name__ == "__main__": 254 | 255 | from torch.utils.data import DataLoader 256 | 257 | # The default values for CoTPC for tasks in ManiSkill2. 258 | batch_size, num_traj, seed, min_seq_length, max_seq_length, task = \ 259 | 256, 500, 0, 60, 60, 'PickCube-v0' 260 | batch_size, num_traj, seed, min_seq_length, max_seq_length, task = \ 261 | 256, 500, 0, 60, 60, 'PushChair-v1' 262 | 263 | train_dataset = MS2Demos( 264 | # control_mode='pd_joint_delta_pos', 265 | control_mode='base_pd_joint_vel_arm_pd_joint_vel', 266 | length=num_traj, seed=seed, 267 | min_seq_length=min_seq_length, 268 | max_seq_length=max_seq_length, 269 | with_key_states=True, 270 | task=task) 271 | 272 | collate_fn = get_padding_fn(['s', 'a', 't', 'k']) 273 | train_data = DataLoader( 274 | dataset=train_dataset, 275 | batch_size=batch_size, 276 | collate_fn=collate_fn) 277 | 278 | data_iter = iter(train_data) 279 | data = next(data_iter) 280 | # print(len(data)) # 4 281 | # for k, v in data.items(): 282 | # print(k, v.shape) 283 | # 's', [256, 60, 51] 284 | # 'a', [256, 60, 8] 285 | # 't', [256, 1] 286 | # 'k', [256, 2, 51] 287 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | from tqdm import tqdm 5 | from collections import defaultdict 6 | 7 | from mani_skill2.utils.io_utils import load_json 8 | import mani_skill2.envs # Load ManiSkill2 envs. 9 | import torch # Load pytorch after maniskill2 to avoid some import error. 10 | 11 | from model import GPTConfig, GPTWithCoT 12 | 13 | from vec_env import get_mp_envs # Used for parallel evaluation. 14 | 15 | try: 16 | # Use might need this for wandb to work due to protobuf issues. 17 | os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' 18 | import wandb 19 | assert wandb.__version__ 20 | USE_WANDB = True 21 | PROJECT_NAME = 'CoTPC' # Please specify the project name. 22 | except Exception: 23 | print('Do not use wandb since it is not found.') 24 | USE_WANDB = False 25 | 26 | # Please specify MODEL_PATH and DATA_PATH (both are base folders) in `path.py`. 27 | from path import MODEL_PATH, DATA_PATH 28 | 29 | 30 | @torch.no_grad() 31 | def predict(model, action_hist, state_hist, t): 32 | assert model.model_type in ['s', 's+a', 's+a+cot'] 33 | 34 | timesteps = torch.from_numpy(t)[:, None].cuda() 35 | if not action_hist: # The first step. 36 | actions = None 37 | else: 38 | actions = torch.stack(action_hist, 1).float().cuda() 39 | states = torch.stack(state_hist, 1).float().cuda() 40 | 41 | if 'cot' in model.model_type: 42 | # T is the max sequence size; S is the current number of steps. 43 | B, T = states.shape[0], model.block_size + model.len_key_states 44 | n_head, S = model.config.n_head, states.shape[1] - 1 # Exclude the init state. 45 | 46 | # Masks for the all-to-all key state query tokens in attention layers. 47 | # The built-in masks for causal (auto-regressive) tokens are in `model.py`. 48 | key_state_mask = torch.zeros([B, n_head, T, T], dtype=bool) 49 | m1 = torch.arange(0, T).repeat(B, 1) 50 | m2 = torch.ones([B, 1]) * (S * 2 + model.len_key_states) 51 | m3 = m1 > m2 # Tokens in the future are masked out. 52 | m3 = m3[:, None, None, :].repeat(1, n_head, model.len_key_states, 1) 53 | key_state_mask[:, :, :model.len_key_states, :] = m3 54 | key_state_mask = key_state_mask.cuda() 55 | preds, _ = model( 56 | states, timesteps, actions=actions, key_state_mask=key_state_mask) 57 | else: 58 | preds, _ = model(states, timesteps, actions=actions) 59 | 60 | return preds[:, -1] # Only output the last action predictions. 61 | 62 | 63 | def update(model, action_hist, state_hist, actions, states, t): 64 | # A function used to update the state and action history. 65 | assert model.model_type in ['s', 's+a', 's+a+cot'] 66 | 67 | actions = torch.from_numpy(actions) 68 | if len(state_hist) == model.block_size // 2: # The context buffer is full. 69 | assert len(action_hist) == model.block_size // 2 - 1 70 | state_hist = state_hist[1:] + [states] 71 | action_hist = action_hist[1:] + [actions] 72 | t += 1 73 | else: 74 | state_hist.append(states) 75 | action_hist.append(actions) 76 | return action_hist, state_hist, t 77 | 78 | 79 | def parse_args(): 80 | parser = argparse.ArgumentParser() 81 | 82 | # Hyper-parameters regarding the demo dataset (used to gather eval_ids) 83 | parser.add_argument('--task', type=str, default='PickCube-v0', help="Task (env-id) in ManiSkill2.") 84 | parser.add_argument('--control_mode', type=str, default='pd_joint_delta_pos', 85 | help="Control mode used in envs from ManiSkill2.") 86 | parser.add_argument('--obs_mode', type=str, default='state', 87 | help="State mode used in envs from ManiSkill2.") 88 | parser.add_argument("--seed", default=0, type=int,help="Random seed for data spliting.") 89 | 90 | # Hyper-parameters regarding the model. 91 | parser.add_argument("--model_name", default='', type=str, help="Model name to be loaded.") 92 | parser.add_argument("--from_ckpt", default=-1, type=int, help="Ckpt of the model to be loaded.") 93 | 94 | parser.add_argument("--eval_max_steps", default=200, type=int, help="Max steps allowed in eval.") 95 | parser.add_argument('--cot_decoder', type=str, default='256', help="Specs of the CoT decoder.") 96 | 97 | return parser.parse_args() 98 | 99 | 100 | if __name__ == "__main__": 101 | 102 | args = parse_args() 103 | assert args.model_name, 'Should specify --model_name' 104 | assert args.from_ckpt > 0, 'Should specify --from_ckpt' 105 | 106 | # Load the model. 107 | path = os.path.join(MODEL_PATH, f'{args.model_name}/{args.from_ckpt}.pth') 108 | # Load to cpu first to avoid cuda related errors from ManiSkill2. 109 | ckpt = torch.load(path, map_location=torch.device('cpu')) 110 | state_dict_from_ckpt, params = ckpt['model'], ckpt['metadata'] 111 | state_dim = state_dict_from_ckpt['state_encoder.net.0.weight'].shape[1] 112 | action_dim = state_dict_from_ckpt['action_encoder.net.0.weight'].shape[1] 113 | max_timestep = state_dict_from_ckpt['global_pos_emb'].shape[1] 114 | print('Loaded ckpt from:', path) 115 | 116 | # Load demos to fetch the env. seeds used in training. 117 | json_path = os.path.join( 118 | DATA_PATH, f'{args.task}/trajectory.{args.obs_mode}.{args.control_mode}.json') 119 | json_data = load_json(json_path) 120 | env_kwargs = json_data["env_info"]["env_kwargs"] 121 | env_kwargs["obs_mode"] = args.obs_mode 122 | env_kwargs["control_mode"] = args.control_mode 123 | np.random.seed(args.seed) 124 | if args.task == 'TurnFaucet-v0': 125 | length_all = len(json_data["episodes"]) 126 | ids = [] 127 | for i in range(10): # Hard-code the 10 data splits for permutation. 128 | t_ids = np.random.permutation( 129 | length_all // 10)[:params['num_traj'] // 10] 130 | t_ids += i * length_all // 10 131 | ids.append(t_ids) 132 | eval_ids = np.concatenate(ids) 133 | elif args.task == 'PushChair-v1': 134 | length_all = len(json_data["episodes"]) 135 | ids = [] 136 | for i in range(5): # Hard-code the 5 data splits for permutation. 137 | t_ids = np.random.permutation(length_all // 5)[:100] 138 | t_ids += i * length_all // 5 139 | ids.append(t_ids) 140 | eval_ids = np.concatenate(ids) 141 | else: 142 | # Only evaluate at most 500 scene configs. 143 | eval_ids = np.random.permutation( 144 | len(json_data["episodes"]))[:params['num_traj']][:500] 145 | 146 | n_env = 25 # Number of parallel environments. 147 | assert len(eval_ids) % n_env == 0, f'{len(eval_ids)}' 148 | envs = get_mp_envs(args.task, n_env, **env_kwargs) 149 | 150 | # Load the ckpt after envs init to avoid cuda related errors from ManiSkill2. 151 | cot_decoder = params['cot_decoder'] if 'cot_decoder' in params else args.cot_decoder 152 | conf = GPTConfig( 153 | params['context_length'], 154 | n_layer=params['n_layer'], 155 | n_head=params['n_head'], 156 | n_embd=params['n_embd'], 157 | model_type=params['model_type'], 158 | key_states=params['key_states'], # Rules for the CoT. 159 | key_state_loss=params['key_state_loss'], # Layers used for CoT modeling. 160 | cot_decoder=cot_decoder, 161 | max_timestep=max_timestep, 162 | ) 163 | model = GPTWithCoT(conf, state_dim=state_dim, action_dim=action_dim).cuda() 164 | model.load_state_dict(state_dict_from_ckpt, strict=False) 165 | model.eval() 166 | 167 | if USE_WANDB: 168 | wandb.init(project=PROJECT_NAME, name=f'eval/{args.model_name}', 169 | id=f'wandb_metrics_{args.model_name}', resume='auto') 170 | 171 | output_str, output_dict = '', dict() 172 | 173 | # Seen scene configurations. 174 | metric_dict = defaultdict(lambda: [[] for _ in range(len(eval_ids))]) 175 | for start_idx in tqdm(range(0, len(eval_ids), n_env)): 176 | reset_args_list = [] 177 | for i in range(start_idx, min(start_idx + n_env, len(eval_ids))): 178 | reset_kwargs = json_data["episodes"][eval_ids[i]]['reset_kwargs'] 179 | reset_args_list.append(reset_kwargs) 180 | 181 | s = torch.from_numpy(envs.reset(reset_args_list)).float() 182 | state_hist, action_hist, t = [s], [], np.zeros([n_env]) 183 | 184 | for step in range(args.eval_max_steps): 185 | a = predict(model, action_hist, state_hist, t).cpu().numpy() 186 | 187 | s, _, _, infos = envs.step(a) 188 | s = torch.from_numpy(s).float() 189 | 190 | action_hist, state_hist, t = update( 191 | model, action_hist, state_hist, a, s, t) 192 | 193 | # Update metrics. 194 | for i, info in enumerate(infos): 195 | j = start_idx + i 196 | # You might want to use these additional metrics. 197 | if args.task == 'PickCube-v0': 198 | metric_dict['is_grasped'][j].append(info['is_grasped']) 199 | if args.task == 'StackCube-v0': 200 | metric_dict['is_cubaA_grasped'][j].append(info['is_cubaA_grasped']) 201 | metric_dict['is_cubeA_on_cubeB'][j].append(info['is_cubeA_on_cubeB']) 202 | if args.task == 'PegInsertionSide-v0': 203 | metric_dict['is_grasped'][j].append(info['is_grasped']) 204 | metric_dict['pre_inserted'][j].append(info['pre_inserted']) 205 | if args.task == 'TurnFaucet-v0': 206 | metric_dict['is_contacted'][j].append(info['is_contacted']) 207 | if args.task == 'PushChair-v1': 208 | metric_dict['close_to_target'][j].append(info['chair_close_to_target']) 209 | metric_dict['static_at_last'][j].append( 210 | info['chair_close_to_target'] and info['chair_static']) 211 | metric_dict['success'][j].append(info['success']) 212 | 213 | for k, v in metric_dict.items(): 214 | v = np.mean([np.any(vv) for vv in v]) * 100 215 | output_str += f'{k} {v:.2f}, ' 216 | output_dict[k] = v 217 | output_str = output_str[:-2] 218 | print(output_str) 219 | 220 | # Unseen scene configurations. 221 | # Unseen objects for peg insertion and seen objects otherwise. 222 | all_reset_kwargs = [] 223 | if args.task == 'TurnFaucet-v0': 224 | length_all = len(json_data["episodes"]) 225 | ids = [] 226 | for i in range(10): # Hard-code the 10 data splits for permutation. 227 | t_ids = np.random.permutation(length_all // 10) 228 | t_ids = t_ids[params['num_traj']//10:params['num_traj']//10+10] 229 | t_ids += i * length_all // 10 230 | ids.append(t_ids) 231 | eval_ids = np.concatenate(ids) 232 | for eval_id in eval_ids: 233 | all_reset_kwargs.append(json_data["episodes"][eval_id]['reset_kwargs']) 234 | elif args.task == 'PushChair-v1': 235 | length_all = len(json_data["episodes"]) 236 | ids = [] 237 | for i in range(5): # Hard-code the 5 data splits for permutation. 238 | t_ids = np.random.permutation(length_all // 5) 239 | t_ids = t_ids[params['num_traj']//5:params['num_traj']//5+50] 240 | t_ids += i * length_all // 5 241 | ids.append(t_ids) 242 | eval_ids = np.concatenate(ids) 243 | for eval_id in eval_ids: 244 | all_reset_kwargs.append(json_data["episodes"][eval_id]['reset_kwargs']) 245 | elif args.task == 'PegInsertionSide-v0': 246 | for i in range(400): 247 | all_reset_kwargs.append({'seed': i + 2000}) 248 | else: 249 | for i in range(100): 250 | all_reset_kwargs.append({'seed': i + 2000}) 251 | metric_dict = defaultdict(lambda: [[] for _ in range(len(all_reset_kwargs))]) 252 | 253 | for start_idx in tqdm(range(0, len(all_reset_kwargs), n_env)): 254 | reset_args_list = [] 255 | for i in range(start_idx, min(start_idx + n_env, len(all_reset_kwargs))): 256 | reset_args_list.append(all_reset_kwargs[i]) 257 | 258 | s = torch.from_numpy(envs.reset(reset_args_list)).float() 259 | state_hist, action_hist, t = [s], [], np.zeros([n_env]) 260 | 261 | for step in range(args.eval_max_steps): 262 | a = predict(model, action_hist, state_hist, t).cpu().numpy() 263 | s, _, _, infos = envs.step(a) 264 | s = torch.from_numpy(s).float() 265 | 266 | action_hist, state_hist, t = update( 267 | model, action_hist, state_hist, a, s, t) 268 | 269 | # Update metrics. 270 | for i, info in enumerate(infos): 271 | j = start_idx + i 272 | # You might want to use these additional metrics. 273 | if args.task == 'PickCube-v0': 274 | metric_dict['test/is_grasped'][j].append(info['is_grasped']) 275 | if args.task == 'StackCube-v0': 276 | metric_dict['test/is_cubaA_grasped'][j].append(info['is_cubaA_grasped']) 277 | metric_dict['test/is_cubeA_on_cubeB'][j].append(info['is_cubeA_on_cubeB']) 278 | if args.task == 'PegInsertionSide-v0': 279 | metric_dict['test/is_grasped'][j].append(info['is_grasped']) 280 | metric_dict['test/pre_inserted'][j].append(info['pre_inserted']) 281 | if args.task == 'TurnFaucet-v0': 282 | metric_dict['test/is_contacted'][j].append(info['is_contacted']) 283 | if args.task == 'PushChair-v1': 284 | metric_dict['test/close_to_target'][j].append(info['chair_close_to_target']) 285 | metric_dict['test/static_at_last'][j].append( 286 | info['chair_close_to_target'] and info['chair_static']) 287 | metric_dict['test/success'][j].append(info['success']) 288 | 289 | output_str = '' 290 | for k, v in metric_dict.items(): 291 | v = np.mean([np.any(vv) for vv in v]) * 100 292 | output_str += f'{k} {v:.2f}, ' 293 | output_dict[k] = v 294 | output_str = output_str[:-2] 295 | print(output_str) 296 | 297 | # Unseen scene configurations with unseen objects (zero-shot). 298 | all_reset_kwargs = [] 299 | if args.task == 'TurnFaucet-v0': 300 | model_ids = [ 301 | 5014, 5037, 5053, 5062, 302 | ] 303 | elif args.task == 'PushChair-v1': 304 | model_ids = [ 305 | 3003, 3013, 3020, 306 | ] 307 | else: 308 | model_ids = [] 309 | for model_id in model_ids: 310 | for i in range(100): 311 | all_reset_kwargs.append({'seed': i + 2000, 'model_id': str(model_id)}) 312 | metric_dict = defaultdict(lambda: [[] for _ in range(len(all_reset_kwargs))]) 313 | 314 | for start_idx in tqdm(range(0, len(all_reset_kwargs), n_env)): 315 | reset_args_list = [] 316 | for i in range(start_idx, min(start_idx + n_env, len(all_reset_kwargs))): 317 | reset_args_list.append(all_reset_kwargs[i]) 318 | 319 | s = torch.from_numpy(envs.reset(reset_args_list)).float() 320 | state_hist, action_hist, t = [s], [], np.zeros([n_env]) 321 | 322 | for step in range(args.eval_max_steps): 323 | a = predict(model, action_hist, state_hist, t).cpu().numpy() 324 | 325 | s, _, _, infos = envs.step(a) 326 | s = torch.from_numpy(s).float() 327 | 328 | action_hist, state_hist, t = update( 329 | model, action_hist, state_hist, a, s, t) 330 | 331 | # Update metrics. 332 | for i, info in enumerate(infos): 333 | j = start_idx + i 334 | if args.task == 'PushChair-v1': 335 | metric_dict['test_h/close_to_target'][j].append(info['chair_close_to_target']) 336 | metric_dict['test_h/static_at_last'][j].append( 337 | info['chair_close_to_target'] and info['chair_static']) 338 | if args.task == 'TurnFaucet-v0': 339 | metric_dict['test_h/is_contacted'][j].append(info['is_contacted']) 340 | metric_dict['test_h/success'][j].append(info['success']) 341 | 342 | if all_reset_kwargs: 343 | output_str = '' 344 | for k, v in metric_dict.items(): 345 | v = np.mean([np.any(vv) for vv in v]) * 100 346 | output_str += f'{k} {v:.2f}, ' 347 | output_dict[k] = v 348 | output_str = output_str[:-2] 349 | print(output_str) 350 | 351 | if USE_WANDB: 352 | output_dict['n_iter'] = args.from_ckpt 353 | wandb.log(output_dict) -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for the model architecture of CoTPC, based on the GPT implementation. 3 | Some of the key hyper-parameters are explained in GPTConfig. 4 | 5 | References: 6 | (1) https://github.com/karpathy/minGPT 7 | (2) https://github.com/kzl/decision-transformer 8 | """ 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import functional as F 14 | import numpy as np 15 | 16 | 17 | class MLP(nn.Module): 18 | def __init__(self, input_dim, output_dim, hidden_dims=[], act_fn='relu'): 19 | super().__init__() 20 | assert act_fn in ['relu', 'tanh', None, ''] 21 | dims = [input_dim] + hidden_dims + [output_dim] 22 | layers = [] 23 | for i, j in zip(dims[:-1], dims[1:]): 24 | layers.append(nn.Linear(i, j)) 25 | if act_fn == 'relu': 26 | layers.append(nn.ReLU()) 27 | if act_fn == 'tanh': 28 | layers.append(nn.Tanh()) 29 | self.net = nn.Sequential(*layers[:-1]) 30 | 31 | def forward(self, x): 32 | return self.net(x) 33 | 34 | 35 | class GELU(nn.Module): 36 | def forward(self, input): 37 | return F.gelu(input) 38 | 39 | 40 | class GPTConfig: 41 | """ base GPT config, params common to all GPT versions """ 42 | 43 | embd_pdrop = 0.0 44 | resid_pdrop = 0.0 45 | attn_pdrop = 0.0 46 | 47 | def __init__(self, block_size, **kwargs): 48 | assert kwargs['model_type'] in ['s', 's+a', 's+cot', 's+a+cot'], \ 49 | f"Unsupported model_type: {kwargs['model_type']}" 50 | 51 | if '+a' in kwargs['model_type']: # If the action history is used. 52 | self.block_size = block_size * 2 53 | else: 54 | self.block_size = block_size 55 | 56 | if 'cot' in kwargs['model_type']: 57 | # `key_states` specifies which of the key states should be used for CoT. 58 | assert 'key_states' in kwargs, 'Should specify `key_states`' 59 | # It is in the form of 'acd...' that represents whether the key 60 | # state x is used. e.g., here a,c,d is used while b is skipped. 61 | assert kwargs['key_states'] not in ['', None] and \ 62 | np.all([ord('z') >= ord(g) >= ord('a') for g in kwargs['key_states']]) 63 | 64 | # `key_state_loss` specifies which layer's features in GPT should be used 65 | # for for the auxiliary key state prediction losses. 66 | assert 'key_state_loss' in kwargs, 'Should specify `key_state_loss`' 67 | # It is in the form of e.g., '023', meaning the features out of attention 68 | # layers of idx 0, 2, 3 are used for key state prediction losses. 69 | assert kwargs['key_state_loss'] not in ['', None] and \ 70 | np.all([l.isnumeric() for l in kwargs['key_state_loss']]) 71 | 72 | self.key_states = kwargs['key_states'] 73 | self.key_state_loss = kwargs['key_state_loss'] 74 | self.len_key_states = len(kwargs['key_states']) 75 | else: 76 | self.len_key_states = 0 77 | 78 | # Set up other attributes. 79 | for k,v in kwargs.items(): 80 | setattr(self, k, v) 81 | 82 | 83 | class CausalSelfAttentionWithCoT(nn.Module): 84 | """ 85 | A multi-head masked self-attention layer equipped with key state query tokens for 86 | chain-of-thought predictive control. It is adapted from the minGPT repo. 87 | """ 88 | 89 | def __init__(self, config): 90 | super().__init__() 91 | assert config.n_embd % config.n_head == 0 92 | 93 | # key, query, value projections for all heads 94 | self.key = nn.Linear(config.n_embd, config.n_embd) 95 | self.query = nn.Linear(config.n_embd, config.n_embd) 96 | self.value = nn.Linear(config.n_embd, config.n_embd) 97 | 98 | # regularization 99 | self.attn_drop = nn.Dropout(config.attn_pdrop) 100 | self.resid_drop = nn.Dropout(config.resid_pdrop) 101 | 102 | # output projection 103 | self.proj = nn.Linear(config.n_embd, config.n_embd) 104 | 105 | # causal mask to ensure that attention is only applied to the left in the input sequence 106 | block_size = config.block_size + config.len_key_states 107 | self.register_buffer("mask", 108 | torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) 109 | 110 | self.n_head = config.n_head 111 | self.model_type = config.model_type 112 | self.len_key_states = config.len_key_states 113 | 114 | # For the learnable key state query tokens, they are actually all-to-all, meaning 115 | # they can access to all future tokens during inference, and up to a future step 116 | # randomly selected during training (see `key_state_mask` in forward(...)). 117 | self.mask[:,:,:self.len_key_states] = 1.0 118 | 119 | def forward(self, x, key_state_mask=None): 120 | B, T, C = x.size() 121 | 122 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 123 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 124 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 125 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 126 | 127 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 128 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 129 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) # Masked attention 130 | 131 | # Masks used for the learnable key state query tokens, which are not causal (auto-regressive). 132 | if 'cot' in self.model_type: 133 | assert key_state_mask is not None 134 | att = att.masked_fill(key_state_mask, float('-inf')) 135 | 136 | att = F.softmax(att, dim=-1) 137 | att = self.attn_drop(att) 138 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 139 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 140 | 141 | # output projection 142 | y = self.resid_drop(self.proj(y)) 143 | return y 144 | 145 | 146 | class Block(nn.Module): 147 | """ 148 | A Transformer block with masks specified for the learnable key state query tokens. 149 | """ 150 | 151 | def __init__(self, config): 152 | super().__init__() 153 | self.ln1 = nn.LayerNorm(config.n_embd) 154 | self.ln2 = nn.LayerNorm(config.n_embd) 155 | self.attn = CausalSelfAttentionWithCoT(config) 156 | self.mlp = nn.Sequential( 157 | nn.Linear(config.n_embd, 4 * config.n_embd), 158 | GELU(), 159 | nn.Linear(4 * config.n_embd, config.n_embd), 160 | nn.Dropout(config.resid_pdrop), 161 | ) 162 | 163 | def forward(self, x, key_state_mask=None): 164 | x = x + self.attn(self.ln1(x), key_state_mask=key_state_mask) 165 | x = x + self.mlp(self.ln2(x)) 166 | return x 167 | 168 | 169 | class BlocksWithCoT(nn.Module): 170 | """ 171 | A wrapper class for a sequence of Transformer blocks with masks specified for 172 | the learnable key state query tokens. 173 | """ 174 | 175 | def __init__(self, config): 176 | super().__init__() 177 | # Register all the individual blocks. 178 | self.block_list = nn.ModuleList(Block(config) for _ in range(config.n_layer)) 179 | self.model_type = config.model_type 180 | self.n_head = config.n_head 181 | self.len_key_states = config.len_key_states 182 | 183 | def forward(self, x, key_state_mask=None): 184 | B, T, _ = x.shape 185 | 186 | # During training the `key_state_mask` is not specified and we apply random 187 | # masking such that the first t tokens after the key state query tokens are 188 | # 0's and otherwise 1's, where t is uniformly sampled from 0 to traj length. 189 | # Here 1's mean no attention over the underlying masked tokens. 190 | # During inference, the evaluator should specify key state masks. 191 | if key_state_mask is None: 192 | # If use both state and action history, 193 | # make sure masks for s and a has the same length. 194 | if '+a' in self.model_type: 195 | r = torch.randint(0, (T - self.len_key_states) // 2, [B])[:, None] * 2 196 | else: 197 | r = torch.randint(0, T - self.len_key_states, [B])[:, None] 198 | # When '+a', we always ignore the last action token (set mask=1). 199 | mask = torch.arange(0, T).repeat(B, 1) > r + self.len_key_states 200 | key_state_mask = torch.zeros( 201 | [B, self.n_head, T, T], dtype=bool, device=x.device) 202 | key_state_mask[:, :, :self.len_key_states, :] = \ 203 | mask[:, None, None, :].repeat(1, self.n_head, self.len_key_states, 1) 204 | 205 | output = [] # Also keep the intermediate results. 206 | for block in self.block_list: 207 | x = block(x, key_state_mask=key_state_mask) 208 | output.append(x) 209 | 210 | return x, output 211 | 212 | 213 | class GPTWithCoT(nn.Module): 214 | """ 215 | GPT implementation with the support of the learnable key state query tokens, 216 | which is used for the chain-of-thought predictive control. Here, the context size 217 | is specified as block_size, which does not count the key state query tokens. 218 | """ 219 | 220 | def __init__(self, config, state_dim=-1, action_dim=-1): 221 | super().__init__() 222 | 223 | assert state_dim > 0 and action_dim > 0 224 | self.config = config 225 | self.state_dim = state_dim 226 | self.action_dim = action_dim 227 | self.model_type = config.model_type 228 | self.key_states = config.key_states 229 | self.key_state_loss = config.key_state_loss 230 | self.len_key_states = config.len_key_states 231 | self.block_size = config.block_size 232 | self.cot_decoder = config.cot_decoder 233 | 234 | # Set up learnable position embedding synchronized for s and a tokens, as proposed 235 | # in Decision Transformer. We use a similar global+local position embedding design. 236 | p_size = config.block_size // 2 if '+a' in self.model_type else config.block_size 237 | self.local_pos_emb = nn.Parameter(torch.zeros(1, p_size, config.n_embd)) 238 | self.global_pos_emb = nn.Parameter( 239 | torch.zeros(1, config.max_timestep, config.n_embd)) 240 | 241 | self.drop = nn.Dropout(config.embd_pdrop) 242 | 243 | if 'cot' in self.model_type: 244 | self.key_state_pos_emb = nn.Parameter( 245 | torch.zeros(1, self.len_key_states, config.n_embd)) 246 | 247 | # Transformer (attention layers) with CoT. 248 | self.blocks = BlocksWithCoT(config) 249 | 250 | # State embeddings. 251 | self.state_encoder = MLP(self.state_dim, config.n_embd, hidden_dims=[256]) 252 | 253 | # Action embeddings. 254 | if '+a' in self.model_type: 255 | self.action_encoder = MLP(self.action_dim, config.n_embd, hidden_dims=[256]) 256 | 257 | # Action predictor. 258 | self.ln = nn.LayerNorm(config.n_embd) 259 | self.action_predictor = MLP(config.n_embd, action_dim, hidden_dims=[256,256]) 260 | 261 | # Key state predictors. By default, we only use one predictor which takes 262 | # features from one attention layer. 263 | if 'cot' in self.model_type: 264 | key_state_predictors = [] 265 | for _ in self.key_state_loss: 266 | key_state_predictor = MLP( 267 | config.n_embd, self.state_dim, hidden_dims=[int(self.cot_decoder)]) 268 | key_state_predictors.append(key_state_predictor) 269 | # Register all the key state predictors. 270 | self.key_state_predictors = nn.ModuleList(key_state_predictors) 271 | 272 | self.apply(self._init_weights) 273 | print(f"Total # of parameters: {sum(p.numel() for p in self.parameters())}") 274 | 275 | def _init_weights(self, module): 276 | if isinstance(module, (nn.Linear, nn.Embedding)): 277 | module.weight.data.normal_(mean=0.0, std=0.02) 278 | if isinstance(module, nn.Linear) and module.bias is not None: 279 | module.bias.data.zero_() 280 | elif isinstance(module, nn.LayerNorm): 281 | module.bias.data.zero_() 282 | module.weight.data.fill_(1.0) 283 | 284 | # Given state (and action) history, predict actions (and key states as CoT). 285 | # `timesteps` is used for the global+local position embedding design similar 286 | # to the one in Decision Transformer. `key_state_mask` is used so that the 287 | # (all-to-all) key state query tokens can attend to later tokens. 288 | def forward(self, states, timesteps, actions=None, key_state_mask=None): 289 | B, T = states.shape[0], states.shape[1] 290 | state_embeddings = self.state_encoder(states) 291 | 292 | # Embeddings for state (action, and key state query) tokens. 293 | token_embeddings = torch.zeros([B, self.block_size, self.config.n_embd], 294 | dtype=torch.float32, device=states.device) 295 | 296 | # If using action history as inputs: during training, all actions are 297 | # specified; during inference, only actions in the past are specified. 298 | # That is, the first action prediction has no action history as inputs. 299 | if '+a' in self.model_type: 300 | token_embeddings[:,:T*2:2,:] = state_embeddings 301 | if actions is not None: 302 | # Assume the last action is not used as inputs during training. 303 | token_embeddings[:,1:T*2-1:2,:] = self.action_encoder(actions[:,:T-1]) 304 | 305 | else: 306 | token_embeddings[:,:T,:] = state_embeddings 307 | 308 | # Set up position embeddings similar to that in Decision Transformer. 309 | global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0) 310 | timesteps_rp = torch.repeat_interleave(timesteps[:, None], self.config.n_embd, dim=-1) 311 | global_pos_emb = torch.gather( 312 | global_pos_emb, 1, timesteps_rp.long()) # BS x 1 x D 313 | local_pos_emb = torch.repeat_interleave(self.local_pos_emb, 2, dim=1) \ 314 | if '+a' in self.model_type else self.local_pos_emb 315 | 316 | x = token_embeddings + global_pos_emb + local_pos_emb 317 | if 'cot' in self.model_type: 318 | key_state_embeddings = self.key_state_pos_emb.repeat(B, 1, 1) 319 | x = torch.cat([key_state_embeddings, x], 1) 320 | 321 | x = self.drop(x) 322 | x, intermediate_feats = self.blocks(x, key_state_mask=key_state_mask) 323 | x = self.ln(x) 324 | act_preds = self.action_predictor(x) 325 | 326 | if 'cot' in self.model_type: 327 | key_state_preds = [] 328 | for idx, loss_layer_idx in enumerate([int(c) for c in self.key_state_loss]): 329 | key_state_preds.append(self.key_state_predictors[idx]( 330 | intermediate_feats[loss_layer_idx][:,:self.len_key_states])) 331 | 332 | # Get rid of dims for key state query tokens. 333 | act_preds = torch.split( 334 | act_preds, [self.len_key_states, self.block_size], dim=1)[1] 335 | else: 336 | key_state_preds = None 337 | 338 | # Get rid of dims for action tokens. 339 | if '+a' in self.model_type: 340 | # Remove the extra tokens when in eval mode. 341 | act_preds = act_preds[:,:T*2:2] 342 | 343 | return act_preds, key_state_preds 344 | 345 | def configure_adamw_optimizers(self, config): 346 | """ 347 | This long function is unfortunately doing something very simple and is being very defensive: 348 | We are separating out all parameters of the model into two buckets: those that will experience 349 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 350 | We are then returning the PyTorch optimizer object. 351 | """ 352 | 353 | # separate out all parameters to those that will and won't experience regularizing weight decay 354 | decay = set() 355 | no_decay = set() 356 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) 357 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 358 | for mn, m in self.named_modules(): 359 | for pn, _ in m.named_parameters(): 360 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 361 | 362 | if pn.endswith('bias'): 363 | # all biases will not be decayed 364 | no_decay.add(fpn) 365 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 366 | # weights of whitelist modules will be weight decayed 367 | decay.add(fpn) 368 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 369 | # weights of blacklist modules will NOT be weight decayed 370 | no_decay.add(fpn) 371 | 372 | # special case the position embedding parameter in the root GPT module as not decayed 373 | no_decay.add('local_pos_emb') 374 | no_decay.add('global_pos_emb') 375 | if 'cot' in self.model_type: 376 | no_decay.add('key_state_pos_emb') 377 | 378 | # validate that we considered every parameter 379 | param_dict = {pn: p for pn, p in self.named_parameters()} 380 | inter_params = decay & no_decay 381 | union_params = decay | no_decay 382 | assert len(inter_params) == 0, \ 383 | "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 384 | assert len(param_dict.keys() - union_params) == 0, \ 385 | "parameters %s were not separated into either decay/no_decay set!" \ 386 | % (str(param_dict.keys() - union_params), ) 387 | 388 | # create the pytorch optimizer object 389 | optim_groups = [ 390 | {"params": [param_dict[pn] for pn in sorted(list(decay))], 391 | "weight_decay": config['weight_decay']}, 392 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], 393 | "weight_decay": 0.0}, 394 | ] 395 | optimizer = torch.optim.AdamW( 396 | optim_groups, 397 | lr=config['init_lr'], 398 | betas=(config['beta1'], config['beta2']) 399 | ) 400 | return optimizer 401 | --------------------------------------------------------------------------------