├── DT_TIT ├── DT.png ├── DT_TIT.png ├── LICENSE.md ├── README.md ├── atari │ ├── LICENSE │ ├── conda_env.yml │ ├── create_dataset.py │ ├── fixed_replay_buffer.py │ ├── log │ │ └── read_log.py │ ├── mingpt │ │ ├── __init__.py │ │ ├── model_atari.py │ │ ├── model_tit.py │ │ ├── trainer_atari.py │ │ └── utils.py │ ├── readme-atari.md │ ├── run.sh │ └── run_dt_atari.py └── gym │ ├── conda_env.yml │ ├── data │ └── download_d4rl_datasets.py │ ├── decision_transformer │ ├── envs │ │ ├── assets │ │ │ └── reacher_2d.xml │ │ └── reacher_2d.py │ ├── evaluation │ │ └── evaluate_episodes.py │ ├── models │ │ ├── TIT.py │ │ ├── decision_transformer.py │ │ ├── mlp_bc.py │ │ ├── model.py │ │ └── trajectory_gpt2.py │ └── training │ │ ├── act_trainer.py │ │ ├── seq_trainer.py │ │ └── trainer.py │ ├── experiment.py │ ├── log │ ├── halfcheetah-medium-dt_log.txt │ ├── halfcheetah-medium-replay-dt_log.txt │ ├── halfcheetah-medium-replay-tit_log.txt │ ├── halfcheetah-medium-tit_log.txt │ ├── hopper-medium-dt_log.txt │ ├── hopper-medium-replay-dt_log.txt │ ├── hopper-medium-replay-tit_log.txt │ ├── hopper-medium-tit_log.txt │ ├── read_log.py │ ├── walker2d-medium-dt_log.txt │ ├── walker2d-medium-replay-dt_log.txt │ ├── walker2d-medium-replay-tit_log.txt │ └── walker2d-medium-tit_log.txt │ ├── readme-gym.md │ └── run.sh ├── PPO_TIT_and_CQL_TIT ├── .gitignore ├── hyperparameter_final_enhanced_tit_cql.yaml ├── hyperparameter_final_enhanced_tit_ppo.yaml ├── hyperparameter_final_vanilla_tit_ppo.yaml ├── hyperparameter_sampling.py ├── main.py ├── network.py ├── offline_main.py ├── offline_network.py ├── readme.md ├── requirements_offlineRL_CQL.txt ├── requirements_onlineRL_PPO.txt ├── results_offline.py ├── results_online.py ├── run.sh └── utils.py ├── RL_Foundation_BabyAI_including_DT_GATO_and_TIT ├── .gitignore ├── config │ ├── algo │ │ ├── bc.yaml │ │ ├── bc_tit.yaml │ │ ├── dt.yaml │ │ ├── dt_tit.yaml │ │ ├── mgdt.yaml │ │ └── mgdt_wo_sample.yaml │ ├── default.yaml │ └── env │ │ ├── BabyAI-BossLevel-v0.yaml │ │ ├── BabyAI-GoToObj-v0.yaml │ │ ├── BabyAI-GoToRedBall-v0.yaml │ │ ├── BabyAI-GoToRedBallGrey-v0.yaml │ │ ├── BabyAI-GoToSeq-v0.yaml │ │ └── Mix-GoTo.yaml ├── demos │ └── BabyAI-BossLevel-v0.pkl ├── evaluation.py ├── main.py ├── network.py ├── trainner.py └── utils.py ├── RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT ├── config │ ├── algo │ │ ├── bc.yaml │ │ ├── dt.yaml │ │ ├── mgdt.yaml │ │ └── mgdt_wo_sample.yaml │ ├── default.yaml │ └── env │ │ ├── antmaze_large_diverse.yaml │ │ ├── antmaze_medium.yaml │ │ ├── antmaze_medium_diverse.yaml │ │ ├── antmaze_umaze.yaml │ │ ├── antmaze_umaze_diverse.yaml │ │ ├── door_cloned.yaml │ │ ├── halfcheetah_medium.yaml │ │ ├── halfcheetah_medium_expert.yaml │ │ ├── halfcheetah_medium_replay.yaml │ │ ├── hammer_cloned.yaml │ │ ├── hopper_medium.yaml │ │ ├── hopper_medium_expert.yaml │ │ ├── hopper_medium_replay.yaml │ │ ├── pen_cloned.yaml │ │ ├── relocate_cloned.yaml │ │ ├── walker2d_medium.yaml │ │ ├── walker2d_medium_expert.yaml │ │ └── walker2d_medium_replay.yaml ├── download_d4rl_datasets.py ├── evaluation.py ├── main.py ├── network.py ├── requirements.txt ├── trainner.py └── utils.py └── readme.md /DT_TIT/DT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maohangyu/TIT_open_source/e1bc0aab48166dfa80f2520f9a03b5b7a9392df8/DT_TIT/DT.png -------------------------------------------------------------------------------- /DT_TIT/DT_TIT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maohangyu/TIT_open_source/e1bc0aab48166dfa80f2520f9a03b5b7a9392df8/DT_TIT/DT_TIT.png -------------------------------------------------------------------------------- /DT_TIT/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Decision Transformer (Decision Transformer: Reinforcement Learning via Sequence Modeling) Authors (https://arxiv.org/abs/2106.01345) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /DT_TIT/README.md: -------------------------------------------------------------------------------- 1 | # Transformer in Transformer as Backbone for Deep Reinforcement Learning 2 | 3 | 4 | 5 | ## Overview 6 | This is the official implementation of Transformer in Transformer (TIT) for Deep Reinforcement Learning. 7 | Contains scripts to reproduce experiments of the offline-SL (i.e., Decision Transformer and DT_TIT). 8 | 9 | 10 | 11 | ## Network Architecture 12 | 13 | Decision Transformer (DT): 14 | 15 | ![image info](./DT.png) 16 | 17 | DT vs DT_TIT: 18 | 19 | ![image info](./DT_TIT.png) 20 | 21 | 22 | 23 | ## Instructions 24 | 25 | Our implementation is highly based on [Decision Transformer (DT)](https://github.com/kzl/decision-transformer). 26 | 27 | We provide code in two sub-directories: 28 | 1. `gym` containing code for d4rl-MoJoCo experiments. Run `bash run.sh`, and it can reproduce our results shown in the paper. 29 | 2. `atari` containing code for d4rl-Atari experiments. Run `bash run.sh`, and it can run normally, but the results may be poor (our GPU was often OOM, so we didn't try Atari at all). 30 | 31 | 32 | 33 | ## Cite 34 | 35 | Please cite our paper as: 36 | ``` 37 | @article{mao2022TIT, 38 | title={Transformer in Transformer as Backbone for Deep Reinforcement Learning}, 39 | author={Mao, Hangyu and Zhao, Rui and Chen, Hao and Hao, Jianye and Chen, Yiqun and Li, Dong and Zhang, Junge and Xiao, Zhen}, 40 | journal={arXiv preprint arXiv:2212.14538}, 41 | year={2022} 42 | } 43 | ``` 44 | 45 | 46 | ## Acknowledgements 47 | 48 | Please cite Decision Transformer as: 49 | ``` 50 | @article{chen2021decisiontransformer, 51 | title={Decision Transformer: Reinforcement Learning via Sequence Modeling}, 52 | author={Lili Chen and Kevin Lu and Aravind Rajeswaran and Kimin Lee and Aditya Grover and Michael Laskin and Pieter Abbeel and Aravind Srinivas and Igor Mordatch}, 53 | journal={arXiv preprint arXiv:2106.01345}, 54 | year={2021} 55 | } 56 | ``` 57 | 58 | ## License 59 | 60 | MIT 61 | -------------------------------------------------------------------------------- /DT_TIT/atari/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /DT_TIT/atari/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: decision-transformer-atari 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - python=3.7.9 6 | - pytorch=1.2 7 | - cudatoolkit=10. 8 | - numpy 9 | - psutil 10 | - opencv 11 | - pip 12 | - pip: 13 | - atari-py 14 | - pyprind 15 | - tensorflow-gpu>=1.13 16 | - absl-py 17 | - atari-py 18 | - gin-config 19 | - gym 20 | - tqdm 21 | - blosc 22 | - git+https://github.com/google/dopamine.git 23 | -------------------------------------------------------------------------------- /DT_TIT/atari/create_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fixed_replay_buffer import FixedReplayBuffer 3 | 4 | def create_dataset(num_buffers, num_steps, game, data_dir_prefix, trajectories_per_buffer): 5 | # -- load data from memory (make more efficient) 6 | obss = [] 7 | actions = [] 8 | returns = [0] 9 | done_idxs = [] 10 | stepwise_returns = [] 11 | 12 | transitions_per_buffer = np.zeros(50, dtype=int) 13 | num_trajectories = 0 14 | while len(obss) < num_steps: 15 | buffer_num = np.random.choice(np.arange(50 - num_buffers, 50), 1)[0] 16 | i = transitions_per_buffer[buffer_num] 17 | print('loading from buffer %d which has %d already loaded' % (buffer_num, i)) 18 | frb = FixedReplayBuffer( 19 | data_dir=data_dir_prefix + game + '/1/replay_logs', 20 | replay_suffix=buffer_num, 21 | observation_shape=(84, 84), 22 | stack_size=4, 23 | update_horizon=1, 24 | gamma=0.99, 25 | observation_dtype=np.uint8, 26 | batch_size=32, 27 | replay_capacity=100000) 28 | if frb._loaded_buffers: 29 | done = False 30 | curr_num_transitions = len(obss) 31 | trajectories_to_load = trajectories_per_buffer 32 | while not done: 33 | states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i]) 34 | states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) 35 | obss += [states] 36 | actions += [ac[0]] 37 | stepwise_returns += [ret[0]] 38 | if terminal[0]: 39 | done_idxs += [len(obss)] 40 | returns += [0] 41 | if trajectories_to_load == 0: 42 | done = True 43 | else: 44 | trajectories_to_load -= 1 45 | returns[-1] += ret[0] 46 | i += 1 47 | if i >= 100000: 48 | obss = obss[:curr_num_transitions] 49 | actions = actions[:curr_num_transitions] 50 | stepwise_returns = stepwise_returns[:curr_num_transitions] 51 | returns[-1] = 0 52 | i = transitions_per_buffer[buffer_num] 53 | done = True 54 | num_trajectories += (trajectories_per_buffer - trajectories_to_load) 55 | transitions_per_buffer[buffer_num] = i 56 | print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories)) 57 | 58 | actions = np.array(actions) 59 | returns = np.array(returns) 60 | stepwise_returns = np.array(stepwise_returns) 61 | done_idxs = np.array(done_idxs) 62 | 63 | # -- create reward-to-go dataset 64 | start_index = 0 65 | rtg = np.zeros_like(stepwise_returns) 66 | for i in done_idxs: 67 | i = int(i) 68 | curr_traj_returns = stepwise_returns[start_index:i] 69 | for j in range(i-1, start_index-1, -1): # start from i-1 70 | rtg_j = curr_traj_returns[j-start_index:i-start_index] 71 | rtg[j] = sum(rtg_j) 72 | start_index = i 73 | print('max rtg is %d' % max(rtg)) 74 | 75 | # -- create timestep dataset 76 | start_index = 0 77 | timesteps = np.zeros(len(actions)+1, dtype=int) 78 | for i in done_idxs: 79 | i = int(i) 80 | timesteps[start_index:i+1] = np.arange(i+1 - start_index) 81 | start_index = i+1 82 | print('max timestep is %d' % max(timesteps)) 83 | 84 | return obss, actions, returns, done_idxs, rtg, timesteps 85 | -------------------------------------------------------------------------------- /DT_TIT/atari/fixed_replay_buffer.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/google-research/batch_rl/blob/master/batch_rl/fixed_replay/replay_memory/fixed_replay_buffer.py 2 | 3 | import collections 4 | from concurrent import futures 5 | from dopamine.replay_memory import circular_replay_buffer 6 | import numpy as np 7 | import tensorflow.compat.v1 as tf 8 | import gin 9 | 10 | gfile = tf.gfile 11 | 12 | STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX 13 | 14 | class FixedReplayBuffer(object): 15 | """Object composed of a list of OutofGraphReplayBuffers.""" 16 | 17 | def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg 18 | """Initialize the FixedReplayBuffer class. 19 | Args: 20 | data_dir: str, log Directory from which to load the replay buffer. 21 | replay_suffix: int, If not None, then only load the replay buffer 22 | corresponding to the specific suffix in data directory. 23 | *args: Arbitrary extra arguments. 24 | **kwargs: Arbitrary keyword arguments. 25 | """ 26 | self._args = args 27 | self._kwargs = kwargs 28 | self._data_dir = data_dir 29 | self._loaded_buffers = False 30 | self.add_count = np.array(0) 31 | self._replay_suffix = replay_suffix 32 | if not self._loaded_buffers: 33 | if replay_suffix is not None: 34 | assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' 35 | self.load_single_buffer(replay_suffix) 36 | else: 37 | self._load_replay_buffers(num_buffers=50) 38 | 39 | def load_single_buffer(self, suffix): 40 | """Load a single replay buffer.""" 41 | replay_buffer = self._load_buffer(suffix) 42 | if replay_buffer is not None: 43 | self._replay_buffers = [replay_buffer] 44 | self.add_count = replay_buffer.add_count 45 | self._num_replay_buffers = 1 46 | self._loaded_buffers = True 47 | 48 | def _load_buffer(self, suffix): 49 | """Loads a OutOfGraphReplayBuffer replay buffer.""" 50 | try: 51 | # pytype: disable=attribute-error 52 | replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer( 53 | *self._args, **self._kwargs) 54 | replay_buffer.load(self._data_dir, suffix) 55 | tf.logging.info('Loaded replay buffer ckpt {} from {}'.format( 56 | suffix, self._data_dir)) 57 | # pytype: enable=attribute-error 58 | return replay_buffer 59 | except tf.errors.NotFoundError: 60 | return None 61 | 62 | def _load_replay_buffers(self, num_buffers=None): 63 | """Loads multiple checkpoints into a list of replay buffers.""" 64 | if not self._loaded_buffers: # pytype: disable=attribute-error 65 | ckpts = gfile.ListDirectory(self._data_dir) # pytype: disable=attribute-error 66 | # Assumes that the checkpoints are saved in a format CKPT_NAME.{SUFFIX}.gz 67 | ckpt_counters = collections.Counter( 68 | [name.split('.')[-2] for name in ckpts]) 69 | # Should contain the files for add_count, action, observation, reward, 70 | # terminal and invalid_range 71 | ckpt_suffixes = [x for x in ckpt_counters if ckpt_counters[x] in [6, 7]] 72 | if num_buffers is not None: 73 | ckpt_suffixes = np.random.choice( 74 | ckpt_suffixes, num_buffers, replace=False) 75 | self._replay_buffers = [] 76 | # Load the replay buffers in parallel 77 | with futures.ThreadPoolExecutor( 78 | max_workers=num_buffers) as thread_pool_executor: 79 | replay_futures = [thread_pool_executor.submit( 80 | self._load_buffer, suffix) for suffix in ckpt_suffixes] 81 | for f in replay_futures: 82 | replay_buffer = f.result() 83 | if replay_buffer is not None: 84 | self._replay_buffers.append(replay_buffer) 85 | self.add_count = max(replay_buffer.add_count, self.add_count) 86 | self._num_replay_buffers = len(self._replay_buffers) 87 | if self._num_replay_buffers: 88 | self._loaded_buffers = True 89 | 90 | def get_transition_elements(self): 91 | return self._replay_buffers[0].get_transition_elements() 92 | 93 | def sample_transition_batch(self, batch_size=None, indices=None): 94 | buffer_index = np.random.randint(self._num_replay_buffers) 95 | return self._replay_buffers[buffer_index].sample_transition_batch( 96 | batch_size=batch_size, indices=indices) 97 | 98 | def load(self, *args, **kwargs): # pylint: disable=unused-argument 99 | pass 100 | 101 | def reload_buffer(self, num_buffers=None): 102 | self._loaded_buffers = False 103 | self._load_replay_buffers(num_buffers) 104 | 105 | def save(self, *args, **kwargs): # pylint: disable=unused-argument 106 | pass 107 | 108 | def add(self, *args, **kwargs): # pylint: disable=unused-argument 109 | pass -------------------------------------------------------------------------------- /DT_TIT/atari/log/read_log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # how to get the score of an expert policy? report the score in the last epoch or the best score during training? 5 | # https://github.com/kzl/decision-transformer/issues/16#issuecomment-890423427 6 | # https://github.com/kzl/decision-transformer/issues/46#issuecomment-1214102788 7 | def normalize_atari_score(env, score): 8 | Random = { 9 | 'Breakout': 2, 10 | 'Qbert': 164, 11 | 'Pong': -21, 12 | 'Seaquest': 68 13 | } 14 | Gamer = { 15 | 'Breakout': 30, 16 | 'Qbert': 13455, 17 | 'Pong': 15, 18 | 'Seaquest': 42055 19 | } 20 | min_score = Random[env] 21 | max_score = Gamer[env] 22 | normalized_score = 100.0 * (score - min_score) / (max_score - min_score) 23 | return normalized_score 24 | 25 | 26 | def read_target_log(file_name, info_type='eval return:'): 27 | value_list = [] 28 | with open(file_name, 'r') as fp: 29 | for line in fp.readlines(): 30 | if info_type in line: 31 | info, value = line.split(info_type) 32 | value = float(value) 33 | value_list.append(value) 34 | return value_list 35 | 36 | 37 | def read_all_log(): 38 | for env in ['Breakout', 'Pong']: 39 | for seed in ['123', '231', '312']: 40 | for algo in ['dt', 'tit']: 41 | file_name = f'{env}-{algo}_log_{seed}.txt' 42 | print('='*50, 'file_name ==>', file_name) 43 | value_list = read_target_log(file_name) 44 | assert len(value_list) == 5 45 | print('value_list ==>', value_list) 46 | return_mean = np.mean(value_list) 47 | return_std = np.std(value_list) 48 | print('return_mean/std ==>', return_mean, return_std) 49 | normalized_return_mean = normalize_atari_score(env, score=return_mean) 50 | normalized_return_std = normalize_atari_score(env, score=return_std) 51 | print('normalized_return_mean/std ==>', normalized_return_mean, normalized_return_std) 52 | 53 | -------------------------------------------------------------------------------- /DT_TIT/atari/mingpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maohangyu/TIT_open_source/e1bc0aab48166dfa80f2520f9a03b5b7a9392df8/DT_TIT/atari/mingpt/__init__.py -------------------------------------------------------------------------------- /DT_TIT/atari/mingpt/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | """ 10 | 11 | import random 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | 17 | def set_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | def top_k_logits(logits, k): 24 | v, ix = torch.topk(logits, k) 25 | out = logits.clone() 26 | out[out < v[:, [-1]]] = -float('Inf') 27 | return out 28 | 29 | @torch.no_grad() 30 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None, actions=None, rtgs=None, timesteps=None): 31 | """ 32 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 33 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 34 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 35 | of block_size, unlike an RNN that has an infinite context window. 36 | """ 37 | block_size = model.get_block_size() 38 | model.eval() 39 | for k in range(steps): 40 | # x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 41 | x_cond = x if x.size(1) <= block_size//3 else x[:, -block_size//3:] # crop context if needed 42 | if actions is not None: 43 | actions = actions if actions.size(1) <= block_size//3 else actions[:, -block_size//3:] # crop context if needed 44 | rtgs = rtgs if rtgs.size(1) <= block_size//3 else rtgs[:, -block_size//3:] # crop context if needed 45 | logits, _ = model(x_cond, actions=actions, targets=None, rtgs=rtgs, timesteps=timesteps) 46 | # pluck the logits at the final step and scale by temperature 47 | logits = logits[:, -1, :] / temperature 48 | # optionally crop probabilities to only the top k options 49 | if top_k is not None: 50 | logits = top_k_logits(logits, top_k) 51 | # apply softmax to convert to probabilities 52 | probs = F.softmax(logits, dim=-1) 53 | # sample from the distribution or take the most likely 54 | if sample: 55 | ix = torch.multinomial(probs, num_samples=1) 56 | else: 57 | _, ix = torch.topk(probs, k=1, dim=-1) 58 | # append to the sequence and continue 59 | # x = torch.cat((x, ix), dim=1) 60 | x = ix 61 | 62 | return x 63 | -------------------------------------------------------------------------------- /DT_TIT/atari/readme-atari.md: -------------------------------------------------------------------------------- 1 | 2 | # Atari 3 | 4 | We build our Atari implementation on top of [minGPT](https://github.com/karpathy/minGPT) and benchmark our results on the [DQN-replay](https://github.com/google-research/batch_rl) dataset. 5 | 6 | ## Installation 7 | 8 | Dependencies can be installed with the following command: 9 | 10 | ``` 11 | conda env create -f conda_env.yml 12 | ``` 13 | 14 | ## Downloading datasets 15 | 16 | Create a directory for the dataset and load the dataset using [gsutil](https://cloud.google.com/storage/docs/gsutil_install#install). Replace `[DIRECTORY_NAME]` and `[GAME_NAME]` accordingly (e.g., `./dqn_replay` for `[DIRECTORY_NAME]` and `Breakout` for `[GAME_NAME]`) 17 | ``` 18 | mkdir [DIRECTORY_NAME] 19 | gsutil -m cp -R gs://atari-replay-datasets/dqn/[GAME_NAME] [DIRECTORY_NAME] 20 | ``` 21 | 22 | ## Example usage 23 | 24 | Scripts to reproduce our Decision Transformer results can be found in `run.sh`. 25 | 26 | ``` 27 | python run_dt_atari.py --seed 123 --block_size 90 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 --data_dir_prefix [DIRECTORY_NAME] 28 | ``` 29 | -------------------------------------------------------------------------------- /DT_TIT/atari/run.sh: -------------------------------------------------------------------------------- 1 | #nohup python run_dt_atari.py --seed 123 --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 > ./log/Breakout-dt_log_123.txt 2>&1 & 2 | #nohup python run_dt_atari.py --seed 231 --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 > ./log/Breakout-dt_log_231.txt 2>&1 & 3 | #nohup python run_dt_atari.py --seed 312 --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 > ./log/Breakout-dt_log_312.txt 2>&1 & 4 | 5 | #nohup python run_dt_atari.py --seed 123 --context_length 50 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 > ./log/Pong-dt_log_123.txt 2>&1 & 6 | #nohup python run_dt_atari.py --seed 231 --context_length 50 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 > ./log/Pong-dt_log_231.txt 2>&1 & 7 | #nohup python run_dt_atari.py --seed 312 --context_length 50 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 > ./log/Pong-dt_log_312.txt 2>&1 & 8 | 9 | 10 | 11 | #nohup python run_dt_atari.py --seed 123 --context_length 30 --epochs 5 --model_type 'reward_conditioned_tit' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 > ./log/Breakout-tit_log_123.txt 2>&1 & 12 | #nohup python run_dt_atari.py --seed 231 --context_length 30 --epochs 5 --model_type 'reward_conditioned_tit' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 > ./log/Breakout-tit_log_231.txt 2>&1 & 13 | #nohup python run_dt_atari.py --seed 312 --context_length 30 --epochs 5 --model_type 'reward_conditioned_tit' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 > ./log/Breakout-tit_log_312.txt 2>&1 & 14 | 15 | #nohup python run_dt_atari.py --seed 123 --context_length 50 --epochs 5 --model_type 'reward_conditioned_tit' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 > ./log/Pong-tit_log_123.txt 2>&1 & 16 | #nohup python run_dt_atari.py --seed 231 --context_length 50 --epochs 5 --model_type 'reward_conditioned_tit' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 > ./log/Pong-tit_log_231.txt 2>&1 & 17 | #nohup python run_dt_atari.py --seed 312 --context_length 50 --epochs 5 --model_type 'reward_conditioned_tit' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 > ./log/Pong-tit_log_312.txt 2>&1 & 18 | 19 | 20 | 21 | 22 | # Decision Transformer (DT) 23 | #for seed in 123 231 312 24 | #do 25 | # python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 26 | #done 27 | # 28 | #for seed in 123 231 312 29 | #do 30 | # python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Qbert' --batch_size 128 31 | #done 32 | # 33 | #for seed in 123 231 312 34 | #do 35 | # python run_dt_atari.py --seed $seed --context_length 50 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 36 | #done 37 | # 38 | #for seed in 123 231 312 39 | #do 40 | # python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Seaquest' --batch_size 128 41 | #done 42 | # 43 | ## Behavior Cloning (BC) 44 | #for seed in 123 231 312 45 | #do 46 | # python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 47 | #done 48 | # 49 | #for seed in 123 231 312 50 | #do 51 | # python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Qbert' --batch_size 128 52 | #done 53 | # 54 | #for seed in 123 231 312 55 | #do 56 | # python run_dt_atari.py --seed $seed --context_length 50 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 57 | #done 58 | # 59 | #for seed in 123 231 312 60 | #do 61 | # python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Seaquest' --batch_size 128 62 | #done -------------------------------------------------------------------------------- /DT_TIT/atari/run_dt_atari.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from mingpt.utils import set_seed 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from mingpt.model_atari import GPT, GPTConfig 6 | from mingpt.model_tit import GPT_TIT, GPTConfig_TIT 7 | from mingpt.trainer_atari import Trainer, TrainerConfig 8 | import torch 9 | import argparse 10 | from create_dataset import create_dataset 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--seed', type=int, default=123) 14 | parser.add_argument('--context_length', type=int, default=30) 15 | parser.add_argument('--epochs', type=int, default=5) 16 | parser.add_argument('--model_type', type=str, default='reward_conditioned') 17 | parser.add_argument('--num_steps', type=int, default=500000) 18 | parser.add_argument('--num_buffers', type=int, default=50) 19 | parser.add_argument('--game', type=str, default='Breakout') 20 | parser.add_argument('--batch_size', type=int, default=128) 21 | # 22 | parser.add_argument('--trajectories_per_buffer', type=int, default=10, help='Number of trajectories to sample from each of the buffers.') 23 | parser.add_argument('--data_dir_prefix', type=str, default='./dqn_replay/') 24 | args = parser.parse_args() 25 | 26 | set_seed(args.seed) 27 | 28 | class StateActionReturnDataset(Dataset): 29 | 30 | def __init__(self, data, block_size, actions, done_idxs, rtgs, timesteps): 31 | self.block_size = block_size 32 | self.vocab_size = max(actions) + 1 33 | self.data = data 34 | self.actions = actions 35 | self.done_idxs = done_idxs 36 | self.rtgs = rtgs 37 | self.timesteps = timesteps 38 | 39 | def __len__(self): 40 | return len(self.data) - self.block_size 41 | 42 | def __getitem__(self, idx): 43 | block_size = self.block_size // 3 44 | done_idx = idx + block_size 45 | for i in self.done_idxs: 46 | if i > idx: # first done_idx greater than idx 47 | done_idx = min(int(i), done_idx) 48 | break 49 | idx = done_idx - block_size 50 | states = torch.tensor(np.array(self.data[idx:done_idx]), dtype=torch.float32).reshape(block_size, -1) # (block_size, 4*84*84) 51 | states = states / 255. 52 | actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) 53 | rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) 54 | timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1) 55 | 56 | return states, actions, rtgs, timesteps 57 | 58 | obss, actions, returns, done_idxs, rtgs, timesteps = create_dataset(args.num_buffers, args.num_steps, args.game, args.data_dir_prefix, args.trajectories_per_buffer) 59 | 60 | # set up logging 61 | logging.basicConfig( 62 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 63 | datefmt="%m/%d/%Y %H:%M:%S", 64 | level=logging.INFO, 65 | ) 66 | 67 | train_dataset = StateActionReturnDataset(obss, args.context_length*3, actions, done_idxs, rtgs, timesteps) 68 | 69 | if args.model_type == 'reward_conditioned' or args.model_type == 'naive': 70 | mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 71 | n_layer=6, n_head=8, n_embd=128, model_type=args.model_type, max_timestep=max(timesteps)) 72 | model = GPT(mconf) 73 | elif args.model_type == 'reward_conditioned_tit': 74 | mconf = GPTConfig_TIT(train_dataset.vocab_size, train_dataset.block_size, 75 | n_layer=6, n_head=8, n_embd=128, model_type=args.model_type, max_timestep=max(timesteps)) 76 | model = GPT_TIT(mconf) 77 | 78 | # initialize a trainer instance and kick off training 79 | epochs = args.epochs 80 | tconf = TrainerConfig(max_epochs=epochs, batch_size=args.batch_size, learning_rate=6e-4, 81 | lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*args.context_length*3, 82 | num_workers=4, seed=args.seed, model_type=args.model_type, game=args.game, max_timestep=max(timesteps)) 83 | trainer = Trainer(model, train_dataset, None, tconf) 84 | 85 | trainer.train() 86 | -------------------------------------------------------------------------------- /DT_TIT/gym/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: decision-transformer-gym 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - python=3.8.5 6 | - anaconda 7 | - cudatoolkit=10. 8 | - numpy 9 | - pip 10 | - pip: 11 | - gym==0.18.3 12 | - mujoco-py<2.2,>=2.1 #==2.0.2.13 we use mojoco2.1 as https://github.com/openai/mujoco-py#install-and-use-mujoco-py 13 | - numpy==1.20.3 14 | - torch==1.8.1 15 | - transformers==4.5.1 16 | - wandb==0.9.1 17 | -------------------------------------------------------------------------------- /DT_TIT/gym/data/download_d4rl_datasets.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | import collections 5 | import pickle 6 | 7 | import d4rl 8 | 9 | 10 | datasets = [] 11 | 12 | for env_name in ['halfcheetah', 'hopper', 'walker2d']: 13 | for dataset_type in ['medium', 'medium-replay', 'expert']: 14 | name = f'{env_name}-{dataset_type}-v2' 15 | env = gym.make(name) 16 | dataset = env.get_dataset() 17 | 18 | N = dataset['rewards'].shape[0] 19 | data_ = collections.defaultdict(list) 20 | 21 | use_timeouts = False 22 | if 'timeouts' in dataset: 23 | use_timeouts = True 24 | 25 | episode_step = 0 26 | paths = [] 27 | for i in range(N): 28 | done_bool = bool(dataset['terminals'][i]) 29 | if use_timeouts: 30 | final_timestep = dataset['timeouts'][i] 31 | else: 32 | final_timestep = (episode_step == 1000-1) 33 | for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']: 34 | data_[k].append(dataset[k][i]) 35 | if done_bool or final_timestep: 36 | episode_step = 0 37 | episode_data = {} 38 | for k in data_: 39 | episode_data[k] = np.array(data_[k]) 40 | paths.append(episode_data) 41 | data_ = collections.defaultdict(list) 42 | episode_step += 1 43 | 44 | returns = np.array([np.sum(p['rewards']) for p in paths]) 45 | num_samples = np.sum([p['rewards'].shape[0] for p in paths]) 46 | print(f'Number of samples collected: {num_samples}') 47 | print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}') 48 | 49 | with open(f'{name}.pkl', 'wb') as f: 50 | pickle.dump(paths, f) 51 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/envs/assets/reacher_2d.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 34 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/envs/reacher_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | 5 | import os 6 | 7 | 8 | class Reacher2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): 9 | 10 | def __init__(self): 11 | self.fingertip_sid = 0 12 | self.target_bid = 0 13 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 14 | mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/reacher_2d.xml', 15) 15 | self.fingertip_sid = self.sim.model.site_name2id('fingertip') 16 | self.target_bid = self.sim.model.body_name2id('target') 17 | utils.EzPickle.__init__(self) 18 | 19 | def step(self, action): 20 | action = np.clip(action, -1.0, 1.0) 21 | self.do_simulation(action, self.frame_skip) 22 | tip = self.data.site_xpos[self.fingertip_sid][:2] 23 | tar = self.data.body_xpos[self.target_bid][:2] 24 | dist = np.sum(np.abs(tip - tar)) 25 | reward_dist = 0. # - 0.1 * dist 26 | reward_ctrl = 0.0 27 | reward_bonus = 1.0 if dist < 0.1 else 0.0 28 | reward = reward_bonus + reward_ctrl + reward_dist 29 | done = False 30 | ob = self._get_obs() 31 | return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl, reward_bonus=reward_bonus) 32 | 33 | def _get_obs(self): 34 | theta = self.data.qpos.ravel() 35 | tip = self.data.site_xpos[self.fingertip_sid][:2] 36 | tar = self.data.body_xpos[self.target_bid][:2] 37 | return np.concatenate([ 38 | # self.data.qpos.flat, 39 | np.sin(theta), 40 | np.cos(theta), 41 | self.dt * self.data.qvel.ravel(), 42 | tip, 43 | tar, 44 | tip-tar, 45 | ]) 46 | 47 | def reset_model(self): 48 | # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos 49 | # qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 50 | qpos = self.np_random.uniform(low=-2.0, high=2.0, size=self.model.nq) 51 | qvel = self.init_qvel * 0.0 52 | while True: 53 | self.goal = self.np_random.uniform(low=-1.5, high=1.5, size=2) 54 | if np.linalg.norm(self.goal) <= 1.0 and np.linalg.norm(self.goal) >= 0.5: 55 | break 56 | self.set_state(qpos, qvel) 57 | self.model.body_pos[self.target_bid][:2] = self.goal 58 | self.sim.forward() 59 | return self._get_obs() 60 | 61 | def viewer_setup(self): 62 | self.viewer.cam.distance = self.model.stat.extent * 5.0 63 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/evaluation/evaluate_episodes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def evaluate_episode( 6 | env, 7 | state_dim, 8 | act_dim, 9 | model, 10 | max_ep_len=1000, 11 | device='cuda', 12 | target_return=None, 13 | mode='normal', 14 | state_mean=0., 15 | state_std=1., 16 | ): 17 | 18 | model.eval() 19 | model.to(device=device) 20 | 21 | state_mean = torch.from_numpy(state_mean).to(device=device) 22 | state_std = torch.from_numpy(state_std).to(device=device) 23 | 24 | state = env.reset() 25 | 26 | # we keep all the histories on the device 27 | # note that the latest action and reward will be "padding" 28 | states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) 29 | actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) 30 | rewards = torch.zeros(0, device=device, dtype=torch.float32) 31 | target_return = torch.tensor(target_return, device=device, dtype=torch.float32) 32 | sim_states = [] 33 | 34 | episode_return, episode_length = 0, 0 35 | for t in range(max_ep_len): 36 | 37 | # add padding 38 | actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) 39 | rewards = torch.cat([rewards, torch.zeros(1, device=device)]) 40 | 41 | action = model.get_action( 42 | (states.to(dtype=torch.float32) - state_mean) / state_std, 43 | actions.to(dtype=torch.float32), 44 | rewards.to(dtype=torch.float32), 45 | target_return=target_return, 46 | ) 47 | actions[-1] = action 48 | action = action.detach().cpu().numpy() 49 | 50 | state, reward, done, _ = env.step(action) 51 | 52 | cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) 53 | states = torch.cat([states, cur_state], dim=0) 54 | rewards[-1] = reward 55 | 56 | episode_return += reward 57 | episode_length += 1 58 | 59 | if done: 60 | break 61 | 62 | return episode_return, episode_length 63 | 64 | 65 | def evaluate_episode_rtg( 66 | env, 67 | state_dim, 68 | act_dim, 69 | model, 70 | max_ep_len=1000, 71 | scale=1000., 72 | state_mean=0., 73 | state_std=1., 74 | device='cuda', 75 | target_return=None, 76 | mode='normal', 77 | ): 78 | 79 | model.eval() 80 | model.to(device=device) 81 | 82 | state_mean = torch.from_numpy(state_mean).to(device=device) 83 | state_std = torch.from_numpy(state_std).to(device=device) 84 | 85 | state = env.reset() 86 | if mode == 'noise': 87 | state = state + np.random.normal(0, 0.1, size=state.shape) 88 | 89 | # we keep all the histories on the device 90 | # note that the latest action and reward will be "padding" 91 | states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) 92 | actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) 93 | rewards = torch.zeros(0, device=device, dtype=torch.float32) 94 | 95 | ep_return = target_return 96 | target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1) 97 | timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1) 98 | 99 | sim_states = [] 100 | 101 | episode_return, episode_length = 0, 0 102 | for t in range(max_ep_len): 103 | 104 | # add padding 105 | actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) 106 | rewards = torch.cat([rewards, torch.zeros(1, device=device)]) 107 | 108 | action = model.get_action( 109 | (states.to(dtype=torch.float32) - state_mean) / state_std, 110 | actions.to(dtype=torch.float32), 111 | rewards.to(dtype=torch.float32), 112 | target_return.to(dtype=torch.float32), 113 | timesteps.to(dtype=torch.long), 114 | ) 115 | actions[-1] = action 116 | action = action.detach().cpu().numpy() 117 | 118 | state, reward, done, _ = env.step(action) 119 | 120 | cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) 121 | states = torch.cat([states, cur_state], dim=0) 122 | rewards[-1] = reward 123 | 124 | if mode != 'delayed': 125 | pred_return = target_return[0,-1] - (reward/scale) 126 | else: 127 | pred_return = target_return[0,-1] 128 | target_return = torch.cat( 129 | [target_return, pred_return.reshape(1, 1)], dim=1) 130 | timesteps = torch.cat( 131 | [timesteps, 132 | torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1) 133 | 134 | episode_return += reward 135 | episode_length += 1 136 | 137 | if done: 138 | break 139 | 140 | return episode_return, episode_length 141 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/models/TIT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import transformers 6 | 7 | from decision_transformer.models.model import TrajectoryModel 8 | from decision_transformer.models.trajectory_gpt2 import GPT2Model 9 | 10 | 11 | class InnerConfig: 12 | # try to keep the hyper-parameters as close as possible to the values shown in experiment.py 13 | def __init__(self): 14 | self.patch_dim = 11 15 | self.num_blocks = 1 16 | self.embed_dim_inner = 128 17 | self.num_heads_inner = 1 18 | self.attention_dropout_inner = 0.0 19 | self.ffn_dropout_inner = 0.0 20 | self.activation_fn_inner = nn.ReLU 21 | self.dim_expand_inner = 1 22 | self.have_position_encoding = False 23 | self.share_tit_blocks = False 24 | 25 | 26 | class InnerTransformerBlock(nn.Module): 27 | def __init__(self, config): 28 | super(InnerTransformerBlock, self).__init__() 29 | self.ln1 = nn.LayerNorm(config.embed_dim_inner) 30 | self.attention = nn.MultiheadAttention( 31 | embed_dim=config.embed_dim_inner, 32 | num_heads=config.num_heads_inner, 33 | dropout=config.attention_dropout_inner, 34 | batch_first=True, 35 | ) 36 | self.ln2 = nn.LayerNorm(config.embed_dim_inner) 37 | self.ffn = nn.Sequential( 38 | nn.Linear(config.embed_dim_inner, config.dim_expand_inner * config.embed_dim_inner), 39 | config.activation_fn_inner(), 40 | nn.Linear(config.dim_expand_inner * config.embed_dim_inner, config.embed_dim_inner), 41 | nn.Dropout(config.ffn_dropout_inner), 42 | ) 43 | 44 | def forward(self, x): 45 | x_ln1 = self.ln1(x) 46 | attn_outputs, attn_weights = self.attention(query=x_ln1, key=x_ln1, value=x_ln1) 47 | x = x + attn_outputs 48 | 49 | x_ln2 = self.ln2(x) 50 | ffn_outputs = self.ffn(x_ln2) 51 | x = x + ffn_outputs 52 | return x 53 | 54 | 55 | class TIT(TrajectoryModel): 56 | 57 | """ 58 | This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...) 59 | """ 60 | 61 | def __init__( 62 | self, 63 | state_dim, 64 | act_dim, 65 | hidden_size, 66 | max_length=None, 67 | max_ep_len=4096, 68 | action_tanh=True, 69 | **kwargs 70 | ): 71 | super().__init__(state_dim, act_dim, max_length=max_length) 72 | 73 | self.hidden_size = hidden_size 74 | config = transformers.GPT2Config( 75 | vocab_size=1, # doesn't matter -- we don't use the vocab 76 | n_embd=hidden_size, 77 | **kwargs 78 | ) 79 | 80 | # note: the only difference between this GPT2Model and the default Huggingface version 81 | # is that the positional embeddings are removed (since we'll add those ourselves) 82 | self.transformer = GPT2Model(config) 83 | 84 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size) 85 | self.embed_return = torch.nn.Linear(1, hidden_size) 86 | # self.embed_state = torch.nn.Linear(self.state_dim, hidden_size) 87 | self.embed_action = torch.nn.Linear(self.act_dim, hidden_size) 88 | # TIT note: change Linear state_embedding by InnerTransformer state_embedding 89 | # the key idea of TIT is that processing state with one Transformer, and processing sequential states(+action+return) by another Transformer 90 | inner_config = InnerConfig() 91 | inner_config.patch_dim = state_dim 92 | assert inner_config.embed_dim_inner == self.hidden_size 93 | print('inner_config.patch_dim ==>', inner_config.patch_dim, state_dim, self.hidden_size) 94 | self.inner_blocks = nn.ModuleList([InnerTransformerBlock(inner_config) for _ in range(inner_config.num_blocks)]) 95 | self.obs_patch_embed = nn.Conv1d( 96 | in_channels=1, 97 | out_channels=inner_config.embed_dim_inner, 98 | kernel_size=inner_config.patch_dim, 99 | stride=inner_config.patch_dim, 100 | bias=False, 101 | ) 102 | self.class_token_encoding = nn.Parameter(torch.zeros(1, 1, inner_config.embed_dim_inner)) 103 | nn.init.trunc_normal_(self.class_token_encoding, mean=0.0, std=0.02) 104 | 105 | self.embed_ln = nn.LayerNorm(hidden_size) 106 | 107 | # note: we don't predict states or returns for the paper 108 | self.predict_state = torch.nn.Linear(hidden_size, self.state_dim) 109 | self.predict_action = nn.Sequential( 110 | *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else [])) 111 | ) 112 | self.predict_return = torch.nn.Linear(hidden_size, 1) 113 | 114 | def _observation_patch_embedding(self, obs): 115 | B, context_len_outer, D = obs.size() 116 | B = B * context_len_outer # new_B 117 | obs = obs.view(B, D) 118 | obs = torch.unsqueeze(obs, dim=1) # (new_B, 1, D), first apply unsqueeze() before applying Conv1d() 119 | obs_patch_embedding = self.obs_patch_embed(obs) # shape is (new_B, out_C, out_length), 120 | # where out_C=embed_dim_inner, out_length=context_len_inner 121 | obs_patch_embedding = obs_patch_embedding.transpose(2, 1) # (new_B, context_len_inner, embed_dim_inner) 122 | return obs_patch_embedding 123 | 124 | def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): 125 | # print('TIT =======> states.shape', states.shape) # torch.Size([64, 20, 11]) 126 | 127 | batch_size, seq_length = states.shape[0], states.shape[1] 128 | 129 | if attention_mask is None: 130 | # attention mask for GPT: 1 if can be attended to, 0 if not 131 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) 132 | 133 | # embed each modality with a different head 134 | # state_embeddings = self.embed_state(states) 135 | action_embeddings = self.embed_action(actions) 136 | returns_embeddings = self.embed_return(returns_to_go) 137 | time_embeddings = self.embed_timestep(timesteps) 138 | # TIT note: change Linear state_embedding by InnerTransformer state_embedding 139 | # the key idea of TIT is that processing state with one Transformer, and processing sequential states(+action+return) by another Transformer 140 | patch_embeddings = self._observation_patch_embedding(states) 141 | # print('TIT =======> patch_embeddings.shape', patch_embeddings.shape) # torch.Size([1280, 11, 128]) where 1280=64*20 142 | context_len_inner = patch_embeddings.shape[1] 143 | inner_tokens = torch.cat([self.class_token_encoding.expand(batch_size*seq_length, -1, -1), patch_embeddings], dim=1) 144 | # print('TIT =======> inner_tokens.shape', inner_tokens.shape) # torch.Size([1280, 12, 128]) 145 | for inner_block in self.inner_blocks: 146 | inner_tokens = inner_block(inner_tokens) 147 | temp = inner_tokens.view(batch_size, seq_length, context_len_inner + 1, self.hidden_size) 148 | # print('TIT =======> temp.shape', temp.shape) # torch.Size([64, 20, 12, 128]) 149 | state_embeddings = temp[:, :, 0, :] # 0 means class_tokens, which serve as the input of outer DT 150 | # print('TIT =======> state_embeddings.shape', state_embeddings.shape) # torch.Size([64, 20, 128]) 151 | 152 | # time embeddings are treated similar to positional embeddings 153 | state_embeddings = state_embeddings + time_embeddings 154 | action_embeddings = action_embeddings + time_embeddings 155 | returns_embeddings = returns_embeddings + time_embeddings 156 | 157 | # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) 158 | # which works nice in an autoregressive sense since states predict actions 159 | stacked_inputs = torch.stack( 160 | (returns_embeddings, state_embeddings, action_embeddings), dim=1 161 | ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size) 162 | stacked_inputs = self.embed_ln(stacked_inputs) 163 | 164 | # to make the attention mask fit the stacked inputs, have to stack it as well 165 | stacked_attention_mask = torch.stack( 166 | (attention_mask, attention_mask, attention_mask), dim=1 167 | ).permute(0, 2, 1).reshape(batch_size, 3*seq_length) 168 | 169 | # we feed in the input embeddings (not word indices as in NLP) to the model 170 | transformer_outputs = self.transformer( 171 | inputs_embeds=stacked_inputs, 172 | attention_mask=stacked_attention_mask, 173 | ) 174 | x = transformer_outputs['last_hidden_state'] 175 | 176 | # reshape x so that the second dimension corresponds to the original 177 | # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t 178 | x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) 179 | 180 | # get predictions 181 | return_preds = self.predict_return(x[:,2]) # predict next return given state and action 182 | state_preds = self.predict_state(x[:,2]) # predict next state given state and action 183 | action_preds = self.predict_action(x[:,1]) # predict next action given state 184 | 185 | return state_preds, action_preds, return_preds 186 | 187 | def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs): 188 | # we don't care about the past rewards in this model 189 | 190 | states = states.reshape(1, -1, self.state_dim) 191 | actions = actions.reshape(1, -1, self.act_dim) 192 | returns_to_go = returns_to_go.reshape(1, -1, 1) 193 | timesteps = timesteps.reshape(1, -1) 194 | 195 | if self.max_length is not None: 196 | states = states[:,-self.max_length:] 197 | actions = actions[:,-self.max_length:] 198 | returns_to_go = returns_to_go[:,-self.max_length:] 199 | timesteps = timesteps[:,-self.max_length:] 200 | 201 | # pad all tokens to sequence length 202 | attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])]) 203 | attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) 204 | states = torch.cat( 205 | [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states], 206 | dim=1).to(dtype=torch.float32) 207 | actions = torch.cat( 208 | [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim), 209 | device=actions.device), actions], 210 | dim=1).to(dtype=torch.float32) 211 | returns_to_go = torch.cat( 212 | [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go], 213 | dim=1).to(dtype=torch.float32) 214 | timesteps = torch.cat( 215 | [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps], 216 | dim=1 217 | ).to(dtype=torch.long) 218 | else: 219 | attention_mask = None 220 | 221 | _, action_preds, return_preds = self.forward( 222 | states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs) 223 | 224 | return action_preds[0,-1] 225 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/models/decision_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import transformers 6 | 7 | from decision_transformer.models.model import TrajectoryModel 8 | from decision_transformer.models.trajectory_gpt2 import GPT2Model 9 | 10 | 11 | class DecisionTransformer(TrajectoryModel): 12 | 13 | """ 14 | This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...) 15 | """ 16 | 17 | def __init__( 18 | self, 19 | state_dim, 20 | act_dim, 21 | hidden_size, 22 | max_length=None, 23 | max_ep_len=4096, 24 | action_tanh=True, 25 | **kwargs 26 | ): 27 | super().__init__(state_dim, act_dim, max_length=max_length) 28 | 29 | self.hidden_size = hidden_size 30 | config = transformers.GPT2Config( 31 | vocab_size=1, # doesn't matter -- we don't use the vocab 32 | n_embd=hidden_size, 33 | **kwargs 34 | ) 35 | 36 | # note: the only difference between this GPT2Model and the default Huggingface version 37 | # is that the positional embeddings are removed (since we'll add those ourselves) 38 | self.transformer = GPT2Model(config) 39 | 40 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size) 41 | self.embed_return = torch.nn.Linear(1, hidden_size) 42 | self.embed_state = torch.nn.Linear(self.state_dim, hidden_size) 43 | self.embed_action = torch.nn.Linear(self.act_dim, hidden_size) 44 | 45 | self.embed_ln = nn.LayerNorm(hidden_size) 46 | 47 | # note: we don't predict states or returns for the paper 48 | self.predict_state = torch.nn.Linear(hidden_size, self.state_dim) 49 | self.predict_action = nn.Sequential( 50 | *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else [])) 51 | ) 52 | self.predict_return = torch.nn.Linear(hidden_size, 1) 53 | 54 | def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): 55 | 56 | batch_size, seq_length = states.shape[0], states.shape[1] 57 | 58 | if attention_mask is None: 59 | # attention mask for GPT: 1 if can be attended to, 0 if not 60 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) 61 | 62 | # embed each modality with a different head 63 | state_embeddings = self.embed_state(states) 64 | action_embeddings = self.embed_action(actions) 65 | returns_embeddings = self.embed_return(returns_to_go) 66 | time_embeddings = self.embed_timestep(timesteps) 67 | 68 | # time embeddings are treated similar to positional embeddings 69 | state_embeddings = state_embeddings + time_embeddings 70 | action_embeddings = action_embeddings + time_embeddings 71 | returns_embeddings = returns_embeddings + time_embeddings 72 | 73 | # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) 74 | # which works nice in an autoregressive sense since states predict actions 75 | stacked_inputs = torch.stack( 76 | (returns_embeddings, state_embeddings, action_embeddings), dim=1 77 | ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size) 78 | stacked_inputs = self.embed_ln(stacked_inputs) 79 | 80 | # to make the attention mask fit the stacked inputs, have to stack it as well 81 | stacked_attention_mask = torch.stack( 82 | (attention_mask, attention_mask, attention_mask), dim=1 83 | ).permute(0, 2, 1).reshape(batch_size, 3*seq_length) 84 | 85 | # we feed in the input embeddings (not word indices as in NLP) to the model 86 | transformer_outputs = self.transformer( 87 | inputs_embeds=stacked_inputs, 88 | attention_mask=stacked_attention_mask, 89 | ) 90 | x = transformer_outputs['last_hidden_state'] 91 | 92 | # reshape x so that the second dimension corresponds to the original 93 | # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t 94 | x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) 95 | 96 | # get predictions 97 | return_preds = self.predict_return(x[:,2]) # predict next return given state and action 98 | state_preds = self.predict_state(x[:,2]) # predict next state given state and action 99 | action_preds = self.predict_action(x[:,1]) # predict next action given state 100 | 101 | return state_preds, action_preds, return_preds 102 | 103 | def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs): 104 | # we don't care about the past rewards in this model 105 | 106 | states = states.reshape(1, -1, self.state_dim) 107 | actions = actions.reshape(1, -1, self.act_dim) 108 | returns_to_go = returns_to_go.reshape(1, -1, 1) 109 | timesteps = timesteps.reshape(1, -1) 110 | 111 | if self.max_length is not None: 112 | states = states[:,-self.max_length:] 113 | actions = actions[:,-self.max_length:] 114 | returns_to_go = returns_to_go[:,-self.max_length:] 115 | timesteps = timesteps[:,-self.max_length:] 116 | 117 | # pad all tokens to sequence length 118 | attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])]) 119 | attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) 120 | states = torch.cat( 121 | [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states], 122 | dim=1).to(dtype=torch.float32) 123 | actions = torch.cat( 124 | [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim), 125 | device=actions.device), actions], 126 | dim=1).to(dtype=torch.float32) 127 | returns_to_go = torch.cat( 128 | [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go], 129 | dim=1).to(dtype=torch.float32) 130 | timesteps = torch.cat( 131 | [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps], 132 | dim=1 133 | ).to(dtype=torch.long) 134 | else: 135 | attention_mask = None 136 | 137 | _, action_preds, return_preds = self.forward( 138 | states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs) 139 | 140 | return action_preds[0,-1] 141 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/models/mlp_bc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from decision_transformer.models.model import TrajectoryModel 6 | 7 | 8 | class MLPBCModel(TrajectoryModel): 9 | 10 | """ 11 | Simple MLP that predicts next action a from past states s. 12 | """ 13 | 14 | def __init__(self, state_dim, act_dim, hidden_size, n_layer, dropout=0.1, max_length=1, **kwargs): 15 | super().__init__(state_dim, act_dim) 16 | 17 | self.hidden_size = hidden_size 18 | self.max_length = max_length 19 | 20 | layers = [nn.Linear(max_length*self.state_dim, hidden_size)] 21 | for _ in range(n_layer-1): 22 | layers.extend([ 23 | nn.ReLU(), 24 | nn.Dropout(dropout), 25 | nn.Linear(hidden_size, hidden_size) 26 | ]) 27 | layers.extend([ 28 | nn.ReLU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_size, self.act_dim), 31 | nn.Tanh(), 32 | ]) 33 | 34 | self.model = nn.Sequential(*layers) 35 | 36 | def forward(self, states, actions, rewards, attention_mask=None, target_return=None): 37 | 38 | states = states[:,-self.max_length:].reshape(states.shape[0], -1) # concat states 39 | actions = self.model(states).reshape(states.shape[0], 1, self.act_dim) 40 | 41 | return None, actions, None 42 | 43 | def get_action(self, states, actions, rewards, **kwargs): 44 | states = states.reshape(1, -1, self.state_dim) 45 | if states.shape[1] < self.max_length: 46 | states = torch.cat( 47 | [torch.zeros((1, self.max_length-states.shape[1], self.state_dim), 48 | dtype=torch.float32, device=states.device), states], dim=1) 49 | states = states.to(dtype=torch.float32) 50 | _, actions, _ = self.forward(states, None, None, **kwargs) 51 | return actions[0,-1] 52 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/models/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TrajectoryModel(nn.Module): 7 | 8 | def __init__(self, state_dim, act_dim, max_length=None): 9 | super().__init__() 10 | 11 | self.state_dim = state_dim 12 | self.act_dim = act_dim 13 | self.max_length = max_length 14 | 15 | def forward(self, states, actions, rewards, masks=None, attention_mask=None): 16 | # "masked" tokens or unspecified inputs can be passed in as None 17 | return None, None, None 18 | 19 | def get_action(self, states, actions, rewards, **kwargs): 20 | # these will come as tensors on the correct device 21 | return torch.zeros_like(actions[-1]) 22 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/training/act_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from decision_transformer.training.trainer import Trainer 5 | 6 | 7 | class ActTrainer(Trainer): 8 | 9 | def train_step(self): 10 | states, actions, rewards, dones, rtg, _, attention_mask = self.get_batch(self.batch_size) 11 | state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards) 12 | 13 | state_preds, action_preds, reward_preds = self.model.forward( 14 | states, actions, rewards, attention_mask=attention_mask, target_return=rtg[:,0], 15 | ) 16 | 17 | act_dim = action_preds.shape[2] 18 | action_preds = action_preds.reshape(-1, act_dim) 19 | action_target = action_target[:,-1].reshape(-1, act_dim) 20 | 21 | loss = self.loss_fn( 22 | state_preds, action_preds, reward_preds, 23 | state_target, action_target, reward_target, 24 | ) 25 | self.optimizer.zero_grad() 26 | loss.backward() 27 | self.optimizer.step() 28 | 29 | return loss.detach().cpu().item() 30 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/training/seq_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from decision_transformer.training.trainer import Trainer 5 | 6 | 7 | class SequenceTrainer(Trainer): 8 | 9 | def train_step(self): 10 | states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size) 11 | action_target = torch.clone(actions) 12 | 13 | state_preds, action_preds, reward_preds = self.model.forward( 14 | states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask, 15 | ) 16 | 17 | act_dim = action_preds.shape[2] 18 | action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 19 | action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 20 | 21 | loss = self.loss_fn( 22 | None, action_preds, None, 23 | None, action_target, None, 24 | ) 25 | 26 | self.optimizer.zero_grad() 27 | loss.backward() 28 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25) 29 | self.optimizer.step() 30 | 31 | with torch.no_grad(): 32 | self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item() 33 | 34 | return loss.detach().cpu().item() 35 | -------------------------------------------------------------------------------- /DT_TIT/gym/decision_transformer/training/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import time 5 | 6 | 7 | class Trainer: 8 | 9 | def __init__(self, model, optimizer, batch_size, get_batch, loss_fn, scheduler=None, eval_fns=None): 10 | self.model = model 11 | self.optimizer = optimizer 12 | self.batch_size = batch_size 13 | self.get_batch = get_batch 14 | self.loss_fn = loss_fn 15 | self.scheduler = scheduler 16 | self.eval_fns = [] if eval_fns is None else eval_fns 17 | self.diagnostics = dict() 18 | 19 | self.start_time = time.time() 20 | 21 | def train_iteration(self, num_steps, iter_num=0, print_logs=False): 22 | 23 | train_losses = [] 24 | logs = dict() 25 | 26 | train_start = time.time() 27 | 28 | self.model.train() 29 | for _ in range(num_steps): 30 | train_loss = self.train_step() 31 | train_losses.append(train_loss) 32 | if self.scheduler is not None: 33 | self.scheduler.step() 34 | 35 | logs['time/training'] = time.time() - train_start 36 | 37 | eval_start = time.time() 38 | 39 | self.model.eval() 40 | for eval_fn in self.eval_fns: 41 | outputs = eval_fn(self.model) 42 | for k, v in outputs.items(): 43 | logs[f'evaluation/{k}'] = v 44 | 45 | logs['time/total'] = time.time() - self.start_time 46 | logs['time/evaluation'] = time.time() - eval_start 47 | logs['training/train_loss_mean'] = np.mean(train_losses) 48 | logs['training/train_loss_std'] = np.std(train_losses) 49 | 50 | for k in self.diagnostics: 51 | logs[k] = self.diagnostics[k] 52 | 53 | if print_logs: 54 | print('=' * 80) 55 | print(f'Iteration {iter_num}') 56 | for k, v in logs.items(): 57 | print(f'{k}: {v}') 58 | 59 | return logs 60 | 61 | def train_step(self): 62 | states, actions, rewards, dones, attention_mask, returns = self.get_batch(self.batch_size) 63 | state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards) 64 | 65 | state_preds, action_preds, reward_preds = self.model.forward( 66 | states, actions, rewards, masks=None, attention_mask=attention_mask, target_return=returns, 67 | ) 68 | 69 | # note: currently indexing & masking is not fully correct 70 | loss = self.loss_fn( 71 | state_preds, action_preds, reward_preds, 72 | state_target[:,1:], action_target, reward_target[:,1:], 73 | ) 74 | self.optimizer.zero_grad() 75 | loss.backward() 76 | self.optimizer.step() 77 | 78 | return loss.detach().cpu().item() 79 | -------------------------------------------------------------------------------- /DT_TIT/gym/log/read_log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from d3rlpy_benchmarks.utils import normalize_d4rl_score 3 | 4 | 5 | def read_target_log(file_name, info_type='return_mean'): 6 | value_list = [] 7 | with open(file_name, 'r') as fp: 8 | for line in fp.readlines(): 9 | if info_type in line: 10 | info, value = line.split(':') 11 | value = float(value) 12 | value_list.append(value) 13 | return value_list 14 | 15 | 16 | def read_all_log(): 17 | for env in ['halfcheetah', 'hopper', 'walker2d']: 18 | for dataset in ['medium', 'medium-replay']: 19 | for algo in ['dt', 'tit']: 20 | file_name = f'{env}-{dataset}-{algo}_log.txt' 21 | print('='*50, 'file_name ==>', file_name) 22 | value_list = read_target_log(file_name) 23 | assert len(value_list) == 60 24 | 25 | return_mean = np.mean([max(value_list[:20]), max(value_list[20:40]), max(value_list[40:])]) 26 | return_std = np.std([max(value_list[:20]), max(value_list[20:40]), max(value_list[40:])]) 27 | print('return_mean/std ==>', return_mean, return_std) 28 | 29 | # https://github.com/takuseno/d3rlpy-benchmarks/blob/main/d3rlpy_benchmarks/utils.py#L16 30 | normalized_return_mean = normalize_d4rl_score(env, score=return_mean) 31 | normalized_return_std = normalize_d4rl_score(env, score=return_std) 32 | print('normalized_return_mean/std ==>', normalized_return_mean, normalized_return_std) 33 | 34 | 35 | if __name__ == '__main__': 36 | read_all_log() 37 | 38 | -------------------------------------------------------------------------------- /DT_TIT/gym/readme-gym.md: -------------------------------------------------------------------------------- 1 | 2 | # OpenAI Gym 3 | 4 | ## Installation 5 | 6 | Experiments require MuJoCo. 7 | Follow the instructions in the [mujoco-py repo](https://github.com/openai/mujoco-py) to install. 8 | Then, dependencies can be installed with the following command: 9 | 10 | ``` 11 | conda env create -f conda_env.yml 12 | ``` 13 | 14 | ## Downloading datasets 15 | 16 | Datasets are stored in the `data` directory. 17 | Install the [D4RL repo](https://github.com/rail-berkeley/d4rl), following the instructions there. 18 | Then, run the following script in order to download the datasets and save them in our format: 19 | 20 | ``` 21 | python download_d4rl_datasets.py 22 | ``` 23 | 24 | ## Example usage 25 | 26 | Experiments can be reproduced with the following: 27 | 28 | ``` 29 | python experiment.py --env hopper --dataset medium --model_type dt 30 | ``` 31 | 32 | Adding `-w True` will log results to Weights and Biases. 33 | -------------------------------------------------------------------------------- /DT_TIT/gym/run.sh: -------------------------------------------------------------------------------- 1 | #nohup python -u experiment.py --env hopper --dataset medium --model_type tit --device cuda:1 > ./log/hopper-medium-tit_log.txt 2>&1 & 2 | #nohup python -u experiment.py --env hopper --dataset medium --model_type dt --device cuda:2 > ./log/hopper-medium-dt_log.txt 2>&1 & 3 | #nohup python -u experiment.py --env walker2d --dataset medium --model_type tit --device cuda:1 > ./log/walker2d-medium-tit_log.txt 2>&1 & 4 | #nohup python -u experiment.py --env walker2d --dataset medium --model_type dt --device cuda:2 > ./log/walker2d-medium-dt_log.txt 2>&1 & 5 | #nohup python -u experiment.py --env halfcheetah --dataset medium --model_type tit --device cuda:1 > ./log/halfcheetah-medium-tit_log.txt 2>&1 & 6 | #nohup python -u experiment.py --env halfcheetah --dataset medium --model_type dt --device cuda:2 > ./log/halfcheetah-medium-dt_log.txt 2>&1 & 7 | 8 | 9 | 10 | 11 | 12 | #nohup python -u experiment.py --env hopper --dataset medium-replay --model_type tit --device cuda:1 > ./log/hopper-medium-replay-tit_log.txt 2>&1 & 13 | #nohup python -u experiment.py --env hopper --dataset medium-replay --model_type dt --device cuda:2 > ./log/hopper-medium-replay-dt_log.txt 2>&1 & 14 | #nohup python -u experiment.py --env walker2d --dataset medium-replay --model_type tit --device cuda:3 > ./log/walker2d-medium-replay-tit_log.txt 2>&1 & 15 | #nohup python -u experiment.py --env walker2d --dataset medium-replay --model_type dt --device cuda:3 > ./log/walker2d-medium-replay-dt_log.txt 2>&1 & 16 | #nohup python -u experiment.py --env halfcheetah --dataset medium-replay --model_type tit --device cuda:3 > ./log/halfcheetah-medium-replay-tit_log.txt 2>&1 & 17 | #nohup python -u experiment.py --env halfcheetah --dataset medium-replay --model_type dt --device cuda:3 > ./log/halfcheetah-medium-replay-dt_log.txt 2>&1 & 18 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/.gitignore: -------------------------------------------------------------------------------- 1 | *log* -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/hyperparameter_final_enhanced_tit_cql.yaml: -------------------------------------------------------------------------------- 1 | halfcheetah-medium-v0: # 42.554218901217745 2 | n_timesteps: 500000 3 | patch_dim: 17 4 | num_heads_inner: 1 5 | attention_dropout_inner: 0.0 6 | ffn_dropout_inner: 0.0 7 | activation_fn_inner: 'gelu' 8 | activation_fn_outer: 'gelu' 9 | dim_expand_inner: 1 10 | dim_expand_outer: 1 11 | have_position_encoding: 0 12 | # the above is fixed while the following is tuned 13 | num_blocks: 2 14 | features_dim: 512 15 | embed_dim_inner: 256 16 | embed_dim_outer: 256 17 | num_heads_outer: 4 18 | attention_dropout_outer: 0.0 19 | ffn_dropout_outer: 0.1 20 | activation_fn_other: 'gelu' 21 | share_tit_blocks: 0 22 | 23 | 24 | 25 | hopper-medium-v0: # 100.57904232310528 26 | n_timesteps: 500000 27 | patch_dim: 11 28 | num_heads_inner: 1 29 | attention_dropout_inner: 0.0 30 | ffn_dropout_inner: 0.0 31 | activation_fn_inner: 'gelu' 32 | activation_fn_outer: 'gelu' 33 | dim_expand_inner: 1 34 | dim_expand_outer: 1 35 | have_position_encoding: 0 36 | # the above is fixed while the following is tuned 37 | num_blocks: 1 38 | features_dim: 512 39 | embed_dim_inner: 32 40 | embed_dim_outer: 32 41 | num_heads_outer: 2 42 | attention_dropout_outer: 0.0 43 | ffn_dropout_outer: 0.0 44 | activation_fn_other: 'relu' 45 | share_tit_blocks: 1 46 | 47 | 48 | 49 | walker2d-medium-v0: # 84.38059453016008 50 | n_timesteps: 500000 51 | patch_dim: 17 52 | num_heads_inner: 1 53 | attention_dropout_inner: 0.0 54 | ffn_dropout_inner: 0.0 55 | activation_fn_inner: 'gelu' 56 | activation_fn_outer: 'gelu' 57 | dim_expand_inner: 1 58 | dim_expand_outer: 1 59 | have_position_encoding: 0 60 | # the above is fixed while the following is tuned 61 | num_blocks: 1 62 | features_dim: 512 63 | embed_dim_inner: 32 64 | embed_dim_outer: 32 65 | num_heads_outer: 2 66 | attention_dropout_outer: 0.0 67 | ffn_dropout_outer: 0.0 68 | activation_fn_other: 'relu' 69 | share_tit_blocks: 1 70 | 71 | 72 | 73 | halfcheetah-medium-replay-v0: # 48.02615866146572 74 | n_timesteps: 500000 75 | patch_dim: 17 76 | num_heads_inner: 1 77 | attention_dropout_inner: 0.0 78 | ffn_dropout_inner: 0.0 79 | activation_fn_inner: 'gelu' 80 | activation_fn_outer: 'gelu' 81 | dim_expand_inner: 1 82 | dim_expand_outer: 1 83 | have_position_encoding: 0 84 | # the above is fixed while the following is tuned 85 | num_blocks: 2 86 | features_dim: 1024 87 | embed_dim_inner: 128 88 | embed_dim_outer: 128 89 | num_heads_outer: 2 90 | attention_dropout_outer: 0.0 91 | ffn_dropout_outer: 0.0 92 | activation_fn_other: 'tanh' 93 | share_tit_blocks: 1 94 | 95 | 96 | 97 | hopper-medium-replay-v0: # 101.5299377469072 98 | n_timesteps: 500000 99 | patch_dim: 11 100 | num_heads_inner: 1 101 | attention_dropout_inner: 0.0 102 | ffn_dropout_inner: 0.0 103 | activation_fn_inner: 'gelu' 104 | activation_fn_outer: 'gelu' 105 | dim_expand_inner: 1 106 | dim_expand_outer: 1 107 | have_position_encoding: 0 108 | # the above is fixed while the following is tuned 109 | num_blocks: 2 110 | features_dim: 1024 111 | embed_dim_inner: 128 112 | embed_dim_outer: 128 113 | num_heads_outer: 1 114 | attention_dropout_outer: 0.1 115 | ffn_dropout_outer: 0.0 116 | activation_fn_other: 'relu' 117 | share_tit_blocks: 1 118 | 119 | 120 | 121 | walker2d-medium-replay-v0: # 59.352175301737354 122 | n_timesteps: 500000 123 | patch_dim: 17 124 | num_heads_inner: 1 125 | attention_dropout_inner: 0.0 126 | ffn_dropout_inner: 0.0 127 | attention_dropout_outer: 0.0 128 | ffn_dropout_outer: 0.0 129 | activation_fn_inner: 'gelu' 130 | activation_fn_outer: 'gelu' 131 | dim_expand_inner: 1 132 | # the above is fixed while the following is tuned 133 | num_blocks: 2 134 | features_dim: 512 135 | embed_dim_inner: 32 136 | embed_dim_outer: 32 137 | num_heads_outer: 2 138 | activation_fn_other: 'relu' 139 | dim_expand_outer: 2 140 | have_position_encoding: 0 141 | share_tit_blocks: 1 142 | 143 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/hyperparameter_final_enhanced_tit_ppo.yaml: -------------------------------------------------------------------------------- 1 | # the performance of online-RL (e.g., PPO) is highly dependent on the evaluation environment (e.g., what kind of GPU), 2 | # so we opensource the hypertuning.py, and you need search your ouw hyperparameter for your own environment. 3 | 4 | MountainCar-v0: # -97.92 7.025 5 | n_timesteps: 100000 6 | patch_dim: 1 7 | num_blocks: 2 8 | attention_dropout_inner: 0.0 9 | ffn_dropout_inner: 0.0 10 | attention_dropout_outer: 0.0 11 | ffn_dropout_outer: 0.0 12 | activation_fn_inner: 'gelu' 13 | activation_fn_outer: 'gelu' 14 | dim_expand_inner: 4 15 | dim_expand_outer: 4 16 | have_position_encoding: 1 17 | share_tit_blocks: 0 18 | # the above is fixed while the following is tuned 19 | features_dim: 128 20 | embed_dim_inner: 16 21 | num_heads_inner: 4 22 | embed_dim_outer: 16 23 | num_heads_outer: 8 24 | activation_fn_other: 'tanh' 25 | 26 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/hyperparameter_final_vanilla_tit_ppo.yaml: -------------------------------------------------------------------------------- 1 | # the performance of online-RL (e.g., PPO) is highly dependent on the evaluation environment (e.g., what kind of GPU), 2 | # so we opensource the hypertuning.py, and you need search your ouw hyperparameter for your own environment. 3 | 4 | MountainCar-v0: # -96.64 7.867 5 | n_timesteps: 100000 6 | patch_dim: 1 7 | num_blocks: 2 8 | attention_dropout_inner: 0.0 9 | ffn_dropout_inner: 0.0 10 | attention_dropout_outer: 0.0 11 | ffn_dropout_outer: 0.0 12 | activation_fn_inner: 'gelu' 13 | activation_fn_outer: 'gelu' 14 | dim_expand_inner: 4 15 | dim_expand_outer: 4 16 | have_position_encoding: 1 17 | share_tit_blocks: 0 18 | # the above is fixed while the following is tuned 19 | features_dim: 64 20 | embed_dim_inner: 32 21 | num_heads_inner: 4 22 | embed_dim_outer: 32 23 | num_heads_outer: 8 24 | activation_fn_other: 'tanh' 25 | 26 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | from stable_baselines3 import PPO 4 | from stable_baselines3.common.utils import set_random_seed 5 | from stable_baselines3.common.env_util import make_atari_env 6 | from stable_baselines3.common.vec_env import VecFrameStack 7 | from stable_baselines3.common.evaluation import evaluate_policy 8 | from utils import linear_schedule, update_args, load_policy_kwargs 9 | from network import OFENetActorCriticPolicy, D2RLNetActorCriticPolicy 10 | 11 | 12 | def make_image_agent(env_name, algo, seed, args): 13 | log_folder = args.log_folder 14 | device = args.device 15 | 16 | policy_type = 'CnnPolicy' 17 | if algo == 'ppo': # use default setting 18 | # https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml 19 | # https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#id2 [Atari Games] 20 | env = VecFrameStack(make_atari_env(env_name, n_envs=8, seed=seed), n_stack=4) 21 | agent = PPO(policy=policy_type, env=env, tensorboard_log=log_folder, verbose=1, seed=seed, 22 | n_steps=128, n_epochs=4, batch_size=256, learning_rate=linear_schedule(2.5e-4), 23 | clip_range=linear_schedule(0.1), vf_coef=0.5, ent_coef=0.01, device=device) 24 | elif algo in ['resnet_ppo', 'catformer_ppo']: 25 | policy_kwargs = load_policy_kwargs(args) 26 | env = VecFrameStack(make_atari_env(env_name, n_envs=8, seed=seed), n_stack=4) 27 | agent = PPO(policy=policy_type, env=env, tensorboard_log=log_folder, verbose=1, seed=seed, 28 | n_steps=128, n_epochs=4, batch_size=256, learning_rate=linear_schedule(2.5e-4), 29 | clip_range=linear_schedule(0.1), vf_coef=0.5, ent_coef=0.01, device=device, 30 | policy_kwargs=policy_kwargs) 31 | else: # 'vanilla_tit_ppo', 'enhanced_tit_ppo' 32 | policy_kwargs = load_policy_kwargs(args) 33 | env = VecFrameStack(make_atari_env(env_name, n_envs=8, seed=seed), n_stack=4) # keep the same as SB3-PPO 34 | agent = PPO(policy=policy_type, env=env, tensorboard_log=log_folder, verbose=1, seed=seed, 35 | n_steps=128, n_epochs=4, batch_size=64, learning_rate=linear_schedule(2.5e-4), # remove 36 | clip_range=linear_schedule(0.1), vf_coef=0.5, ent_coef=0.01, device=device, 37 | policy_kwargs=policy_kwargs) 38 | # we don't use linear_schedule for learning_rate/clip_range, 39 | # we also use batch_size=64 rather than 256 to avoid CUDA OOM, 40 | # we will show that even with these un-tuned hyperparameters, TIT works well! 41 | 42 | return agent 43 | 44 | 45 | def make_array_agent(env_name, algo, seed, args): 46 | log_folder = args.log_folder 47 | device = args.device 48 | 49 | if algo == 'ppo': # use default setting 50 | policy_type = 'MlpPolicy' 51 | agent = PPO(policy=policy_type, env=env_name, tensorboard_log=log_folder, verbose=1, seed=seed, 52 | device=device) 53 | elif algo in ['ofe_ppo', 'd2rl_ppo']: 54 | policy_type = OFENetActorCriticPolicy if algo=='ofe_ppo' else D2RLNetActorCriticPolicy 55 | agent = PPO(policy=policy_type, env=env_name, tensorboard_log=log_folder, verbose=1, seed=seed, 56 | device=device) 57 | else: # 'vanilla_tit_ppo', 'enhanced_tit_ppo' 58 | policy_type = 'MlpPolicy' 59 | policy_kwargs = load_policy_kwargs(args) 60 | agent = PPO(policy=policy_type, env=env_name, tensorboard_log=log_folder, verbose=1, seed=seed, 61 | device=device, policy_kwargs=policy_kwargs) 62 | 63 | return agent 64 | 65 | 66 | def train(env_name, algo, n_timesteps, seed, args): 67 | args.log_folder = './log/' + env_name + '__' + algo + '__' + str(seed) + '__running' 68 | 69 | if 'NoFrameskip' in env_name: 70 | agent = make_image_agent(env_name, algo, seed, args) 71 | print('make_image_agent') 72 | else: 73 | agent = make_array_agent(env_name, algo, seed, args) 74 | print('==' * 20, 'policy structure ==>', agent.policy) 75 | print('==' * 20, 'number of parameters: %d' % sum(p.numel() for p in agent.policy.parameters())) 76 | print('==' * 20, 'observation_space.shape ==>', agent.get_env().observation_space.shape) # (4,) or (4, 84, 84) 77 | 78 | # agent.learn(total_timesteps=train_step, eval_freq=int(train_step // 4), 79 | # eval_env=gym.make(env_name), n_eval_episodes=100, eval_log_path=log_path) 80 | agent.learn(total_timesteps=n_timesteps) 81 | agent.save(args.log_folder + '/final_model') 82 | 83 | episode_rewards, episode_lengths = evaluate_policy(agent.policy, agent.get_env(), 84 | n_eval_episodes=100, return_episode_rewards=True) 85 | print('==' * 20, 'mean/std episode_rewards ==>', np.mean(episode_rewards), np.std(episode_rewards)) 86 | np.save(args.log_folder + '/eval_episode_rewards.npy', episode_rewards) 87 | np.save(args.log_folder + '/eval_episode_lengths.npy', episode_lengths) 88 | 89 | 90 | if __name__ == '__main__': 91 | # import torch 92 | # torch.set_num_threads(3) 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--env-name", help="Environment ID", type=str, default="CartPole-v1") 95 | parser.add_argument("--algo", help="RL Algorithm", type=str, default="enhanced_tit_ppo", 96 | choices=['ppo', 'ofe_ppo', 'd2rl_ppo', 'resnet_ppo', 'catformer_ppo', 'vanilla_tit_ppo', 'enhanced_tit_ppo']) 97 | parser.add_argument("--device", help="PyTorch device (ex: cpu, cuda:0, cuda:1, ...)", type=str, default="auto") 98 | parser.add_argument("--log-folder", help="Log folder", type=str, default="./log/") 99 | # 100 | parser.add_argument("--n-timesteps", help="Timesteps to run the env for one trial", type=int, default=100000) 101 | parser.add_argument("--patch-dim", help="patch_dim", type=int, default=6) 102 | parser.add_argument("--num-blocks", help="how many Transformer blocks to use", type=int, default=2) 103 | parser.add_argument("--features-dim", help="features_dim of last layer", type=int, default=64) 104 | parser.add_argument("--embed-dim-inner", help="embed_dim_inner", type=int, default=8) 105 | parser.add_argument("--num-heads-inner", help="num_heads_inner", type=int, default=4) 106 | parser.add_argument("--attention-dropout-inner", help="attention_dropout_inner", type=float, default=0.0) 107 | parser.add_argument("--ffn-dropout-inner", help="ffn_dropout_inner", type=float, default=0.0) 108 | parser.add_argument("--embed-dim-outer", help="embed_dim_outer", type=int, default=64) 109 | parser.add_argument("--num-heads-outer", help="num_heads_outer", type=int, default=4) 110 | parser.add_argument("--attention-dropout-outer", help="attention_dropout_outer", type=float, default=0.0) 111 | parser.add_argument("--ffn-dropout-outer", help="ffn_dropout_outer", type=float, default=0.0) 112 | parser.add_argument("--activation-fn-inner", help="activation_function_inner", default=None) 113 | parser.add_argument("--activation-fn-outer", help="activation_function_outer", default=None) 114 | parser.add_argument("--activation-fn-other", help="activation_function_other", default=None) 115 | args = parser.parse_args() 116 | 117 | if args.algo == 'ppo': 118 | env_name_list = [ 119 | 'Acrobot-v1', 'CartPole-v1', 'MountainCar-v0', 120 | 'Ant-v3', 'Hopper-v3', 'Walker2d-v3', 121 | 'BreakoutNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'PongNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 122 | ] 123 | train_step_list = [ 124 | 100000, 100000, 100000, 125 | 1000000, 1000000, 1000000, 126 | 10000000, 10000000, 10000000, 10000000, 127 | ] 128 | for env_name, n_timesteps in zip(env_name_list, train_step_list): 129 | for seed in range(5): 130 | set_random_seed(seed) 131 | train(env_name, args.algo, n_timesteps, seed, args) 132 | elif args.algo in ['ofe_ppo', 'd2rl_ppo']: 133 | env_name_list = [ 134 | 'Acrobot-v1', 'CartPole-v1', 'MountainCar-v0', 135 | 'Ant-v3', 'Hopper-v3', 'Walker2d-v3', 136 | ] 137 | train_step_list = [ 138 | 100000, 100000, 100000, 139 | 1000000, 1000000, 1000000, 140 | ] 141 | for env_name, n_timesteps in zip(env_name_list, train_step_list): 142 | for seed in range(5): 143 | set_random_seed(seed) 144 | train(env_name, args.algo, n_timesteps, seed, args) 145 | elif args.algo in ['resnet_ppo', 'catformer_ppo']: 146 | env_name_list = [ 147 | 'BreakoutNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'PongNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 148 | ] 149 | train_step_list = [ 150 | 10000000, 10000000, 10000000, 10000000, 151 | ] 152 | for env_name, n_timesteps in zip(env_name_list, train_step_list): 153 | for seed in range(5): 154 | set_random_seed(seed) 155 | train(env_name, args.algo, n_timesteps, seed, args) 156 | elif args.algo in ['vanilla_tit_ppo', 'enhanced_tit_ppo']: 157 | args = update_args(args) 158 | experiment_count = 2 if 'NoFrameskip' in args.env_name else 5 159 | for seed in range(experiment_count): 160 | set_random_seed(seed) 161 | train(args.env_name, args.algo, args.n_timesteps, seed, args) 162 | 163 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/offline_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import d3rlpy 3 | from sklearn.model_selection import train_test_split 4 | from offline_network import MyCustomEncoderFactory 5 | from utils import update_args 6 | 7 | 8 | def make_agent(args): 9 | if "medium-v0" in args.env_name: 10 | conservative_weight = 10.0 11 | else: 12 | conservative_weight = 5.0 13 | 14 | # https://d3rlpy.readthedocs.io/en/v1.1.1/references/network_architectures.html 15 | if args.algo in ['vanilla_tit_cql', 'enhanced_tit_cql']: 16 | actor_encoder = MyCustomEncoderFactory( 17 | algo=args.algo, 18 | patch_dim=args.patch_dim, 19 | num_blocks=args.num_blocks, 20 | features_dim=args.features_dim, 21 | embed_dim_inner=args.embed_dim_inner, 22 | num_heads_inner=args.num_heads_inner, 23 | attention_dropout_inner=args.attention_dropout_inner, 24 | ffn_dropout_inner=args.ffn_dropout_inner, 25 | embed_dim_outer=args.embed_dim_outer, 26 | num_heads_outer=args.num_heads_outer, 27 | attention_dropout_outer=args.attention_dropout_outer, 28 | ffn_dropout_outer=args.ffn_dropout_outer, 29 | activation_fn_inner=args.activation_fn_inner, 30 | activation_fn_outer=args.activation_fn_outer, 31 | activation_fn_other=args.activation_fn_other, 32 | dim_expand_inner=args.dim_expand_inner, 33 | dim_expand_outer=args.dim_expand_outer, 34 | have_position_encoding=args.have_position_encoding, 35 | share_tit_blocks=args.share_tit_blocks 36 | ) 37 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256]) 38 | agent = d3rlpy.algos.CQL(actor_learning_rate=1e-4, 39 | critic_learning_rate=3e-4, 40 | temp_learning_rate=1e-4, 41 | actor_encoder_factory=actor_encoder, 42 | critic_encoder_factory=encoder, 43 | batch_size=256, 44 | n_action_samples=10, 45 | alpha_learning_rate=0.0, 46 | conservative_weight=conservative_weight, 47 | use_gpu=d3rlpy.gpu.Device(idx=int(args.device))) 48 | elif args.algo == 'cql': 49 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256]) 50 | agent = d3rlpy.algos.CQL(actor_learning_rate=1e-4, 51 | critic_learning_rate=3e-4, 52 | temp_learning_rate=1e-4, 53 | actor_encoder_factory=encoder, 54 | critic_encoder_factory=encoder, 55 | batch_size=256, 56 | n_action_samples=10, 57 | alpha_learning_rate=0.0, 58 | conservative_weight=conservative_weight, 59 | use_gpu=args.device) 60 | 61 | print('agent ==>', agent) 62 | return agent 63 | 64 | 65 | def train(env_name, seed, args): 66 | # https://github.com/takuseno/d3rlpy/blob/master/reproductions/offline/cql.py 67 | dataset, env = d3rlpy.datasets.get_dataset(env_name) 68 | print("len(dataset):", len(dataset), type(dataset)) 69 | print("len(dataset[0]):", len(dataset), type(dataset[0])) 70 | 71 | # fix seed 72 | d3rlpy.seed(seed) 73 | env.seed(seed) 74 | 75 | agent = make_agent(args) 76 | agent.build_with_dataset(dataset) 77 | print('agent.impl._policy ==>', agent.impl._policy) 78 | print('agent.impl._q_func ==>', agent.impl._q_func) 79 | 80 | _, test_episodes = train_test_split(dataset, test_size=0.2) 81 | results = agent.fit(dataset, 82 | eval_episodes=test_episodes, 83 | n_steps=args.n_timesteps, 84 | n_steps_per_epoch=1000, 85 | save_interval=10, 86 | scorers={ 87 | 'environment': d3rlpy.metrics.evaluate_on_environment(env), 88 | 'value_scale': d3rlpy.metrics.average_value_estimation_scorer, 89 | }, 90 | experiment_name=f"CQL_{env_name}_{args.algo}_{seed}", 91 | logdir=args.log_folder, 92 | show_progress=args.show_progress) 93 | 94 | agent.save_policy(f"{args.log_folder}/CQL_{env_name}_{args.algo}_{seed}.pt") # save greedy-policy as TorchScript 95 | 96 | return results 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument("--env-name", help="Environment ID", type=str, default="halfcheetah-medium-v0", 102 | choices=['halfcheetah-medium-v0', 'hopper-medium-v0', 'walker2d-medium-v0', 103 | 'halfcheetah-medium-replay-v0', 'hopper-medium-replay-v0', 'walker2d-medium-replay-v0']) 104 | parser.add_argument("--algo", help="RL Algorithm", type=str, default="cql", choices=['cql', 'enhanced_tit_cql']) 105 | parser.add_argument("--device", help="PyTorch device (ex: cpu, cuda:0, cuda:1, ...)", default=True) 106 | parser.add_argument("--log-folder", help="Log folder", type=str, default="./log/") 107 | parser.add_argument("--show-progress", help="flag to show progress bar for iterations", default=True) 108 | # 109 | parser.add_argument("--n-timesteps", help="Timesteps to run the env for one trial", type=int, default=500000) 110 | parser.add_argument("--patch-dim", help="patch_dim", type=int, default=6) 111 | parser.add_argument("--num-blocks", help="how many Transformer blocks to use", type=int, default=2) 112 | parser.add_argument("--features-dim", help="features_dim of last layer", type=int, default=64) 113 | parser.add_argument("--embed-dim-inner", help="embed_dim_inner", type=int, default=8) 114 | parser.add_argument("--num-heads-inner", help="num_heads_inner", type=int, default=4) 115 | parser.add_argument("--attention-dropout-inner", help="attention_dropout_inner", type=float, default=0.0) 116 | parser.add_argument("--ffn-dropout-inner", help="ffn_dropout_inner", type=float, default=0.0) 117 | parser.add_argument("--embed-dim-outer", help="embed_dim_outer", type=int, default=64) 118 | parser.add_argument("--num-heads-outer", help="num_heads_outer", type=int, default=4) 119 | parser.add_argument("--attention-dropout-outer", help="attention_dropout_outer", type=float, default=0.0) 120 | parser.add_argument("--ffn-dropout-outer", help="ffn_dropout_outer", type=float, default=0.0) 121 | parser.add_argument("--activation-fn-inner", help="activation_function_inner", default=None) 122 | parser.add_argument("--activation-fn-outer", help="activation_function_outer", default=None) 123 | parser.add_argument("--activation-fn-other", help="activation_function_other", default=None) 124 | args = parser.parse_args() 125 | 126 | if args.algo == 'cql': 127 | env_name_list = [ 128 | 'halfcheetah-medium-v0', 'hopper-medium-v0', 'walker2d-medium-v0', 129 | 'halfcheetah-medium-replay-v0', 'hopper-medium-replay-v0', 'walker2d-medium-replay-v0' 130 | ] 131 | for env_name in env_name_list: 132 | for seed in range(5): 133 | train(env_name, seed, args) 134 | elif args.algo in ['vanilla_tit_cql', 'enhanced_tit_cql']: 135 | args = update_args(args) 136 | for seed in range(5): 137 | train(args.env_name, seed, args) 138 | 139 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/offline_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from d3rlpy.models.encoders import EncoderFactory 4 | from network import Config, TIT 5 | 6 | 7 | # https://d3rlpy.readthedocs.io/en/v1.1.1/references/network_architectures.html 8 | # [You can also build your own encoder factory.] 9 | # self-defined CQL: 1. define your own neural network 10 | class MyCustomEncoder(nn.Module): 11 | def __init__(self, observation_space, 12 | algo, 13 | patch_dim, num_blocks, features_dim, 14 | embed_dim_inner, num_heads_inner, attention_dropout_inner, ffn_dropout_inner, 15 | embed_dim_outer, num_heads_outer, attention_dropout_outer, ffn_dropout_outer, 16 | activation_fn_inner, activation_fn_outer, activation_fn_other, 17 | dim_expand_inner, dim_expand_outer, have_position_encoding, share_tit_blocks): 18 | super(MyCustomEncoder, self).__init__() 19 | self.features_dim = features_dim 20 | 21 | C, H, W, D = 0, 0, 0, 0 22 | print("observation_space:", observation_space) 23 | if len(observation_space) == 3: # (4, 84, 84) 24 | observation_type = 'image' 25 | C, H, W = observation_space[0], observation_space[1], observation_space[2] 26 | assert (H % patch_dim == 0) and (W % patch_dim == 0) 27 | context_len_inner = (H // patch_dim) * (W // patch_dim) 28 | n_stack = 1 # 4 29 | context_len_outer = n_stack 30 | elif len(observation_space) == 1: # (4,) 31 | observation_type = 'array' 32 | D = observation_space[0] 33 | # patch_dim = 1 34 | # assert patch_dim == 1 35 | # context_len_inner = D // patch_dim 36 | context_len_inner = int(np.ceil(D / patch_dim)) 37 | n_stack = 1 38 | context_len_outer = n_stack 39 | else: 40 | raise ValueError('len(observation_space.shape) should either be 1 or 3') 41 | config = Config(algo, 42 | patch_dim, 43 | num_blocks, 44 | features_dim, 45 | embed_dim_inner, 46 | num_heads_inner, 47 | attention_dropout_inner, 48 | ffn_dropout_inner, 49 | context_len_inner, 50 | embed_dim_outer, 51 | num_heads_outer, 52 | attention_dropout_outer, 53 | ffn_dropout_outer, 54 | context_len_outer, 55 | observation_type, 56 | C, H, W, D, 57 | activation_fn_inner, 58 | activation_fn_outer, 59 | activation_fn_other, 60 | dim_expand_inner, 61 | dim_expand_outer, 62 | have_position_encoding, 63 | share_tit_blocks) 64 | self.pure_transformer_backbone = TIT(config) 65 | 66 | def forward(self, observations): 67 | # print("observations.shape:", observations.shape) # torch.Size([32, 100]) torch.Size([32, 1, 84, 84]) 68 | return self.pure_transformer_backbone(observations) 69 | 70 | # THIS IS IMPORTANT! 71 | def get_feature_size(self): 72 | return self.features_dim 73 | 74 | 75 | # self-defined CQL: 2. define your own encoder factory 76 | class MyCustomEncoderFactory(EncoderFactory): 77 | TYPE = 'custom' # this is necessary 78 | 79 | def __init__(self, 80 | algo, 81 | patch_dim, num_blocks, features_dim, 82 | embed_dim_inner, num_heads_inner, attention_dropout_inner, ffn_dropout_inner, 83 | embed_dim_outer, num_heads_outer, attention_dropout_outer, ffn_dropout_outer, 84 | activation_fn_inner, activation_fn_outer, activation_fn_other, 85 | dim_expand_inner, dim_expand_outer, have_position_encoding, share_tit_blocks): 86 | self.algo = algo 87 | self.patch_dim = patch_dim 88 | self.num_blocks = num_blocks 89 | self.features_dim = features_dim 90 | self.embed_dim_inner = embed_dim_inner 91 | self.num_heads_inner = num_heads_inner 92 | self.attention_dropout_inner = attention_dropout_inner 93 | self.ffn_dropout_inner = ffn_dropout_inner 94 | self.embed_dim_outer = embed_dim_outer 95 | self.num_heads_outer = num_heads_outer 96 | self.attention_dropout_outer = attention_dropout_outer 97 | self.ffn_dropout_outer = ffn_dropout_outer 98 | self.activation_fn_inner = activation_fn_inner 99 | self.activation_fn_outer = activation_fn_outer 100 | self.activation_fn_other = activation_fn_other 101 | self.dim_expand_inner = dim_expand_inner 102 | self.dim_expand_outer = dim_expand_outer 103 | self.have_position_encoding = have_position_encoding 104 | self.share_tit_blocks = share_tit_blocks 105 | 106 | def create(self, observation_shape): 107 | return MyCustomEncoder( 108 | observation_shape, 109 | self.algo, 110 | self.patch_dim, self.num_blocks, self.features_dim, 111 | self.embed_dim_inner, self.num_heads_inner, self.attention_dropout_inner, self.ffn_dropout_inner, 112 | self.embed_dim_outer, self.num_heads_outer, self.attention_dropout_outer, self.ffn_dropout_outer, 113 | self.activation_fn_inner, self.activation_fn_outer, self.activation_fn_other, 114 | self.dim_expand_inner, self.dim_expand_outer, self.have_position_encoding, self.share_tit_blocks 115 | ) 116 | 117 | def get_params(self, deep=False): 118 | return {'feature_size': self.features_dim} 119 | 120 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/readme.md: -------------------------------------------------------------------------------- 1 | # Transformer in Transformer as Backbone for Deep Reinforcement Learning 2 | 3 | 4 | 5 | ## Overview 6 | This is the official implementation of Transformer in Transformer (TIT) for Deep Reinforcement Learning. 7 | Contains scripts to reproduce experiments of the online-RL (i.e., PPO and PPO_TIT) and offline-RL (i.e., CQL and CQL_TIT). 8 | 9 | 10 | 11 | ## Result Reproduction 12 | You can run ```bash run.sh``` to do hyperparameter search and result reproduction. 13 | In general, you can produce the results of offline-RL (i.e., CQL and CQL_TIT) and 14 | offline-SL (i.e., Decision Transformer and DT_TIT) easily. 15 | 16 | However, to reproduce the results of online-RL (i.e., PPO and PPO_TIT), you should 17 | search your own hyperparameters for your own environment. Because we found that the 18 | performance of online-RL is highly dependent on the evaluation environment (e.g., 19 | what kind of GPU). You can also verify this by running the ```check_reproduction_of_optimization()``` 20 | function in hypertuning.py. 21 | 22 | Note that, although PPO_TIT need a few effort to search the hyperparameters for good 23 | results, it doesn't need complex optimization skills. Moreover, for CQL_TIT and DT_TIT, 24 | we can get good results even didn't search the hyperparameters. 25 | 26 | 27 | 28 | ## Network Architecture 29 | We try to implement the network architectures of baseline methods as closely as possible 30 | to their original papers and open-source code repositories. If you have more correct 31 | implementations, we are happy to redo the experiments based on yours. 32 | 33 | Our TIT architecture can be found in the paper. 34 | 35 | 36 | 37 | ## Cite 38 | Please cite our paper as: 39 | ``` 40 | @article{mao2022TIT, 41 | title={Transformer in Transformer as Backbone for Deep Reinforcement Learning}, 42 | author={Mao, Hangyu and Zhao, Rui and Chen, Hao and Hao, Jianye and Chen, Yiqun and Li, Dong and Zhang, Junge and Xiao, Zhen}, 43 | journal={arXiv preprint arXiv:2212.14538}, 44 | year={2022} 45 | } 46 | ``` 47 | 48 | 49 | 50 | ## License 51 | 52 | MIT 53 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/requirements_offlineRL_CQL.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | aiohttp==3.8.3 3 | aiosignal==1.2.0 4 | ale-py==0.7.4 5 | alembic==1.8.1 6 | arch==5.3.0 7 | argcomplete==2.0.0 8 | astunparse==1.6.3 9 | async-timeout==4.0.2 10 | atari-py==0.2.9 11 | attrs==22.1.0 12 | autopage==0.5.1 13 | AutoROM==0.4.2 14 | AutoROM.accept-rom-license==0.4.2 15 | boto==2.49.0 16 | cachetools==5.2.0 17 | certifi==2022.9.24 18 | cffi==1.15.1 19 | charset-normalizer==2.1.1 20 | chex==0.1.5 21 | click==8.1.3 22 | cliff==4.1.0 23 | cloudpickle==2.2.0 24 | cmaes==0.9.0 25 | cmd2==2.4.2 26 | colorama==0.4.6 27 | colorlog==6.7.0 28 | commonmark==0.9.1 29 | configparser==5.3.0 30 | contourpy==1.0.6 31 | crcmod==1.7 32 | cryptography==38.0.3 33 | cycler==0.11.0 34 | Cython==0.29.32 35 | d3rlpy==1.1.1 36 | -e git+https://github.com/takuseno/d3rlpy-benchmarks@587ff2b25119a70af1c6049150b2e99da03ef407#egg=d3rlpy_benchmarks 37 | D4RL @ git+https://github.com/rail-berkeley/d4rl@4235ef21ac5ba35285ecfce133d9eff62f3490e5 38 | d4rl-atari @ git+https://github.com/takuseno/d4rl-atari@799428bbc570a224c2df58a78c878f3410b0fa59 39 | decorator==5.1.1 40 | dm-control==1.0.8 41 | dm-env==1.5 42 | dm-tree==0.1.7 43 | docker-pycreds==0.4.0 44 | dopamine-rl==4.0.6 45 | fasteners==0.18 46 | filelock==3.8.0 47 | flatbuffers==22.10.26 48 | flax==0.6.2 49 | fonttools==4.38.0 50 | frozenlist==1.3.1 51 | gast==0.4.0 52 | gcs-oauth2-boto-plugin==3.0 53 | gin-config==0.5.0 54 | gitdb==4.0.9 55 | GitPython==3.1.29 56 | glfw==2.5.5 57 | google-apitools==0.5.32 58 | google-auth==2.14.0 59 | google-auth-oauthlib==0.4.6 60 | google-pasta==0.2.0 61 | google-reauth==0.1.1 62 | gql==0.2.0 63 | graphql-core==1.1 64 | greenlet==2.0.1 65 | grpcio==1.50.0 66 | gsutil==5.16 67 | gym==0.21.0 68 | gym-notices==0.0.8 69 | h5py==3.7.0 70 | httplib2==0.21.0 71 | idna==3.4 72 | imageio==2.22.3 73 | importlib-metadata==4.13.0 74 | importlib-resources==5.10.0 75 | jax==0.3.25 76 | jaxlib==0.3.25 77 | joblib==1.2.0 78 | keras==2.11.0 79 | kiwisolver==1.4.4 80 | labmaze==1.0.5 81 | libclang==14.0.6 82 | lxml==4.9.1 83 | Mako==1.2.4 84 | Markdown==3.4.1 85 | MarkupSafe==2.1.1 86 | matplotlib==3.6.2 87 | mjrl @ git+https://github.com/aravindr93/mjrl@3871d93763d3b49c4741e6daeaebbc605fe140dc 88 | mkl-fft==1.3.1 89 | mkl-random==1.2.2 90 | mkl-service==2.4.0 91 | monotonic==1.6 92 | msgpack==1.0.4 93 | mujoco==2.3.0 94 | mujoco-py==2.1.2.14 95 | multidict==6.0.2 96 | numpy==1.23.4 97 | nvidia-cublas-cu11==11.10.3.66 98 | nvidia-cuda-nvrtc-cu11==11.7.99 99 | nvidia-cuda-runtime-cu11==11.7.99 100 | nvidia-cudnn-cu11==8.5.0.96 101 | nvidia-ml-py3==7.352.0 102 | oauth2client==4.1.3 103 | oauthlib==3.2.2 104 | olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work 105 | opencv-python==4.6.0.66 106 | opt-einsum==3.3.0 107 | optax==0.1.3 108 | optuna==3.0.4 109 | packaging==21.3 110 | pandas==1.5.1 111 | patsy==0.5.3 112 | pbr==5.11.0 113 | Pillow==9.3.0 114 | prettytable==3.5.0 115 | promise==2.3 116 | property-cached==1.6.4 117 | protobuf==3.19.6 118 | psutil==5.9.3 119 | pyasn1==0.4.8 120 | pyasn1-modules==0.2.8 121 | pybullet==3.2.5 122 | pycparser==2.21 123 | pygame==2.1.2 124 | Pygments==2.13.0 125 | PyOpenGL==3.1.6 126 | pyOpenSSL==22.1.0 127 | pyparsing==2.4.7 128 | pyperclip==1.8.2 129 | python-dateutil==2.8.2 130 | pytz==2022.6 131 | pyu2f==0.1.5 132 | PyYAML==6.0 133 | regex==2022.10.31 134 | requests==2.28.1 135 | requests-oauthlib==1.3.1 136 | retry-decorator==1.1.1 137 | rich==12.6.0 138 | rliable==1.0.8 139 | rsa==4.7.2 140 | sacremoses==0.0.53 141 | scikit-learn==1.1.3 142 | scipy==1.8.1 143 | seaborn==0.12.1 144 | sentry-sdk==1.11.0 145 | shortuuid==1.0.11 146 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 147 | smmap==5.0.0 148 | SQLAlchemy==1.4.44 149 | stable-baselines3==1.6.2 150 | statsmodels==0.13.5 151 | stevedore==4.1.1 152 | structlog==22.1.0 153 | subprocess32==3.5.4 154 | tensorboard==2.11.0 155 | tensorboard-data-server==0.6.1 156 | tensorboard-plugin-wit==1.8.1 157 | tensorboardX==2.5.1 158 | tensorflow==2.11.0 159 | tensorflow-estimator==2.11.0 160 | tensorflow-io-gcs-filesystem==0.27.0 161 | tensorflow-probability==0.18.0 162 | tensorstore==0.1.28 163 | termcolor==2.1.0 164 | tf-slim==1.1.0 165 | threadpoolctl==3.1.0 166 | tokenizers==0.10.3 167 | toolz==0.12.0 168 | torch==1.13.0 169 | torchaudio==0.10.1 170 | torchvision==0.11.2 171 | tqdm==4.64.1 172 | transformers==4.5.1 173 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1665144421445/work 174 | urllib3==1.26.12 175 | wandb==0.9.1 176 | watchdog==2.1.9 177 | wcwidth==0.2.5 178 | Werkzeug==2.2.2 179 | wrapt==1.14.1 180 | yarl==1.8.1 181 | zipp==3.10.0 182 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/requirements_onlineRL_PPO.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | alembic==1.7.7 3 | atari-py==0.2.5 4 | attrs==22.1.0 5 | autopage==0.5.1 6 | Box2D==2.3.10 7 | Box2D-kengz==2.3.3 8 | cached-property==1.5.2 9 | cachetools==4.2.4 10 | certifi==2021.5.30 11 | cffi==1.15.1 12 | charset-normalizer==2.0.12 13 | click==8.0.4 14 | cliff==3.10.1 15 | cloudpickle==1.6.0 16 | cmaes==0.8.2 17 | cmd2==2.4.2 18 | colorama==0.4.5 19 | colorlog==6.7.0 20 | conda-pack==0.6.0 21 | cycler==0.11.0 22 | Cython==0.29.32 23 | d3rlpy==1.0.0 24 | dataclasses @ file:///tmp/build/80754af9/dataclasses_1614363715916/work 25 | fasteners==0.18 26 | glfw==2.5.5 27 | google-auth==2.12.0 28 | google-auth-oauthlib==0.4.6 29 | GPUtil==1.4.0 30 | grad-cam==1.4.6 31 | greenlet==1.1.3 32 | grpcio==1.48.2 33 | gym==0.19.0 34 | h5py==3.1.0 35 | idna==3.4 36 | imageio==2.15.0 37 | importlib-metadata==4.8.3 38 | importlib-resources==5.4.0 39 | joblib==1.1.1 40 | kiwisolver==1.3.1 41 | Mako==1.1.6 42 | Markdown==3.3.7 43 | MarkupSafe==2.0.1 44 | matplotlib==3.3.4 45 | mkl-fft==1.3.0 46 | mkl-random==1.1.1 47 | mkl-service==2.3.0 48 | mujoco-py==2.1.2.14 49 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603487797006/work 50 | oauthlib==3.2.1 51 | olefile==0.46 52 | opencv-python==4.6.0.66 53 | optuna==2.10.1 54 | packaging==21.3 55 | pandas==1.1.5 56 | pbr==5.10.0 57 | Pillow==8.4.0 58 | prettytable==2.5.0 59 | protobuf==3.19.5 60 | psutil==5.9.2 61 | pyasn1==0.4.8 62 | pyasn1-modules==0.2.8 63 | pycparser==2.21 64 | pyparsing==3.0.9 65 | pyperclip==1.8.2 66 | python-dateutil==2.8.2 67 | pytz==2022.2.1 68 | PyYAML==6.0 69 | requests==2.27.1 70 | requests-oauthlib==1.3.1 71 | rsa==4.9 72 | scikit-learn==0.24.2 73 | scipy==1.5.4 74 | seaborn==0.11.2 75 | six @ file:///tmp/build/80754af9/six_1644875935023/work 76 | SQLAlchemy==1.4.41 77 | stable-baselines3==1.3.0 78 | stevedore==3.5.0 79 | structlog==21.5.0 80 | swig==4.0.2 81 | tensorboard==2.10.1 82 | tensorboard-data-server==0.6.1 83 | tensorboard-plugin-wit==1.8.1 84 | tensorboardX @ file:///tmp/build/80754af9/tensorboardx_1621440489103/work 85 | threadpoolctl==3.1.0 86 | torch==1.10.2 87 | torchaudio==0.10.2 88 | torchvision==0.11.3 89 | tqdm==4.64.1 90 | ttach==0.0.3 91 | typing_extensions @ file:///opt/conda/conda-bld/typing_extensions_1647553014482/work 92 | urllib3==1.26.12 93 | wcwidth==0.2.5 94 | Werkzeug==2.0.3 95 | zipp==3.6.0 96 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/results_offline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import glob 6 | import dataclasses 7 | from d3rlpy_benchmarks.data_loader import load_d4rl_score 8 | from d3rlpy_benchmarks.plot_utils import plot_score_curve 9 | from d3rlpy_benchmarks.utils import get_canonical_algo_name, normalize_d4rl_score 10 | 11 | 12 | @dataclasses.dataclass(frozen=True) 13 | class ScoreData: 14 | algo: str 15 | env: str 16 | dataset: str 17 | steps: np.ndarray 18 | raw_scores: np.ndarray 19 | scores: np.ndarray 20 | 21 | 22 | def load_my_score(algo: str, env: str, dataset: str, MY_DIR = "./d3rlpy_logs/"): 23 | # https://github.com/takuseno/d3rlpy-benchmarks/blob/main/d3rlpy_benchmarks/data_loader.py#L33 24 | score_list = [] 25 | step_list = [] 26 | for log_dir in glob.glob(os.path.join(MY_DIR, f"CQL_{env}-{dataset}_{algo}_*")): 27 | if log_dir.endswith('.pt'): 28 | continue 29 | with open(os.path.join(log_dir, "environment.csv"), "r") as f: 30 | data = np.loadtxt(f, delimiter=",", skiprows=1) 31 | if len(data[:, 2]) < 499: # discard incomplete data 32 | continue 33 | score_list.append(data[:, 2]) 34 | step_list.append(data[:, 1]) 35 | raw_scores = np.array(score_list) 36 | steps = np.array(step_list) 37 | 38 | if algo == 'cql': 39 | algo = 'CQL_Reproduced' 40 | elif algo == 'enhanced_tit_cql': 41 | algo = 'CQL_TIT_Enhanced' 42 | elif algo == 'vanilla_tit_cql': 43 | algo = 'CQL_TIT_Vanilla' 44 | 45 | return ScoreData( 46 | algo=get_canonical_algo_name(algo), 47 | env=env, 48 | dataset=dataset, 49 | steps=steps, 50 | raw_scores=raw_scores, 51 | scores=normalize_d4rl_score(env, raw_scores), 52 | ) 53 | 54 | 55 | def load_all_d4rl_scores(): 56 | env_name_list = ["halfcheetah", "hopper", "walker2d"] 57 | type_name_list = ["medium-v0", "medium-replay-v0"][1:] 58 | algo_name_list = ['CQL', 'cql', 'enhanced_tit_cql'] 59 | for env_name in env_name_list: 60 | for type_name in type_name_list: 61 | plt.cla() 62 | for algo_name in algo_name_list: 63 | try: 64 | print('=='*30, f'load results of {env_name}-{type_name}-{algo_name}') 65 | if algo_name == "CQL": # load baseline CQL 66 | score = load_d4rl_score(algo_name, env_name, type_name) # score.scores.shape ==> (10, 499) 67 | else: # load our implementation of algorithms 68 | score = load_my_score(algo_name, env_name, type_name, MY_DIR='./log/') 69 | print(score.scores.max(axis=1), np.mean(score.scores.max(axis=1)), np.std(score.scores.max(axis=1))) 70 | plot_score_curve(score, window_size=100) 71 | except: 72 | pass 73 | plt.plot() 74 | cur_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 75 | plt.savefig('./log/' + env_name + '_' + type_name + '_' + cur_time + '_learning_curve.pdf') 76 | 77 | 78 | def load_my_score_for_one_experiment(log_dir: str): 79 | score_list = [] 80 | step_list = [] 81 | with open(os.path.join(log_dir, "environment.csv"), "r") as f: 82 | data = np.loadtxt(f, delimiter=",", skiprows=1) 83 | score_list.append(data[:, 2]) 84 | step_list.append(data[:, 1]) 85 | raw_scores = np.array(score_list) 86 | steps = np.array(step_list) 87 | 88 | algo = log_dir.split('_')[-1] 89 | env_dataset = log_dir.split('_')[1] 90 | env = env_dataset.split('-')[0] 91 | dataset = env_dataset.split('-')[1] 92 | if algo == 'cql': 93 | algo = 'CQL_Reproduced' 94 | elif algo == 'enhanced_tit_cql': 95 | algo = 'CQL_TIT_Enhanced' 96 | elif algo == 'vanilla_tit_cql': 97 | algo = 'CQL_TIT_Vanilla' 98 | 99 | return ScoreData( 100 | algo=get_canonical_algo_name(algo), 101 | env=env, 102 | dataset=dataset, 103 | steps=steps, 104 | raw_scores=raw_scores, 105 | scores=normalize_d4rl_score(env, raw_scores), 106 | ) 107 | 108 | 109 | def load_all_d4rl_scores_for_one_experiment(): 110 | env_name_list = ["halfcheetah", "hopper", "walker2d"] 111 | type_name_list = ["medium-v0", "medium-replay-v0"][:1] 112 | algo_name_list = ['CQL', 'enhanced_tit_cql'] 113 | for env_name in env_name_list: 114 | for type_name in type_name_list: 115 | plt.cla() 116 | for algo_name in algo_name_list: 117 | print('=='*30, f'load results of {env_name}-{type_name}-{algo_name}') 118 | if algo_name == "CQL": # load baseline CQL 119 | score = load_d4rl_score(algo_name, env_name, type_name) # score.scores.shape ==> (10, 499) 120 | print(score.scores.max(axis=1), np.mean(score.scores.max(axis=1)), np.std(score.scores.max(axis=1))) 121 | # pass 122 | plot_score_curve(score, window_size=100) 123 | else: # load our implementation of algorithms 124 | for log_dir in glob.glob(os.path.join(f"./log/CQL_{env_name}-{type_name}_{algo_name}_*")): 125 | if log_dir.endswith('.pt'): 126 | continue 127 | print('log_dir ==>', log_dir) 128 | score = load_my_score_for_one_experiment(log_dir) 129 | print(score.scores.max(axis=1), np.mean(score.scores.max(axis=1)), np.std(score.scores.max(axis=1))) 130 | # pass 131 | plot_score_curve(score, window_size=100) 132 | plt.plot() 133 | cur_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 134 | plt.savefig('./log/' + env_name + '_' + type_name + '_' + cur_time + '_learning_curve.pdf') 135 | 136 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/results_online.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tensorboard.backend.event_processing import event_accumulator 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def read_evaluation_results(dir_name='./log/'): 8 | eval_log_path_list = os.listdir(dir_name) 9 | for log_path in sorted(eval_log_path_list): 10 | try: 11 | episode_rewards = np.load(dir_name + log_path + '/eval_episode_rewards.npy') 12 | print('==' * 20, 'eval_log_path ==>', log_path) 13 | print('mean/std rewards ==>', np.mean(episode_rewards), np.std(episode_rewards)) 14 | except: 15 | pass # if some experiments are still running, skip it 16 | 17 | 18 | def read_tensorboard_logs(dir_name='./log/', env_name='CartPole-v1'): 19 | def plot(steps_values, label): 20 | import matplotlib.pyplot as plt 21 | fig = plt.figure(figsize=(20, 16)) 22 | ax1 = fig.add_subplot(111) 23 | ax1.plot([i.step for i in steps_values], [i.value for i in steps_values], label=label) 24 | ax1.set_xlabel('training step') 25 | ax1.set_ylabel('mean episode reward') 26 | plt.legend(loc='lower right') 27 | plt.show() 28 | 29 | fig = plt.figure(figsize=(20, 16)) 30 | ax1 = fig.add_subplot(111) 31 | # file_name='./log/CartPole-v1__PPO__0/PPO_1/events.out.tfevents.1665235.8280L-SYS-7049GP-TRT.96687.0' 32 | for log_path in sorted(os.listdir(dir_name)): 33 | try: 34 | if env_name in log_path and 'hypertuning' not in log_path: 35 | dir_name_temp = dir_name + log_path + '/PPO_1/' 36 | file_name = dir_name_temp + os.listdir(dir_name_temp)[0] 37 | ea = event_accumulator.EventAccumulator(file_name) 38 | ea.Reload() 39 | print(ea.scalars.Keys()) # ['rollout/ep_len_mean', 'rollout/ep_rew_mean', 'train/approx_kl', ...] 40 | ep_rew_mean = ea.scalars.Items('rollout/ep_rew_mean') 41 | print('episode mean rewards ==>', [(i.step, i.value) for i in ep_rew_mean]) 42 | # plot(steps_values=ep_rew_mean, label=log_path) 43 | ax1.plot([i.step for i in ep_rew_mean], [i.value for i in ep_rew_mean], label=log_path) 44 | except: 45 | pass # if some experiments are still running, skip it 46 | ax1.set_xlabel('training step') 47 | ax1.set_ylabel('mean episode reward') 48 | plt.legend(loc='lower right') 49 | # plt.show() 50 | plt.savefig(f'./log/{env_name}_tensorboard.png') 51 | 52 | 53 | def read_hypertuning_tensorboard_logs(dir_name='./log/', env_name='CartPole-v1__PPO__0__hypertuning'): 54 | fig = plt.figure(figsize=(20, 16)) 55 | ax1 = fig.add_subplot(111) 56 | # file_name='./log/CartPole-v1__PPO__0__hypertuning/PPO_1/events.out.tfevents.1665235.8280L-SYS-7049GP-TRT.96687.0' 57 | for ppo_i in sorted(os.listdir(dir_name + env_name)): 58 | dir_name_temp = dir_name + env_name + '/' + ppo_i + '/' 59 | file_name = dir_name_temp + os.listdir(dir_name_temp)[0] 60 | ea = event_accumulator.EventAccumulator(file_name) 61 | ea.Reload() 62 | print(ea.scalars.Keys()) # ['rollout/ep_len_mean', 'rollout/ep_rew_mean', 'train/approx_kl', ...] 63 | ep_rew_mean = ea.scalars.Items('rollout/ep_rew_mean') 64 | print('episode mean rewards ==>', [(i.step, i.value) for i in ep_rew_mean]) 65 | ax1.plot([i.step for i in ep_rew_mean], [i.value for i in ep_rew_mean], label=dir_name_temp) 66 | ax1.set_xlabel('training step') 67 | ax1.set_ylabel('mean episode reward') 68 | plt.legend() 69 | # plt.show() 70 | plt.savefig(f'./log/{env_name}_tensorboard.png') 71 | 72 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/run.sh: -------------------------------------------------------------------------------- 1 | # run the baselines 2 | #screen python main.py --algo ppo --device cuda:0 --log-folder ./log/ 3 | #screen python main.py --algo ofe_ppo --device cuda:1 --log-folder ./log/ 4 | #screen python main.py --algo d2rl_ppo --device cuda:2 --log-folder ./log/ 5 | #screen python main.py --algo resnet_ppo --device cuda:3 --log-folder ./log/ 6 | 7 | 8 | 9 | 10 | 11 | # hypertuning & run the online-ClassicControl 12 | #nohup python hypertuning.py --env-name Acrobot-v1 --algo vanilla_tit_ppo --n-timesteps 100000 --device cuda:1 > ./log/TuningAcrobotVanilla.txt 2>&1 & 13 | #nohup python hypertuning.py --env-name Acrobot-v1 --algo enhanced_tit_ppo --n-timesteps 100000 --device cuda:1 > ./log/TuningAcrobotEnhanced.txt 2>&1 & 14 | #nohup python hypertuning.py --env-name CartPole-v1 --algo vanilla_tit_ppo --n-timesteps 100000 --device cuda:2 > ./log/TuningCartPoleVanilla.txt 2>&1 & 15 | #nohup python hypertuning.py --env-name CartPole-v1 --algo enhanced_tit_ppo --n-timesteps 100000 --device cuda:2 > ./log/TuningCartPoleEnhanced.txt 2>&1 & 16 | #nohup python hypertuning.py --env-name MountainCar-v0 --algo vanilla_tit_ppo --n-timesteps 100000 --device cuda:3 > ./log/TuningMountainCarVanilla.txt 2>&1 & 17 | #nohup python hypertuning.py --env-name MountainCar-v0 --algo enhanced_tit_ppo --n-timesteps 100000 --device cuda:3 > ./log/TuningMountainCarEnhanced.txt 2>&1 & 18 | 19 | #nohup python main.py --env-name Acrobot-v1 --algo vanilla_tit_ppo --device cuda:1 > ./log/RunningAcrobotVanilla.txt 2>&1 & 20 | #nohup python main.py --env-name Acrobot-v1 --algo enhanced_tit_ppo --device cuda:1 > ./log/RunningAcrobotEnhanced.txt 2>&1 & 21 | #nohup python main.py --env-name CartPole-v1 --algo vanilla_tit_ppo --device cuda:2 > ./log/RunningCartPoleVanilla.txt 2>&1 & 22 | #nohup python main.py --env-name CartPole-v1 --algo enhanced_tit_ppo --device cuda:2 > ./log/RunningCartPoleEnhanced.txt 2>&1 & 23 | #nohup python main.py --env-name MountainCar-v0 --algo vanilla_tit_ppo --device cuda:3 > ./log/RunningMountainCarVanilla.txt 2>&1 & 24 | #nohup python main.py --env-name MountainCar-v0 --algo enhanced_tit_ppo --device cuda:3 > ./log/RunningMountainCarEnhanced.txt 2>&1 & 25 | 26 | 27 | 28 | 29 | 30 | # hypertuning & run the online-MuJoCo 31 | #nohup python -u hypertuning.py --env-name Ant-v3 --algo vanilla_tit_ppo --n-timesteps 1000000 --device cuda:1 > ./log/TuningAntVanilla.txt 2>&1 & 32 | #nohup python -u hypertuning.py --env-name Ant-v3 --algo enhanced_tit_ppo --n-timesteps 1000000 --device cuda:1 > ./log/TuningAntEnhanced.txt 2>&1 & 33 | #nohup python -u hypertuning.py --env-name Hopper-v3 --algo vanilla_tit_ppo --n-timesteps 1000000 --device cuda:2 > ./log/TuningHopperVanilla.txt 2>&1 & 34 | #nohup python -u hypertuning.py --env-name Hopper-v3 --algo enhanced_tit_ppo --n-timesteps 1000000 --device cuda:2 > ./log/TuningHopperEnhanced.txt 2>&1 & 35 | #nohup python -u hypertuning.py --env-name Walker2d-v3 --algo vanilla_tit_ppo --n-timesteps 1000000 --device cuda:3 > ./log/TuningWalker2dVanilla.txt 2>&1 & 36 | #nohup python -u hypertuning.py --env-name Walker2d-v3 --algo enhanced_tit_ppo --n-timesteps 1000000 --device cuda:3 > ./log/TuningWalker2dEnhanced.txt 2>&1 & 37 | 38 | #nohup python main.py --env-name Ant-v3 --algo vanilla_tit_ppo --device cuda:1 > ./log/RunningAntVanilla.txt 2>&1 & 39 | #nohup python main.py --env-name Ant-v3 --algo enhanced_tit_ppo --device cuda:1 > ./log/RunningAntEnhanced.txt 2>&1 & 40 | #nohup python main.py --env-name Hopper-v3 --algo vanilla_tit_ppo --device cuda:2 > ./log/RunningHopperVanilla.txt 2>&1 & 41 | #nohup python main.py --env-name Hopper-v3 --algo enhanced_tit_ppo --device cuda:2 > ./log/RunningHopperEnhanced.txt 2>&1 & 42 | #nohup python main.py --env-name Walker2d-v3 --algo vanilla_tit_ppo --device cuda:3 > ./log/RunningWalker2dVanilla.txt 2>&1 & 43 | #nohup python main.py --env-name Walker2d-v3 --algo enhanced_tit_ppo --device cuda:3 > ./log/RunningWalker2dEnhanced.txt 2>&1 & 44 | 45 | 46 | 47 | 48 | 49 | # hypertuning & run the online-Atari 50 | #nohup python hypertuning.py --env-name BreakoutNoFrameskip-v4 --algo vanilla_tit_ppo --n-timesteps 1000000 --device cuda:0 > ./log/TuningBreakoutVanilla.txt 2>&1 & 51 | #nohup python hypertuning.py --env-name BreakoutNoFrameskip-v4 --algo enhanced_tit_ppo --n-timesteps 1000000 --device cuda:0 > ./log/TuningBreakoutEnhanced.txt 2>&1 & 52 | #nohup python hypertuning.py --env-name MsPacmanNoFrameskip-v4 --algo vanilla_tit_ppo --n-timesteps 1000000 --device cuda:1 > ./log/TuningMsPacmanVanilla.txt 2>&1 & 53 | #nohup python hypertuning.py --env-name MsPacmanNoFrameskip-v4 --algo enhanced_tit_ppo --n-timesteps 1000000 --device cuda:1 > ./log/TuningMsPacmanEnhanced.txt 2>&1 & 54 | #nohup python hypertuning.py --env-name PongNoFrameskip-v4 --algo vanilla_tit_ppo --n-timesteps 1000000 --device cuda:2 > ./log/TuningPongVanilla.txt 2>&1 & 55 | #nohup python hypertuning.py --env-name PongNoFrameskip-v4 --algo enhanced_tit_ppo --n-timesteps 1000000 --device cuda:2 > ./log/TuningPongEnhanced.txt 2>&1 & 56 | #nohup python hypertuning.py --env-name SpaceInvadersNoFrameskip-v4 --algo vanilla_tit_ppo --n-timesteps 1000000 --device cuda:3 > ./log/TuningSpaceInvadersVanilla.txt 2>&1 & 57 | #nohup python hypertuning.py --env-name SpaceInvadersNoFrameskip-v4 --algo enhanced_tit_ppo --n-timesteps 1000000 --device cuda:3 > ./log/TuningSpaceInvadersEnhanced.txt 2>&1 & 58 | 59 | #nohup python main.py --env-name BreakoutNoFrameskip-v4 --algo vanilla_tit_ppo --device cuda:0 > ./log/RunningBreakoutVanilla.txt 2>&1 & 60 | #nohup python main.py --env-name BreakoutNoFrameskip-v4 --algo enhanced_tit_ppo --device cuda:0 > ./log/RunningBreakoutEnhanced.txt 2>&1 & 61 | #nohup python main.py --env-name MsPacmanNoFrameskip-v4 --algo vanilla_tit_ppo --device cuda:1 > ./log/RunningMsPacmanVanilla.txt 2>&1 & 62 | #nohup python main.py --env-name MsPacmanNoFrameskip-v4 --algo enhanced_tit_ppo --device cuda:1 > ./log/RunningMsPacmanEnhanced.txt 2>&1 & 63 | #nohup python main.py --env-name PongNoFrameskip-v4 --algo vanilla_tit_ppo --device cuda:2 > ./log/RunningPongVanilla.txt 2>&1 & 64 | #nohup python main.py --env-name PongNoFrameskip-v4 --algo enhanced_tit_ppo --device cuda:2 > ./log/RunningPongEnhanced.txt 2>&1 & 65 | #nohup python main.py --env-name SpaceInvadersNoFrameskip-v4 --algo vanilla_tit_ppo --device cuda:3 > ./log/RunningSpaceInvadersVanilla.txt 2>&1 & 66 | #nohup python main.py --env-name SpaceInvadersNoFrameskip-v4 --algo enhanced_tit_ppo --device cuda:3 > ./log/RunningSpaceInvadersEnhanced.txt 2>&1 & 67 | 68 | 69 | 70 | 71 | 72 | # hypertuning & run the offline-MuJoCo-medium 73 | #nohup python -u offline_hypertuning.py --env-name halfcheetah-medium-v0 --algo enhanced_tit_cql --n-timesteps 500000 --device 1 > ./log/TuningHalfcheetahMediumEnhanced.txt 2>&1 & 74 | #nohup python -u offline_hypertuning.py --env-name hopper-medium-v0 --algo enhanced_tit_cql --n-timesteps 500000 --device 2 > ./log/TuningHopperMediumEnhanced.txt 2>&1 & 75 | #nohup python -u offline_hypertuning.py --env-name walker2d-medium-v0 --algo enhanced_tit_cql --n-timesteps 500000 --device 3 > ./log/TuningWalker2dMediumEnhanced.txt 2>&1 & 76 | 77 | #nohup python offline_main.py --env-name halfcheetah-medium-v0 --algo enhanced_tit_cql --device 1 > ./log/RunningHalfcheetahMediumEnhanced.txt 2>&1 & 78 | #nohup python offline_main.py --env-name hopper-medium-v0 --algo enhanced_tit_cql --device 2 > ./log/RunningHopperMediumEnhanced.txt 2>&1 & 79 | #nohup python offline_main.py --env-name walker2d-medium-v0 --algo enhanced_tit_cql --device 3 > ./log/RunningWalker2dMediumEnhanced.txt 2>&1 & 80 | 81 | 82 | 83 | 84 | 85 | # hypertuning & run the offline-MuJoCo-medium-replay 86 | #nohup python -u offline_hypertuning.py --env-name halfcheetah-medium-replay-v0 --algo enhanced_tit_cql --n-timesteps 500000 --device 1 > ./log/TuningHalfcheetahMediumReplayEnhanced.txt 2>&1 & 87 | #nohup python -u offline_hypertuning.py --env-name hopper-medium-replay-v0 --algo enhanced_tit_cql --n-timesteps 500000 --device 2 > ./log/TuningHopperMediumReplayEnhanced.txt 2>&1 & 88 | #nohup python -u offline_hypertuning.py --env-name walker2d-medium-replay-v0 --algo enhanced_tit_cql --n-timesteps 500000 --device 3 > ./log/TuningWalker2dMediumReplayEnhanced.txt 2>&1 & 89 | 90 | #nohup python offline_main.py --env-name halfcheetah-medium-replay-v0 --algo enhanced_tit_cql --device 1 > ./log/RunningHalfcheetahMediumReplayEnhanced.txt 2>&1 & 91 | #nohup python offline_main.py --env-name hopper-medium-replay-v0 --algo enhanced_tit_cql --device 2 > ./log/RunningHopperMediumReplayEnhanced.txt 2>&1 & 92 | #nohup python offline_main.py --env-name walker2d-medium-replay-v0 --algo enhanced_tit_cql --device 3 > ./log/RunningWalker2dMediumReplayEnhanced.txt 2>&1 & 93 | 94 | -------------------------------------------------------------------------------- /PPO_TIT_and_CQL_TIT/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch.nn as nn 4 | import yaml 5 | from network import TitFeaturesExtractor, ResnetFeaturesExtractor, CatformerFeaturesExtractor 6 | 7 | 8 | def linear_schedule(initial_value: float): 9 | # https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule 10 | def func(progress_remaining: float) -> float: 11 | return progress_remaining * initial_value 12 | return func 13 | 14 | 15 | def update_args(args, hyperparams=None): 16 | if hyperparams is None: 17 | yaml_file = './hyperparameter_final_' + args.algo + '.yaml' 18 | with open(yaml_file) as f: 19 | hyperparams_dict = yaml.safe_load(f) 20 | if args.env_name in list(hyperparams_dict.keys()): 21 | hyperparams = hyperparams_dict[args.env_name] 22 | else: 23 | raise ValueError(f'Hyperparameters not found for {args.algo}-{args.env_name}') 24 | print('the loaded hyperparams ==>', hyperparams) 25 | else: 26 | print('the given hyperparams ==>', hyperparams) 27 | 28 | args.n_timesteps = hyperparams['n_timesteps'] 29 | args.patch_dim = hyperparams['patch_dim'] 30 | args.num_blocks = hyperparams['num_blocks'] 31 | args.features_dim = hyperparams['features_dim'] 32 | args.embed_dim_inner = hyperparams['embed_dim_inner'] 33 | args.num_heads_inner = hyperparams['num_heads_inner'] 34 | args.attention_dropout_inner = hyperparams['attention_dropout_inner'] 35 | args.ffn_dropout_inner = hyperparams['ffn_dropout_inner'] 36 | args.embed_dim_outer = hyperparams['embed_dim_outer'] 37 | args.num_heads_outer = hyperparams['num_heads_outer'] 38 | args.attention_dropout_outer = hyperparams['attention_dropout_outer'] 39 | args.ffn_dropout_outer = hyperparams['ffn_dropout_outer'] 40 | activation_fn = {'tanh': nn.Tanh, 'relu': nn.ReLU, 'gelu': nn.GELU} 41 | args.activation_fn_inner = activation_fn[hyperparams['activation_fn_inner']] 42 | args.activation_fn_outer = activation_fn[hyperparams['activation_fn_outer']] 43 | args.activation_fn_other = activation_fn[hyperparams['activation_fn_other']] 44 | args.dim_expand_inner = hyperparams['dim_expand_inner'] 45 | args.dim_expand_outer = hyperparams['dim_expand_outer'] 46 | args.have_position_encoding = hyperparams['have_position_encoding'] 47 | args.share_tit_blocks = hyperparams['share_tit_blocks'] 48 | print('the updated args ==>', args) 49 | 50 | return args 51 | 52 | 53 | def load_policy_kwargs(args): 54 | if args.algo == 'resnet_ppo': 55 | policy_kwargs = dict( 56 | features_extractor_class=ResnetFeaturesExtractor, 57 | features_extractor_kwargs=dict(features_dim=512), 58 | net_arch=[], 59 | ) 60 | elif args.algo == 'catformer_ppo': 61 | policy_kwargs = dict( 62 | features_extractor_class=CatformerFeaturesExtractor, 63 | features_extractor_kwargs=dict(features_dim=512), 64 | net_arch=[], 65 | ) 66 | else: # vanilla_tit and enhanced_tit 67 | policy_kwargs = dict( 68 | features_extractor_class=TitFeaturesExtractor, 69 | features_extractor_kwargs=dict( 70 | algo=args.algo, 71 | patch_dim=args.patch_dim, 72 | num_blocks=args.num_blocks, 73 | features_dim=args.features_dim, 74 | embed_dim_inner=args.embed_dim_inner, 75 | num_heads_inner=args.num_heads_inner, 76 | attention_dropout_inner=args.attention_dropout_inner, 77 | ffn_dropout_inner=args.ffn_dropout_inner, 78 | embed_dim_outer=args.embed_dim_outer, 79 | num_heads_outer=args.num_heads_outer, 80 | attention_dropout_outer=args.attention_dropout_outer, 81 | ffn_dropout_outer=args.ffn_dropout_outer, 82 | activation_fn_inner=args.activation_fn_inner, 83 | activation_fn_outer=args.activation_fn_outer, 84 | activation_fn_other=args.activation_fn_other, 85 | dim_expand_inner=args.dim_expand_inner, 86 | dim_expand_outer=args.dim_expand_outer, 87 | have_position_encoding=args.have_position_encoding, 88 | share_tit_blocks=args.share_tit_blocks, 89 | ), 90 | net_arch=[], 91 | ) 92 | return policy_kwargs 93 | 94 | -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode 3 | log/ 4 | -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/algo/bc.yaml: -------------------------------------------------------------------------------- 1 | # behavior clone 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | model_type: "bc" 14 | length_times: 2 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/algo/bc_tit.yaml: -------------------------------------------------------------------------------- 1 | # behavior clone 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | model_type: "bc" 14 | length_times: 2 15 | tit: True -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/algo/dt.yaml: -------------------------------------------------------------------------------- 1 | # decision transformer 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | model_type: "dt" 14 | length_times: 3 15 | -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/algo/dt_tit.yaml: -------------------------------------------------------------------------------- 1 | # decision transformer 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | model_type: "dt" 14 | length_times: 3 15 | tit: True -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/algo/mgdt.yaml: -------------------------------------------------------------------------------- 1 | # multi-game decision transformer 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | model_type: "mgdt" 14 | length_times: 4 15 | sample_return: True 16 | num_sample_return: 256 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/algo/mgdt_wo_sample.yaml: -------------------------------------------------------------------------------- 1 | # multi-game decision transformer 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | model_type: "mgdt" 14 | length_times: 4 15 | sample_return: False 16 | num_sample_return: 256 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/default.yaml: -------------------------------------------------------------------------------- 1 | device: 'cuda' 2 | log_to_tensorboard: True 3 | num_steps_per_iter: 1000 4 | max_iters: 50 5 | num_eval_episodes: 100 6 | warmup_steps: 10000 7 | weight_decay: 0.0001 8 | learning_rate: 0.0001 9 | dataset_path: "./demos" 10 | text_max_size: 100 11 | word_embedding_size: 32 12 | env_targets: [1, 2] 13 | save_model: True 14 | step_num: 15 | 16 | tit: False 17 | inner: 18 | n_ctx: 1024 19 | n_embd: 128 20 | hidden_size: 128 21 | n_layer: 3 22 | n_head: 1 23 | n_inner: 512 24 | activation_function: "relu" 25 | n_position: 1024 26 | resid_pdrop: 0.1 27 | attn_pdrop: 0.1 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/env/BabyAI-BossLevel-v0.yaml: -------------------------------------------------------------------------------- 1 | data_name: ["BabyAI-BossLevel-v0"] 2 | env_name: ["BabyAI-BossLevel-v0"] 3 | pct_traj: 1.0 4 | K: 20 5 | batch_size: 64 6 | max_ep_len: 1000 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/env/BabyAI-GoToObj-v0.yaml: -------------------------------------------------------------------------------- 1 | data_name: ["BabyAI-GoToObj-v0"] 2 | env_name: ["BabyAI-GoToObj-v0"] 3 | pct_traj: 1.0 4 | K: 20 5 | batch_size: 64 6 | max_ep_len: 1000 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/env/BabyAI-GoToRedBall-v0.yaml: -------------------------------------------------------------------------------- 1 | data_name: ["BabyAI-GoToRedBall-v0"] 2 | env_name: ["BabyAI-GoToRedBall-v0"] 3 | pct_traj: 1.0 4 | K: 20 5 | batch_size: 64 6 | max_ep_len: 1000 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/env/BabyAI-GoToRedBallGrey-v0.yaml: -------------------------------------------------------------------------------- 1 | data_name: ["BabyAI-GoToRedBallGrey-v0"] 2 | env_name: ["BabyAI-GoToRedBallGrey-v0"] 3 | pct_traj: 1.0 4 | K: 20 5 | batch_size: 64 6 | max_ep_len: 1000 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/env/BabyAI-GoToSeq-v0.yaml: -------------------------------------------------------------------------------- 1 | data_name: ["BabyAI-GoToSeq-v0"] 2 | env_name: ["BabyAI-GoToSeq-v0"] 3 | pct_traj: 1.0 4 | K: 20 5 | batch_size: 64 6 | max_ep_len: 1000 7 | -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/config/env/Mix-GoTo.yaml: -------------------------------------------------------------------------------- 1 | data_name: ["BabyAI-GoToSeq-v0", "BabyAI-GoToRedBall-v0", "BabyAI-GoToObj-v0"] 2 | env_name: ["BabyAI-GoToSeq-v0", "BabyAI-GoToRedBall-v0", "BabyAI-GoToObj-v0"] 3 | pct_traj: 1.0 4 | K: 20 5 | batch_size: 64 6 | max_ep_len: 1000 7 | step_num: 5000000 -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/demos/BabyAI-BossLevel-v0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maohangyu/TIT_open_source/e1bc0aab48166dfa80f2520f9a03b5b7a9392df8/RL_Foundation_BabyAI_including_DT_GATO_and_TIT/demos/BabyAI-BossLevel-v0.pkl -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on 3 | https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py#L166 4 | https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/evaluation/evaluate_episodes.py 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import gym 12 | from utils import preprocess_texts 13 | 14 | 15 | class Evaluation(object): 16 | def __init__(self, config, vocab): 17 | self.config = config 18 | self.device = config.get('device', 'cuda') 19 | self.env_name = config['env_name'] 20 | self.num_eval_episodes = config['num_eval_episodes'] 21 | self.model_type = config['model_type'] 22 | 23 | self.env = gym.make(self.env_name[0], disable_env_checker=True) 24 | self.image_dim = self.env.observation_space['image'].shape 25 | self.act_dim = self.env.action_space.n 26 | 27 | self.vocab = vocab 28 | self.max_ep_len = config['max_ep_len'] 29 | 30 | def evaluate_episode(self, model, env): 31 | self.env = gym.make(env, disable_env_checker=True) 32 | model.eval() 33 | model.to(device=self.device) 34 | 35 | state = self.env.reset() 36 | image = state['image'] 37 | mission = state['mission'] 38 | mission, _ = preprocess_texts([mission], self.vocab) 39 | 40 | # we keep all the histories on the device 41 | # note that the latest action and reward will be "padding" 42 | images = torch.from_numpy(image).reshape(1, *self.image_dim).to(device=self.device, dtype=torch.float32) 43 | missions = torch.from_numpy(mission).to(device=self.device, dtype=torch.long) 44 | actions = torch.zeros((0, self.act_dim), device=self.device, dtype=torch.float32) 45 | rewards = torch.zeros(0, device=self.device, dtype=torch.float32) 46 | returns = torch.zeros(0, device=self.device, dtype=torch.float32) 47 | timesteps = torch.tensor(0, device=self.device, dtype=torch.long).reshape(1, 1) 48 | 49 | episode_return, episode_length = 0, 0 50 | for t in range(self.max_ep_len): 51 | # add padding 52 | actions = torch.cat([actions, torch.zeros((1, self.act_dim), device=self.device)], dim=0) 53 | rewards = torch.cat([rewards, torch.zeros(1, device=self.device)]) 54 | returns = torch.cat([returns, torch.zeros(1, device=self.device)]) 55 | 56 | if self.config['model_type'] in ['bc']: 57 | action = model.get_action( 58 | images.to(dtype=torch.float32), 59 | missions.to(dtype=torch.long), 60 | actions.to(dtype=torch.float32), 61 | rewards.to(dtype=torch.float32), 62 | returns.to(dtype=torch.float32), 63 | timesteps.to(dtype=torch.long), 64 | ) 65 | elif self.config['model_type'] in ['mgdt']: 66 | _, ret = model.get_action( 67 | images.to(dtype=torch.float32), 68 | missions.to(dtype=torch.long), 69 | actions.to(dtype=torch.float32), 70 | rewards.to(dtype=torch.float32), 71 | returns.to(dtype=torch.float32), 72 | timesteps.to(dtype=torch.long), 73 | ) 74 | if self.config['sample_return'] == True: 75 | eps = torch.randn(self.config['num_sample_return'], 1).to(ret[1].device) 76 | ret_tmp = ret[0] + eps * torch.exp(0.5 * ret[1]) 77 | ret = ret_tmp.max(0)[0] 78 | returns[-1] = ret 79 | action, _ = model.get_action( 80 | images.to(dtype=torch.float32), 81 | missions.to(dtype=torch.long), 82 | actions.to(dtype=torch.float32), 83 | rewards.to(dtype=torch.float32), 84 | returns.to(dtype=torch.float32), 85 | timesteps.to(dtype=torch.long), 86 | ) 87 | actions[-1] = F.one_hot(action.argmax(), self.act_dim) 88 | action = action.argmax().detach().cpu().numpy() 89 | 90 | state, reward, done, _ = self.env.step(action) 91 | if reward > 0: reward = max(0, 1 - 0.9 * ((t + 1) / self.max_ep_len)) 92 | 93 | cur_image = torch.from_numpy(state['image']).to(device=self.device).reshape(1, *self.image_dim) 94 | images = torch.cat([images, cur_image], dim=0) 95 | missions = torch.cat([missions, torch.from_numpy(mission).to(device=self.device, dtype=torch.long)], dim=0) 96 | rewards[-1] = reward 97 | timesteps = torch.cat([timesteps, torch.ones((1, 1), device=self.device, dtype=torch.long) * (t + 1)], dim=1) 98 | 99 | episode_return += reward 100 | episode_length += 1 101 | 102 | if done: 103 | break 104 | 105 | return episode_return, episode_length 106 | 107 | def evaluate_episode_rtg(self, model, env, target_return=None): 108 | self.env = gym.make(env, disable_env_checker=True) 109 | model.eval() 110 | model.to(device=self.device) 111 | 112 | state = self.env.reset() 113 | image = state['image'] 114 | mission = state['mission'] 115 | mission, _ = preprocess_texts([mission], self.vocab) 116 | 117 | # we keep all the histories on the device 118 | # note that the latest action and reward will be "padding" 119 | images = torch.from_numpy(image).reshape(1, *self.image_dim).to(device=self.device, dtype=torch.float32) 120 | missions = torch.from_numpy(mission).to(device=self.device, dtype=torch.long) 121 | actions = torch.zeros((0, self.act_dim), device=self.device, dtype=torch.float32) 122 | rewards = torch.zeros(0, device=self.device, dtype=torch.float32) 123 | 124 | ep_return = target_return 125 | target_return = torch.tensor(ep_return, device=self.device, dtype=torch.float32).reshape(1, 1) 126 | timesteps = torch.tensor(0, device=self.device, dtype=torch.long).reshape(1, 1) 127 | 128 | episode_return, episode_length = 0, 0 129 | for t in range(self.max_ep_len): 130 | # add padding 131 | actions = torch.cat([actions, torch.zeros((1, self.act_dim), device=self.device)], dim=0) 132 | rewards = torch.cat([rewards, torch.zeros(1, device=self.device)]) 133 | 134 | action = model.get_action( 135 | images.to(dtype=torch.float32), 136 | missions.to(dtype=torch.long), 137 | actions.to(dtype=torch.float32), 138 | rewards.to(dtype=torch.float32), 139 | target_return.to(dtype=torch.float32), 140 | timesteps.to(dtype=torch.long), 141 | ) 142 | actions[-1] = F.one_hot(action.argmax(), self.act_dim) 143 | action = action.argmax().detach().cpu().numpy() 144 | 145 | state, reward, done, _ = self.env.step(action) 146 | if reward > 0: reward = max(0, 1 - 0.9 * ((t + 1) / self.max_ep_len)) 147 | 148 | cur_image = torch.from_numpy(state['image']).to(device=self.device).reshape(1, *self.image_dim) 149 | images = torch.cat([images, cur_image], dim=0) 150 | missions = torch.cat([missions, torch.from_numpy(mission).to(device=self.device, dtype=torch.long)], dim=0) 151 | rewards[-1] = reward 152 | 153 | pred_return = target_return[0,-1] - reward 154 | target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1) 155 | timesteps = torch.cat([timesteps, torch.ones((1, 1), device=self.device, dtype=torch.long) * (t + 1)], dim=1) 156 | 157 | episode_return += reward 158 | episode_length += 1 159 | 160 | if done: 161 | break 162 | 163 | return episode_return, episode_length 164 | 165 | def eval_fn(self, target_rew): 166 | def fn(model): 167 | returns, lengths = [[] for _ in range(len(self.env_name))], [[] for _ in range(len(self.env_name))] 168 | successes = [[] for _ in range(len(self.env_name))] 169 | for _ in range(self.num_eval_episodes): 170 | with torch.no_grad(): 171 | if self.model_type in ['dt']: 172 | for i, env in enumerate(self.env_name): 173 | ret, length = self.evaluate_episode_rtg(model, env, target_return=target_rew) 174 | returns[i].append(ret) 175 | successes[i].append(1 if ret > 0 else 0) 176 | lengths[i].append(length) 177 | else: 178 | for i, env in enumerate(self.env_name): 179 | ret, length = self.evaluate_episode(model, env) 180 | returns[i].append(ret) 181 | successes[i].append(1 if ret > 0 else 0) 182 | lengths[i].append(length) 183 | log = {} 184 | for i, env in enumerate(self.env_name): 185 | log[f'{env}_target_{target_rew}_return_mean'] = np.mean(returns[i]) 186 | log[f'{env}_target_{target_rew}_return_std'] = np.std(returns[i]) 187 | log[f'{env}_target_{target_rew}_successes'] = np.mean(successes[i]) 188 | log[f'{env}_target_{target_rew}_length_mean'] = np.mean(lengths[i]) 189 | log[f'{env}_target_{target_rew}_length_std'] = np.std(lengths[i]) 190 | return log 191 | return fn 192 | 193 | -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py#L208 3 | """ 4 | 5 | # import wandb 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | 9 | import argparse 10 | import yaml 11 | import os 12 | import numpy as np 13 | 14 | from network import DecisionTransformer, TIT_DecisionTransformer 15 | from trainner import Trainer 16 | from evaluation import Evaluation 17 | from utils import SequenceDataset 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--algo', type=str, default='dt') 23 | parser.add_argument('--env', type=str, default='BabyAI-BossLevel-v0') 24 | args = parser.parse_args() 25 | 26 | with open('config/default.yaml'.format(args.algo), 'r') as f: 27 | config = yaml.safe_load(f) 28 | with open('config/env/{}.yaml'.format(args.env), 'r') as f: 29 | config.update(yaml.safe_load(f)) 30 | with open('config/algo/{}.yaml'.format(args.algo), 'r') as f: 31 | config.update(yaml.safe_load(f)) 32 | 33 | if config['log_to_tensorboard']: 34 | path = './log/{}/{}/'.format(args.algo, args.env) 35 | os.makedirs(path, exist_ok=True) 36 | list_files = os.listdir(path) 37 | list_files = [int(x) for x in list_files] 38 | file_name = 0 if len(list_files) == 0 else max(list_files) + 1 39 | final_path = path+'{}'.format(file_name) 40 | writer = SummaryWriter(final_path) 41 | with open(final_path+'/config.txt', 'w') as f: 42 | yaml.dump(config, f) 43 | f.close() 44 | else: 45 | writer = None 46 | 47 | dataset = SequenceDataset(config) 48 | if config['tit']: 49 | model = TIT_DecisionTransformer(config).to(config['device']) 50 | else: 51 | model = DecisionTransformer(config).to(config['device']) 52 | 53 | evaluation = Evaluation(config, vocab=dataset.vocab) 54 | 55 | warmup_steps = config['warmup_steps'] 56 | optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) 57 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps+1)/warmup_steps, 1)) 58 | trainer = Trainer( 59 | model=model, 60 | optimizer=optimizer, 61 | batch_size=config['batch_size'], 62 | dataset=dataset, 63 | scheduler=scheduler, 64 | config=config, 65 | eval_fns=[evaluation.eval_fn(tar) for tar in config['env_targets']], 66 | writer=writer 67 | ) 68 | 69 | for iter in range(config['max_iters']): 70 | outputs = trainer.train_iteration(num_steps=config['num_steps_per_iter'], iter_num=iter+1, print_logs=True) 71 | if config['log_to_tensorboard']: 72 | for k, v in outputs.items(): 73 | writer.add_scalar(k, v, iter) 74 | 75 | if config['save_model']: 76 | save_path = './model' 77 | os.makedirs(save_path, exist_ok=True) 78 | torch.save(model, save_path+'/{}_{}_{}.pkl'.format(args.algo, args.env, np.random.randint(10000))) -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/trainner.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/training/seq_trainer.py 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | # import wandb 10 | 11 | import time 12 | from torch.utils.data import DataLoader, WeightedRandomSampler 13 | from tqdm import tqdm 14 | 15 | 16 | class Trainer: 17 | def __init__(self, model, optimizer, batch_size, dataset, writer, config, scheduler=None, eval_fns=None): 18 | self.model = model 19 | self.optimizer = optimizer 20 | self.batch_size = batch_size 21 | self.dataset = dataset 22 | self.scheduler = scheduler 23 | self.eval_fns = [] if eval_fns is None else eval_fns 24 | self.diagnostics = dict() 25 | self.writer = writer 26 | self.model_type = config['model_type'] 27 | self.config = config 28 | 29 | self.train_count = 0 30 | 31 | self.start_time = time.time() 32 | 33 | def train_iteration(self, num_steps, iter_num=0, print_logs=False): 34 | train_losses = [] 35 | logs = dict() 36 | 37 | train_start = time.time() 38 | sampler = WeightedRandomSampler(self.dataset.p_sample, num_samples=num_steps*self.batch_size, replacement=True) 39 | dataloader = DataLoader(self.dataset, sampler=sampler, batch_size=self.batch_size) 40 | 41 | self.model.train() 42 | for images, missions, mission_masks, actions, rewards, dones, rtg, timesteps, attention_mask in tqdm(dataloader): 43 | train_loss = self.train_step(images, missions, mission_masks, actions, rewards, dones, rtg, timesteps, attention_mask) 44 | train_losses.append(train_loss) 45 | if self.writer is not None: 46 | self.writer.add_scalar('train_loss', train_loss, self.train_count) 47 | self.train_count += 1 48 | if self.scheduler is not None: 49 | self.scheduler.step() 50 | 51 | logs['time/training'] = time.time() - train_start 52 | 53 | eval_start = time.time() 54 | 55 | self.model.eval() 56 | for eval_fn in self.eval_fns: 57 | outputs = eval_fn(self.model) 58 | for k, v in outputs.items(): 59 | logs[f'evaluation/{k}'] = v 60 | 61 | logs['time/total'] = time.time() - self.start_time 62 | logs['time/evaluation'] = time.time() - eval_start 63 | logs['training/train_loss_mean'] = np.mean(train_losses) 64 | logs['training/train_loss_std'] = np.std(train_losses) 65 | 66 | for k in self.diagnostics: 67 | logs[k] = self.diagnostics[k] 68 | 69 | if print_logs: 70 | print('=' * 80) 71 | print(f'Iteration {iter_num}') 72 | for k, v in logs.items(): 73 | print(f'{k}: {v}') 74 | 75 | return logs 76 | 77 | def train_step(self, images, missions, mission_masks, actions, rewards, dones, rtg, timesteps, attention_mask): 78 | rewards_target, action_target, rtg_target = torch.clone(rewards), torch.clone(actions), torch.clone(rtg) 79 | 80 | state_preds, action_preds, return_preds, reward_preds = self.model.forward( 81 | images, missions, mission_masks, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask, 82 | ) 83 | 84 | act_dim = action_preds.shape[2] 85 | action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 86 | action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 87 | 88 | if self.model_type in ['dt', 'bc']: 89 | loss = F.cross_entropy(action_preds, action_target.max(-1)[1]) 90 | elif self.model_type in ['mgdt']: 91 | if self.config['sample_return'] == True: 92 | eps = torch.randn_like(return_preds[1]) 93 | return_preds_tmp = return_preds[0] + eps * torch.exp(0.5 * return_preds[1]) 94 | return_preds = return_preds_tmp 95 | return_preds = return_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0] 96 | return_target = rtg_target[:,:-1].reshape(-1, 1)[attention_mask.reshape(-1) > 0] 97 | reward_preds = reward_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0] 98 | reward_target = rewards_target.reshape(-1, 1)[attention_mask.reshape(-1) > 0] 99 | loss = F.cross_entropy(action_preds, action_target.max(-1)[1]) \ 100 | + torch.mean((return_preds - return_target) ** 2) \ 101 | + torch.mean((reward_preds - reward_target) ** 2) 102 | 103 | self.optimizer.zero_grad() 104 | loss.backward() 105 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25) 106 | self.optimizer.step() 107 | 108 | with torch.no_grad(): 109 | self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item() 110 | 111 | return loss.detach().cpu().item() 112 | 113 | -------------------------------------------------------------------------------- /RL_Foundation_BabyAI_including_DT_GATO_and_TIT/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py 3 | """ 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | import numpy as np 9 | import pickle 10 | import random 11 | 12 | import os 13 | import blosc 14 | import gym 15 | import babyai 16 | import re 17 | 18 | 19 | class SequenceDataset(Dataset): 20 | def __init__(self, config): 21 | super(SequenceDataset, self).__init__() 22 | self.device = config.get('device', 'cuda') 23 | self.env_name = config['env_name'] 24 | 25 | self.env = gym.make(self.env_name[0], disable_env_checker=True) 26 | 27 | self.image_dim = self.env.observation_space['image'].shape 28 | self.act_dim = self.env.action_space.n 29 | 30 | dataset_path = config['dataset_path'] 31 | self.demos = load_demos(dataset_path, config['data_name'], config['step_num']) 32 | 33 | self.max_ep_len = config['max_ep_len'] 34 | self.vocab = Vocabulary(config['text_max_size']) 35 | # save all path information into separate lists 36 | self.images, self.traj_lens, self.returns = [], [], [] 37 | for demo in self.demos: 38 | self.images.append(demo['image']) 39 | self.traj_lens.append(len(demo['action'])) 40 | demo['reward'][-1] = max(0, 1 - 0.9 * (len(demo['reward']) / self.max_ep_len)) 41 | self.returns.append(sum(demo['reward'])) 42 | self.mission, self.mission_mask = preprocess_texts([demo['mission'] for demo in self.demos], self.vocab) 43 | self.traj_lens, self.returns = np.array(self.traj_lens), np.array(self.returns) 44 | 45 | # used for input normalization 46 | self.images = np.concatenate(self.images, axis=0) 47 | 48 | self.K = config['K'] 49 | self.pct_traj = config.get('pct_traj', 1.) 50 | 51 | # only train on top pct_traj trajectories (for %BC experiment) 52 | num_timesteps = sum(self.traj_lens) 53 | num_timesteps = max(int(self.pct_traj * num_timesteps), 1) 54 | sorted_inds = np.argsort(self.returns) # lowest to highest 55 | num_trajectories = 1 56 | timesteps = self.traj_lens[sorted_inds[-1]] 57 | ind = len(self.demos) - 2 58 | while ind >= 0 and timesteps + self.traj_lens[sorted_inds[ind]] <= num_timesteps: 59 | timesteps += self.traj_lens[sorted_inds[ind]] 60 | num_trajectories += 1 61 | ind -= 1 62 | self.sorted_inds = sorted_inds[-num_trajectories:] 63 | 64 | # used to reweight sampling so we sample according to timesteps instead of trajectories 65 | self.p_sample = self.traj_lens[self.sorted_inds] / sum(self.traj_lens[self.sorted_inds]) 66 | 67 | def __getitem__(self, index): 68 | traj = self.demos[int(self.sorted_inds[index])] 69 | start_t = random.randint(0, len(traj['action']) - 1) 70 | 71 | s = traj['image'][start_t: start_t + self.K] 72 | a = traj['action'][start_t: start_t + self.K] 73 | a = np.eye(self.act_dim)[a] 74 | r = traj['reward'][start_t: start_t + self.K] 75 | m = self.mission[int(self.sorted_inds[index])] 76 | m_mask = self.mission_mask[int(self.sorted_inds[index])] 77 | if 'terminal' in traj: 78 | d = traj['terminal'][start_t: start_t + self.K] 79 | else: 80 | d = traj['done'][start_t: start_t + self.K] 81 | timesteps = np.arange(start_t, start_t + s.shape[0]) 82 | timesteps[timesteps >= self.max_ep_len] = self.max_ep_len - 1 # padding cutoff 83 | rtg = self.discount_cumsum(traj['reward'][start_t:], gamma=1.)[:s.shape[0] + 1].reshape(-1, 1) 84 | if rtg.shape[0] <= s.shape[0]: 85 | rtg = np.concatenate([rtg, np.zeros((1, 1))], axis=0) 86 | 87 | # padding and state + reward + rtg normalization 88 | tlen = s.shape[0] 89 | s = np.concatenate([np.zeros((self.K - tlen, *self.image_dim)), s], axis=0) 90 | m = np.expand_dims(m, 0).repeat(self.K, axis=0) 91 | m_mask = np.expand_dims(m_mask, 0).repeat(self.K, axis=0) 92 | a = np.concatenate([np.zeros((self.K - tlen, self.act_dim)), a], axis=0) # how to pad action 93 | r = np.concatenate([np.zeros((self.K - tlen, 1)), np.array(r).reshape(-1, 1)], axis=0) 94 | d = np.concatenate([np.ones((self.K - tlen)) * 2, d], axis=0) 95 | rtg = np.concatenate([np.zeros((self.K - tlen, 1)), rtg], axis=0) 96 | timesteps = np.concatenate([np.zeros((self.K - tlen)), timesteps], axis=0) 97 | mask = np.concatenate([np.zeros((self.K - tlen)), np.ones((tlen))], axis=0) 98 | 99 | s = torch.from_numpy(s).to(dtype=torch.float32, device=self.device) 100 | m = torch.from_numpy(m).to(dtype=torch.long, device=self.device) 101 | m_mask = torch.from_numpy(m_mask).to(dtype=torch.float32, device=self.device) 102 | a = torch.from_numpy(a).to(dtype=torch.float32, device=self.device) 103 | r = torch.from_numpy(r).to(dtype=torch.float32, device=self.device) 104 | d = torch.from_numpy(d).to(dtype=torch.long, device=self.device) 105 | rtg = torch.from_numpy(rtg).to(dtype=torch.float32, device=self.device) 106 | timesteps = torch.from_numpy(timesteps).to(dtype=torch.long, device=self.device) 107 | mask = torch.from_numpy(mask).to(device=self.device) 108 | return s, m, m_mask, a, r, d, rtg, timesteps, mask 109 | 110 | def discount_cumsum(self, x, gamma=1.): 111 | discount_cumsum = np.zeros_like(x) 112 | discount_cumsum[-1] = x[-1] 113 | for t in reversed(range(len(x)-1)): 114 | discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] 115 | return discount_cumsum 116 | 117 | 118 | 119 | def load_pickle(path, raise_not_found=True): 120 | try: 121 | return pickle.load(open(path, "rb")) 122 | except FileNotFoundError: 123 | if raise_not_found: 124 | raise FileNotFoundError("No demos found at {}".format(path)) 125 | else: 126 | return [] 127 | 128 | 129 | def transform_demos(demos): 130 | new_demos = [] 131 | for demo in demos: 132 | new_demo = {} 133 | mission = demo[0] 134 | all_images = demo[1] 135 | directions = demo[2] 136 | actions = demo[3] 137 | all_images = blosc.unpack_array(all_images) 138 | n_observations = all_images.shape[0] 139 | assert len(directions) == len(actions) == n_observations, "error transforming demos" 140 | new_demo['image'] = all_images 141 | new_demo['direction'] = directions 142 | new_demo['mission'] = mission 143 | new_demo['action'] = [a.value for a in actions] 144 | new_demo['done'] = [i == n_observations - 1 for i in range(n_observations)] 145 | new_demo['reward'] = [1 if i == n_observations - 1 else 0 for i in range(n_observations)] 146 | new_demos.append(new_demo) 147 | return new_demos 148 | 149 | 150 | def load_demos(path, files, step_num=None): 151 | all_demos = [] 152 | for f in files: 153 | demos = load_pickle(path+'/'+f+'.pkl') 154 | demos = transform_demos(demos) 155 | demos_step_num = [len(d['action']) for d in demos] 156 | print('{} has {} time steps'.format(f, sum(demos_step_num))) 157 | if step_num is not None: 158 | n = 0 159 | for i, l in enumerate(demos_step_num): 160 | n += l 161 | if n >= step_num: 162 | break 163 | all_demos.extend(demos[:i+1]) 164 | else: 165 | all_demos.extend(demos) 166 | return all_demos 167 | 168 | 169 | class Vocabulary: 170 | def __init__(self, max_size): 171 | self.max_size = max_size 172 | self.vocab = {} 173 | 174 | def load_vocab(self, vocab): 175 | self.vocab = vocab 176 | 177 | def __getitem__(self, token): 178 | if not token in self.vocab.keys(): 179 | if len(self.vocab) >= self.max_size: 180 | raise ValueError("Maximum vocabulary capacity reached") 181 | self.vocab[token] = len(self.vocab) + 1 182 | return self.vocab[token] 183 | 184 | 185 | def preprocess_texts(texts, vocab): 186 | var_indexed_texts = [] 187 | max_text_len = 0 188 | 189 | for text in texts: 190 | tokens = re.findall("([a-z]+)", text.lower()) 191 | var_indexed_text = np.array([vocab[token] for token in tokens]) 192 | var_indexed_texts.append(var_indexed_text) 193 | max_text_len = max(len(var_indexed_text), max_text_len) 194 | 195 | indexed_texts = np.zeros((len(texts), max_text_len)) 196 | texts_mask = np.zeros((len(texts), max_text_len)) 197 | 198 | for i, indexed_text in enumerate(var_indexed_texts): 199 | indexed_texts[i, :len(indexed_text)] = indexed_text 200 | texts_mask[i, :len(indexed_text)] = 1 201 | 202 | return indexed_texts, texts_mask 203 | -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/algo/bc.yaml: -------------------------------------------------------------------------------- 1 | # behavior clone 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | action_tanh: True 14 | model_type: "bc" 15 | length_times: 2 -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/algo/dt.yaml: -------------------------------------------------------------------------------- 1 | # decision transformer 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | action_tanh: True 14 | model_type: "dt" 15 | length_times: 3 16 | -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/algo/mgdt.yaml: -------------------------------------------------------------------------------- 1 | # multi-game decision transformer 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | action_tanh: True 14 | model_type: "mgdt" 15 | length_times: 4 16 | sample_return: True 17 | num_sample_return: 256 -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/algo/mgdt_wo_sample.yaml: -------------------------------------------------------------------------------- 1 | # multi-game decision transformer 2 | 3 | n_ctx: 1024 4 | n_embd: 128 5 | hidden_size: 128 6 | n_layer: 3 7 | n_head: 1 8 | n_inner: 512 9 | activation_function: "relu" 10 | n_position: 1024 11 | resid_pdrop: 0.1 12 | attn_pdrop: 0.1 13 | action_tanh: True 14 | model_type: "mgdt" 15 | length_times: 4 16 | sample_return: False 17 | num_sample_return: 256 -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/default.yaml: -------------------------------------------------------------------------------- 1 | device: 'cuda' 2 | log_to_tensorboard: True 3 | num_steps_per_iter: 10000 4 | max_iters: 10 5 | num_eval_episodes: 100 6 | warmup_steps: 10000 7 | weight_decay: 0.0001 8 | learning_rate: 0.0001 9 | save_model: True 10 | -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/antmaze_large_diverse.yaml: -------------------------------------------------------------------------------- 1 | data_name: "antmaze-large-diverse-v0" 2 | env_name: "antmaze-large-play-v0" 3 | max_ep_len: 999 4 | env_targets: [2, 1] 5 | scale: 1.0 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [0., 1.] 11 | reward_scale: [0., 1.] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/antmaze_medium.yaml: -------------------------------------------------------------------------------- 1 | data_name: "antmaze-medium-diverse-v0" 2 | env_name: "antmaze-medium-play-v0" 3 | max_ep_len: 999 4 | env_targets: [2, 1] 5 | scale: 1.0 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/antmaze_medium_diverse.yaml: -------------------------------------------------------------------------------- 1 | data_name: "antmaze-medium-diverse-v0" 2 | env_name: "antmaze-medium-play-v0" 3 | max_ep_len: 999 4 | env_targets: [2, 1] 5 | scale: 1.0 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [0., 1.] 11 | reward_scale: [0., 1.] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/antmaze_umaze.yaml: -------------------------------------------------------------------------------- 1 | data_name: "antmaze-umaze-diverse-v0" 2 | env_name: "antmaze-umaze-v0" 3 | max_ep_len: 999 4 | env_targets: [2, 1] 5 | scale: 1.0 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/antmaze_umaze_diverse.yaml: -------------------------------------------------------------------------------- 1 | data_name: "antmaze-umaze-diverse-v0" 2 | env_name: "antmaze-umaze-v0" 3 | max_ep_len: 999 4 | env_targets: [2, 1] 5 | scale: 1.0 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [0., 1.] 11 | reward_scale: [0., 1.] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/door_cloned.yaml: -------------------------------------------------------------------------------- 1 | data_name: "door-cloned-v0" 2 | env_name: "door-v0" 3 | max_ep_len: 200 4 | env_targets: [1000, 500] 5 | scale: 500 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-60.68874309829196, 1136.837553024292] 11 | reward_scale: [-0.34398257147165034, 19.992286682128906] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/halfcheetah_medium.yaml: -------------------------------------------------------------------------------- 1 | data_name: "halfcheetah-medium-v2" 2 | env_name: "HalfCheetah-v2" 3 | max_ep_len: 1000 4 | env_targets: [12000, 6000] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-310.23419189453125, 5309.37939453125] 11 | reward_scale: [-2.8353304862976074, 8.32674503326416] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/halfcheetah_medium_expert.yaml: -------------------------------------------------------------------------------- 1 | data_name: "halfcheetah-medium-expert-v2" 2 | env_name: "HalfCheetah-v2" 3 | max_ep_len: 1000 4 | env_targets: [12000, 6000] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-310.23419189453125, 11252.03515625] 11 | reward_scale: [-3.0135135650634766, 13.854623794555664] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/halfcheetah_medium_replay.yaml: -------------------------------------------------------------------------------- 1 | data_name: "halfcheetah-medium-replay-v2" 2 | env_name: "HalfCheetah-v2" 3 | max_ep_len: 1000 4 | env_targets: [12000, 6000] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-638.4852905273438, 4985.1416015625] 11 | reward_scale: [-3.298140287399292, 7.619414806365967] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/hammer_cloned.yaml: -------------------------------------------------------------------------------- 1 | data_name: "hammer-cloned-v0" 2 | env_name: "hammer-v0" 3 | max_ep_len: 200 4 | env_targets: [10000, 5000] 5 | scale: 2000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-260.24013566970825, 9570.416120767593] 11 | reward_scale: [-1.927051305770874, 101.8631591796875] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/hopper_medium.yaml: -------------------------------------------------------------------------------- 1 | data_name: "hopper-medium-v2" 2 | env_name: "Hopper-v2" 3 | max_ep_len: 1000 4 | env_targets: [3600, 1800] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [315.8680114746094, 3222.360595703125] 11 | reward_scale: [0.548995316028595, 5.944143295288086] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/hopper_medium_expert.yaml: -------------------------------------------------------------------------------- 1 | data_name: "hopper-medium-expert-v2" 2 | env_name: "Hopper-v2" 3 | max_ep_len: 1000 4 | env_targets: [3600, 1800] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [315.8680114746094, 3759.083740234375] 11 | reward_scale: [0.548995316028595, 6.628322124481201] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/hopper_medium_replay.yaml: -------------------------------------------------------------------------------- 1 | data_name: "hopper-medium-replay-v2" 2 | env_name: "Hopper-v2" 3 | max_ep_len: 1000 4 | env_targets: [3600, 1800] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-1.4400691986083984, 3192.925048828125] 11 | reward_scale: [-1.2545479536056519, 6.385763168334961] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/pen_cloned.yaml: -------------------------------------------------------------------------------- 1 | data_name: "pen-cloned-v0" 2 | env_name: "pen-v0" 3 | max_ep_len: 100 4 | env_targets: [6000, 3000] 5 | scale: 2000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-98.1524640354733, 6096.043075561523] 11 | reward_scale: [-6.293779002433353, 60.98057630916589] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/relocate_cloned.yaml: -------------------------------------------------------------------------------- 1 | data_name: "relocate-cloned-v0" 2 | env_name: "relocate-v0" 3 | max_ep_len: 200 4 | env_targets: [5000, 2500] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-21.950507557005587, 4713.180918633938] 11 | reward_scale: [-0.13278540455051674, 30.991315841674805] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/walker2d_medium.yaml: -------------------------------------------------------------------------------- 1 | data_name: "walker2d-medium-v2" 2 | env_name: "Walker2d-v2" 3 | max_ep_len: 1000 4 | env_targets: [5000, 2500] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-6.6056718826293945, 4226.93994140625] 11 | reward_scale: [-2.557255268096924, 8.469034194946289] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/walker2d_medium_expert.yaml: -------------------------------------------------------------------------------- 1 | data_name: "walker2d-medium-expert-v2" 2 | env_name: "Walker2d-v2" 3 | max_ep_len: 1000 4 | env_targets: [5000, 2500] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-6.6056718826293945, 5011.693359375] 11 | reward_scale: [-2.557255268096924, 8.469034194946289] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/config/env/walker2d_medium_replay.yaml: -------------------------------------------------------------------------------- 1 | data_name: "walker2d-medium-replay-v2" 2 | env_name: "Walker2d-v2" 3 | max_ep_len: 1000 4 | env_targets: [5000, 2500] 5 | scale: 1000 6 | delayed_reward: False 7 | pct_traj: 1.0 8 | K: 20 9 | batch_size: 64 10 | return_scale: [-50.196834564208984, 4132.00048828125] 11 | reward_scale: [-4.713055610656738, 8.553143501281738] -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/download_d4rl_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on https://github.com/kzl/decision-transformer/blob/master/gym/data/download_d4rl_datasets.py 3 | """ 4 | 5 | import gym 6 | import numpy as np 7 | import os 8 | 9 | import collections 10 | import pickle 11 | 12 | import d4rl # Import required to register environments, you may need to also import the submodule 13 | 14 | os.makedirs('./data', exist_ok=True) 15 | 16 | dataset_name = [ 17 | 'hopper-medium-v2', 18 | 'hopper-medium-replay-v2', 19 | 'hopper-medium-expert-v2', 20 | 'halfcheetah-medium-v2', 21 | 'halfcheetah-medium-replay-v2', 22 | 'halfcheetah-medium-expert-v2', 23 | 'walker2d-medium-v2', 24 | 'walker2d-medium-replay-v2', 25 | 'walker2d-medium-expert-v2', 26 | 'pen-cloned-v0', 27 | 'door-cloned-v0', 28 | 'relocate-cloned-v0', 29 | 'hammer-cloned-v0', 30 | 'antmaze-umaze-diverse-v0', 31 | 'antmaze-medium-diverse-v0', 32 | 'antmaze-large-diverse-v0', 33 | ] 34 | 35 | for env_name in dataset_name: 36 | env = gym.make(env_name) 37 | dataset = env.get_dataset() 38 | 39 | N = dataset['rewards'].shape[0] 40 | max_reward = max(dataset['rewards']) 41 | min_reward = min(dataset['rewards']) 42 | data_ = collections.defaultdict(list) 43 | 44 | use_timeouts = False 45 | if 'timeouts' in dataset: 46 | use_timeouts = True 47 | 48 | max_ep_len = 0 49 | 50 | episode_step = 0 51 | paths = [] 52 | for i in range(N): 53 | done_bool = bool(dataset['terminals'][i]) 54 | if use_timeouts: 55 | final_timestep = dataset['timeouts'][i] 56 | else: 57 | final_timestep = (episode_step == 1000 - 1) 58 | for k in ['observations', 'actions', 'rewards', 'terminals']: 59 | data_[k].append(dataset[k][i]) 60 | 61 | if done_bool or final_timestep: 62 | max_ep_len = max(episode_step, max_ep_len) 63 | episode_step = 0 64 | episode_data = {} 65 | for k in data_: 66 | episode_data[k] = np.array(data_[k]) 67 | paths.append(episode_data) 68 | data_ = collections.defaultdict(list) 69 | episode_step += 1 70 | 71 | returns = np.array([np.sum(p['rewards']) for p in paths]) 72 | num_samples = np.sum([p['rewards'].shape[0] for p in paths]) 73 | print('env: {}'.format(env_name)) 74 | print(f'Number of samples collected: {num_samples}') 75 | print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}, max_ep_len = {int(max_ep_len)}') 76 | print(f'Trajectory rewards: min = {min_reward}, max = {max_reward}') 77 | 78 | with open(f'./data/{env_name}.pkl', 'wb') as f: 79 | pickle.dump(paths, f) 80 | 81 | -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on 3 | https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py#L166 4 | https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/evaluation/evaluate_episodes.py 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import gym 10 | 11 | 12 | class Evaluation(object): 13 | def __init__(self, config, state_mean, state_std): 14 | self.config = config 15 | self.state_mean = state_mean 16 | self.state_std = state_std 17 | self.device = config.get('device', 'cuda') 18 | self.env_name = config['env_name'] 19 | self.max_ep_len = config['max_ep_len'] 20 | self.scale = config['scale'] 21 | self.num_eval_episodes = config['num_eval_episodes'] 22 | self.model_type = config['model_type'] 23 | 24 | self.env = gym.make(self.env_name) 25 | self.state_dim = self.env.observation_space.shape[0] 26 | self.act_dim = self.env.action_space.shape[0] 27 | 28 | self.is_delayed_reward = config['delayed_reward'] 29 | 30 | def evaluate_episode(self, model, target_return=None): 31 | model.eval() 32 | model.to(device=self.device) 33 | 34 | state_mean = torch.from_numpy(self.state_mean).to(device=self.device) 35 | state_std = torch.from_numpy(self.state_std).to(device=self.device) 36 | 37 | state = self.env.reset() 38 | 39 | # we keep all the histories on the device 40 | # note that the latest action and reward will be "padding" 41 | states = torch.from_numpy(state).reshape(1, self.state_dim).to(device=self.device, dtype=torch.float32) 42 | actions = torch.zeros((0, self.act_dim), device=self.device, dtype=torch.float32) 43 | rewards = torch.zeros(0, device=self.device, dtype=torch.float32) 44 | returns = torch.zeros(0, device=self.device, dtype=torch.float32) 45 | timesteps = torch.tensor(0, device=self.device, dtype=torch.long).reshape(1, 1) 46 | 47 | episode_return, episode_length = 0, 0 48 | for t in range(self.max_ep_len): 49 | # add padding 50 | actions = torch.cat([actions, torch.zeros((1, self.act_dim), device=self.device)], dim=0) 51 | rewards = torch.cat([rewards, torch.zeros(1, device=self.device)]) 52 | returns = torch.cat([returns, torch.zeros(1, device=self.device)]) 53 | 54 | if self.config['model_type'] in ['bc']: 55 | action = model.get_action( 56 | (states.to(dtype=torch.float32) - state_mean) / state_std, 57 | actions.to(dtype=torch.float32), 58 | rewards.to(dtype=torch.float32), 59 | returns.to(dtype=torch.float32), 60 | timesteps.to(dtype=torch.long), 61 | ) 62 | elif self.config['model_type'] in ['mgdt']: 63 | _, ret = model.get_action( 64 | (states.to(dtype=torch.float32) - state_mean) / state_std, 65 | actions.to(dtype=torch.float32), 66 | rewards.to(dtype=torch.float32), 67 | returns.to(dtype=torch.float32), 68 | timesteps.to(dtype=torch.long), 69 | ) 70 | if self.config['sample_return'] == True: 71 | eps = torch.randn(self.config['num_sample_return'], 1).to(ret[1].device) 72 | ret_tmp = ret[0] + eps * torch.exp(0.5 * ret[1]) 73 | ret = ret_tmp.max(0)[0] 74 | returns[-1] = ret 75 | action, _ = model.get_action( 76 | (states.to(dtype=torch.float32) - state_mean) / state_std, 77 | actions.to(dtype=torch.float32), 78 | rewards.to(dtype=torch.float32), 79 | returns.to(dtype=torch.float32), 80 | timesteps.to(dtype=torch.long), 81 | ) 82 | actions[-1] = action 83 | action = action.detach().cpu().numpy() 84 | 85 | state, reward, done, _ = self.env.step(action) 86 | 87 | cur_state = torch.from_numpy(state).to(device=self.device).reshape(1, self.state_dim) 88 | states = torch.cat([states, cur_state], dim=0) 89 | if self.config['model_type'] in ['mgdt']: # only MGDT actually uses the reward, so we should normalize it 90 | rewards[-1] = -1 + 2 * (reward - self.config['reward_scale'][0]) / (self.config['reward_scale'][1] - self.config['reward_scale'][0]) 91 | elif self.config['model_type'] in ['bc']: 92 | rewards[-1] = reward 93 | timesteps = torch.cat([timesteps, torch.ones((1, 1), device=self.device, dtype=torch.long) * (t + 1)], dim=1) 94 | 95 | episode_return += reward 96 | episode_length += 1 97 | 98 | if done: 99 | break 100 | 101 | return episode_return, episode_length 102 | 103 | def evaluate_episode_rtg(self, model, target_return=None): 104 | model.eval() 105 | model.to(device=self.device) 106 | 107 | state_mean = torch.from_numpy(self.state_mean).to(device=self.device) 108 | state_std = torch.from_numpy(self.state_std).to(device=self.device) 109 | 110 | state = self.env.reset() 111 | 112 | # we keep all the histories on the device 113 | # note that the latest action and reward will be "padding" 114 | states = torch.from_numpy(state).reshape(1, self.state_dim).to(device=self.device, dtype=torch.float32) 115 | actions = torch.zeros((0, self.act_dim), device=self.device, dtype=torch.float32) 116 | rewards = torch.zeros(0, device=self.device, dtype=torch.float32) 117 | 118 | ep_return = target_return 119 | target_return = torch.tensor(ep_return, device=self.device, dtype=torch.float32).reshape(1, 1) 120 | timesteps = torch.tensor(0, device=self.device, dtype=torch.long).reshape(1, 1) 121 | 122 | episode_return, episode_length = 0, 0 123 | for t in range(self.max_ep_len): 124 | 125 | # add padding 126 | actions = torch.cat([actions, torch.zeros((1, self.act_dim), device=self.device)], dim=0) 127 | rewards = torch.cat([rewards, torch.zeros(1, device=self.device)]) 128 | 129 | action = model.get_action( 130 | (states.to(dtype=torch.float32) - state_mean) / state_std, 131 | actions.to(dtype=torch.float32), 132 | rewards.to(dtype=torch.float32), 133 | target_return.to(dtype=torch.float32), 134 | timesteps.to(dtype=torch.long), 135 | ) 136 | actions[-1] = action 137 | action = action.detach().cpu().numpy() 138 | 139 | state, reward, done, _ = self.env.step(action) 140 | 141 | cur_state = torch.from_numpy(state).to(device=self.device).reshape(1, self.state_dim) 142 | states = torch.cat([states, cur_state], dim=0) 143 | rewards[-1] = reward 144 | 145 | if self.is_delayed_reward != 'delayed': 146 | pred_return = target_return[0,-1] - (reward / self.scale) 147 | else: 148 | pred_return = target_return[0,-1] 149 | target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1) 150 | timesteps = torch.cat([timesteps, torch.ones((1, 1), device=self.device, dtype=torch.long) * (t + 1)], dim=1) 151 | 152 | episode_return += reward 153 | episode_length += 1 154 | 155 | if done: 156 | break 157 | 158 | return episode_return, episode_length 159 | 160 | def eval_fn(self, target_rew): 161 | def fn(model): 162 | returns, lengths = [], [] 163 | for _ in range(self.num_eval_episodes): 164 | with torch.no_grad(): 165 | if self.model_type in ['dt']: 166 | ret, length = self.evaluate_episode_rtg(model, target_return=target_rew/self.scale) 167 | else: 168 | ret, length = self.evaluate_episode(model, target_return=target_rew/self.scale) 169 | returns.append(ret) 170 | lengths.append(length) 171 | return { 172 | f'target_{target_rew}_return_mean': np.mean(returns), 173 | f'target_{target_rew}_return_std': np.std(returns), 174 | f'target_{target_rew}_length_mean': np.mean(lengths), 175 | f'target_{target_rew}_length_std': np.std(lengths), 176 | } 177 | return fn 178 | 179 | -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py#L208 3 | """ 4 | 5 | # import wandb 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | 9 | import argparse 10 | import yaml 11 | import os 12 | 13 | from network import DecisionTransformer 14 | from trainner import Trainer 15 | from evaluation import Evaluation 16 | from utils import SequenceDataset 17 | import numpy as np 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--algo', type=str, default='dt') 23 | parser.add_argument('--env', type=str, default='pen_cloned') 24 | args = parser.parse_args() 25 | 26 | with open('config/default.yaml'.format(args.algo), 'r') as f: 27 | config = yaml.safe_load(f) 28 | with open('config/env/{}.yaml'.format(args.env), 'r') as f: 29 | config.update(yaml.safe_load(f)) 30 | with open('config/algo/{}.yaml'.format(args.algo), 'r') as f: 31 | config.update(yaml.safe_load(f)) 32 | 33 | if config['log_to_tensorboard']: 34 | path = './log/{}/{}/'.format(args.algo, args.env) 35 | os.makedirs(path, exist_ok=True) 36 | list_files = os.listdir(path) 37 | list_files = [int(x) for x in list_files] 38 | file_name = 0 if len(list_files) == 0 else max(list_files) + 1 39 | final_path = path+'{}'.format(file_name) 40 | writer = SummaryWriter(final_path) 41 | with open(final_path+'/config.txt', 'w') as f: 42 | yaml.dump(config, f) 43 | f.close() 44 | else: 45 | writer = None 46 | 47 | dataset = SequenceDataset(config) 48 | model = DecisionTransformer(config).to(config['device']) 49 | 50 | evaluation = Evaluation(config, state_mean=dataset.state_mean, state_std=dataset.state_std) 51 | 52 | warmup_steps = config['warmup_steps'] 53 | optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) 54 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps+1)/warmup_steps, 1)) 55 | trainer = Trainer( 56 | model=model, 57 | optimizer=optimizer, 58 | batch_size=config['batch_size'], 59 | dataset=dataset, 60 | scheduler=scheduler, 61 | config=config, 62 | eval_fns=[evaluation.eval_fn(tar) for tar in config['env_targets']], 63 | writer=writer 64 | ) 65 | 66 | for iter in range(config['max_iters']): 67 | outputs = trainer.train_iteration(num_steps=config['num_steps_per_iter'], iter_num=iter+1, print_logs=True) 68 | if config['log_to_tensorboard']: 69 | for k, v in outputs.items(): 70 | writer.add_scalar(k, v, iter) 71 | 72 | if config['save_model']: 73 | save_path = './model' 74 | os.makedirs(save_path, exist_ok=True) 75 | torch.save(model, save_path+'/{}_{}_{}.pkl'.format(args.algo, args.env, np.random.randint(10000))) 76 | 77 | -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on 3 | https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/model_atari.py 4 | https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/models/decision_transformer.py 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import math 12 | import gym 13 | 14 | 15 | class CausalSelfAttention(nn.Module): 16 | def __init__(self, config): 17 | super().__init__() 18 | assert config['n_embd'] % config['n_head'] == 0 19 | # key, query, value projections for all heads 20 | self.key = nn.Linear(config['n_embd'], config['n_embd']) 21 | self.query = nn.Linear(config['n_embd'], config['n_embd']) 22 | self.value = nn.Linear(config['n_embd'], config['n_embd']) 23 | # regularization 24 | self.attn_drop = nn.Dropout(config['attn_pdrop']) 25 | self.resid_drop = nn.Dropout(config['resid_pdrop']) 26 | 27 | # causal mask to ensure that attention is only applied to the left in the input sequence 28 | self.register_buffer("bias", torch.tril(torch.ones(config['n_ctx'], config['n_ctx'])).view(1, 1, config['n_ctx'], config['n_ctx'])) 29 | self.register_buffer("masked_bias", torch.tensor(-1e4)) 30 | 31 | # output projection 32 | self.proj = nn.Linear(config['n_embd'], config['n_embd']) 33 | self.n_head = config['n_head'] 34 | 35 | def forward(self, x, mask): 36 | B, T, C = x.size() 37 | 38 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 39 | ## [ B x n_heads x T x head_dim ] 40 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 41 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 42 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 43 | 44 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 45 | ## [ B x n_heads x T x T ] 46 | mask = mask.view(B, -1) 47 | mask = mask[:, None, None, :] 48 | mask = (1.0 - mask) * -10000.0 49 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 50 | att = torch.where(self.bias[:, :, :T, :T].bool(), att, self.masked_bias.to(att.dtype)) 51 | att = att + mask 52 | att = F.softmax(att, dim=-1) 53 | self._attn_map = att.clone() 54 | att = self.attn_drop(att) 55 | ## [ B x n_heads x T x head_size ] 56 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 57 | ## [ B x T x embedding_dim ] 58 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 59 | 60 | # output projection 61 | y = self.resid_drop(self.proj(y)) 62 | return y 63 | 64 | 65 | class Block(nn.Module): 66 | def __init__(self, config): 67 | super().__init__() 68 | self.ln1 = nn.LayerNorm(config['n_embd']) 69 | self.ln2 = nn.LayerNorm(config['n_embd']) 70 | self.attn = CausalSelfAttention(config) 71 | self.mlp = nn.Sequential( 72 | nn.Linear(config['n_embd'], config['n_inner']), 73 | nn.GELU(), 74 | nn.Linear(config['n_inner'], config['n_embd']), 75 | nn.Dropout(config['resid_pdrop']), 76 | ) 77 | 78 | def forward(self, inputs_embeds, attention_mask): 79 | x = inputs_embeds + self.attn(self.ln1(inputs_embeds), attention_mask) 80 | x = x + self.mlp(self.ln2(x)) 81 | return x 82 | 83 | 84 | class DecisionTransformer(nn.Module): 85 | """ 86 | This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...) 87 | """ 88 | def __init__(self, config, action_tanh=True, **kwargs): 89 | super(DecisionTransformer, self).__init__() 90 | 91 | self.config = config 92 | self.length_times = config['length_times'] 93 | self.hidden_size = config['hidden_size'] 94 | assert self.hidden_size == config['n_embd'] 95 | self.max_length = config['K'] 96 | self.max_ep_len = config['max_ep_len'] 97 | 98 | self.env = gym.make(config['env_name']) 99 | self.state_dim = self.env.observation_space.shape[0] 100 | self.act_dim = self.env.action_space.shape[0] 101 | 102 | # note: the only difference between this GPT2Model and the default Huggingface version 103 | # is that the positional embeddings are removed (since we'll add those ourselves) 104 | # self.transformer = GPT2Model(config) 105 | self.transformer = nn.ModuleList([Block(config) for _ in range(config['n_layer'])]) 106 | 107 | self.embed_timestep = nn.Embedding(self.max_ep_len, self.hidden_size) 108 | self.embed_return = torch.nn.Linear(1, self.hidden_size) 109 | self.embed_reward = torch.nn.Linear(1, self.hidden_size) 110 | self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) 111 | self.embed_action = torch.nn.Linear(self.act_dim, self.hidden_size) 112 | 113 | self.embed_ln = nn.LayerNorm(self.hidden_size) 114 | 115 | # note: we don't predict states or returns for the paper 116 | self.predict_state = torch.nn.Linear(self.hidden_size, self.state_dim) 117 | self.predict_action = nn.Sequential( 118 | *([nn.Linear(self.hidden_size, self.act_dim)] + ([nn.Tanh()] if config['action_tanh'] else [])) 119 | ) 120 | if self.config['model_type'] in ['mgdt']: 121 | if self.config['sample_return'] == False: 122 | self.predict_return = torch.nn.Linear(self.hidden_size, 1) 123 | else: 124 | self.predict_return_mu = torch.nn.Linear(self.hidden_size, 1) 125 | self.predict_return_sigma = torch.nn.Linear(self.hidden_size, 1) 126 | else: 127 | self.predict_return = torch.nn.Linear(self.hidden_size, 1) 128 | self.predict_reward = torch.nn.Linear(self.hidden_size, 1) 129 | 130 | def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): 131 | 132 | batch_size, seq_length = states.shape[0], states.shape[1] 133 | 134 | if attention_mask is None: 135 | # attention mask for GPT: 1 if can be attended to, 0 if not 136 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) 137 | 138 | # embed each modality with a different head 139 | state_embeddings = self.embed_state(states) 140 | action_embeddings = self.embed_action(actions) 141 | returns_embeddings = self.embed_return(returns_to_go) 142 | rewards_embeddings = self.embed_reward(rewards) 143 | time_embeddings = self.embed_timestep(timesteps) 144 | 145 | # time embeddings are treated similar to positional embeddings 146 | state_embeddings = state_embeddings + time_embeddings 147 | action_embeddings = action_embeddings + time_embeddings 148 | returns_embeddings = returns_embeddings + time_embeddings 149 | rewards_embeddings = rewards_embeddings + time_embeddings 150 | 151 | # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) 152 | # which works nice in an autoregressive sense since states predict actions 153 | if self.config['model_type'] in ['dt']: 154 | stacked_inputs = torch.stack( 155 | (returns_embeddings, state_embeddings, action_embeddings), dim=1 156 | ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size) 157 | stacked_inputs = self.embed_ln(stacked_inputs) 158 | 159 | elif self.config['model_type'] in ['bc']: 160 | stacked_inputs = torch.stack( 161 | (state_embeddings, action_embeddings), dim=1 162 | ).permute(0, 2, 1, 3).reshape(batch_size, 2*seq_length, self.hidden_size) 163 | stacked_inputs = self.embed_ln(stacked_inputs) 164 | 165 | elif self.config['model_type'] in ['mgdt']: 166 | stacked_inputs = torch.stack( 167 | (state_embeddings, returns_embeddings, action_embeddings, rewards_embeddings), dim=1 168 | ).permute(0, 2, 1, 3).reshape(batch_size, 4*seq_length, self.hidden_size) 169 | stacked_inputs = self.embed_ln(stacked_inputs) 170 | 171 | # to make the attention mask fit the stacked inputs, have to stack it as well 172 | stacked_attention_mask = torch.stack( 173 | ([attention_mask for _ in range(self.length_times)]), dim=1 174 | ).permute(0, 2, 1).reshape(batch_size, self.length_times*seq_length).to(stacked_inputs.dtype) 175 | 176 | # we feed in the input embeddings (not word indices as in NLP) to the model 177 | x = stacked_inputs 178 | for block in self.transformer: 179 | x = block(x, stacked_attention_mask) 180 | 181 | # reshape x so that the second dimension corresponds to the original 182 | # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t 183 | x = x.reshape(batch_size, seq_length, self.length_times, self.hidden_size).permute(0, 2, 1, 3) 184 | 185 | # get predictions 186 | if self.config['model_type'] in ['dt']: 187 | return_preds = self.predict_return(x[:,2]) # predict next return given state and action 188 | state_preds = self.predict_state(x[:,2]) # predict next state given state and action 189 | action_preds = self.predict_action(x[:,1]) # predict next action given state 190 | return state_preds, action_preds, return_preds, None 191 | elif self.config['model_type'] in ['bc']: 192 | action_preds = self.predict_action(x[:,0]) # predict next action given state 193 | return None, action_preds, None, None 194 | elif self.config['model_type'] in ['mgdt']: 195 | if self.config['sample_return'] == False: 196 | return_preds = self.predict_return(x[:,0]) # predict next return 197 | else: 198 | return_preds_mu = self.predict_return_mu(x[:,0]) 199 | return_preds_sigma = self.predict_return_sigma(x[:,0]) 200 | # eps = torch.randn_like(return_preds_sigma) 201 | # return_preds = return_preds_mu + eps * torch.exp(0.5 * return_preds_sigma) 202 | reward_preds = self.predict_reward(x[:,2]) # predict next rewards 203 | action_preds = self.predict_action(x[:,1]) # predict next action 204 | if self.config['sample_return'] == False: 205 | return None, action_preds, return_preds, reward_preds 206 | else: 207 | return None, action_preds, [return_preds_mu, return_preds_sigma], reward_preds 208 | 209 | def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs): 210 | # we don't care about the past rewards in this model 211 | states = states.reshape(1, -1, self.state_dim) 212 | actions = actions.reshape(1, -1, self.act_dim) 213 | returns_to_go = returns_to_go.reshape(1, -1, 1) 214 | rewards = rewards.reshape(1, -1, 1) 215 | timesteps = timesteps.reshape(1, -1) 216 | 217 | if self.max_length is not None: 218 | states = states[:,-self.max_length:] 219 | actions = actions[:,-self.max_length:] 220 | returns_to_go = returns_to_go[:,-self.max_length:] 221 | rewards = rewards[:,-self.max_length:] 222 | timesteps = timesteps[:,-self.max_length:] 223 | 224 | # pad all tokens to sequence length 225 | attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])]) 226 | attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) 227 | states = torch.cat( 228 | [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states], 229 | dim=1).to(dtype=torch.float32) 230 | actions = torch.cat( 231 | [torch.zeros((actions.shape[0], self.max_length-actions.shape[1], self.act_dim), device=actions.device), actions], 232 | dim=1).to(dtype=torch.float32) 233 | returns_to_go = torch.cat( 234 | [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go], 235 | dim=1).to(dtype=torch.float32) 236 | rewards = torch.cat( 237 | [torch.zeros((rewards.shape[0], self.max_length-rewards.shape[1], 1), device=rewards.device), rewards], 238 | dim=1).to(dtype=torch.float32) 239 | timesteps = torch.cat( 240 | [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps], 241 | dim=1).to(dtype=torch.long) 242 | else: 243 | attention_mask = None 244 | 245 | _, action_preds, return_preds, reward_preds = self.forward( 246 | states, actions, rewards, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs) 247 | 248 | if self.config['model_type'] in ['bc', 'dt']: 249 | return action_preds[0, -1] 250 | elif self.config['model_type'] in ['mgdt']: 251 | if self.config['sample_return'] == False: 252 | return action_preds[0, -1], return_preds[0, -1] 253 | else: 254 | return action_preds[0, -1], [return_preds[0][0, -1], return_preds[1][0, -1]] 255 | -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | ale-py==0.8.1 3 | astunparse==1.6.3 4 | atari-py==0.2.9 5 | blosc==1.11.1 6 | cached-property==1.5.2 7 | cachetools==5.3.0 8 | certifi==2022.12.7 9 | cffi==1.15.1 10 | charset-normalizer==3.0.1 11 | chex==0.1.6 12 | click==8.1.3 13 | cloudpickle==2.2.1 14 | configparser==5.3.0 15 | contourpy==1.0.7 16 | cycler==0.11.0 17 | cython==0.29.33 18 | d4rl==1.1 19 | decorator==5.1.1 20 | dm-control==1.0.10 21 | dm-env==1.6 22 | dm-tree==0.1.8 23 | docker-pycreds==0.4.0 24 | dopamine-rl==4.0.6 25 | etils==1.0.0 26 | fasteners==0.18 27 | filelock==3.9.0 28 | flatbuffers==23.1.21 29 | flax==0.6.6 30 | fonttools==4.38.0 31 | gast==0.4.0 32 | gin-config==0.5.0 33 | gitdb==4.0.10 34 | gitpython==3.1.31 35 | glfw==2.5.6 36 | google-auth-oauthlib==0.4.6 37 | google-auth==2.16.1 38 | google-pasta==0.2.0 39 | gql==0.2.0 40 | graphql-core==1.1 41 | grpcio==1.51.3 42 | gym-notices==0.0.8 43 | gym==0.23.1 44 | h5py==3.8.0 45 | huggingface-hub==0.12.1 46 | idna==3.4 47 | imageio==2.25.1 48 | importlib-metadata==6.0.0 49 | importlib-resources==5.12.0 50 | jax==0.4.4 51 | jaxlib==0.4.4 52 | joblib==1.2.0 53 | keras==2.11.0 54 | kiwisolver==1.4.4 55 | labmaze==1.0.6 56 | libclang==15.0.6.1 57 | lxml==4.9.2 58 | markdown-it-py==2.2.0 59 | markdown==3.4.1 60 | markupsafe==2.1.2 61 | matplotlib==3.7.0 62 | mdurl==0.1.2 63 | mj-envs==1.0.0 64 | mjrl==1.0.0 65 | msgpack==1.0.4 66 | mujoco-py==2.1.2.14 67 | mujoco==2.3.2 68 | mypy-extensions==1.0.0 69 | numpy==1.24.2 70 | nvidia-ml-py3==7.352.0 71 | oauthlib==3.2.2 72 | opencv-python==4.7.0.72 73 | opt-einsum==3.3.0 74 | optax==0.1.4 75 | orbax==0.1.2 76 | packaging==23.0 77 | pandas==1.5.3 78 | pillow==9.4.0 79 | pip==22.3.1 80 | promise==2.3 81 | protobuf==3.19.6 82 | psutil==5.9.4 83 | pyasn1-modules==0.2.8 84 | pyasn1==0.4.8 85 | pybullet==3.2.5 86 | pycparser==2.21 87 | pygame==2.2.0 88 | pygments==2.14.0 89 | pyopengl==3.1.6 90 | pyparsing==3.0.9 91 | python-dateutil==2.8.2 92 | pytz==2022.7.1 93 | pyyaml==6.0 94 | regex==2022.10.31 95 | requests-oauthlib==1.3.1 96 | requests==2.28.2 97 | rich==13.3.1 98 | rsa==4.9 99 | sacremoses==0.0.53 100 | scikit-video==1.1.11 101 | scipy==1.10.1 102 | sentry-sdk==1.15.0 103 | setuptools==65.6.3 104 | shortuuid==1.0.11 105 | six==1.16.0 106 | smmap==5.0.0 107 | style==1.1.0 108 | subprocess32==3.5.4 109 | tensorboard-data-server==0.6.1 110 | tensorboard-plugin-wit==1.8.1 111 | tensorboard==2.11.2 112 | tensorflow-estimator==2.11.0 113 | tensorflow-io-gcs-filesystem==0.30.0 114 | tensorflow-probability==0.19.0 115 | tensorflow==2.11.0 116 | tensorstore==0.1.33 117 | termcolor==2.2.0 118 | tf-slim==1.1.0 119 | tokenizers==0.10.3 120 | toolz==0.12.0 121 | torch==1.10.0+cu111 122 | torchaudio==0.10.0+rocm4.1 123 | torchvision==0.11.0+cu111 124 | tqdm==4.64.1 125 | transformers==4.11.3 126 | typed-argument-parser==1.7.2 127 | typing-extensions==4.5.0 128 | typing-inspect==0.8.0 129 | update==0.0.1 130 | urllib3==1.26.14 131 | wandb==0.9.1 132 | watchdog==2.3.0 133 | werkzeug==2.2.3 134 | wheel==0.38.4 135 | wrapt==1.14.1 136 | zipp==3.14.0 -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/trainner.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/training/seq_trainer.py 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | # import wandb 8 | 9 | import time 10 | from torch.utils.data import DataLoader, WeightedRandomSampler 11 | from tqdm import tqdm 12 | 13 | 14 | class Trainer: 15 | def __init__(self, model, optimizer, batch_size, dataset, writer, config, scheduler=None, eval_fns=None): 16 | self.model = model 17 | self.optimizer = optimizer 18 | self.batch_size = batch_size 19 | self.dataset = dataset 20 | self.scheduler = scheduler 21 | self.eval_fns = [] if eval_fns is None else eval_fns 22 | self.diagnostics = dict() 23 | self.writer = writer 24 | self.model_type = config['model_type'] 25 | self.reward_scale = config['reward_scale'] 26 | self.config = config 27 | 28 | self.train_count = 0 29 | 30 | self.start_time = time.time() 31 | 32 | def train_iteration(self, num_steps, iter_num=0, print_logs=False): 33 | train_losses = [] 34 | logs = dict() 35 | 36 | train_start = time.time() 37 | sampler = WeightedRandomSampler(self.dataset.p_sample, num_samples=num_steps*self.batch_size, replacement=True) 38 | dataloader = DataLoader(self.dataset, sampler=sampler, batch_size=self.batch_size) 39 | 40 | self.model.train() 41 | for states, actions, rewards, dones, rtg, timesteps, attention_mask in tqdm(dataloader): 42 | train_loss = self.train_step(states, actions, rewards, dones, rtg, timesteps, attention_mask) 43 | train_losses.append(train_loss) 44 | if self.writer is not None: 45 | self.writer.add_scalar('train_loss', train_loss, self.train_count) 46 | self.train_count += 1 47 | if self.scheduler is not None: 48 | self.scheduler.step() 49 | 50 | logs['time/training'] = time.time() - train_start 51 | 52 | eval_start = time.time() 53 | 54 | self.model.eval() 55 | for eval_fn in self.eval_fns: 56 | outputs = eval_fn(self.model) 57 | for k, v in outputs.items(): 58 | logs[f'evaluation/{k}'] = v 59 | 60 | logs['time/total'] = time.time() - self.start_time 61 | logs['time/evaluation'] = time.time() - eval_start 62 | logs['training/train_loss_mean'] = np.mean(train_losses) 63 | logs['training/train_loss_std'] = np.std(train_losses) 64 | 65 | for k in self.diagnostics: 66 | logs[k] = self.diagnostics[k] 67 | 68 | if print_logs: 69 | print('=' * 80) 70 | print(f'Iteration {iter_num}') 71 | for k, v in logs.items(): 72 | print(f'{k}: {v}') 73 | 74 | return logs 75 | 76 | def train_step(self, states, actions, rewards, dones, rtg, timesteps, attention_mask): 77 | rewards_target, action_target, rtg_target = torch.clone(rewards), torch.clone(actions), torch.clone(rtg) 78 | 79 | state_preds, action_preds, return_preds, reward_preds = self.model.forward( 80 | states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask, 81 | ) 82 | 83 | act_dim = action_preds.shape[2] 84 | action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 85 | action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 86 | 87 | if self.model_type in ['dt', 'bc']: 88 | loss = torch.mean((action_preds - action_target) ** 2) 89 | elif self.model_type in ['mgdt']: 90 | if self.config['sample_return'] == True: 91 | eps = torch.randn_like(return_preds[1]) 92 | return_preds_tmp = return_preds[0] + eps * torch.exp(0.5 * return_preds[1]) 93 | return_preds = return_preds_tmp 94 | return_preds = return_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0] 95 | return_target = rtg_target[:,:-1].reshape(-1, 1)[attention_mask.reshape(-1) > 0] 96 | reward_preds = reward_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0] 97 | reward_target = rewards_target.reshape(-1, 1)[attention_mask.reshape(-1) > 0] 98 | loss = torch.mean((action_preds - action_target) ** 2) \ 99 | + torch.mean((return_preds - return_target) ** 2) \ 100 | + torch.mean((reward_preds - reward_target) ** 2) 101 | 102 | self.optimizer.zero_grad() 103 | loss.backward() 104 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25) 105 | self.optimizer.step() 106 | 107 | with torch.no_grad(): 108 | self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item() 109 | 110 | return loss.detach().cpu().item() 111 | 112 | -------------------------------------------------------------------------------- /RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | highly based on https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py 3 | """ 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | import gym 9 | import d4rl 10 | import numpy as np 11 | import pickle 12 | import random 13 | 14 | 15 | class SequenceDataset(Dataset): 16 | def __init__(self, config): 17 | super(SequenceDataset, self).__init__() 18 | self.device = config.get('device', 'cuda') 19 | self.env_name = config['env_name'] 20 | 21 | self.env = gym.make(self.env_name) 22 | self.max_ep_len = config['max_ep_len'] 23 | self.scale = config['scale'] 24 | self.reward_scale = config['reward_scale'] 25 | 26 | self.state_dim = self.env.observation_space.shape[0] 27 | self.act_dim = self.env.action_space.shape[0] 28 | 29 | dataset_path = 'data/{}.pkl'.format(config['data_name']) 30 | with open(dataset_path, 'rb') as f: 31 | self.trajectories = pickle.load(f) 32 | 33 | # save all path information into separate lists 34 | self.is_delayed_reward = config['delayed_reward'] 35 | self.states, self.traj_lens, self.returns = [], [], [] 36 | for path in self.trajectories: 37 | if self.is_delayed_reward: # delayed: all rewards moved to end of trajectory 38 | path['rewards'][-1] = path['rewards'].sum() 39 | path['rewards'][:-1] = 0. 40 | self.states.append(path['observations']) 41 | self.traj_lens.append(len(path['observations'])) 42 | self.returns.append(path['rewards'].sum()) 43 | self.traj_lens, self.returns = np.array(self.traj_lens), np.array(self.returns) 44 | 45 | # used for input normalization 46 | self.states = np.concatenate(self.states, axis=0) 47 | self.state_mean, self.state_std = np.mean(self.states, axis=0), np.std(self.states, axis=0) + 1e-6 48 | 49 | self.K = config['K'] 50 | self.pct_traj = config.get('pct_traj', 1.) 51 | 52 | # only train on top pct_traj trajectories (for %BC experiment) 53 | num_timesteps = sum(self.traj_lens) 54 | num_timesteps = max(int(self.pct_traj * num_timesteps), 1) 55 | sorted_inds = np.argsort(self.returns) # lowest to highest 56 | num_trajectories = 1 57 | timesteps = self.traj_lens[sorted_inds[-1]] 58 | ind = len(self.trajectories) - 2 59 | while ind >= 0 and timesteps + self.traj_lens[sorted_inds[ind]] <= num_timesteps: 60 | timesteps += self.traj_lens[sorted_inds[ind]] 61 | num_trajectories += 1 62 | ind -= 1 63 | self.sorted_inds = sorted_inds[-num_trajectories:] 64 | 65 | # used to reweight sampling so we sample according to timesteps instead of trajectories 66 | self.p_sample = self.traj_lens[self.sorted_inds] / sum(self.traj_lens[self.sorted_inds]) 67 | 68 | def __getitem__(self, index): 69 | traj = self.trajectories[int(self.sorted_inds[index])] 70 | start_t = random.randint(0, traj['rewards'].shape[0] - 1) 71 | 72 | s = traj['observations'][start_t: start_t + self.K] 73 | a = traj['actions'][start_t: start_t + self.K] 74 | r = traj['rewards'][start_t: start_t + self.K].reshape(-1, 1) 75 | if 'terminals' in traj: 76 | d = traj['terminals'][start_t: start_t + self.K] 77 | else: 78 | d = traj['dones'][start_t: start_t + self.K] 79 | timesteps = np.arange(start_t, start_t + s.shape[0]) 80 | timesteps[timesteps >= self.max_ep_len] = self.max_ep_len - 1 # padding cutoff 81 | rtg = self.discount_cumsum(traj['rewards'][start_t:], gamma=1.)[:s.shape[0] + 1].reshape(-1, 1) 82 | if rtg.shape[0] <= s.shape[0]: 83 | rtg = np.concatenate([rtg, np.zeros((1, 1))], axis=0) 84 | 85 | # padding and state + reward + rtg normalization 86 | tlen = s.shape[0] 87 | s = np.concatenate([np.zeros((self.K - tlen, self.state_dim)), s], axis=0) 88 | s = (s - self.state_mean) / self.state_std 89 | a = np.concatenate([np.ones((self.K - tlen, self.act_dim)) * -10., a], axis=0) 90 | r = np.concatenate([np.zeros((self.K - tlen, 1)), r], axis=0) 91 | r = -1 + 2 * (r - self.reward_scale[0]) / (self.reward_scale[1] - self.reward_scale[0]) 92 | d = np.concatenate([np.ones((self.K - tlen)) * 2, d], axis=0) 93 | rtg = np.concatenate([np.zeros((self.K - tlen, 1)), rtg], axis=0) / self.scale 94 | timesteps = np.concatenate([np.zeros((self.K - tlen)), timesteps], axis=0) 95 | mask = np.concatenate([np.zeros((self.K - tlen)), np.ones((tlen))], axis=0) 96 | 97 | s = torch.from_numpy(s).to(dtype=torch.float32, device=self.device) 98 | a = torch.from_numpy(a).to(dtype=torch.float32, device=self.device) 99 | r = torch.from_numpy(r).to(dtype=torch.float32, device=self.device) 100 | d = torch.from_numpy(d).to(dtype=torch.long, device=self.device) 101 | rtg = torch.from_numpy(rtg).to(dtype=torch.float32, device=self.device) 102 | timesteps = torch.from_numpy(timesteps).to(dtype=torch.long, device=self.device) 103 | mask = torch.from_numpy(mask).to(device=self.device) 104 | return s, a, r, d, rtg, timesteps, mask 105 | 106 | def discount_cumsum(self, x, gamma=1.): 107 | discount_cumsum = np.zeros_like(x) 108 | discount_cumsum[-1] = x[-1] 109 | for t in reversed(range(x.shape[0]-1)): 110 | discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] 111 | return discount_cumsum 112 | 113 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Transformer in Transformer as Backbone for Deep Reinforcement Learning 2 | Paper: 3 | [old version](https://arxiv.org/abs/2212.14538): Transformer in Transformer as Backbone for Deep Reinforcement Learning 4 | [new version](https://arxiv.org/abs/2312.15863): PDiT: Interleaving Perception and Decision-making Transformers for Deep Reinforcement Learning. AAMAS 2024 (full paper with oral presentation) 5 | 6 | Code: 7 | 1) The two folders, DT_TIT and PPO_TIT_and_CQL_TIT, contain the old version; 8 | 2) The two folders, RL_Foundation_Mujoco_including_DT_MGDT_GATO_and_TIT and RLFoundation_BabyAI_including_DT_GATO_and_TIT, contain the new version; 9 | 3) **We recommend the readers to use the new version, which has a satisfactory performance and good file structure (thus, is easy to modify to design new algorithms). Thanks to [Zhiwei Xu](https://github.com/deligentfool) for the contribution of this new version.** 10 | 11 | 12 | # Cite 13 | Please cite our paper as: 14 | ``` 15 | @inproceedings{mao2024PDiT, 16 | title={PDiT: Interleaving Perception and Decision-making Transformers for Deep Reinforcement Learning}, 17 | author={Mao, Hangyu and Zhao, Rui and Li, Ziyue and Xu, Zhiwei and Chen, Hao and Chen, Yiqun and Zhang, Bin and Xiao, Zhen and Zhang, Junge and Yin, Jiangjin}, 18 | booktitle={Proceedings of the 23rd International Conference on Autonomous Agents and MultiAgent Systems}, 19 | year={2024} 20 | } 21 | ``` 22 | 23 | and cite the preliminary study as: 24 | ``` 25 | @article{mao2022transformer, 26 | title={Transformer in Transformer as Backbone for Deep Reinforcement Learning}, 27 | author={Mao, Hangyu and Zhao, Rui and Chen, Hao and Hao, Jianye and Chen, Yiqun and Li, Dong and Zhang, Junge and Xiao, Zhen}, 28 | journal={arXiv preprint arXiv:2212.14538}, 29 | year={2022} 30 | } 31 | ``` 32 | 33 | 34 | ## License 35 | MIT 36 | --------------------------------------------------------------------------------