├── .gitignore
├── github_teaser.png
├── scripts
├── eval.sh
├── train.sh
└── replay_turn_faucet_trajectories.sh
├── maniskill2_patches
├── pick_cube.py
├── peg_insertion_side.py
├── turn_faucet.py
└── record.py
├── src
├── train_utils.py
├── vec_env.py
├── train.py
├── data.py
├── eval.py
└── model.py
├── README.md
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | src/__pycache__
2 | models
3 | wandb
4 | path.py
5 |
--------------------------------------------------------------------------------
/github_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SeanJia/CoTPC/HEAD/github_teaser.png
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../src &&
4 |
5 | python eval.py --num_traj=500 --eval_max_steps=200 \
6 | --key_states=abc --key_state_loss=0 \
7 | --from_ckpt=1_800_000 --task=StackCube-v0 \
8 | --model_name=some_model_name
9 |
--------------------------------------------------------------------------------
/maniskill2_patches/pick_cube.py:
--------------------------------------------------------------------------------
1 | """
2 | Patch the pick_cube env in ManiSkill2 so that it allows additional metrics for eval
3 | and flags for obtaining key states in training. Please simply replace the `evaluate`
4 | function in (with the correct level of indentation):
5 |
6 | https://github.com/haosulab/ManiSkill2/blob/main/mani_skill2/envs/pick_and_place/pick_cube.py
7 | """
8 |
9 | def evaluate(self, **kwargs):
10 | is_obj_placed = self.check_obj_placed()
11 | is_robot_static = self.check_robot_static()
12 | is_grasped = self.agent.check_grasp(self.obj, max_angle=30)
13 | return dict(
14 | is_obj_placed=is_obj_placed,
15 | is_robot_static=is_robot_static,
16 | is_grasped=is_grasped,
17 | success=is_obj_placed and is_robot_static,
18 | )
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../src &&
4 |
5 | # Example script for PickCube training (with a good set of hyper-parameters).
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --model_name=some_model_name \
8 | --num_traj=500 --n_iters=1_600_000 \
9 | --context_length=60 --model_type=s+a+cot \
10 | --task=PickCube-v0 --key_state_coeff=0.1 \
11 | --key_state_loss=0 --key_states=ab \
12 | --init_lr=5e-4 --num_workers=20
13 |
14 | # CUDA_VISIBLE_DEVICES=0 python train.py \
15 | # --model_name=some_model_name \
16 | # --num_traj=500 --n_iters=1_600_000 \
17 | # --context_length=60 --model_type=s+a+cot \
18 | # --task=TurnFaucet-v0 --key_state_coeff=0.1 \
19 | # --key_state_loss=0 --key_states=ab \
20 | # --init_lr=5e-4 --num_workers=20
21 |
--------------------------------------------------------------------------------
/src/train_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for the cosine decay learning rate with linear warmup.
3 | """
4 |
5 | import math
6 | import functools
7 | import torch
8 |
9 |
10 | def _cosine_decay_warmup(iteration, warmup_iterations, total_iterations):
11 | """
12 | Linear warmup from 0 --> 1.0, then decay using cosine decay to 0.1
13 | """
14 | if iteration <= warmup_iterations:
15 | multiplier = iteration / warmup_iterations
16 | else:
17 | multiplier = (iteration - warmup_iterations) / (total_iterations - warmup_iterations)
18 | multiplier = max(0.1, 0.5 * (1 + math.cos(math.pi * multiplier)))
19 | return multiplier
20 |
21 |
22 | def CosineAnnealingLRWarmup(optimizer, T_max, T_warmup):
23 | _decay_func = functools.partial(
24 | _cosine_decay_warmup,
25 | warmup_iterations=T_warmup, total_iterations=T_max
26 | )
27 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, _decay_func)
28 | return scheduler
29 |
--------------------------------------------------------------------------------
/scripts/replay_turn_faucet_trajectories.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Assume that the *.h5 and *.json are in `../data/rigid_body_envs/TurnFaucet-v0/raw`,
4 | # replay the trajectories with a subset of a total of 10 faucet models.
5 | for s in 5002 5021 5023 5028 5029 5045 5047 5051 5056 5063
6 | do
7 | python -m mani_skill2.trajectory.replay_trajectory \
8 | --traj-path ../data/rigid_body_envs/TurnFaucet-v0/raw/$s.h5 \
9 | --save-traj --target-control-mode pd_joint_delta_pos \
10 | --obs-mode state --num-procs 20
11 | done
12 |
13 | mv ../data/rigid_body_envs/TurnFaucet-v0/raw/*.state.pd_joint_delta_pos.h5 \
14 | ../data/rigid_body_envs/TurnFaucet-v0/merged/
15 | mv ../data/rigid_body_envs/TurnFaucet-v0/raw/*.state.pd_joint_delta_pos.json \
16 | ../data/rigid_body_envs/TurnFaucet-v0/merged/
17 |
18 | python -m mani_skill2.trajectory.merge_trajectory \
19 | -i ../data/rigid_body_envs/TurnFaucet-v0/merged -p *.h5 \
20 | -o ../data/rigid_body_envs/TurnFaucet-v0/trajectory.state.pd_joint_delta_pos.h5
--------------------------------------------------------------------------------
/maniskill2_patches/peg_insertion_side.py:
--------------------------------------------------------------------------------
1 | # I updated part of the environment file in ManiSkill2 so that it allows metrics
2 | # computation of either reaching an intermediate key state during testing.
3 | # Simply replace the `evaluate` function.
4 |
5 | """
6 | Patch the peg_insertion_side env in ManiSkill2 so that it allows additional metrics for
7 | eval and flags for obtaining key states in training. Please simply replace the `evaluate`
8 | function in (with the correct level of indentation):
9 |
10 | https://github.com/haosulab/ManiSkill2/blob/main/mani_skill2/envs/assembly/peg_insertion_side.py
11 | """
12 |
13 | def evaluate(self, **kwargs):
14 | is_grasped = self.agent.check_grasp(self.peg, max_angle=20)
15 |
16 | pre_inserted = False
17 | if is_grasped:
18 | peg_head_wrt_goal = self.goal_pose.inv() * self.peg_head_pose
19 | peg_head_wrt_goal_yz_dist = np.linalg.norm(peg_head_wrt_goal.p[1:])
20 | peg_wrt_goal = self.goal_pose.inv() * self.peg.pose
21 | peg_wrt_goal_yz_dist = np.linalg.norm(peg_wrt_goal.p[1:])
22 | if peg_head_wrt_goal_yz_dist < 0.01 and peg_wrt_goal_yz_dist < 0.01:
23 | pre_inserted = True
24 |
25 | success, peg_head_pos_at_hole = self.has_peg_inserted()
26 | return dict(
27 | success=success,
28 | pre_inserted=pre_inserted,
29 | peg_head_pos_at_hole=peg_head_pos_at_hole,
30 | is_grasped=is_grasped,
31 | )
--------------------------------------------------------------------------------
/maniskill2_patches/turn_faucet.py:
--------------------------------------------------------------------------------
1 | """
2 | Patch the turn_faucet env in ManiSkill2 so that it allows additional metrics for eval
3 | and flags for obtaining key states in training. Please replace the `evaluate` function
4 | in (with the correct level of indentation):
5 |
6 | https://github.com/haosulab/ManiSkill2/blob/main/mani_skill2/envs/misc/turn_faucet.py
7 |
8 | Moroever, replace the `_get_obs_extra` function to customize the state representation
9 | so that it is easier to distinguish among different faucet models.
10 |
11 | Note:
12 | The 10 faucet models (a simpler subset of all faucets in ManiSkill2) we use in the
13 | CoTPC paper have the ids: 5002,5021,5023,5028,5029,5045,5047,5051,5056,5063.
14 | """
15 |
16 | def _get_curr_target_link_pos(self):
17 | """
18 | Access the current pose of the target link (i.e., the handle of the faucet).
19 | """
20 | cmass_pose = self.target_link.pose * self.target_link.cmass_local_pose
21 | return cmass_pose.p
22 |
23 | def _get_obs_extra(self):
24 | obs = OrderedDict(
25 | tcp_pose=vectorize_pose(self.tcp.pose),
26 | target_angle_diff=self.target_angle_diff,
27 | target_joint_axis=self.target_joint_axis,
28 | target_link_pos=self.target_link_pos,
29 | curr_target_link_pos=self._get_curr_target_link_pos(), # Added code.
30 | )
31 | if self._obs_mode in ["state", "state_dict"]:
32 | angle_dist = self.target_angle - self.current_angle
33 | obs["angle_dist"] = angle_dist
34 | return obs
35 |
36 | def evaluate(self, **kwargs):
37 | is_contacted = any(self.agent.check_contact_fingers(self.target_link))
38 | angle_dist = self.target_angle - self.current_angle
39 | return dict(
40 | success=angle_dist < 0,
41 | angle_dist=angle_dist,
42 | is_contacted=is_contacted)
--------------------------------------------------------------------------------
/maniskill2_patches/record.py:
--------------------------------------------------------------------------------
1 | """
2 | Patch the record utility in ManiSkill2 so that it records additional metrics for eval
3 | and flags for obtaining key states in training. Please patch the `flush_trajectory`
4 | function in (with the correct level of indentation):
5 |
6 | https://github.com/haosulab/ManiSkill2/blob/main/mani_skill2/utils/wrappers/record.py
7 | """
8 |
9 | def flush_trajectory(self, **args):
10 |
11 | # some code here ...
12 |
13 | ########################### ADDED CODE #############################
14 | # Append info (boolean flags) to the recorded trajectories.
15 | # This tells you what info to store in the trajs.
16 | info_bool_keys = []
17 | for k, v in self._episode_data[-1]['info'].items():
18 | if type(v).__module__ == np.__name__ and v.dtype == 'bool':
19 | info_bool_keys.append(k)
20 | elif isinstance(v, bool):
21 | info_bool_keys.append(k)
22 |
23 | # This info only appears in some trajs.
24 | if 'TimeLimit.truncated' in info_bool_keys:
25 | info_bool_keys.remove('TimeLimit.truncated')
26 | ####################################################################
27 |
28 | if len(self._episode_data) == 1:
29 | # some code here ...
30 |
31 | ########################### ADDED CODE #############################
32 | infos_bool = {k: np.empty(shape=(0,), dtype=bool) for k in info_bool_keys}
33 | ####################################################################
34 | else:
35 | # some code here ...
36 |
37 | ########################### ADDED CODE #############################
38 | infos_bool = {k: [] for k in info_bool_keys}
39 | for x in self._episode_data[1:]:
40 | for k in info_bool_keys:
41 | infos_bool[k].append(x['info'][k])
42 | for k in infos_bool:
43 | infos_bool[k] = np.stack(infos_bool[k])
44 | ####################################################################
45 |
46 | # some code here ...
47 |
48 | ########################### ADDED CODE #############################
49 | # Dump the additional entries to the demo trajectories.
50 | rewards = np.array([x['r'] for x in self._episode_data[1:]], dtype=np.float32)
51 | group.create_dataset("rewards", data=rewards, dtype=np.float32)
52 | for k, v in infos_bool.items():
53 | group.create_dataset(f"infos/{k}", data=v, dtype=bool)
54 | ####################################################################
55 |
56 | # Handle JSON
57 | # some code here ...
--------------------------------------------------------------------------------
/src/vec_env.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import cloudpickle
3 | import numpy as np
4 | from multiprocessing import Pipe, Process
5 | import gym
6 |
7 | from transforms3d.euler import euler2quat
8 | from transforms3d.quaternions import qmult
9 |
10 | import sapien.core as sapien
11 |
12 |
13 | def disturb(env, kwargs):
14 | if 'peg' in kwargs:
15 | dx, dy, dr = kwargs['peg']
16 | pose = env.peg.get_pose()
17 | quat = euler2quat(0, 0, dr)
18 | env.peg.set_pose(sapien.Pose(
19 | p=pose.p+[dx,dy,0], q=qmult(quat, pose.q)))
20 | if 'box' in kwargs:
21 | dx, dy, dr = kwargs['box']
22 | pose = env.box.get_pose()
23 | quat = euler2quat(0, 0, dr)
24 | env.box.set_pose(sapien.Pose(
25 | p=pose.p+[dx,dy,0], q=qmult(quat, pose.q)))
26 |
27 | def get_mp_envs(env_id, n_env, start_idx=0, **env_kwargs):
28 | def env_fn(rank):
29 | def fn():
30 | env = gym.make(env_id, **env_kwargs)
31 | return env
32 | return fn
33 | return VecEnv([env_fn(i + start_idx) for i in range(n_env)])
34 |
35 |
36 | class CloudpickleWrapper(object):
37 | def __init__(self, x):
38 | self.x = x
39 |
40 | def __getstate__(self):
41 | return cloudpickle.dumps(self.x)
42 |
43 | def __setstate__(self, ob):
44 | self.x = pickle.loads(ob)
45 |
46 | def __call__(self):
47 | return self.x()
48 |
49 | def worker(remote, parent_remote, env_fn):
50 | parent_remote.close()
51 | env = env_fn()
52 | while True:
53 | cmd, data = remote.recv()
54 |
55 | if cmd == 'step':
56 | ob, reward, done, info = env.step(data)
57 | # if done: ob = env.reset() # We ignore the done signal here.
58 | remote.send((ob, reward, done, info))
59 | elif cmd == 'reset':
60 | ob = env.reset(**data)
61 | remote.send(ob)
62 | elif cmd == 'render':
63 | remote.send(env.render())
64 | elif cmd == 'close':
65 | remote.close()
66 | break
67 | elif cmd == 'disturb':
68 | disturb(env, data)
69 | else:
70 | raise NameError('NotImplentedError')
71 |
72 | class VecEnv():
73 | def __init__(self, env_fns):
74 | self.waiting = False
75 | self.closed = False
76 | no_envs = len(env_fns)
77 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(no_envs)])
78 | self.ps = []
79 |
80 | for wrk, rem, fn in zip(self.work_remotes, self.remotes, env_fns):
81 | p = Process(target=worker, args=(wrk, rem, CloudpickleWrapper(fn)))
82 | self.ps.append(p)
83 |
84 | for p in self.ps:
85 | p.daemon = True
86 | p.start()
87 |
88 | for remote in self.work_remotes:
89 | remote.close()
90 |
91 | def step_async(self, actions):
92 | if self.waiting:
93 | raise NameError('AlreadySteppingError')
94 | self.waiting = True
95 | for remote, action in zip(self.remotes, actions):
96 | remote.send(('step', action))
97 |
98 | def step_wait(self):
99 | if not self.waiting:
100 | raise NameError('NotSteppingError')
101 | self.waiting = False
102 | results = [remote.recv() for remote in self.remotes]
103 | obs, rews, dones, infos = zip(*results)
104 | return np.stack(obs), np.stack(rews), np.stack(dones), infos
105 |
106 | def step(self, actions):
107 | self.step_async(actions)
108 | return self.step_wait()
109 |
110 | def reset(self, kwargs_list):
111 | for remote, kwargs in zip(self.remotes, kwargs_list):
112 | remote.send(('reset', kwargs))
113 | return np.stack([remote.recv() for remote in self.remotes])
114 |
115 | def disturb(self, kwargs_list):
116 | for remote, kwargs in zip(self.remotes, kwargs_list):
117 | remote.send(('disturb', kwargs))
118 |
119 | def close(self):
120 | if self.closed:
121 | return
122 | if self.waiting:
123 | for remote in self.remotes:
124 | remote.recv()
125 | for remote in self.remotes:
126 | remote.send(('close', None))
127 | for p in self.ps:
128 | p.join()
129 | self.closed = True
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chain-of-Thought Predictive Control (CoTPC)
2 | This is the official repository for CoTPC, a powerful hierarchical imitation learning model presented in the following paper:
3 |
4 | ### **[Chain-of-Thought Predictive Control](https://zjia.eng.ucsd.edu/cotpc)**
5 | Zhiwei Jia, Fangchen Liu, Vineet Thumuluri, Linghao Chen, Zhiao Huang, Hao Su
6 | UC San Diego, UC Berkeley, Zhejiang University
7 |
8 |
9 | 
10 | [arXiv] [website]
11 |
12 |
13 | ### Tasks
14 | Currently the code supports five tasks from the [ManiSkill2](https://github.com/haosulab/ManiSkill2) benchmark:
15 | `PickCube-v0`, `StackCube-v0`, `PegInsertionSide-v0`, `TurnFaucet-v0`, and `PushChair-v1`.
16 |
17 | ### Demonstration Data
18 | The state-based demo trajectories used in the paper are stored in this Googld Drive [folder](https://drive.google.com/drive/folders/1VdunXUlzqAvy-D8MniQ4anhV5LLBfNbJ).
19 | Each folder has a `*.h5` file (for actual trajectories) and a `*.json` file (for metadata regarding the trajectories).
20 | Each task has over 1000 demo trajectories.
21 | Each trajectory comes with a different env configurations (i.e., env seed, which influences object poses, object geometries, etc.).
22 | These demos are generated by replaying the official ManiSkill2 [demos](https://github.com/haosulab/ManiSkill2#demonstrations) or the ones adapted from [this](https://github.com/caiqi/Silver-Bullet-3D/tree/master/No_Restriction) repo with several patches to the ManiSkill2 code (see `CoTPC/maniskill2_patches`).
23 | Specifically, we add additional flags to the tasks so that the key states (the Chain-of-Thought) can be obtained with priviledged information from the simulator.
24 | For the task `TurnFaucet-v0`, we use a subset of 10 faucet models for the demos (see `CoTPC/scripts/replay_turn_faucet_trajectories.sh`).
25 | For the task `PushChair-v1`, we use a subset of 5 chair models for the demos.
26 | If you want to generate visual-based demos, please refer to the official ManiSkill2 guidance [here](https://github.com/haosulab/ManiSkill2#demonstrations).
27 |
28 | ### Data Loader
29 | The data loader samples (contiguous) subarrays of demo trajectories by specifying the min and max sample lengths.
30 | In CoTPC, we simply use a fixed value for both min and max (e.g., set the context size as 60 for all of the 5 tasks from ManiSkill2).
31 | With a fixed random seed `seed` and `num_traj`, each time the data loader samples a fixed subset of all trajectories (by default, we use `seed=0`).
32 | For TurnFaucet and PushChair, due to the variations of the different object models, the loader performs sampling such that the number of trajs
33 | per model is the same (hopefully this data balance eases model training).
34 | The key states (a.k.a the Chain-of-Thought) can be obtained with the function `get_key_states(data_idx)` that accesses privileged info available during training.
35 |
36 | ### Evaluation
37 | We patch the environment code from ManiSkill2 (see `CoTPC/maniskill2_patches`) to provide additional evaluation metrics (intermediate success rate) each task execution.
38 | We report evaluation results for models using both the seen env configurations and the unseen env configuration.
39 | For tasks involving geometric variations we also evaluate using unseen objects (i.e., zero-shot transfer).
40 | Please refer to `CoTPC/src/eval.py` for details.
41 |
42 |
43 |
44 |
45 | ### Sampled Training Scripts
46 | Please see `CoTPC/scripts/train.sh` as examples.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 | import torch.nn.functional as F
7 |
8 | from collections import deque
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | from data import MS2Demos, get_padding_fn
13 | from model import GPTConfig, GPTWithCoT
14 | from train_utils import CosineAnnealingLRWarmup
15 |
16 | try:
17 | # Use might need this for wandb to work due to protobuf issues.
18 | os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
19 |
20 | import wandb
21 | USE_WANDB = True
22 | PROJECT_NAME = 'CoTPC' # Please specify the project name.
23 | except ImportError:
24 | print('Do not use wandb since it is not found.')
25 | USE_WANDB = False
26 |
27 | # Please specify MODEL_PATH (the base folder for storing models) in `path.py`.
28 | from path import MODEL_PATH
29 |
30 | def parse_args():
31 | parser = argparse.ArgumentParser()
32 |
33 | # Training hyper-parameters.
34 | parser.add_argument("--n_iters", default=1_600_000, type=int, help="Number of training iterations.")
35 | parser.add_argument("--batch_size", default=256, type=int, help="Batch size.")
36 | parser.add_argument("--init_lr", default='5e-4', type=str, help="The initial learning rate.")
37 | parser.add_argument("--weight_decay", default='0', type=str, help="Weight decay coefficient.")
38 | parser.add_argument("--beta1", default='0.9', type=str, help="Beta1 in the Adam optimizer.")
39 | parser.add_argument("--beta2", default='0.95', type=str, help="Beta2 in the Adam optimizer.")
40 | parser.add_argument("--dropout", default='0.0', type=str, help="Dropout probability.")
41 | parser.add_argument("--lr_schedule", default='cos_decay_with_warmup', type=str,
42 | help="The learning rate schedule.")
43 |
44 | # Hyper-parameters regarding CoTPC.
45 | parser.add_argument("--key_state_coeff", default=0.0, type=float,
46 | help="Coefficient for the key state prediction loss.")
47 | parser.add_argument('--model_type', type=str, default='s+a+cot',
48 | help="Model type for the CoTPC model (see GPTConfig).")
49 | parser.add_argument('--key_states', type=str, default='a',
50 | help="Which key states to use (see GPTConfig for the spec. format).")
51 | parser.add_argument("--key_state_loss", default='', type=str,
52 | help="Features out of what attention layers to use for key state prediction " +
53 | "losses (see GPTConfig for the spec. format).")
54 | parser.add_argument('--cot_decoder', type=str, default='256', help="Specs of the CoT decoder.")
55 |
56 | # General hyper-parameters regarding model loading and saving
57 | parser.add_argument("--model_name", default='', type=str, help="Model name (for storing ckpts).")
58 | parser.add_argument("--from_model_name", default='', type=str, help="Name of the pretrained model.")
59 | parser.add_argument("--from_ckpt", default=-1, type=int, help="Ckpt of pretrained model.")
60 |
61 | # Hyper-parameters regarding the demo dataset
62 | parser.add_argument('--task', type=str, default='PickCube-v0', help="Task (env-id) in ManiSkill2.")
63 | parser.add_argument('--control_mode', type=str, default='pd_joint_delta_pos',
64 | help="Control mode used in envs from ManiSkill2.")
65 | parser.add_argument('--obs_mode', type=str, default='state',
66 | help="State mode used in envs from ManiSkill2.")
67 | parser.add_argument("--seed", default=0, type=int,help="Random seed for data spliting.")
68 | parser.add_argument("--num_traj", default=-1, type=int, help="Number of training trajectories.")
69 | parser.add_argument('--context_length', type=int, default=60,
70 | help="Context size of CoTPC (the maximium length of sequences " +
71 | "sampled from demo trajectories in training).")
72 | parser.add_argument('--min_seq_length', type=int, default=60,
73 | help="Mininum length of sequences sampled from demo trajectories in training.")
74 |
75 | # Save and log frequencies.
76 | parser.add_argument("--save_every", default=40000, type=int, help="Save model every # iters.")
77 | parser.add_argument("--log_every", default=2000, type=int, help="log metrics every # iters.")
78 |
79 | # General hyper-parameters for the GPT architecture.
80 | parser.add_argument("--n_layer", default=4, type=int, help="Number of attention layers.")
81 | parser.add_argument("--n_head", default=8, type=int, help="Number of attention heads.")
82 | parser.add_argument("--n_embd", default=128, type=int, help="Hidden feature dimension.")
83 |
84 | # For faster data loader.
85 | parser.add_argument("--num_workers", default=2, type=int,
86 | help="A positive number for fast async data loading.")
87 | parser.add_argument('--multiplier', type=int, default=20,
88 | help="Duplicate the dataset to reduce data loader overhead.")
89 |
90 | return parser.parse_args()
91 |
92 |
93 | def mse_loss_with_weights(preds, targets, weights=None):
94 | losses = torch.mean((preds - targets) ** 2, -1)
95 | if weights is None:
96 | return torch.mean(losses)
97 | else:
98 | assert losses.shape == weights.shape, losses.shape
99 | return torch.mean(losses * weights)
100 |
101 |
102 | def get_loss(preds, targets, lengths):
103 | # If we have sequences of varied lengths, use masks so we do not compute loss
104 | # over padded values. If we set max_seq_length=min_seq_length, then it should
105 | # not matter since all sequences have the same length.
106 | B = preds.shape[0]
107 | max_len = torch.max(lengths) # Max length of the current mini-batch.
108 | lengths = lengths[:, None] # B x 1
109 | temp = torch.arange(0, max_len)[None].expand(B, -1).cuda() # B x max_len
110 | masks = (temp < lengths.expand(B, max_len)).float() # B x max_len
111 |
112 | loss = mse_loss_with_weights(
113 | preds.reshape(-1, preds.size(-1)),
114 | targets.reshape(-1, targets.size(-1)),
115 | masks.reshape(-1))
116 | return loss
117 |
118 |
119 | if __name__ == "__main__":
120 |
121 | args = parse_args()
122 | assert args.model_name != '', 'Should specify --model_name'
123 | print('Model name:', args.model_name)
124 |
125 | if 'cot' in args.model_type:
126 | assert args.key_states, 'Should specify --key_states.'
127 |
128 | train_dataset = MS2Demos(
129 | control_mode=args.control_mode,
130 | obs_mode=args.obs_mode,
131 | length=args.num_traj, seed=args.seed,
132 | min_seq_length=args.min_seq_length,
133 | max_seq_length=args.context_length,
134 | with_key_states='cot' in args.model_type,
135 | task=args.task, multiplier=args.multiplier)
136 | print('Training data size:', len(train_dataset))
137 | print('Max steps:', train_dataset.max_steps)
138 |
139 | input_dict = ['s', 'a', 't']
140 | input_dict += ['k'] if 'cot' in args.model_type else []
141 | collate_fn = get_padding_fn(input_dict)
142 | train_data = DataLoader(
143 | dataset=train_dataset,
144 | batch_size=args.batch_size,
145 | shuffle=True,
146 | pin_memory=True, # Faster data loading if using GPU.
147 | num_workers=args.num_workers,
148 | persistent_workers=True, # Faster data loader resets.
149 | collate_fn=collate_fn,
150 | drop_last=True,
151 | )
152 | data_iter = iter(train_data)
153 |
154 | state_dim, action_dim = train_dataset.info()
155 | conf = GPTConfig(
156 | args.context_length,
157 | n_layer=args.n_layer,
158 | n_head=args.n_head,
159 | n_embd=args.n_embd,
160 | model_type=args.model_type,
161 | key_states=args.key_states,
162 | key_state_loss=args.key_state_loss,
163 | max_timestep=train_dataset.max_steps,
164 | embd_pdrop=float(args.dropout),
165 | resid_pdrop=float(args.dropout),
166 | attn_pdrop=float(args.dropout),
167 | cot_decoder=args.cot_decoder,
168 | )
169 | model = GPTWithCoT(conf, state_dim=state_dim, action_dim=action_dim).cuda()
170 | optimizer = model.configure_adamw_optimizers({
171 | 'init_lr': float(args.init_lr),
172 | 'weight_decay': float(args.weight_decay),
173 | 'beta1': float(args.beta1),
174 | 'beta2': float(args.beta2),
175 | })
176 |
177 | # Learning rate schedules (which might require more tuning).
178 | if args.lr_schedule == 'cos_decay_with_warmup':
179 | lr_scheduler = CosineAnnealingLRWarmup(
180 | optimizer, T_max=args.n_iters, T_warmup=1000)
181 | else:
182 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
183 | optimizer, milestones=[78000], gamma=0.1)
184 |
185 | model_path = os.path.join(MODEL_PATH, args.model_name)
186 | os.makedirs(model_path, exist_ok=True)
187 |
188 | # If loaded from pretrained model first.
189 | if args.from_ckpt > 0:
190 | if args.from_model_name:
191 | path = os.path.join(
192 | MODEL_PATH, args.from_model_name, f'{args.from_ckpt}.pth')
193 | else:
194 | path = os.path.join(model_path, f'{args.from_ckpt}.pth')
195 | model.load_state_dict(torch.load(path), strict=True)
196 | print(f'Pretrained model loaded from {path}.')
197 |
198 | log_path = os.path.join(model_path, 'log.txt')
199 | if USE_WANDB:
200 | wandb.init(
201 | project=PROJECT_NAME, name=args.model_name, config=args,
202 | config_exclude_keys=['model_name', 'save_every', 'log_every'],
203 | )
204 | wandb.run.log_code(".") # Need to first enable it on wandb web UI.
205 |
206 | losses_act_pred = deque(maxlen=1000)
207 | losses_key_states = deque(maxlen=1000)
208 |
209 | # Convert key states to integers.
210 | key_states = [ord(c) - ord('a') for c in args.key_states]
211 |
212 | # Main training loop.
213 | for idx in tqdm(range(args.n_iters + 1)):
214 |
215 | # Adjust lr schedule when loaded from pretrained models.
216 | if args.from_ckpt > 0 and idx <= args.from_ckpt:
217 | lr_scheduler.step()
218 | continue
219 |
220 | # Obtain the current mini-batch (infinite loop).
221 | try:
222 | batch = next(data_iter)
223 | except StopIteration:
224 | data_iter = iter(train_data)
225 | batch = next(data_iter)
226 | batch = {k: v.cuda() for k, v in batch.items()}
227 |
228 | # Forward pass.
229 | act_pred, key_states_pred = model(batch['s'], batch['t'], batch['a'])
230 |
231 | # Obtain training losses.
232 | loss_act_pred = get_loss(act_pred, batch['a'], batch['lengths'])
233 | total_loss = loss_act_pred
234 |
235 | loss_key_states = torch.tensor(-1) # -1 means N/A.
236 | if 'cot' in args.model_type:
237 | ks_gt = torch.stack(
238 | [batch['k'][:, k_idx] for k_idx in key_states], 1)
239 | loss_key_states = torch.mean(torch.stack(
240 | [F.mse_loss(ks_pred, ks_gt) for ks_pred in key_states_pred]))
241 | if args.key_state_coeff > 0:
242 | total_loss += args.key_state_coeff * loss_key_states
243 |
244 | losses_act_pred.append(loss_act_pred.item())
245 | losses_key_states.append(loss_key_states.item())
246 | optimizer.zero_grad()
247 | total_loss.backward()
248 | optimizer.step()
249 |
250 | if idx % args.log_every == 0:
251 | with open(log_path, 'a' if os.path.exists(log_path) else 'w') as f:
252 | avg_loss_act_pred = np.mean(losses_act_pred)
253 | avg_loss_key_states = np.mean(losses_key_states)
254 | print(f'Iteration {idx}: {avg_loss_act_pred}, {avg_loss_key_states}')
255 | f.write(f'{idx},{avg_loss_act_pred},{avg_loss_key_states}\n')
256 | if USE_WANDB:
257 | log_dict = {
258 | 'n_iter': idx,
259 | 'loss_actions': avg_loss_act_pred,
260 | 'loss_sum': avg_loss_act_pred,
261 | }
262 | if 'cot' in args.model_type:
263 | log_dict['loss_key_states'] = avg_loss_key_states
264 | log_dict['loss_sum'] = avg_loss_act_pred + avg_loss_key_states
265 | wandb.log(log_dict)
266 |
267 | if idx > 0 and idx % args.save_every == 0:
268 | save_path = os.path.join(model_path, f'{idx}.pth')
269 | torch.save({
270 | 'model': model.state_dict(),
271 | 'metadata': vars(args)
272 | }, save_path)
273 |
274 | # Update learning rate.
275 | lr_scheduler.step()
276 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import h5py
4 |
5 | from torch.utils.data.dataset import Dataset
6 | from torch.nn.utils.rnn import pad_sequence
7 | import torch
8 |
9 | # Please specify the DATA_PATH (the base folder for storing data) in `path.py`.
10 | from path import DATA_PATH
11 |
12 |
13 | class MS2Demos(Dataset):
14 | def __init__(self,
15 | data_split='train',
16 | task='PickCube-v0',
17 | obs_mode='state',
18 | control_mode='pd_joint_delta_pos',
19 | length=-1,
20 | min_seq_length=None,
21 | max_seq_length=None,
22 | with_key_states=False,
23 | multiplier=20, # Used for faster data loading.
24 | seed=None): # seed for train/test spliting.
25 | super().__init__()
26 | self.task = task
27 | self.data_split = data_split
28 | self.seed = seed
29 | self.min_seq_length = min_seq_length # For sampling trajectories.
30 | self.max_seq_length = max_seq_length # For sampling trajectories.
31 | self.with_key_states = with_key_states # Whether output key states.
32 | self.multiplier = multiplier
33 |
34 | # Usually set min and max traj length to be the same value.
35 | self.max_steps = -1 # Maximum timesteps across all trajectories.
36 | traj_path = os.path.join(DATA_PATH,
37 | f'{task}/trajectory.{obs_mode}.{control_mode}.h5')
38 | print('Traj path:', traj_path)
39 | self.data = self.load_demo_dataset(traj_path, length)
40 |
41 | # Cache key states for faster data loading.
42 | if self.with_key_states:
43 | self.idx_to_key_states = dict()
44 |
45 | def __len__(self):
46 | return len(self.data['env_states'])
47 |
48 | def __getitem__(self, index):
49 | # Offset by one since the last obs does not have a corresponding action.
50 | l = len(self.data['obs'][index]) - 1
51 |
52 | # Sample starting and ending index given the min and max traj length.
53 | if self.min_seq_length is None and self.max_seq_length is None:
54 | s_idx, e_idx = 0, l
55 | else:
56 | min_length = 0 if self.min_seq_length is None else self.min_seq_length
57 | max_length = l if self.max_seq_length is None else self.max_seq_length
58 | assert min_length <= max_length
59 | if min_length == max_length:
60 | length = min_length
61 | else:
62 | length = np.random.randint(min_length, max_length, 1)[0]
63 | if length <= l:
64 | s_idx = np.random.randint(0, l - length + 1, 1)[0]
65 | e_idx = s_idx + length
66 | else:
67 | s_idx, e_idx = 0, l
68 | assert e_idx <= l, f'{e_idx}, {l}'
69 |
70 | # Call get_key_states() if you want to use the key states.
71 | # Here `s` is the state observation, `a` is the action,
72 | # `env_states` not used during training (can be used to reconstruct env for debugging).
73 | # `t` is used for positional embedding as in Decision Transformer.
74 | data_dict = {
75 | 's': self.data['obs'][index][s_idx:e_idx].astype(np.float32),
76 | 'a': self.data['actions'][index][s_idx:e_idx].astype(np.float32),
77 | 't': np.array([s_idx]).astype(np.float32),
78 | # 'env_states': self.data['env_states'][index][s_idx:e_idx].astype(np.float32),
79 | }
80 | if self.with_key_states:
81 | if f'key_states_{index}' not in self.idx_to_key_states:
82 | self.idx_to_key_states[f'key_states_{index}'] = self.get_key_states(index)
83 | data_dict['k'] = self.idx_to_key_states[f'key_states_{index}']
84 | return data_dict
85 |
86 | def info(self): # Get observation and action shapes.
87 | return self.data['obs'][0].shape[-1], self.data['actions'][0].shape[-1]
88 |
89 | def load_demo_dataset(self, path, length):
90 | dataset = {}
91 | traj_all = h5py.File(path)
92 | if length == -1:
93 | length = len(traj_all)
94 | np.random.seed(self.seed) # Fix the random seed for train/test data split.
95 |
96 | # Since TurnFaucet uses 10 different faucet models, we shuffle the data
97 | # such that the resulting sampled data are evenly sampled across faucet models.
98 | if self.task == 'TurnFaucet-v0':
99 | ids = []
100 | for i in range(10): # Hard-code the 10 data splits for permutation.
101 | t_ids = np.random.permutation(len(traj_all)//10)[:length//10]
102 | t_ids += i*len(traj_all)//10
103 | ids.append(t_ids)
104 | ids = np.concatenate(ids)
105 | # Since PushChair uses 5 different faucet models, we shuffle the data
106 | # such that the resulting sampled data are evenly sampled across chair models.
107 | elif self.task == 'PushChair-v1':
108 | ids = []
109 | for i in range(5): # Hard-code the 5 data splits for permutation.
110 | t_ids = np.random.permutation(len(traj_all)//5)[:length//5]
111 | t_ids += i*len(traj_all)//5
112 | ids.append(t_ids)
113 | ids = np.concatenate(ids)
114 | else:
115 | ids = np.random.permutation(len(traj_all))[:length]
116 |
117 | ids = ids.tolist() * self.multiplier # Duplicate the data for faster loading.
118 |
119 | # Note that the size of `env_states` and `obs` is that of the others + 1.
120 | # And most `infos` is for the next obs rather than the current obs.
121 |
122 | # `env_states` is used for reseting the env (might be helpful for eval)
123 | dataset['env_states'] = [np.array(
124 | traj_all[f"traj_{i}"]['env_states']) for i in ids]
125 | # `obs` is the observation of each step.
126 | dataset['obs'] = [np.array(traj_all[f"traj_{i}"]["obs"]) for i in ids]
127 | dataset['actions'] = [np.array(traj_all[f"traj_{i}"]["actions"]) for i in ids]
128 |
129 | # actions = np.concatenate(dataset['actions'])
130 | # actions_std = np.std(actions, 0)
131 | # dataset['actions'] = [
132 | # np.array(traj_all[f"traj_{i}"]["actions"]) / (actions_std + 1e-7) for i in ids]
133 |
134 | # `rewards` is not currently used in CoTPC training.
135 | dataset['rewards'] = [np.array(traj_all[f"traj_{i}"]["rewards"]) for i in ids]
136 | for k in traj_all['traj_0']['infos'].keys():
137 | dataset[f'infos/{k}'] = [np.array(
138 | traj_all[f"traj_{i}"]["infos"][k]) for i in ids]
139 | if k == 'info': # For PushChair.
140 | for kk in traj_all['traj_0']['infos'][k].keys():
141 | dataset[f'infos/demo_{kk}'] = [np.array(
142 | traj_all[f"traj_{i}"]["infos"][k][kk]) for i in ids]
143 |
144 | self.max_steps = np.max([len(s) for s in dataset['env_states']])
145 |
146 | return dataset
147 |
148 | def get_key_states(self, idx):
149 | # Note that `infos` is for the next obs rather than the current obs.
150 | # Thus, we need to offset the `step_idx`` by one.
151 | key_states = []
152 |
153 | # If TurnFaucet (two key states)
154 | # key state I: is_contacted -> true
155 | # key state II: end of the trajectory
156 | if self.task == 'TurnFaucet-v0':
157 | for step_idx, key in enumerate(self.data['infos/is_contacted'][idx]):
158 | if key: break
159 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32))
160 |
161 | # If PegInsertion (three key states)
162 | # key state I: is_grasped -> true
163 | # key state II: pre_inserted -> true
164 | # key state III: end of the trajectory
165 | if self.task == 'PegInsertionSide-v0':
166 | for step_idx, key in enumerate(self.data['infos/is_grasped'][idx]):
167 | if key: break
168 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32))
169 | for step_idx, key in enumerate(self.data['infos/pre_inserted'][idx]):
170 | if key: break
171 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32))
172 |
173 | # If PickCube (two key states)
174 | # key state I: is_grasped -> true
175 | # key state II: end of the trajectory
176 | if self.task == 'PickCube-v0':
177 | for step_idx, key in enumerate(self.data['infos/is_grasped'][idx]):
178 | if key: break
179 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32))
180 |
181 | # If StackCube (three key states)
182 | # key state I: is_cubaA_grasped -> true
183 | # key state II: the last state of is_cubeA_on_cubeB -> true
184 | # right before is_cubaA_grasped -> false
185 | # key state III: end of the trajectory
186 | if self.task == 'StackCube-v0':
187 | for step_idx, key in enumerate(self.data['infos/is_cubaA_grasped'][idx]):
188 | if key: break
189 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32))
190 | for step_idx, k1 in enumerate(self.data['infos/is_cubeA_on_cubeB'][idx]):
191 | k2 = self.data['infos/is_cubaA_grasped'][idx][step_idx]
192 | if k1 and not k2: break
193 | # Right before such a state and so we do not use step_idx+1.
194 | key_states.append(self.data['obs'][idx][step_idx].astype(np.float32))
195 |
196 | # If PushChair (four key states):
197 | # key state I: right before demo_rotate -> true
198 | # key state II: right before demo_move -> true
199 | # key state III: when chair_close_to_target & chair_standing -> true
200 | # key state IV: end of the trajectory
201 | lengths = []
202 | # In PushChair, demo_* indicate the current state (not the next).
203 | if self.task == 'PushChair-v1':
204 | for step_idx, key in enumerate(self.data['infos/demo_rotate'][idx]):
205 | if key: break
206 | lengths.append(step_idx)
207 | key_states.append(self.data['obs'][idx][step_idx].astype(np.float32))
208 | for step_idx, key in enumerate(self.data['infos/demo_move'][idx]):
209 | if key: break
210 | lengths.append(step_idx - np.sum(lengths))
211 | key_states.append(self.data['obs'][idx][step_idx].astype(np.float32))
212 | for step_idx, key in enumerate(np.bitwise_and(
213 | self.data['infos/chair_close_to_target'][idx],
214 | self.data['infos/chair_standing'][idx])):
215 | if key: break
216 | lengths.append(step_idx + 1 - np.sum(lengths))
217 | key_states.append(self.data['obs'][idx][step_idx+1].astype(np.float32))
218 | lengths.append(len(self.data['infos/success'][idx]) - np.sum(lengths))
219 |
220 | # Always append the last state in the trajectory as the last key state.
221 | key_states.append(self.data['obs'][idx][-1].astype(np.float32))
222 |
223 | key_states = np.stack(key_states, 0).astype(np.float32)
224 | assert len(key_states) > 0, self.task
225 | return key_states
226 |
227 |
228 | # To obtain the padding function for sequences.
229 | def get_padding_fn(data_names):
230 | assert 's' in data_names, 'Should at least include `s` in data_names.'
231 |
232 | def pad_collate(*args):
233 | assert len(args) == 1
234 | output = {k: [] for k in data_names}
235 | for b in args[0]: # Batches
236 | for k in data_names:
237 | output[k].append(torch.from_numpy(b[k]))
238 |
239 | # Include the actual length of each sequence sampled from a trajectory.
240 | # If we set max_seq_length=min_seq_length, this is a constant across samples.
241 | output['lengths'] = torch.tensor([len(s) for s in output['s']])
242 |
243 | # Padding all the sequences.
244 | for k in data_names:
245 | output[k] = pad_sequence(output[k], batch_first=True, padding_value=0)
246 |
247 | return output
248 |
249 | return pad_collate
250 |
251 |
252 | # Sample code for the data loader.
253 | if __name__ == "__main__":
254 |
255 | from torch.utils.data import DataLoader
256 |
257 | # The default values for CoTPC for tasks in ManiSkill2.
258 | batch_size, num_traj, seed, min_seq_length, max_seq_length, task = \
259 | 256, 500, 0, 60, 60, 'PickCube-v0'
260 | batch_size, num_traj, seed, min_seq_length, max_seq_length, task = \
261 | 256, 500, 0, 60, 60, 'PushChair-v1'
262 |
263 | train_dataset = MS2Demos(
264 | # control_mode='pd_joint_delta_pos',
265 | control_mode='base_pd_joint_vel_arm_pd_joint_vel',
266 | length=num_traj, seed=seed,
267 | min_seq_length=min_seq_length,
268 | max_seq_length=max_seq_length,
269 | with_key_states=True,
270 | task=task)
271 |
272 | collate_fn = get_padding_fn(['s', 'a', 't', 'k'])
273 | train_data = DataLoader(
274 | dataset=train_dataset,
275 | batch_size=batch_size,
276 | collate_fn=collate_fn)
277 |
278 | data_iter = iter(train_data)
279 | data = next(data_iter)
280 | # print(len(data)) # 4
281 | # for k, v in data.items():
282 | # print(k, v.shape)
283 | # 's', [256, 60, 51]
284 | # 'a', [256, 60, 8]
285 | # 't', [256, 1]
286 | # 'k', [256, 2, 51]
287 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | from tqdm import tqdm
5 | from collections import defaultdict
6 |
7 | from mani_skill2.utils.io_utils import load_json
8 | import mani_skill2.envs # Load ManiSkill2 envs.
9 | import torch # Load pytorch after maniskill2 to avoid some import error.
10 |
11 | from model import GPTConfig, GPTWithCoT
12 |
13 | from vec_env import get_mp_envs # Used for parallel evaluation.
14 |
15 | try:
16 | # Use might need this for wandb to work due to protobuf issues.
17 | os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
18 | import wandb
19 | assert wandb.__version__
20 | USE_WANDB = True
21 | PROJECT_NAME = 'CoTPC' # Please specify the project name.
22 | except Exception:
23 | print('Do not use wandb since it is not found.')
24 | USE_WANDB = False
25 |
26 | # Please specify MODEL_PATH and DATA_PATH (both are base folders) in `path.py`.
27 | from path import MODEL_PATH, DATA_PATH
28 |
29 |
30 | @torch.no_grad()
31 | def predict(model, action_hist, state_hist, t):
32 | assert model.model_type in ['s', 's+a', 's+a+cot']
33 |
34 | timesteps = torch.from_numpy(t)[:, None].cuda()
35 | if not action_hist: # The first step.
36 | actions = None
37 | else:
38 | actions = torch.stack(action_hist, 1).float().cuda()
39 | states = torch.stack(state_hist, 1).float().cuda()
40 |
41 | if 'cot' in model.model_type:
42 | # T is the max sequence size; S is the current number of steps.
43 | B, T = states.shape[0], model.block_size + model.len_key_states
44 | n_head, S = model.config.n_head, states.shape[1] - 1 # Exclude the init state.
45 |
46 | # Masks for the all-to-all key state query tokens in attention layers.
47 | # The built-in masks for causal (auto-regressive) tokens are in `model.py`.
48 | key_state_mask = torch.zeros([B, n_head, T, T], dtype=bool)
49 | m1 = torch.arange(0, T).repeat(B, 1)
50 | m2 = torch.ones([B, 1]) * (S * 2 + model.len_key_states)
51 | m3 = m1 > m2 # Tokens in the future are masked out.
52 | m3 = m3[:, None, None, :].repeat(1, n_head, model.len_key_states, 1)
53 | key_state_mask[:, :, :model.len_key_states, :] = m3
54 | key_state_mask = key_state_mask.cuda()
55 | preds, _ = model(
56 | states, timesteps, actions=actions, key_state_mask=key_state_mask)
57 | else:
58 | preds, _ = model(states, timesteps, actions=actions)
59 |
60 | return preds[:, -1] # Only output the last action predictions.
61 |
62 |
63 | def update(model, action_hist, state_hist, actions, states, t):
64 | # A function used to update the state and action history.
65 | assert model.model_type in ['s', 's+a', 's+a+cot']
66 |
67 | actions = torch.from_numpy(actions)
68 | if len(state_hist) == model.block_size // 2: # The context buffer is full.
69 | assert len(action_hist) == model.block_size // 2 - 1
70 | state_hist = state_hist[1:] + [states]
71 | action_hist = action_hist[1:] + [actions]
72 | t += 1
73 | else:
74 | state_hist.append(states)
75 | action_hist.append(actions)
76 | return action_hist, state_hist, t
77 |
78 |
79 | def parse_args():
80 | parser = argparse.ArgumentParser()
81 |
82 | # Hyper-parameters regarding the demo dataset (used to gather eval_ids)
83 | parser.add_argument('--task', type=str, default='PickCube-v0', help="Task (env-id) in ManiSkill2.")
84 | parser.add_argument('--control_mode', type=str, default='pd_joint_delta_pos',
85 | help="Control mode used in envs from ManiSkill2.")
86 | parser.add_argument('--obs_mode', type=str, default='state',
87 | help="State mode used in envs from ManiSkill2.")
88 | parser.add_argument("--seed", default=0, type=int,help="Random seed for data spliting.")
89 |
90 | # Hyper-parameters regarding the model.
91 | parser.add_argument("--model_name", default='', type=str, help="Model name to be loaded.")
92 | parser.add_argument("--from_ckpt", default=-1, type=int, help="Ckpt of the model to be loaded.")
93 |
94 | parser.add_argument("--eval_max_steps", default=200, type=int, help="Max steps allowed in eval.")
95 | parser.add_argument('--cot_decoder', type=str, default='256', help="Specs of the CoT decoder.")
96 |
97 | return parser.parse_args()
98 |
99 |
100 | if __name__ == "__main__":
101 |
102 | args = parse_args()
103 | assert args.model_name, 'Should specify --model_name'
104 | assert args.from_ckpt > 0, 'Should specify --from_ckpt'
105 |
106 | # Load the model.
107 | path = os.path.join(MODEL_PATH, f'{args.model_name}/{args.from_ckpt}.pth')
108 | # Load to cpu first to avoid cuda related errors from ManiSkill2.
109 | ckpt = torch.load(path, map_location=torch.device('cpu'))
110 | state_dict_from_ckpt, params = ckpt['model'], ckpt['metadata']
111 | state_dim = state_dict_from_ckpt['state_encoder.net.0.weight'].shape[1]
112 | action_dim = state_dict_from_ckpt['action_encoder.net.0.weight'].shape[1]
113 | max_timestep = state_dict_from_ckpt['global_pos_emb'].shape[1]
114 | print('Loaded ckpt from:', path)
115 |
116 | # Load demos to fetch the env. seeds used in training.
117 | json_path = os.path.join(
118 | DATA_PATH, f'{args.task}/trajectory.{args.obs_mode}.{args.control_mode}.json')
119 | json_data = load_json(json_path)
120 | env_kwargs = json_data["env_info"]["env_kwargs"]
121 | env_kwargs["obs_mode"] = args.obs_mode
122 | env_kwargs["control_mode"] = args.control_mode
123 | np.random.seed(args.seed)
124 | if args.task == 'TurnFaucet-v0':
125 | length_all = len(json_data["episodes"])
126 | ids = []
127 | for i in range(10): # Hard-code the 10 data splits for permutation.
128 | t_ids = np.random.permutation(
129 | length_all // 10)[:params['num_traj'] // 10]
130 | t_ids += i * length_all // 10
131 | ids.append(t_ids)
132 | eval_ids = np.concatenate(ids)
133 | elif args.task == 'PushChair-v1':
134 | length_all = len(json_data["episodes"])
135 | ids = []
136 | for i in range(5): # Hard-code the 5 data splits for permutation.
137 | t_ids = np.random.permutation(length_all // 5)[:100]
138 | t_ids += i * length_all // 5
139 | ids.append(t_ids)
140 | eval_ids = np.concatenate(ids)
141 | else:
142 | # Only evaluate at most 500 scene configs.
143 | eval_ids = np.random.permutation(
144 | len(json_data["episodes"]))[:params['num_traj']][:500]
145 |
146 | n_env = 25 # Number of parallel environments.
147 | assert len(eval_ids) % n_env == 0, f'{len(eval_ids)}'
148 | envs = get_mp_envs(args.task, n_env, **env_kwargs)
149 |
150 | # Load the ckpt after envs init to avoid cuda related errors from ManiSkill2.
151 | cot_decoder = params['cot_decoder'] if 'cot_decoder' in params else args.cot_decoder
152 | conf = GPTConfig(
153 | params['context_length'],
154 | n_layer=params['n_layer'],
155 | n_head=params['n_head'],
156 | n_embd=params['n_embd'],
157 | model_type=params['model_type'],
158 | key_states=params['key_states'], # Rules for the CoT.
159 | key_state_loss=params['key_state_loss'], # Layers used for CoT modeling.
160 | cot_decoder=cot_decoder,
161 | max_timestep=max_timestep,
162 | )
163 | model = GPTWithCoT(conf, state_dim=state_dim, action_dim=action_dim).cuda()
164 | model.load_state_dict(state_dict_from_ckpt, strict=False)
165 | model.eval()
166 |
167 | if USE_WANDB:
168 | wandb.init(project=PROJECT_NAME, name=f'eval/{args.model_name}',
169 | id=f'wandb_metrics_{args.model_name}', resume='auto')
170 |
171 | output_str, output_dict = '', dict()
172 |
173 | # Seen scene configurations.
174 | metric_dict = defaultdict(lambda: [[] for _ in range(len(eval_ids))])
175 | for start_idx in tqdm(range(0, len(eval_ids), n_env)):
176 | reset_args_list = []
177 | for i in range(start_idx, min(start_idx + n_env, len(eval_ids))):
178 | reset_kwargs = json_data["episodes"][eval_ids[i]]['reset_kwargs']
179 | reset_args_list.append(reset_kwargs)
180 |
181 | s = torch.from_numpy(envs.reset(reset_args_list)).float()
182 | state_hist, action_hist, t = [s], [], np.zeros([n_env])
183 |
184 | for step in range(args.eval_max_steps):
185 | a = predict(model, action_hist, state_hist, t).cpu().numpy()
186 |
187 | s, _, _, infos = envs.step(a)
188 | s = torch.from_numpy(s).float()
189 |
190 | action_hist, state_hist, t = update(
191 | model, action_hist, state_hist, a, s, t)
192 |
193 | # Update metrics.
194 | for i, info in enumerate(infos):
195 | j = start_idx + i
196 | # You might want to use these additional metrics.
197 | if args.task == 'PickCube-v0':
198 | metric_dict['is_grasped'][j].append(info['is_grasped'])
199 | if args.task == 'StackCube-v0':
200 | metric_dict['is_cubaA_grasped'][j].append(info['is_cubaA_grasped'])
201 | metric_dict['is_cubeA_on_cubeB'][j].append(info['is_cubeA_on_cubeB'])
202 | if args.task == 'PegInsertionSide-v0':
203 | metric_dict['is_grasped'][j].append(info['is_grasped'])
204 | metric_dict['pre_inserted'][j].append(info['pre_inserted'])
205 | if args.task == 'TurnFaucet-v0':
206 | metric_dict['is_contacted'][j].append(info['is_contacted'])
207 | if args.task == 'PushChair-v1':
208 | metric_dict['close_to_target'][j].append(info['chair_close_to_target'])
209 | metric_dict['static_at_last'][j].append(
210 | info['chair_close_to_target'] and info['chair_static'])
211 | metric_dict['success'][j].append(info['success'])
212 |
213 | for k, v in metric_dict.items():
214 | v = np.mean([np.any(vv) for vv in v]) * 100
215 | output_str += f'{k} {v:.2f}, '
216 | output_dict[k] = v
217 | output_str = output_str[:-2]
218 | print(output_str)
219 |
220 | # Unseen scene configurations.
221 | # Unseen objects for peg insertion and seen objects otherwise.
222 | all_reset_kwargs = []
223 | if args.task == 'TurnFaucet-v0':
224 | length_all = len(json_data["episodes"])
225 | ids = []
226 | for i in range(10): # Hard-code the 10 data splits for permutation.
227 | t_ids = np.random.permutation(length_all // 10)
228 | t_ids = t_ids[params['num_traj']//10:params['num_traj']//10+10]
229 | t_ids += i * length_all // 10
230 | ids.append(t_ids)
231 | eval_ids = np.concatenate(ids)
232 | for eval_id in eval_ids:
233 | all_reset_kwargs.append(json_data["episodes"][eval_id]['reset_kwargs'])
234 | elif args.task == 'PushChair-v1':
235 | length_all = len(json_data["episodes"])
236 | ids = []
237 | for i in range(5): # Hard-code the 5 data splits for permutation.
238 | t_ids = np.random.permutation(length_all // 5)
239 | t_ids = t_ids[params['num_traj']//5:params['num_traj']//5+50]
240 | t_ids += i * length_all // 5
241 | ids.append(t_ids)
242 | eval_ids = np.concatenate(ids)
243 | for eval_id in eval_ids:
244 | all_reset_kwargs.append(json_data["episodes"][eval_id]['reset_kwargs'])
245 | elif args.task == 'PegInsertionSide-v0':
246 | for i in range(400):
247 | all_reset_kwargs.append({'seed': i + 2000})
248 | else:
249 | for i in range(100):
250 | all_reset_kwargs.append({'seed': i + 2000})
251 | metric_dict = defaultdict(lambda: [[] for _ in range(len(all_reset_kwargs))])
252 |
253 | for start_idx in tqdm(range(0, len(all_reset_kwargs), n_env)):
254 | reset_args_list = []
255 | for i in range(start_idx, min(start_idx + n_env, len(all_reset_kwargs))):
256 | reset_args_list.append(all_reset_kwargs[i])
257 |
258 | s = torch.from_numpy(envs.reset(reset_args_list)).float()
259 | state_hist, action_hist, t = [s], [], np.zeros([n_env])
260 |
261 | for step in range(args.eval_max_steps):
262 | a = predict(model, action_hist, state_hist, t).cpu().numpy()
263 | s, _, _, infos = envs.step(a)
264 | s = torch.from_numpy(s).float()
265 |
266 | action_hist, state_hist, t = update(
267 | model, action_hist, state_hist, a, s, t)
268 |
269 | # Update metrics.
270 | for i, info in enumerate(infos):
271 | j = start_idx + i
272 | # You might want to use these additional metrics.
273 | if args.task == 'PickCube-v0':
274 | metric_dict['test/is_grasped'][j].append(info['is_grasped'])
275 | if args.task == 'StackCube-v0':
276 | metric_dict['test/is_cubaA_grasped'][j].append(info['is_cubaA_grasped'])
277 | metric_dict['test/is_cubeA_on_cubeB'][j].append(info['is_cubeA_on_cubeB'])
278 | if args.task == 'PegInsertionSide-v0':
279 | metric_dict['test/is_grasped'][j].append(info['is_grasped'])
280 | metric_dict['test/pre_inserted'][j].append(info['pre_inserted'])
281 | if args.task == 'TurnFaucet-v0':
282 | metric_dict['test/is_contacted'][j].append(info['is_contacted'])
283 | if args.task == 'PushChair-v1':
284 | metric_dict['test/close_to_target'][j].append(info['chair_close_to_target'])
285 | metric_dict['test/static_at_last'][j].append(
286 | info['chair_close_to_target'] and info['chair_static'])
287 | metric_dict['test/success'][j].append(info['success'])
288 |
289 | output_str = ''
290 | for k, v in metric_dict.items():
291 | v = np.mean([np.any(vv) for vv in v]) * 100
292 | output_str += f'{k} {v:.2f}, '
293 | output_dict[k] = v
294 | output_str = output_str[:-2]
295 | print(output_str)
296 |
297 | # Unseen scene configurations with unseen objects (zero-shot).
298 | all_reset_kwargs = []
299 | if args.task == 'TurnFaucet-v0':
300 | model_ids = [
301 | 5014, 5037, 5053, 5062,
302 | ]
303 | elif args.task == 'PushChair-v1':
304 | model_ids = [
305 | 3003, 3013, 3020,
306 | ]
307 | else:
308 | model_ids = []
309 | for model_id in model_ids:
310 | for i in range(100):
311 | all_reset_kwargs.append({'seed': i + 2000, 'model_id': str(model_id)})
312 | metric_dict = defaultdict(lambda: [[] for _ in range(len(all_reset_kwargs))])
313 |
314 | for start_idx in tqdm(range(0, len(all_reset_kwargs), n_env)):
315 | reset_args_list = []
316 | for i in range(start_idx, min(start_idx + n_env, len(all_reset_kwargs))):
317 | reset_args_list.append(all_reset_kwargs[i])
318 |
319 | s = torch.from_numpy(envs.reset(reset_args_list)).float()
320 | state_hist, action_hist, t = [s], [], np.zeros([n_env])
321 |
322 | for step in range(args.eval_max_steps):
323 | a = predict(model, action_hist, state_hist, t).cpu().numpy()
324 |
325 | s, _, _, infos = envs.step(a)
326 | s = torch.from_numpy(s).float()
327 |
328 | action_hist, state_hist, t = update(
329 | model, action_hist, state_hist, a, s, t)
330 |
331 | # Update metrics.
332 | for i, info in enumerate(infos):
333 | j = start_idx + i
334 | if args.task == 'PushChair-v1':
335 | metric_dict['test_h/close_to_target'][j].append(info['chair_close_to_target'])
336 | metric_dict['test_h/static_at_last'][j].append(
337 | info['chair_close_to_target'] and info['chair_static'])
338 | if args.task == 'TurnFaucet-v0':
339 | metric_dict['test_h/is_contacted'][j].append(info['is_contacted'])
340 | metric_dict['test_h/success'][j].append(info['success'])
341 |
342 | if all_reset_kwargs:
343 | output_str = ''
344 | for k, v in metric_dict.items():
345 | v = np.mean([np.any(vv) for vv in v]) * 100
346 | output_str += f'{k} {v:.2f}, '
347 | output_dict[k] = v
348 | output_str = output_str[:-2]
349 | print(output_str)
350 |
351 | if USE_WANDB:
352 | output_dict['n_iter'] = args.from_ckpt
353 | wandb.log(output_dict)
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for the model architecture of CoTPC, based on the GPT implementation.
3 | Some of the key hyper-parameters are explained in GPTConfig.
4 |
5 | References:
6 | (1) https://github.com/karpathy/minGPT
7 | (2) https://github.com/kzl/decision-transformer
8 | """
9 |
10 | import math
11 | import torch
12 | import torch.nn as nn
13 | from torch.nn import functional as F
14 | import numpy as np
15 |
16 |
17 | class MLP(nn.Module):
18 | def __init__(self, input_dim, output_dim, hidden_dims=[], act_fn='relu'):
19 | super().__init__()
20 | assert act_fn in ['relu', 'tanh', None, '']
21 | dims = [input_dim] + hidden_dims + [output_dim]
22 | layers = []
23 | for i, j in zip(dims[:-1], dims[1:]):
24 | layers.append(nn.Linear(i, j))
25 | if act_fn == 'relu':
26 | layers.append(nn.ReLU())
27 | if act_fn == 'tanh':
28 | layers.append(nn.Tanh())
29 | self.net = nn.Sequential(*layers[:-1])
30 |
31 | def forward(self, x):
32 | return self.net(x)
33 |
34 |
35 | class GELU(nn.Module):
36 | def forward(self, input):
37 | return F.gelu(input)
38 |
39 |
40 | class GPTConfig:
41 | """ base GPT config, params common to all GPT versions """
42 |
43 | embd_pdrop = 0.0
44 | resid_pdrop = 0.0
45 | attn_pdrop = 0.0
46 |
47 | def __init__(self, block_size, **kwargs):
48 | assert kwargs['model_type'] in ['s', 's+a', 's+cot', 's+a+cot'], \
49 | f"Unsupported model_type: {kwargs['model_type']}"
50 |
51 | if '+a' in kwargs['model_type']: # If the action history is used.
52 | self.block_size = block_size * 2
53 | else:
54 | self.block_size = block_size
55 |
56 | if 'cot' in kwargs['model_type']:
57 | # `key_states` specifies which of the key states should be used for CoT.
58 | assert 'key_states' in kwargs, 'Should specify `key_states`'
59 | # It is in the form of 'acd...' that represents whether the key
60 | # state x is used. e.g., here a,c,d is used while b is skipped.
61 | assert kwargs['key_states'] not in ['', None] and \
62 | np.all([ord('z') >= ord(g) >= ord('a') for g in kwargs['key_states']])
63 |
64 | # `key_state_loss` specifies which layer's features in GPT should be used
65 | # for for the auxiliary key state prediction losses.
66 | assert 'key_state_loss' in kwargs, 'Should specify `key_state_loss`'
67 | # It is in the form of e.g., '023', meaning the features out of attention
68 | # layers of idx 0, 2, 3 are used for key state prediction losses.
69 | assert kwargs['key_state_loss'] not in ['', None] and \
70 | np.all([l.isnumeric() for l in kwargs['key_state_loss']])
71 |
72 | self.key_states = kwargs['key_states']
73 | self.key_state_loss = kwargs['key_state_loss']
74 | self.len_key_states = len(kwargs['key_states'])
75 | else:
76 | self.len_key_states = 0
77 |
78 | # Set up other attributes.
79 | for k,v in kwargs.items():
80 | setattr(self, k, v)
81 |
82 |
83 | class CausalSelfAttentionWithCoT(nn.Module):
84 | """
85 | A multi-head masked self-attention layer equipped with key state query tokens for
86 | chain-of-thought predictive control. It is adapted from the minGPT repo.
87 | """
88 |
89 | def __init__(self, config):
90 | super().__init__()
91 | assert config.n_embd % config.n_head == 0
92 |
93 | # key, query, value projections for all heads
94 | self.key = nn.Linear(config.n_embd, config.n_embd)
95 | self.query = nn.Linear(config.n_embd, config.n_embd)
96 | self.value = nn.Linear(config.n_embd, config.n_embd)
97 |
98 | # regularization
99 | self.attn_drop = nn.Dropout(config.attn_pdrop)
100 | self.resid_drop = nn.Dropout(config.resid_pdrop)
101 |
102 | # output projection
103 | self.proj = nn.Linear(config.n_embd, config.n_embd)
104 |
105 | # causal mask to ensure that attention is only applied to the left in the input sequence
106 | block_size = config.block_size + config.len_key_states
107 | self.register_buffer("mask",
108 | torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
109 |
110 | self.n_head = config.n_head
111 | self.model_type = config.model_type
112 | self.len_key_states = config.len_key_states
113 |
114 | # For the learnable key state query tokens, they are actually all-to-all, meaning
115 | # they can access to all future tokens during inference, and up to a future step
116 | # randomly selected during training (see `key_state_mask` in forward(...)).
117 | self.mask[:,:,:self.len_key_states] = 1.0
118 |
119 | def forward(self, x, key_state_mask=None):
120 | B, T, C = x.size()
121 |
122 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
123 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
124 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
125 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
126 |
127 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
128 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
129 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) # Masked attention
130 |
131 | # Masks used for the learnable key state query tokens, which are not causal (auto-regressive).
132 | if 'cot' in self.model_type:
133 | assert key_state_mask is not None
134 | att = att.masked_fill(key_state_mask, float('-inf'))
135 |
136 | att = F.softmax(att, dim=-1)
137 | att = self.attn_drop(att)
138 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
139 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
140 |
141 | # output projection
142 | y = self.resid_drop(self.proj(y))
143 | return y
144 |
145 |
146 | class Block(nn.Module):
147 | """
148 | A Transformer block with masks specified for the learnable key state query tokens.
149 | """
150 |
151 | def __init__(self, config):
152 | super().__init__()
153 | self.ln1 = nn.LayerNorm(config.n_embd)
154 | self.ln2 = nn.LayerNorm(config.n_embd)
155 | self.attn = CausalSelfAttentionWithCoT(config)
156 | self.mlp = nn.Sequential(
157 | nn.Linear(config.n_embd, 4 * config.n_embd),
158 | GELU(),
159 | nn.Linear(4 * config.n_embd, config.n_embd),
160 | nn.Dropout(config.resid_pdrop),
161 | )
162 |
163 | def forward(self, x, key_state_mask=None):
164 | x = x + self.attn(self.ln1(x), key_state_mask=key_state_mask)
165 | x = x + self.mlp(self.ln2(x))
166 | return x
167 |
168 |
169 | class BlocksWithCoT(nn.Module):
170 | """
171 | A wrapper class for a sequence of Transformer blocks with masks specified for
172 | the learnable key state query tokens.
173 | """
174 |
175 | def __init__(self, config):
176 | super().__init__()
177 | # Register all the individual blocks.
178 | self.block_list = nn.ModuleList(Block(config) for _ in range(config.n_layer))
179 | self.model_type = config.model_type
180 | self.n_head = config.n_head
181 | self.len_key_states = config.len_key_states
182 |
183 | def forward(self, x, key_state_mask=None):
184 | B, T, _ = x.shape
185 |
186 | # During training the `key_state_mask` is not specified and we apply random
187 | # masking such that the first t tokens after the key state query tokens are
188 | # 0's and otherwise 1's, where t is uniformly sampled from 0 to traj length.
189 | # Here 1's mean no attention over the underlying masked tokens.
190 | # During inference, the evaluator should specify key state masks.
191 | if key_state_mask is None:
192 | # If use both state and action history,
193 | # make sure masks for s and a has the same length.
194 | if '+a' in self.model_type:
195 | r = torch.randint(0, (T - self.len_key_states) // 2, [B])[:, None] * 2
196 | else:
197 | r = torch.randint(0, T - self.len_key_states, [B])[:, None]
198 | # When '+a', we always ignore the last action token (set mask=1).
199 | mask = torch.arange(0, T).repeat(B, 1) > r + self.len_key_states
200 | key_state_mask = torch.zeros(
201 | [B, self.n_head, T, T], dtype=bool, device=x.device)
202 | key_state_mask[:, :, :self.len_key_states, :] = \
203 | mask[:, None, None, :].repeat(1, self.n_head, self.len_key_states, 1)
204 |
205 | output = [] # Also keep the intermediate results.
206 | for block in self.block_list:
207 | x = block(x, key_state_mask=key_state_mask)
208 | output.append(x)
209 |
210 | return x, output
211 |
212 |
213 | class GPTWithCoT(nn.Module):
214 | """
215 | GPT implementation with the support of the learnable key state query tokens,
216 | which is used for the chain-of-thought predictive control. Here, the context size
217 | is specified as block_size, which does not count the key state query tokens.
218 | """
219 |
220 | def __init__(self, config, state_dim=-1, action_dim=-1):
221 | super().__init__()
222 |
223 | assert state_dim > 0 and action_dim > 0
224 | self.config = config
225 | self.state_dim = state_dim
226 | self.action_dim = action_dim
227 | self.model_type = config.model_type
228 | self.key_states = config.key_states
229 | self.key_state_loss = config.key_state_loss
230 | self.len_key_states = config.len_key_states
231 | self.block_size = config.block_size
232 | self.cot_decoder = config.cot_decoder
233 |
234 | # Set up learnable position embedding synchronized for s and a tokens, as proposed
235 | # in Decision Transformer. We use a similar global+local position embedding design.
236 | p_size = config.block_size // 2 if '+a' in self.model_type else config.block_size
237 | self.local_pos_emb = nn.Parameter(torch.zeros(1, p_size, config.n_embd))
238 | self.global_pos_emb = nn.Parameter(
239 | torch.zeros(1, config.max_timestep, config.n_embd))
240 |
241 | self.drop = nn.Dropout(config.embd_pdrop)
242 |
243 | if 'cot' in self.model_type:
244 | self.key_state_pos_emb = nn.Parameter(
245 | torch.zeros(1, self.len_key_states, config.n_embd))
246 |
247 | # Transformer (attention layers) with CoT.
248 | self.blocks = BlocksWithCoT(config)
249 |
250 | # State embeddings.
251 | self.state_encoder = MLP(self.state_dim, config.n_embd, hidden_dims=[256])
252 |
253 | # Action embeddings.
254 | if '+a' in self.model_type:
255 | self.action_encoder = MLP(self.action_dim, config.n_embd, hidden_dims=[256])
256 |
257 | # Action predictor.
258 | self.ln = nn.LayerNorm(config.n_embd)
259 | self.action_predictor = MLP(config.n_embd, action_dim, hidden_dims=[256,256])
260 |
261 | # Key state predictors. By default, we only use one predictor which takes
262 | # features from one attention layer.
263 | if 'cot' in self.model_type:
264 | key_state_predictors = []
265 | for _ in self.key_state_loss:
266 | key_state_predictor = MLP(
267 | config.n_embd, self.state_dim, hidden_dims=[int(self.cot_decoder)])
268 | key_state_predictors.append(key_state_predictor)
269 | # Register all the key state predictors.
270 | self.key_state_predictors = nn.ModuleList(key_state_predictors)
271 |
272 | self.apply(self._init_weights)
273 | print(f"Total # of parameters: {sum(p.numel() for p in self.parameters())}")
274 |
275 | def _init_weights(self, module):
276 | if isinstance(module, (nn.Linear, nn.Embedding)):
277 | module.weight.data.normal_(mean=0.0, std=0.02)
278 | if isinstance(module, nn.Linear) and module.bias is not None:
279 | module.bias.data.zero_()
280 | elif isinstance(module, nn.LayerNorm):
281 | module.bias.data.zero_()
282 | module.weight.data.fill_(1.0)
283 |
284 | # Given state (and action) history, predict actions (and key states as CoT).
285 | # `timesteps` is used for the global+local position embedding design similar
286 | # to the one in Decision Transformer. `key_state_mask` is used so that the
287 | # (all-to-all) key state query tokens can attend to later tokens.
288 | def forward(self, states, timesteps, actions=None, key_state_mask=None):
289 | B, T = states.shape[0], states.shape[1]
290 | state_embeddings = self.state_encoder(states)
291 |
292 | # Embeddings for state (action, and key state query) tokens.
293 | token_embeddings = torch.zeros([B, self.block_size, self.config.n_embd],
294 | dtype=torch.float32, device=states.device)
295 |
296 | # If using action history as inputs: during training, all actions are
297 | # specified; during inference, only actions in the past are specified.
298 | # That is, the first action prediction has no action history as inputs.
299 | if '+a' in self.model_type:
300 | token_embeddings[:,:T*2:2,:] = state_embeddings
301 | if actions is not None:
302 | # Assume the last action is not used as inputs during training.
303 | token_embeddings[:,1:T*2-1:2,:] = self.action_encoder(actions[:,:T-1])
304 |
305 | else:
306 | token_embeddings[:,:T,:] = state_embeddings
307 |
308 | # Set up position embeddings similar to that in Decision Transformer.
309 | global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0)
310 | timesteps_rp = torch.repeat_interleave(timesteps[:, None], self.config.n_embd, dim=-1)
311 | global_pos_emb = torch.gather(
312 | global_pos_emb, 1, timesteps_rp.long()) # BS x 1 x D
313 | local_pos_emb = torch.repeat_interleave(self.local_pos_emb, 2, dim=1) \
314 | if '+a' in self.model_type else self.local_pos_emb
315 |
316 | x = token_embeddings + global_pos_emb + local_pos_emb
317 | if 'cot' in self.model_type:
318 | key_state_embeddings = self.key_state_pos_emb.repeat(B, 1, 1)
319 | x = torch.cat([key_state_embeddings, x], 1)
320 |
321 | x = self.drop(x)
322 | x, intermediate_feats = self.blocks(x, key_state_mask=key_state_mask)
323 | x = self.ln(x)
324 | act_preds = self.action_predictor(x)
325 |
326 | if 'cot' in self.model_type:
327 | key_state_preds = []
328 | for idx, loss_layer_idx in enumerate([int(c) for c in self.key_state_loss]):
329 | key_state_preds.append(self.key_state_predictors[idx](
330 | intermediate_feats[loss_layer_idx][:,:self.len_key_states]))
331 |
332 | # Get rid of dims for key state query tokens.
333 | act_preds = torch.split(
334 | act_preds, [self.len_key_states, self.block_size], dim=1)[1]
335 | else:
336 | key_state_preds = None
337 |
338 | # Get rid of dims for action tokens.
339 | if '+a' in self.model_type:
340 | # Remove the extra tokens when in eval mode.
341 | act_preds = act_preds[:,:T*2:2]
342 |
343 | return act_preds, key_state_preds
344 |
345 | def configure_adamw_optimizers(self, config):
346 | """
347 | This long function is unfortunately doing something very simple and is being very defensive:
348 | We are separating out all parameters of the model into two buckets: those that will experience
349 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
350 | We are then returning the PyTorch optimizer object.
351 | """
352 |
353 | # separate out all parameters to those that will and won't experience regularizing weight decay
354 | decay = set()
355 | no_decay = set()
356 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d)
357 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
358 | for mn, m in self.named_modules():
359 | for pn, _ in m.named_parameters():
360 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
361 |
362 | if pn.endswith('bias'):
363 | # all biases will not be decayed
364 | no_decay.add(fpn)
365 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
366 | # weights of whitelist modules will be weight decayed
367 | decay.add(fpn)
368 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
369 | # weights of blacklist modules will NOT be weight decayed
370 | no_decay.add(fpn)
371 |
372 | # special case the position embedding parameter in the root GPT module as not decayed
373 | no_decay.add('local_pos_emb')
374 | no_decay.add('global_pos_emb')
375 | if 'cot' in self.model_type:
376 | no_decay.add('key_state_pos_emb')
377 |
378 | # validate that we considered every parameter
379 | param_dict = {pn: p for pn, p in self.named_parameters()}
380 | inter_params = decay & no_decay
381 | union_params = decay | no_decay
382 | assert len(inter_params) == 0, \
383 | "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
384 | assert len(param_dict.keys() - union_params) == 0, \
385 | "parameters %s were not separated into either decay/no_decay set!" \
386 | % (str(param_dict.keys() - union_params), )
387 |
388 | # create the pytorch optimizer object
389 | optim_groups = [
390 | {"params": [param_dict[pn] for pn in sorted(list(decay))],
391 | "weight_decay": config['weight_decay']},
392 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))],
393 | "weight_decay": 0.0},
394 | ]
395 | optimizer = torch.optim.AdamW(
396 | optim_groups,
397 | lr=config['init_lr'],
398 | betas=(config['beta1'], config['beta2'])
399 | )
400 | return optimizer
401 |
--------------------------------------------------------------------------------