├── LICENSE.md ├── README.md ├── architecture.png ├── atari ├── LICENSE ├── conda_env.yml ├── create_dataset.py ├── fixed_replay_buffer.py ├── mingpt │ ├── __init__.py │ ├── model_atari.py │ ├── trainer_atari.py │ └── utils.py ├── readme-atari.md ├── run.sh └── run_dt_atari.py └── gym ├── conda_env.yml ├── data └── download_d4rl_datasets.py ├── decision_transformer ├── envs │ ├── assets │ │ └── reacher_2d.xml │ └── reacher_2d.py ├── evaluation │ └── evaluate_episodes.py ├── models │ ├── decision_transformer.py │ ├── mlp_bc.py │ ├── model.py │ └── trajectory_gpt2.py └── training │ ├── act_trainer.py │ ├── seq_trainer.py │ └── trainer.py ├── experiment.py └── readme-gym.md /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Decision Transformer (Decision Transformer: Reinforcement Learning via Sequence Modeling) Authors (https://arxiv.org/abs/2106.01345) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Decision Transformer 3 | 4 | Lili Chen\*, Kevin Lu\*, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas†, and Igor Mordatch† 5 | 6 | \*equal contribution, †equal advising 7 | 8 | A link to our paper can be found on [arXiv](https://arxiv.org/abs/2106.01345). 9 | 10 | ## Overview 11 | 12 | Official codebase for [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://sites.google.com/berkeley.edu/decision-transformer). 13 | Contains scripts to reproduce experiments. 14 | 15 | ![image info](./architecture.png) 16 | 17 | ## Instructions 18 | 19 | We provide code in two sub-directories: `atari` containing code for Atari experiments and `gym` containing code for OpenAI Gym experiments. 20 | See corresponding READMEs in each folder for instructions; scripts should be run from the respective directories. 21 | It may be necessary to add the respective directories to your PYTHONPATH. 22 | 23 | ## Citation 24 | 25 | Please cite our paper as: 26 | 27 | ``` 28 | @article{chen2021decisiontransformer, 29 | title={Decision Transformer: Reinforcement Learning via Sequence Modeling}, 30 | author={Lili Chen and Kevin Lu and Aravind Rajeswaran and Kimin Lee and Aditya Grover and Michael Laskin and Pieter Abbeel and Aravind Srinivas and Igor Mordatch}, 31 | journal={arXiv preprint arXiv:2106.01345}, 32 | year={2021} 33 | } 34 | ``` 35 | 36 | Note: this is not an official Google or Facebook product. 37 | 38 | ## License 39 | 40 | MIT 41 | -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kzl/decision-transformer/e2d82e68f330c00f763507b3b01d774740bee53f/architecture.png -------------------------------------------------------------------------------- /atari/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /atari/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: decision-transformer-atari 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - python=3.7.9 6 | - pytorch=1.2 7 | - cudatoolkit=10. 8 | - numpy 9 | - psutil 10 | - opencv 11 | - pip 12 | - pip: 13 | - atari-py 14 | - pyprind 15 | - tensorflow-gpu>=1.13 16 | - absl-py 17 | - atari-py 18 | - gin-config 19 | - gym 20 | - tqdm 21 | - blosc 22 | - git+https://github.com/google/dopamine.git 23 | -------------------------------------------------------------------------------- /atari/create_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | # make deterministic 4 | from mingpt.utils import set_seed 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | import math 10 | from torch.utils.data import Dataset 11 | from mingpt.model_atari import GPT, GPTConfig 12 | from mingpt.trainer_atari import Trainer, TrainerConfig 13 | from mingpt.utils import sample 14 | from collections import deque 15 | import random 16 | import torch 17 | import pickle 18 | import blosc 19 | import argparse 20 | from fixed_replay_buffer import FixedReplayBuffer 21 | 22 | def create_dataset(num_buffers, num_steps, game, data_dir_prefix, trajectories_per_buffer): 23 | # -- load data from memory (make more efficient) 24 | obss = [] 25 | actions = [] 26 | returns = [0] 27 | done_idxs = [] 28 | stepwise_returns = [] 29 | 30 | transitions_per_buffer = np.zeros(50, dtype=int) 31 | num_trajectories = 0 32 | while len(obss) < num_steps: 33 | buffer_num = np.random.choice(np.arange(50 - num_buffers, 50), 1)[0] 34 | i = transitions_per_buffer[buffer_num] 35 | print('loading from buffer %d which has %d already loaded' % (buffer_num, i)) 36 | frb = FixedReplayBuffer( 37 | data_dir=data_dir_prefix + game + '/1/replay_logs', 38 | replay_suffix=buffer_num, 39 | observation_shape=(84, 84), 40 | stack_size=4, 41 | update_horizon=1, 42 | gamma=0.99, 43 | observation_dtype=np.uint8, 44 | batch_size=32, 45 | replay_capacity=100000) 46 | if frb._loaded_buffers: 47 | done = False 48 | curr_num_transitions = len(obss) 49 | trajectories_to_load = trajectories_per_buffer 50 | while not done: 51 | states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i]) 52 | states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) 53 | obss += [states] 54 | actions += [ac[0]] 55 | stepwise_returns += [ret[0]] 56 | if terminal[0]: 57 | done_idxs += [len(obss)] 58 | returns += [0] 59 | if trajectories_to_load == 0: 60 | done = True 61 | else: 62 | trajectories_to_load -= 1 63 | returns[-1] += ret[0] 64 | i += 1 65 | if i >= 100000: 66 | obss = obss[:curr_num_transitions] 67 | actions = actions[:curr_num_transitions] 68 | stepwise_returns = stepwise_returns[:curr_num_transitions] 69 | returns[-1] = 0 70 | i = transitions_per_buffer[buffer_num] 71 | done = True 72 | num_trajectories += (trajectories_per_buffer - trajectories_to_load) 73 | transitions_per_buffer[buffer_num] = i 74 | print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories)) 75 | 76 | actions = np.array(actions) 77 | returns = np.array(returns) 78 | stepwise_returns = np.array(stepwise_returns) 79 | done_idxs = np.array(done_idxs) 80 | 81 | # -- create reward-to-go dataset 82 | start_index = 0 83 | rtg = np.zeros_like(stepwise_returns) 84 | for i in done_idxs: 85 | i = int(i) 86 | curr_traj_returns = stepwise_returns[start_index:i] 87 | for j in range(i-1, start_index-1, -1): # start from i-1 88 | rtg_j = curr_traj_returns[j-start_index:i-start_index] 89 | rtg[j] = sum(rtg_j) 90 | start_index = i 91 | print('max rtg is %d' % max(rtg)) 92 | 93 | # -- create timestep dataset 94 | start_index = 0 95 | timesteps = np.zeros(len(actions)+1, dtype=int) 96 | for i in done_idxs: 97 | i = int(i) 98 | timesteps[start_index:i+1] = np.arange(i+1 - start_index) 99 | start_index = i+1 100 | print('max timestep is %d' % max(timesteps)) 101 | 102 | return obss, actions, returns, done_idxs, rtg, timesteps 103 | -------------------------------------------------------------------------------- /atari/fixed_replay_buffer.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/google-research/batch_rl/blob/master/batch_rl/fixed_replay/replay_memory/fixed_replay_buffer.py 2 | 3 | import collections 4 | from concurrent import futures 5 | from dopamine.replay_memory import circular_replay_buffer 6 | import numpy as np 7 | import tensorflow.compat.v1 as tf 8 | import gin 9 | 10 | gfile = tf.gfile 11 | 12 | STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX 13 | 14 | class FixedReplayBuffer(object): 15 | """Object composed of a list of OutofGraphReplayBuffers.""" 16 | 17 | def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg 18 | """Initialize the FixedReplayBuffer class. 19 | Args: 20 | data_dir: str, log Directory from which to load the replay buffer. 21 | replay_suffix: int, If not None, then only load the replay buffer 22 | corresponding to the specific suffix in data directory. 23 | *args: Arbitrary extra arguments. 24 | **kwargs: Arbitrary keyword arguments. 25 | """ 26 | self._args = args 27 | self._kwargs = kwargs 28 | self._data_dir = data_dir 29 | self._loaded_buffers = False 30 | self.add_count = np.array(0) 31 | self._replay_suffix = replay_suffix 32 | if not self._loaded_buffers: 33 | if replay_suffix is not None: 34 | assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' 35 | self.load_single_buffer(replay_suffix) 36 | else: 37 | self._load_replay_buffers(num_buffers=50) 38 | 39 | def load_single_buffer(self, suffix): 40 | """Load a single replay buffer.""" 41 | replay_buffer = self._load_buffer(suffix) 42 | if replay_buffer is not None: 43 | self._replay_buffers = [replay_buffer] 44 | self.add_count = replay_buffer.add_count 45 | self._num_replay_buffers = 1 46 | self._loaded_buffers = True 47 | 48 | def _load_buffer(self, suffix): 49 | """Loads a OutOfGraphReplayBuffer replay buffer.""" 50 | try: 51 | # pytype: disable=attribute-error 52 | replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer( 53 | *self._args, **self._kwargs) 54 | replay_buffer.load(self._data_dir, suffix) 55 | tf.logging.info('Loaded replay buffer ckpt {} from {}'.format( 56 | suffix, self._data_dir)) 57 | # pytype: enable=attribute-error 58 | return replay_buffer 59 | except tf.errors.NotFoundError: 60 | return None 61 | 62 | def _load_replay_buffers(self, num_buffers=None): 63 | """Loads multiple checkpoints into a list of replay buffers.""" 64 | if not self._loaded_buffers: # pytype: disable=attribute-error 65 | ckpts = gfile.ListDirectory(self._data_dir) # pytype: disable=attribute-error 66 | # Assumes that the checkpoints are saved in a format CKPT_NAME.{SUFFIX}.gz 67 | ckpt_counters = collections.Counter( 68 | [name.split('.')[-2] for name in ckpts]) 69 | # Should contain the files for add_count, action, observation, reward, 70 | # terminal and invalid_range 71 | ckpt_suffixes = [x for x in ckpt_counters if ckpt_counters[x] in [6, 7]] 72 | if num_buffers is not None: 73 | ckpt_suffixes = np.random.choice( 74 | ckpt_suffixes, num_buffers, replace=False) 75 | self._replay_buffers = [] 76 | # Load the replay buffers in parallel 77 | with futures.ThreadPoolExecutor( 78 | max_workers=num_buffers) as thread_pool_executor: 79 | replay_futures = [thread_pool_executor.submit( 80 | self._load_buffer, suffix) for suffix in ckpt_suffixes] 81 | for f in replay_futures: 82 | replay_buffer = f.result() 83 | if replay_buffer is not None: 84 | self._replay_buffers.append(replay_buffer) 85 | self.add_count = max(replay_buffer.add_count, self.add_count) 86 | self._num_replay_buffers = len(self._replay_buffers) 87 | if self._num_replay_buffers: 88 | self._loaded_buffers = True 89 | 90 | def get_transition_elements(self): 91 | return self._replay_buffers[0].get_transition_elements() 92 | 93 | def sample_transition_batch(self, batch_size=None, indices=None): 94 | buffer_index = np.random.randint(self._num_replay_buffers) 95 | return self._replay_buffers[buffer_index].sample_transition_batch( 96 | batch_size=batch_size, indices=indices) 97 | 98 | def load(self, *args, **kwargs): # pylint: disable=unused-argument 99 | pass 100 | 101 | def reload_buffer(self, num_buffers=None): 102 | self._loaded_buffers = False 103 | self._load_replay_buffers(num_buffers) 104 | 105 | def save(self, *args, **kwargs): # pylint: disable=unused-argument 106 | pass 107 | 108 | def add(self, *args, **kwargs): # pylint: disable=unused-argument 109 | pass -------------------------------------------------------------------------------- /atari/mingpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kzl/decision-transformer/e2d82e68f330c00f763507b3b01d774740bee53f/atari/mingpt/__init__.py -------------------------------------------------------------------------------- /atari/mingpt/model_atari.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | """ 10 | 11 | """ 12 | GPT model: 13 | - the initial stem consists of a combination of token encoding and a positional encoding 14 | - the meat of it is a uniform sequence of Transformer blocks 15 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 16 | - all blocks feed into a central residual pathway similar to resnets 17 | - the final decoder is a linear projection into a vanilla Softmax classifier 18 | """ 19 | 20 | import math 21 | import logging 22 | 23 | import torch 24 | import torch.nn as nn 25 | from torch.nn import functional as F 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | import numpy as np 30 | 31 | class GELU(nn.Module): 32 | def forward(self, input): 33 | return F.gelu(input) 34 | 35 | class GPTConfig: 36 | """ base GPT config, params common to all GPT versions """ 37 | embd_pdrop = 0.1 38 | resid_pdrop = 0.1 39 | attn_pdrop = 0.1 40 | 41 | def __init__(self, vocab_size, block_size, **kwargs): 42 | self.vocab_size = vocab_size 43 | self.block_size = block_size 44 | for k,v in kwargs.items(): 45 | setattr(self, k, v) 46 | 47 | class GPT1Config(GPTConfig): 48 | """ GPT-1 like network roughly 125M params """ 49 | n_layer = 12 50 | n_head = 12 51 | n_embd = 768 52 | 53 | class CausalSelfAttention(nn.Module): 54 | """ 55 | A vanilla multi-head masked self-attention layer with a projection at the end. 56 | It is possible to use torch.nn.MultiheadAttention here but I am including an 57 | explicit implementation here to show that there is nothing too scary here. 58 | """ 59 | 60 | def __init__(self, config): 61 | super().__init__() 62 | assert config.n_embd % config.n_head == 0 63 | # key, query, value projections for all heads 64 | self.key = nn.Linear(config.n_embd, config.n_embd) 65 | self.query = nn.Linear(config.n_embd, config.n_embd) 66 | self.value = nn.Linear(config.n_embd, config.n_embd) 67 | # regularization 68 | self.attn_drop = nn.Dropout(config.attn_pdrop) 69 | self.resid_drop = nn.Dropout(config.resid_pdrop) 70 | # output projection 71 | self.proj = nn.Linear(config.n_embd, config.n_embd) 72 | # causal mask to ensure that attention is only applied to the left in the input sequence 73 | # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 74 | # .view(1, 1, config.block_size, config.block_size)) 75 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size + 1, config.block_size + 1)) 76 | .view(1, 1, config.block_size + 1, config.block_size + 1)) 77 | self.n_head = config.n_head 78 | 79 | def forward(self, x, layer_past=None): 80 | B, T, C = x.size() 81 | 82 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 83 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 84 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 85 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 86 | 87 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 88 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 89 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 90 | att = F.softmax(att, dim=-1) 91 | att = self.attn_drop(att) 92 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 93 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 94 | 95 | # output projection 96 | y = self.resid_drop(self.proj(y)) 97 | return y 98 | 99 | class Block(nn.Module): 100 | """ an unassuming Transformer block """ 101 | 102 | def __init__(self, config): 103 | super().__init__() 104 | self.ln1 = nn.LayerNorm(config.n_embd) 105 | self.ln2 = nn.LayerNorm(config.n_embd) 106 | self.attn = CausalSelfAttention(config) 107 | self.mlp = nn.Sequential( 108 | nn.Linear(config.n_embd, 4 * config.n_embd), 109 | GELU(), 110 | nn.Linear(4 * config.n_embd, config.n_embd), 111 | nn.Dropout(config.resid_pdrop), 112 | ) 113 | 114 | def forward(self, x): 115 | x = x + self.attn(self.ln1(x)) 116 | x = x + self.mlp(self.ln2(x)) 117 | return x 118 | 119 | class GPT(nn.Module): 120 | """ the full GPT language model, with a context size of block_size """ 121 | 122 | def __init__(self, config): 123 | super().__init__() 124 | 125 | self.config = config 126 | 127 | self.model_type = config.model_type 128 | 129 | # input embedding stem 130 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 131 | # self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 132 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd)) 133 | self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep+1, config.n_embd)) 134 | self.drop = nn.Dropout(config.embd_pdrop) 135 | 136 | # transformer 137 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 138 | # decoder head 139 | self.ln_f = nn.LayerNorm(config.n_embd) 140 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 141 | 142 | self.block_size = config.block_size 143 | self.apply(self._init_weights) 144 | 145 | 146 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 147 | 148 | 149 | self.state_encoder = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), 150 | nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), 151 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), 152 | nn.Flatten(), nn.Linear(3136, config.n_embd), nn.Tanh()) 153 | 154 | self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh()) 155 | 156 | self.action_embeddings = nn.Sequential(nn.Embedding(config.vocab_size, config.n_embd), nn.Tanh()) 157 | nn.init.normal_(self.action_embeddings[0].weight, mean=0.0, std=0.02) 158 | 159 | def get_block_size(self): 160 | return self.block_size 161 | 162 | def _init_weights(self, module): 163 | if isinstance(module, (nn.Linear, nn.Embedding)): 164 | module.weight.data.normal_(mean=0.0, std=0.02) 165 | if isinstance(module, nn.Linear) and module.bias is not None: 166 | module.bias.data.zero_() 167 | elif isinstance(module, nn.LayerNorm): 168 | module.bias.data.zero_() 169 | module.weight.data.fill_(1.0) 170 | 171 | def configure_optimizers(self, train_config): 172 | """ 173 | This long function is unfortunately doing something very simple and is being very defensive: 174 | We are separating out all parameters of the model into two buckets: those that will experience 175 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 176 | We are then returning the PyTorch optimizer object. 177 | """ 178 | 179 | # separate out all parameters to those that will and won't experience regularizing weight decay 180 | decay = set() 181 | no_decay = set() 182 | # whitelist_weight_modules = (torch.nn.Linear, ) 183 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) 184 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 185 | for mn, m in self.named_modules(): 186 | for pn, p in m.named_parameters(): 187 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 188 | 189 | if pn.endswith('bias'): 190 | # all biases will not be decayed 191 | no_decay.add(fpn) 192 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 193 | # weights of whitelist modules will be weight decayed 194 | decay.add(fpn) 195 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 196 | # weights of blacklist modules will NOT be weight decayed 197 | no_decay.add(fpn) 198 | 199 | # special case the position embedding parameter in the root GPT module as not decayed 200 | no_decay.add('pos_emb') 201 | no_decay.add('global_pos_emb') 202 | 203 | # validate that we considered every parameter 204 | param_dict = {pn: p for pn, p in self.named_parameters()} 205 | inter_params = decay & no_decay 206 | union_params = decay | no_decay 207 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 208 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 209 | % (str(param_dict.keys() - union_params), ) 210 | 211 | # create the pytorch optimizer object 212 | optim_groups = [ 213 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 214 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 215 | ] 216 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 217 | return optimizer 218 | 219 | # state, action, and return 220 | def forward(self, states, actions, targets=None, rtgs=None, timesteps=None): 221 | # states: (batch, block_size, 4*84*84) 222 | # actions: (batch, block_size, 1) 223 | # targets: (batch, block_size, 1) 224 | # rtgs: (batch, block_size, 1) 225 | # timesteps: (batch, 1, 1) 226 | 227 | state_embeddings = self.state_encoder(states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()) # (batch * block_size, n_embd) 228 | state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd) 229 | 230 | if actions is not None and self.model_type == 'reward_conditioned': 231 | rtg_embeddings = self.ret_emb(rtgs.type(torch.float32)) 232 | action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd) 233 | 234 | token_embeddings = torch.zeros((states.shape[0], states.shape[1]*3 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) 235 | token_embeddings[:,::3,:] = rtg_embeddings 236 | token_embeddings[:,1::3,:] = state_embeddings 237 | token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:] 238 | elif actions is None and self.model_type == 'reward_conditioned': # only happens at very first timestep of evaluation 239 | rtg_embeddings = self.ret_emb(rtgs.type(torch.float32)) 240 | 241 | token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2, self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) 242 | token_embeddings[:,::2,:] = rtg_embeddings # really just [:,0,:] 243 | token_embeddings[:,1::2,:] = state_embeddings # really just [:,1,:] 244 | elif actions is not None and self.model_type == 'naive': 245 | action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd) 246 | 247 | token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) 248 | token_embeddings[:,::2,:] = state_embeddings 249 | token_embeddings[:,1::2,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:] 250 | elif actions is None and self.model_type == 'naive': # only happens at very first timestep of evaluation 251 | token_embeddings = state_embeddings 252 | else: 253 | raise NotImplementedError() 254 | 255 | batch_size = states.shape[0] 256 | all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd 257 | 258 | position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :] 259 | 260 | x = self.drop(token_embeddings + position_embeddings) 261 | x = self.blocks(x) 262 | x = self.ln_f(x) 263 | logits = self.head(x) 264 | 265 | if actions is not None and self.model_type == 'reward_conditioned': 266 | logits = logits[:, 1::3, :] # only keep predictions from state_embeddings 267 | elif actions is None and self.model_type == 'reward_conditioned': 268 | logits = logits[:, 1:, :] 269 | elif actions is not None and self.model_type == 'naive': 270 | logits = logits[:, ::2, :] # only keep predictions from state_embeddings 271 | elif actions is None and self.model_type == 'naive': 272 | logits = logits # for completeness 273 | else: 274 | raise NotImplementedError() 275 | 276 | # if we are given some desired targets also calculate the loss 277 | loss = None 278 | if targets is not None: 279 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) 280 | 281 | return logits, loss 282 | -------------------------------------------------------------------------------- /atari/mingpt/trainer_atari.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | """ 10 | 11 | """ 12 | Simple training loop; Boilerplate that could apply to any arbitrary neural network, 13 | so nothing in this file really has anything to do with GPT specifically. 14 | """ 15 | 16 | import math 17 | import logging 18 | 19 | from tqdm import tqdm 20 | import numpy as np 21 | 22 | import torch 23 | import torch.optim as optim 24 | from torch.optim.lr_scheduler import LambdaLR 25 | from torch.utils.data.dataloader import DataLoader 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | from mingpt.utils import sample 30 | import atari_py 31 | from collections import deque 32 | import random 33 | import cv2 34 | import torch 35 | from PIL import Image 36 | 37 | class TrainerConfig: 38 | # optimization parameters 39 | max_epochs = 10 40 | batch_size = 64 41 | learning_rate = 3e-4 42 | betas = (0.9, 0.95) 43 | grad_norm_clip = 1.0 44 | weight_decay = 0.1 # only applied on matmul weights 45 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original 46 | lr_decay = False 47 | warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere 48 | final_tokens = 260e9 # (at what point we reach 10% of original LR) 49 | # checkpoint settings 50 | ckpt_path = None 51 | num_workers = 0 # for DataLoader 52 | 53 | def __init__(self, **kwargs): 54 | for k,v in kwargs.items(): 55 | setattr(self, k, v) 56 | 57 | class Trainer: 58 | 59 | def __init__(self, model, train_dataset, test_dataset, config): 60 | self.model = model 61 | self.train_dataset = train_dataset 62 | self.test_dataset = test_dataset 63 | self.config = config 64 | 65 | # take over whatever gpus are on the system 66 | self.device = 'cpu' 67 | if torch.cuda.is_available(): 68 | self.device = torch.cuda.current_device() 69 | self.model = torch.nn.DataParallel(self.model).to(self.device) 70 | 71 | def save_checkpoint(self): 72 | # DataParallel wrappers keep raw model object in .module attribute 73 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 74 | logger.info("saving %s", self.config.ckpt_path) 75 | # torch.save(raw_model.state_dict(), self.config.ckpt_path) 76 | 77 | def train(self): 78 | model, config = self.model, self.config 79 | raw_model = model.module if hasattr(self.model, "module") else model 80 | optimizer = raw_model.configure_optimizers(config) 81 | 82 | def run_epoch(split, epoch_num=0): 83 | is_train = split == 'train' 84 | model.train(is_train) 85 | data = self.train_dataset if is_train else self.test_dataset 86 | loader = DataLoader(data, shuffle=True, pin_memory=True, 87 | batch_size=config.batch_size, 88 | num_workers=config.num_workers) 89 | 90 | losses = [] 91 | pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader) 92 | for it, (x, y, r, t) in pbar: 93 | 94 | # place data on the correct device 95 | x = x.to(self.device) 96 | y = y.to(self.device) 97 | r = r.to(self.device) 98 | t = t.to(self.device) 99 | 100 | # forward the model 101 | with torch.set_grad_enabled(is_train): 102 | # logits, loss = model(x, y, r) 103 | logits, loss = model(x, y, y, r, t) 104 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus 105 | losses.append(loss.item()) 106 | 107 | if is_train: 108 | 109 | # backprop and update the parameters 110 | model.zero_grad() 111 | loss.backward() 112 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 113 | optimizer.step() 114 | 115 | # decay the learning rate based on our progress 116 | if config.lr_decay: 117 | self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) 118 | if self.tokens < config.warmup_tokens: 119 | # linear warmup 120 | lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) 121 | else: 122 | # cosine learning rate decay 123 | progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) 124 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 125 | lr = config.learning_rate * lr_mult 126 | for param_group in optimizer.param_groups: 127 | param_group['lr'] = lr 128 | else: 129 | lr = config.learning_rate 130 | 131 | # report progress 132 | pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}") 133 | 134 | if not is_train: 135 | test_loss = float(np.mean(losses)) 136 | logger.info("test loss: %f", test_loss) 137 | return test_loss 138 | 139 | # best_loss = float('inf') 140 | 141 | best_return = -float('inf') 142 | 143 | self.tokens = 0 # counter used for learning rate decay 144 | 145 | for epoch in range(config.max_epochs): 146 | 147 | run_epoch('train', epoch_num=epoch) 148 | # if self.test_dataset is not None: 149 | # test_loss = run_epoch('test') 150 | 151 | # # supports early stopping based on the test loss, or just save always if no test set is provided 152 | # good_model = self.test_dataset is None or test_loss < best_loss 153 | # if self.config.ckpt_path is not None and good_model: 154 | # best_loss = test_loss 155 | # self.save_checkpoint() 156 | 157 | # -- pass in target returns 158 | if self.config.model_type == 'naive': 159 | eval_return = self.get_returns(0) 160 | elif self.config.model_type == 'reward_conditioned': 161 | if self.config.game == 'Breakout': 162 | eval_return = self.get_returns(90) 163 | elif self.config.game == 'Seaquest': 164 | eval_return = self.get_returns(1150) 165 | elif self.config.game == 'Qbert': 166 | eval_return = self.get_returns(14000) 167 | elif self.config.game == 'Pong': 168 | eval_return = self.get_returns(20) 169 | else: 170 | raise NotImplementedError() 171 | else: 172 | raise NotImplementedError() 173 | 174 | def get_returns(self, ret): 175 | self.model.train(False) 176 | args=Args(self.config.game.lower(), self.config.seed) 177 | env = Env(args) 178 | env.eval() 179 | 180 | T_rewards, T_Qs = [], [] 181 | done = True 182 | for i in range(10): 183 | state = env.reset() 184 | state = state.type(torch.float32).to(self.device).unsqueeze(0).unsqueeze(0) 185 | rtgs = [ret] 186 | # first state is from env, first rtg is target return, and first timestep is 0 187 | sampled_action = sample(self.model.module, state, 1, temperature=1.0, sample=True, actions=None, 188 | rtgs=torch.tensor(rtgs, dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1), 189 | timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(self.device)) 190 | 191 | j = 0 192 | all_states = state 193 | actions = [] 194 | while True: 195 | if done: 196 | state, reward_sum, done = env.reset(), 0, False 197 | action = sampled_action.cpu().numpy()[0,-1] 198 | actions += [sampled_action] 199 | state, reward, done = env.step(action) 200 | reward_sum += reward 201 | j += 1 202 | 203 | if done: 204 | T_rewards.append(reward_sum) 205 | break 206 | 207 | state = state.unsqueeze(0).unsqueeze(0).to(self.device) 208 | 209 | all_states = torch.cat([all_states, state], dim=0) 210 | 211 | rtgs += [rtgs[-1] - reward] 212 | # all_states has all previous states and rtgs has all previous rtgs (will be cut to block_size in utils.sample) 213 | # timestep is just current timestep 214 | sampled_action = sample(self.model.module, all_states.unsqueeze(0), 1, temperature=1.0, sample=True, 215 | actions=torch.tensor(actions, dtype=torch.long).to(self.device).unsqueeze(1).unsqueeze(0), 216 | rtgs=torch.tensor(rtgs, dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1), 217 | timesteps=(min(j, self.config.max_timestep) * torch.ones((1, 1, 1), dtype=torch.int64).to(self.device))) 218 | env.close() 219 | eval_return = sum(T_rewards)/10. 220 | print("target return: %d, eval return: %d" % (ret, eval_return)) 221 | self.model.train(True) 222 | return eval_return 223 | 224 | 225 | class Env(): 226 | def __init__(self, args): 227 | self.device = args.device 228 | self.ale = atari_py.ALEInterface() 229 | self.ale.setInt('random_seed', args.seed) 230 | self.ale.setInt('max_num_frames_per_episode', args.max_episode_length) 231 | self.ale.setFloat('repeat_action_probability', 0) # Disable sticky actions 232 | self.ale.setInt('frame_skip', 0) 233 | self.ale.setBool('color_averaging', False) 234 | self.ale.loadROM(atari_py.get_game_path(args.game)) # ROM loading must be done after setting options 235 | actions = self.ale.getMinimalActionSet() 236 | self.actions = dict([i, e] for i, e in zip(range(len(actions)), actions)) 237 | self.lives = 0 # Life counter (used in DeepMind training) 238 | self.life_termination = False # Used to check if resetting only from loss of life 239 | self.window = args.history_length # Number of frames to concatenate 240 | self.state_buffer = deque([], maxlen=args.history_length) 241 | self.training = True # Consistent with model training mode 242 | 243 | def _get_state(self): 244 | state = cv2.resize(self.ale.getScreenGrayscale(), (84, 84), interpolation=cv2.INTER_LINEAR) 245 | return torch.tensor(state, dtype=torch.float32, device=self.device).div_(255) 246 | 247 | def _reset_buffer(self): 248 | for _ in range(self.window): 249 | self.state_buffer.append(torch.zeros(84, 84, device=self.device)) 250 | 251 | def reset(self): 252 | if self.life_termination: 253 | self.life_termination = False # Reset flag 254 | self.ale.act(0) # Use a no-op after loss of life 255 | else: 256 | # Reset internals 257 | self._reset_buffer() 258 | self.ale.reset_game() 259 | # Perform up to 30 random no-ops before starting 260 | for _ in range(random.randrange(30)): 261 | self.ale.act(0) # Assumes raw action 0 is always no-op 262 | if self.ale.game_over(): 263 | self.ale.reset_game() 264 | # Process and return "initial" state 265 | observation = self._get_state() 266 | self.state_buffer.append(observation) 267 | self.lives = self.ale.lives() 268 | return torch.stack(list(self.state_buffer), 0) 269 | 270 | def step(self, action): 271 | # Repeat action 4 times, max pool over last 2 frames 272 | frame_buffer = torch.zeros(2, 84, 84, device=self.device) 273 | reward, done = 0, False 274 | for t in range(4): 275 | reward += self.ale.act(self.actions.get(action)) 276 | if t == 2: 277 | frame_buffer[0] = self._get_state() 278 | elif t == 3: 279 | frame_buffer[1] = self._get_state() 280 | done = self.ale.game_over() 281 | if done: 282 | break 283 | observation = frame_buffer.max(0)[0] 284 | self.state_buffer.append(observation) 285 | # Detect loss of life as terminal in training mode 286 | if self.training: 287 | lives = self.ale.lives() 288 | if lives < self.lives and lives > 0: # Lives > 0 for Q*bert 289 | self.life_termination = not done # Only set flag when not truly done 290 | done = True 291 | self.lives = lives 292 | # Return state, reward, done 293 | return torch.stack(list(self.state_buffer), 0), reward, done 294 | 295 | # Uses loss of life as terminal signal 296 | def train(self): 297 | self.training = True 298 | 299 | # Uses standard terminal signal 300 | def eval(self): 301 | self.training = False 302 | 303 | def action_space(self): 304 | return len(self.actions) 305 | 306 | def render(self): 307 | cv2.imshow('screen', self.ale.getScreenRGB()[:, :, ::-1]) 308 | cv2.waitKey(1) 309 | 310 | def close(self): 311 | cv2.destroyAllWindows() 312 | 313 | class Args: 314 | def __init__(self, game, seed): 315 | self.device = torch.device('cuda') 316 | self.seed = seed 317 | self.max_episode_length = 108e3 318 | self.game = game 319 | self.history_length = 4 320 | -------------------------------------------------------------------------------- /atari/mingpt/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | """ 10 | 11 | import random 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | 17 | def set_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | def top_k_logits(logits, k): 24 | v, ix = torch.topk(logits, k) 25 | out = logits.clone() 26 | out[out < v[:, [-1]]] = -float('Inf') 27 | return out 28 | 29 | @torch.no_grad() 30 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None, actions=None, rtgs=None, timesteps=None): 31 | """ 32 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 33 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 34 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 35 | of block_size, unlike an RNN that has an infinite context window. 36 | """ 37 | block_size = model.get_block_size() 38 | model.eval() 39 | for k in range(steps): 40 | # x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 41 | x_cond = x if x.size(1) <= block_size//3 else x[:, -block_size//3:] # crop context if needed 42 | if actions is not None: 43 | actions = actions if actions.size(1) <= block_size//3 else actions[:, -block_size//3:] # crop context if needed 44 | rtgs = rtgs if rtgs.size(1) <= block_size//3 else rtgs[:, -block_size//3:] # crop context if needed 45 | logits, _ = model(x_cond, actions=actions, targets=None, rtgs=rtgs, timesteps=timesteps) 46 | # pluck the logits at the final step and scale by temperature 47 | logits = logits[:, -1, :] / temperature 48 | # optionally crop probabilities to only the top k options 49 | if top_k is not None: 50 | logits = top_k_logits(logits, top_k) 51 | # apply softmax to convert to probabilities 52 | probs = F.softmax(logits, dim=-1) 53 | # sample from the distribution or take the most likely 54 | if sample: 55 | ix = torch.multinomial(probs, num_samples=1) 56 | else: 57 | _, ix = torch.topk(probs, k=1, dim=-1) 58 | # append to the sequence and continue 59 | # x = torch.cat((x, ix), dim=1) 60 | x = ix 61 | 62 | return x 63 | -------------------------------------------------------------------------------- /atari/readme-atari.md: -------------------------------------------------------------------------------- 1 | 2 | # Atari 3 | 4 | We build our Atari implementation on top of [minGPT](https://github.com/karpathy/minGPT) and benchmark our results on the [DQN-replay](https://github.com/google-research/batch_rl) dataset. 5 | 6 | ## Installation 7 | 8 | Dependencies can be installed with the following command: 9 | 10 | ``` 11 | conda env create -f conda_env.yml 12 | ``` 13 | 14 | ## Downloading datasets 15 | 16 | Create a directory for the dataset and load the dataset using [gsutil](https://cloud.google.com/storage/docs/gsutil_install#install). Replace `[DIRECTORY_NAME]` and `[GAME_NAME]` accordingly (e.g., `./dqn_replay` for `[DIRECTORY_NAME]` and `Breakout` for `[GAME_NAME]`) 17 | ``` 18 | mkdir [DIRECTORY_NAME] 19 | gsutil -m cp -R gs://atari-replay-datasets/dqn/[GAME_NAME] [DIRECTORY_NAME] 20 | ``` 21 | 22 | ## Example usage 23 | 24 | Scripts to reproduce our Decision Transformer results can be found in `run.sh`. 25 | 26 | ``` 27 | python run_dt_atari.py --seed 123 --block_size 90 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 --data_dir_prefix [DIRECTORY_NAME] 28 | ``` 29 | -------------------------------------------------------------------------------- /atari/run.sh: -------------------------------------------------------------------------------- 1 | # Decision Transformer (DT) 2 | for seed in 123 231 312 3 | do 4 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 5 | done 6 | 7 | for seed in 123 231 312 8 | do 9 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Qbert' --batch_size 128 10 | done 11 | 12 | for seed in 123 231 312 13 | do 14 | python run_dt_atari.py --seed $seed --context_length 50 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 15 | done 16 | 17 | for seed in 123 231 312 18 | do 19 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Seaquest' --batch_size 128 20 | done 21 | 22 | # Behavior Cloning (BC) 23 | for seed in 123 231 312 24 | do 25 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 26 | done 27 | 28 | for seed in 123 231 312 29 | do 30 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Qbert' --batch_size 128 31 | done 32 | 33 | for seed in 123 231 312 34 | do 35 | python run_dt_atari.py --seed $seed --context_length 50 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 36 | done 37 | 38 | for seed in 123 231 312 39 | do 40 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Seaquest' --batch_size 128 41 | done -------------------------------------------------------------------------------- /atari/run_dt_atari.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | # make deterministic 4 | from mingpt.utils import set_seed 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | import math 10 | from torch.utils.data import Dataset 11 | from mingpt.model_atari import GPT, GPTConfig 12 | from mingpt.trainer_atari import Trainer, TrainerConfig 13 | from mingpt.utils import sample 14 | from collections import deque 15 | import random 16 | import torch 17 | import pickle 18 | import blosc 19 | import argparse 20 | from create_dataset import create_dataset 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--seed', type=int, default=123) 24 | parser.add_argument('--context_length', type=int, default=30) 25 | parser.add_argument('--epochs', type=int, default=5) 26 | parser.add_argument('--model_type', type=str, default='reward_conditioned') 27 | parser.add_argument('--num_steps', type=int, default=500000) 28 | parser.add_argument('--num_buffers', type=int, default=50) 29 | parser.add_argument('--game', type=str, default='Breakout') 30 | parser.add_argument('--batch_size', type=int, default=128) 31 | # 32 | parser.add_argument('--trajectories_per_buffer', type=int, default=10, help='Number of trajectories to sample from each of the buffers.') 33 | parser.add_argument('--data_dir_prefix', type=str, default='./dqn_replay/') 34 | args = parser.parse_args() 35 | 36 | set_seed(args.seed) 37 | 38 | class StateActionReturnDataset(Dataset): 39 | 40 | def __init__(self, data, block_size, actions, done_idxs, rtgs, timesteps): 41 | self.block_size = block_size 42 | self.vocab_size = max(actions) + 1 43 | self.data = data 44 | self.actions = actions 45 | self.done_idxs = done_idxs 46 | self.rtgs = rtgs 47 | self.timesteps = timesteps 48 | 49 | def __len__(self): 50 | return len(self.data) - self.block_size 51 | 52 | def __getitem__(self, idx): 53 | block_size = self.block_size // 3 54 | done_idx = idx + block_size 55 | for i in self.done_idxs: 56 | if i > idx: # first done_idx greater than idx 57 | done_idx = min(int(i), done_idx) 58 | break 59 | idx = done_idx - block_size 60 | states = torch.tensor(np.array(self.data[idx:done_idx]), dtype=torch.float32).reshape(block_size, -1) # (block_size, 4*84*84) 61 | states = states / 255. 62 | actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) 63 | rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) 64 | timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1) 65 | 66 | return states, actions, rtgs, timesteps 67 | 68 | obss, actions, returns, done_idxs, rtgs, timesteps = create_dataset(args.num_buffers, args.num_steps, args.game, args.data_dir_prefix, args.trajectories_per_buffer) 69 | 70 | # set up logging 71 | logging.basicConfig( 72 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 73 | datefmt="%m/%d/%Y %H:%M:%S", 74 | level=logging.INFO, 75 | ) 76 | 77 | train_dataset = StateActionReturnDataset(obss, args.context_length*3, actions, done_idxs, rtgs, timesteps) 78 | 79 | mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 80 | n_layer=6, n_head=8, n_embd=128, model_type=args.model_type, max_timestep=max(timesteps)) 81 | model = GPT(mconf) 82 | 83 | # initialize a trainer instance and kick off training 84 | epochs = args.epochs 85 | tconf = TrainerConfig(max_epochs=epochs, batch_size=args.batch_size, learning_rate=6e-4, 86 | lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*args.context_length*3, 87 | num_workers=4, seed=args.seed, model_type=args.model_type, game=args.game, max_timestep=max(timesteps)) 88 | trainer = Trainer(model, train_dataset, None, tconf) 89 | 90 | trainer.train() 91 | -------------------------------------------------------------------------------- /gym/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: decision-transformer-gym 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - python=3.8.5 6 | - anaconda 7 | - cudatoolkit=10. 8 | - numpy 9 | - pip 10 | - pip: 11 | - gym==0.18.3 12 | - mujoco-py==2.0.2.13 13 | - numpy==1.20.3 14 | - torch==1.8.1 15 | - transformers==4.5.1 16 | - wandb==0.9.1 17 | -------------------------------------------------------------------------------- /gym/data/download_d4rl_datasets.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | import collections 5 | import pickle 6 | 7 | import d4rl 8 | 9 | 10 | datasets = [] 11 | 12 | for env_name in ['halfcheetah', 'hopper', 'walker2d']: 13 | for dataset_type in ['medium', 'medium-replay', 'expert']: 14 | name = f'{env_name}-{dataset_type}-v2' 15 | env = gym.make(name) 16 | dataset = env.get_dataset() 17 | 18 | N = dataset['rewards'].shape[0] 19 | data_ = collections.defaultdict(list) 20 | 21 | use_timeouts = False 22 | if 'timeouts' in dataset: 23 | use_timeouts = True 24 | 25 | episode_step = 0 26 | paths = [] 27 | for i in range(N): 28 | done_bool = bool(dataset['terminals'][i]) 29 | if use_timeouts: 30 | final_timestep = dataset['timeouts'][i] 31 | else: 32 | final_timestep = (episode_step == 1000-1) 33 | for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']: 34 | data_[k].append(dataset[k][i]) 35 | if done_bool or final_timestep: 36 | episode_step = 0 37 | episode_data = {} 38 | for k in data_: 39 | episode_data[k] = np.array(data_[k]) 40 | paths.append(episode_data) 41 | data_ = collections.defaultdict(list) 42 | episode_step += 1 43 | 44 | returns = np.array([np.sum(p['rewards']) for p in paths]) 45 | num_samples = np.sum([p['rewards'].shape[0] for p in paths]) 46 | print(f'Number of samples collected: {num_samples}') 47 | print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}') 48 | 49 | with open(f'{name}.pkl', 'wb') as f: 50 | pickle.dump(paths, f) 51 | -------------------------------------------------------------------------------- /gym/decision_transformer/envs/assets/reacher_2d.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 34 | -------------------------------------------------------------------------------- /gym/decision_transformer/envs/reacher_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | 5 | import os 6 | 7 | 8 | class Reacher2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): 9 | 10 | def __init__(self): 11 | self.fingertip_sid = 0 12 | self.target_bid = 0 13 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 14 | mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/reacher_2d.xml', 15) 15 | self.fingertip_sid = self.sim.model.site_name2id('fingertip') 16 | self.target_bid = self.sim.model.body_name2id('target') 17 | utils.EzPickle.__init__(self) 18 | 19 | def step(self, action): 20 | action = np.clip(action, -1.0, 1.0) 21 | self.do_simulation(action, self.frame_skip) 22 | tip = self.data.site_xpos[self.fingertip_sid][:2] 23 | tar = self.data.body_xpos[self.target_bid][:2] 24 | dist = np.sum(np.abs(tip - tar)) 25 | reward_dist = 0. # - 0.1 * dist 26 | reward_ctrl = 0.0 27 | reward_bonus = 1.0 if dist < 0.1 else 0.0 28 | reward = reward_bonus + reward_ctrl + reward_dist 29 | done = False 30 | ob = self._get_obs() 31 | return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl, reward_bonus=reward_bonus) 32 | 33 | def _get_obs(self): 34 | theta = self.data.qpos.ravel() 35 | tip = self.data.site_xpos[self.fingertip_sid][:2] 36 | tar = self.data.body_xpos[self.target_bid][:2] 37 | return np.concatenate([ 38 | # self.data.qpos.flat, 39 | np.sin(theta), 40 | np.cos(theta), 41 | self.dt * self.data.qvel.ravel(), 42 | tip, 43 | tar, 44 | tip-tar, 45 | ]) 46 | 47 | def reset_model(self): 48 | # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos 49 | # qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 50 | qpos = self.np_random.uniform(low=-2.0, high=2.0, size=self.model.nq) 51 | qvel = self.init_qvel * 0.0 52 | while True: 53 | self.goal = self.np_random.uniform(low=-1.5, high=1.5, size=2) 54 | if np.linalg.norm(self.goal) <= 1.0 and np.linalg.norm(self.goal) >= 0.5: 55 | break 56 | self.set_state(qpos, qvel) 57 | self.model.body_pos[self.target_bid][:2] = self.goal 58 | self.sim.forward() 59 | return self._get_obs() 60 | 61 | def viewer_setup(self): 62 | self.viewer.cam.distance = self.model.stat.extent * 5.0 63 | -------------------------------------------------------------------------------- /gym/decision_transformer/evaluation/evaluate_episodes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def evaluate_episode( 6 | env, 7 | state_dim, 8 | act_dim, 9 | model, 10 | max_ep_len=1000, 11 | device='cuda', 12 | target_return=None, 13 | mode='normal', 14 | state_mean=0., 15 | state_std=1., 16 | ): 17 | 18 | model.eval() 19 | model.to(device=device) 20 | 21 | state_mean = torch.from_numpy(state_mean).to(device=device) 22 | state_std = torch.from_numpy(state_std).to(device=device) 23 | 24 | state = env.reset() 25 | 26 | # we keep all the histories on the device 27 | # note that the latest action and reward will be "padding" 28 | states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) 29 | actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) 30 | rewards = torch.zeros(0, device=device, dtype=torch.float32) 31 | target_return = torch.tensor(target_return, device=device, dtype=torch.float32) 32 | sim_states = [] 33 | 34 | episode_return, episode_length = 0, 0 35 | for t in range(max_ep_len): 36 | 37 | # add padding 38 | actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) 39 | rewards = torch.cat([rewards, torch.zeros(1, device=device)]) 40 | 41 | action = model.get_action( 42 | (states.to(dtype=torch.float32) - state_mean) / state_std, 43 | actions.to(dtype=torch.float32), 44 | rewards.to(dtype=torch.float32), 45 | target_return=target_return, 46 | ) 47 | actions[-1] = action 48 | action = action.detach().cpu().numpy() 49 | 50 | state, reward, done, _ = env.step(action) 51 | 52 | cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) 53 | states = torch.cat([states, cur_state], dim=0) 54 | rewards[-1] = reward 55 | 56 | episode_return += reward 57 | episode_length += 1 58 | 59 | if done: 60 | break 61 | 62 | return episode_return, episode_length 63 | 64 | 65 | def evaluate_episode_rtg( 66 | env, 67 | state_dim, 68 | act_dim, 69 | model, 70 | max_ep_len=1000, 71 | scale=1000., 72 | state_mean=0., 73 | state_std=1., 74 | device='cuda', 75 | target_return=None, 76 | mode='normal', 77 | ): 78 | 79 | model.eval() 80 | model.to(device=device) 81 | 82 | state_mean = torch.from_numpy(state_mean).to(device=device) 83 | state_std = torch.from_numpy(state_std).to(device=device) 84 | 85 | state = env.reset() 86 | if mode == 'noise': 87 | state = state + np.random.normal(0, 0.1, size=state.shape) 88 | 89 | # we keep all the histories on the device 90 | # note that the latest action and reward will be "padding" 91 | states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) 92 | actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) 93 | rewards = torch.zeros(0, device=device, dtype=torch.float32) 94 | 95 | ep_return = target_return 96 | target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1) 97 | timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1) 98 | 99 | sim_states = [] 100 | 101 | episode_return, episode_length = 0, 0 102 | for t in range(max_ep_len): 103 | 104 | # add padding 105 | actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) 106 | rewards = torch.cat([rewards, torch.zeros(1, device=device)]) 107 | 108 | action = model.get_action( 109 | (states.to(dtype=torch.float32) - state_mean) / state_std, 110 | actions.to(dtype=torch.float32), 111 | rewards.to(dtype=torch.float32), 112 | target_return.to(dtype=torch.float32), 113 | timesteps.to(dtype=torch.long), 114 | ) 115 | actions[-1] = action 116 | action = action.detach().cpu().numpy() 117 | 118 | state, reward, done, _ = env.step(action) 119 | 120 | cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) 121 | states = torch.cat([states, cur_state], dim=0) 122 | rewards[-1] = reward 123 | 124 | if mode != 'delayed': 125 | pred_return = target_return[0,-1] - (reward/scale) 126 | else: 127 | pred_return = target_return[0,-1] 128 | target_return = torch.cat( 129 | [target_return, pred_return.reshape(1, 1)], dim=1) 130 | timesteps = torch.cat( 131 | [timesteps, 132 | torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1) 133 | 134 | episode_return += reward 135 | episode_length += 1 136 | 137 | if done: 138 | break 139 | 140 | return episode_return, episode_length 141 | -------------------------------------------------------------------------------- /gym/decision_transformer/models/decision_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import transformers 6 | 7 | from decision_transformer.models.model import TrajectoryModel 8 | from decision_transformer.models.trajectory_gpt2 import GPT2Model 9 | 10 | 11 | class DecisionTransformer(TrajectoryModel): 12 | 13 | """ 14 | This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...) 15 | """ 16 | 17 | def __init__( 18 | self, 19 | state_dim, 20 | act_dim, 21 | hidden_size, 22 | max_length=None, 23 | max_ep_len=4096, 24 | action_tanh=True, 25 | **kwargs 26 | ): 27 | super().__init__(state_dim, act_dim, max_length=max_length) 28 | 29 | self.hidden_size = hidden_size 30 | config = transformers.GPT2Config( 31 | vocab_size=1, # doesn't matter -- we don't use the vocab 32 | n_embd=hidden_size, 33 | **kwargs 34 | ) 35 | 36 | # note: the only difference between this GPT2Model and the default Huggingface version 37 | # is that the positional embeddings are removed (since we'll add those ourselves) 38 | self.transformer = GPT2Model(config) 39 | 40 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size) 41 | self.embed_return = torch.nn.Linear(1, hidden_size) 42 | self.embed_state = torch.nn.Linear(self.state_dim, hidden_size) 43 | self.embed_action = torch.nn.Linear(self.act_dim, hidden_size) 44 | 45 | self.embed_ln = nn.LayerNorm(hidden_size) 46 | 47 | # note: we don't predict states or returns for the paper 48 | self.predict_state = torch.nn.Linear(hidden_size, self.state_dim) 49 | self.predict_action = nn.Sequential( 50 | *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else [])) 51 | ) 52 | self.predict_return = torch.nn.Linear(hidden_size, 1) 53 | 54 | def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): 55 | 56 | batch_size, seq_length = states.shape[0], states.shape[1] 57 | 58 | if attention_mask is None: 59 | # attention mask for GPT: 1 if can be attended to, 0 if not 60 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) 61 | 62 | # embed each modality with a different head 63 | state_embeddings = self.embed_state(states) 64 | action_embeddings = self.embed_action(actions) 65 | returns_embeddings = self.embed_return(returns_to_go) 66 | time_embeddings = self.embed_timestep(timesteps) 67 | 68 | # time embeddings are treated similar to positional embeddings 69 | state_embeddings = state_embeddings + time_embeddings 70 | action_embeddings = action_embeddings + time_embeddings 71 | returns_embeddings = returns_embeddings + time_embeddings 72 | 73 | # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) 74 | # which works nice in an autoregressive sense since states predict actions 75 | stacked_inputs = torch.stack( 76 | (returns_embeddings, state_embeddings, action_embeddings), dim=1 77 | ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size) 78 | stacked_inputs = self.embed_ln(stacked_inputs) 79 | 80 | # to make the attention mask fit the stacked inputs, have to stack it as well 81 | stacked_attention_mask = torch.stack( 82 | (attention_mask, attention_mask, attention_mask), dim=1 83 | ).permute(0, 2, 1).reshape(batch_size, 3*seq_length) 84 | 85 | # we feed in the input embeddings (not word indices as in NLP) to the model 86 | transformer_outputs = self.transformer( 87 | inputs_embeds=stacked_inputs, 88 | attention_mask=stacked_attention_mask, 89 | ) 90 | x = transformer_outputs['last_hidden_state'] 91 | 92 | # reshape x so that the second dimension corresponds to the original 93 | # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t 94 | x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) 95 | 96 | # get predictions 97 | return_preds = self.predict_return(x[:,2]) # predict next return given state and action 98 | state_preds = self.predict_state(x[:,2]) # predict next state given state and action 99 | action_preds = self.predict_action(x[:,1]) # predict next action given state 100 | 101 | return state_preds, action_preds, return_preds 102 | 103 | def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs): 104 | # we don't care about the past rewards in this model 105 | 106 | states = states.reshape(1, -1, self.state_dim) 107 | actions = actions.reshape(1, -1, self.act_dim) 108 | returns_to_go = returns_to_go.reshape(1, -1, 1) 109 | timesteps = timesteps.reshape(1, -1) 110 | 111 | if self.max_length is not None: 112 | states = states[:,-self.max_length:] 113 | actions = actions[:,-self.max_length:] 114 | returns_to_go = returns_to_go[:,-self.max_length:] 115 | timesteps = timesteps[:,-self.max_length:] 116 | 117 | # pad all tokens to sequence length 118 | attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])]) 119 | attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) 120 | states = torch.cat( 121 | [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states], 122 | dim=1).to(dtype=torch.float32) 123 | actions = torch.cat( 124 | [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim), 125 | device=actions.device), actions], 126 | dim=1).to(dtype=torch.float32) 127 | returns_to_go = torch.cat( 128 | [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go], 129 | dim=1).to(dtype=torch.float32) 130 | timesteps = torch.cat( 131 | [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps], 132 | dim=1 133 | ).to(dtype=torch.long) 134 | else: 135 | attention_mask = None 136 | 137 | _, action_preds, return_preds = self.forward( 138 | states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs) 139 | 140 | return action_preds[0,-1] 141 | -------------------------------------------------------------------------------- /gym/decision_transformer/models/mlp_bc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from decision_transformer.models.model import TrajectoryModel 6 | 7 | 8 | class MLPBCModel(TrajectoryModel): 9 | 10 | """ 11 | Simple MLP that predicts next action a from past states s. 12 | """ 13 | 14 | def __init__(self, state_dim, act_dim, hidden_size, n_layer, dropout=0.1, max_length=1, **kwargs): 15 | super().__init__(state_dim, act_dim) 16 | 17 | self.hidden_size = hidden_size 18 | self.max_length = max_length 19 | 20 | layers = [nn.Linear(max_length*self.state_dim, hidden_size)] 21 | for _ in range(n_layer-1): 22 | layers.extend([ 23 | nn.ReLU(), 24 | nn.Dropout(dropout), 25 | nn.Linear(hidden_size, hidden_size) 26 | ]) 27 | layers.extend([ 28 | nn.ReLU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_size, self.act_dim), 31 | nn.Tanh(), 32 | ]) 33 | 34 | self.model = nn.Sequential(*layers) 35 | 36 | def forward(self, states, actions, rewards, attention_mask=None, target_return=None): 37 | 38 | states = states[:,-self.max_length:].reshape(states.shape[0], -1) # concat states 39 | actions = self.model(states).reshape(states.shape[0], 1, self.act_dim) 40 | 41 | return None, actions, None 42 | 43 | def get_action(self, states, actions, rewards, **kwargs): 44 | states = states.reshape(1, -1, self.state_dim) 45 | if states.shape[1] < self.max_length: 46 | states = torch.cat( 47 | [torch.zeros((1, self.max_length-states.shape[1], self.state_dim), 48 | dtype=torch.float32, device=states.device), states], dim=1) 49 | states = states.to(dtype=torch.float32) 50 | _, actions, _ = self.forward(states, None, None, **kwargs) 51 | return actions[0,-1] 52 | -------------------------------------------------------------------------------- /gym/decision_transformer/models/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TrajectoryModel(nn.Module): 7 | 8 | def __init__(self, state_dim, act_dim, max_length=None): 9 | super().__init__() 10 | 11 | self.state_dim = state_dim 12 | self.act_dim = act_dim 13 | self.max_length = max_length 14 | 15 | def forward(self, states, actions, rewards, masks=None, attention_mask=None): 16 | # "masked" tokens or unspecified inputs can be passed in as None 17 | return None, None, None 18 | 19 | def get_action(self, states, actions, rewards, **kwargs): 20 | # these will come as tensors on the correct device 21 | return torch.zeros_like(actions[-1]) 22 | -------------------------------------------------------------------------------- /gym/decision_transformer/models/trajectory_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT-2 model.""" 17 | 18 | import os 19 | from dataclasses import dataclass 20 | from typing import List, Optional, Tuple 21 | 22 | import torch 23 | import torch.nn as nn 24 | from torch.nn import CrossEntropyLoss, MSELoss 25 | 26 | from transformers.activations import ACT2FN 27 | from transformers.file_utils import ( 28 | ModelOutput, 29 | add_code_sample_docstrings, 30 | add_start_docstrings, 31 | add_start_docstrings_to_model_forward, 32 | replace_return_docstrings, 33 | ) 34 | from transformers.modeling_outputs import ( 35 | BaseModelOutputWithPastAndCrossAttentions, 36 | ) 37 | from transformers.modeling_utils import ( 38 | Conv1D, 39 | PreTrainedModel, 40 | SequenceSummary, 41 | find_pruneable_heads_and_indices, 42 | prune_conv1d_layer, 43 | ) 44 | from transformers.utils import logging 45 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 46 | from transformers.models.gpt2.configuration_gpt2 import GPT2Config 47 | 48 | logger = logging.get_logger(__name__) 49 | 50 | _CONFIG_FOR_DOC = "GPT2Config" 51 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer" 52 | 53 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ 54 | "gpt2", 55 | "gpt2-medium", 56 | "gpt2-large", 57 | "gpt2-xl", 58 | "distilgpt2", 59 | # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 60 | ] 61 | 62 | 63 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): 64 | """Load tf checkpoints in a pytorch model""" 65 | try: 66 | import re 67 | 68 | import tensorflow as tf 69 | except ImportError: 70 | logger.error( 71 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions." 73 | ) 74 | raise 75 | tf_path = os.path.abspath(gpt2_checkpoint_path) 76 | logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) 77 | # Load weights from TF model 78 | init_vars = tf.train.list_variables(tf_path) 79 | names = [] 80 | arrays = [] 81 | for name, shape in init_vars: 82 | logger.info("Loading TF weight {} with shape {}".format(name, shape)) 83 | array = tf.train.load_variable(tf_path, name) 84 | names.append(name) 85 | arrays.append(array.squeeze()) 86 | 87 | for name, array in zip(names, arrays): 88 | name = name[6:] # skip "model/" 89 | name = name.split("/") 90 | pointer = model 91 | for m_name in name: 92 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 93 | scope_names = re.split(r"(\d+)", m_name) 94 | else: 95 | scope_names = [m_name] 96 | if scope_names[0] == "w" or scope_names[0] == "g": 97 | pointer = getattr(pointer, "weight") 98 | elif scope_names[0] == "b": 99 | pointer = getattr(pointer, "bias") 100 | elif scope_names[0] == "wpe" or scope_names[0] == "wte": 101 | pointer = getattr(pointer, scope_names[0]) 102 | pointer = getattr(pointer, "weight") 103 | else: 104 | pointer = getattr(pointer, scope_names[0]) 105 | if len(scope_names) >= 2: 106 | num = int(scope_names[1]) 107 | pointer = pointer[num] 108 | try: 109 | assert ( 110 | pointer.shape == array.shape 111 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 112 | except AssertionError as e: 113 | e.args += (pointer.shape, array.shape) 114 | raise 115 | logger.info("Initialize PyTorch weight {}".format(name)) 116 | pointer.data = torch.from_numpy(array) 117 | return model 118 | 119 | 120 | class Attention(nn.Module): 121 | def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): 122 | super().__init__() 123 | 124 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 125 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 126 | assert n_state % config.n_head == 0 127 | self.register_buffer( 128 | "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx) 129 | ) 130 | self.register_buffer("masked_bias", torch.tensor(-1e4)) 131 | self.n_head = config.n_head 132 | self.split_size = n_state 133 | self.scale = scale 134 | self.is_cross_attention = is_cross_attention 135 | if self.is_cross_attention: 136 | self.c_attn = Conv1D(2 * n_state, nx) 137 | self.q_attn = Conv1D(n_state, nx) 138 | else: 139 | self.c_attn = Conv1D(3 * n_state, nx) 140 | self.c_proj = Conv1D(n_state, nx) 141 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 142 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 143 | self.pruned_heads = set() 144 | 145 | def prune_heads(self, heads): 146 | if len(heads) == 0: 147 | return 148 | heads, index = find_pruneable_heads_and_indices( 149 | heads, self.n_head, self.split_size // self.n_head, self.pruned_heads 150 | ) 151 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 152 | 153 | # Prune conv1d layers 154 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 155 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 156 | 157 | # Update hyper params 158 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 159 | self.n_head = self.n_head - len(heads) 160 | self.pruned_heads = self.pruned_heads.union(heads) 161 | 162 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): 163 | w = torch.matmul(q, k) 164 | if self.scale: 165 | w = w / (float(v.size(-1)) ** 0.5) 166 | nd, ns = w.size(-2), w.size(-1) 167 | 168 | if not self.is_cross_attention: 169 | # if only "normal" attention layer implements causal mask 170 | mask = self.bias[:, :, ns - nd: ns, :ns] 171 | w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) 172 | 173 | if attention_mask is not None: 174 | # Apply the attention mask 175 | w = w + attention_mask 176 | 177 | w = nn.Softmax(dim=-1)(w) 178 | w = self.attn_dropout(w) 179 | 180 | # Mask heads if we want to 181 | if head_mask is not None: 182 | w = w * head_mask 183 | 184 | outputs = [torch.matmul(w, v)] 185 | if output_attentions: 186 | outputs.append(w) 187 | return outputs 188 | 189 | def merge_heads(self, x): 190 | x = x.permute(0, 2, 1, 3).contiguous() 191 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 192 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 193 | 194 | def split_heads(self, x, k=False): 195 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 196 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 197 | if k: 198 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 199 | else: 200 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 201 | 202 | def forward( 203 | self, 204 | hidden_states, 205 | layer_past=None, 206 | attention_mask=None, 207 | head_mask=None, 208 | encoder_hidden_states=None, 209 | encoder_attention_mask=None, 210 | use_cache=False, 211 | output_attentions=False, 212 | ): 213 | if encoder_hidden_states is not None: 214 | assert hasattr( 215 | self, "q_attn" 216 | ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." 217 | query = self.q_attn(hidden_states) 218 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 219 | attention_mask = encoder_attention_mask 220 | else: 221 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 222 | 223 | query = self.split_heads(query) 224 | key = self.split_heads(key, k=True) 225 | value = self.split_heads(value) 226 | if layer_past is not None: 227 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 228 | key = torch.cat((past_key, key), dim=-1) 229 | value = torch.cat((past_value, value), dim=-2) 230 | 231 | if use_cache is True: 232 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 233 | else: 234 | present = (None,) 235 | 236 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) 237 | a = attn_outputs[0] 238 | 239 | a = self.merge_heads(a) 240 | a = self.c_proj(a) 241 | a = self.resid_dropout(a) 242 | 243 | outputs = [a, present] + attn_outputs[1:] 244 | return outputs # a, present, (attentions) 245 | 246 | 247 | class MLP(nn.Module): 248 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 249 | super().__init__() 250 | nx = config.n_embd 251 | self.c_fc = Conv1D(n_state, nx) 252 | self.c_proj = Conv1D(nx, n_state) 253 | self.act = ACT2FN[config.activation_function] 254 | self.dropout = nn.Dropout(config.resid_pdrop) 255 | 256 | def forward(self, x): 257 | h = self.act(self.c_fc(x)) 258 | h2 = self.c_proj(h) 259 | return self.dropout(h2) 260 | 261 | 262 | class AdapterMLP(nn.Module): 263 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 264 | super().__init__() 265 | nx = config.n_embd 266 | self.c_fc = Conv1D(n_state, nx) 267 | self.c_proj = Conv1D(nx, n_state) 268 | self.act = ACT2FN[config.activation_function] 269 | self.dropout = nn.Dropout(config.resid_pdrop) 270 | 271 | def forward(self, x): 272 | h = self.act(self.c_fc(x)) 273 | h2 = self.c_proj(h) 274 | return self.dropout(h2) 275 | 276 | 277 | class Block(nn.Module): 278 | def __init__(self, n_ctx, config, scale=False): 279 | super().__init__() 280 | hidden_size = config.n_embd 281 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 282 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 283 | self.attn = Attention(hidden_size, n_ctx, config, scale) 284 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 285 | # self.adapter_ln = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 286 | if config.add_cross_attention: 287 | self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) 288 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 289 | self.mlp = MLP(inner_dim, config) 290 | # self.adapter_mlp = AdapterMLP(512, config) # ADAPTER 291 | 292 | def forward( 293 | self, 294 | hidden_states, 295 | layer_past=None, 296 | attention_mask=None, 297 | head_mask=None, 298 | encoder_hidden_states=None, 299 | encoder_attention_mask=None, 300 | use_cache=False, 301 | output_attentions=False, 302 | ): 303 | attn_outputs = self.attn( 304 | self.ln_1(hidden_states), 305 | layer_past=layer_past, 306 | attention_mask=attention_mask, 307 | head_mask=head_mask, 308 | use_cache=use_cache, 309 | output_attentions=output_attentions, 310 | ) 311 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 312 | outputs = attn_outputs[1:] 313 | # residual connection 314 | hidden_states = attn_output + hidden_states 315 | 316 | if encoder_hidden_states is not None: 317 | # add one self-attention block for cross-attention 318 | assert hasattr( 319 | self, "crossattention" 320 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 321 | cross_attn_outputs = self.crossattention( 322 | self.ln_cross_attn(hidden_states), 323 | attention_mask=attention_mask, 324 | head_mask=head_mask, 325 | encoder_hidden_states=encoder_hidden_states, 326 | encoder_attention_mask=encoder_attention_mask, 327 | output_attentions=output_attentions, 328 | ) 329 | attn_output = cross_attn_outputs[0] 330 | # residual connection 331 | hidden_states = hidden_states + attn_output 332 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 333 | 334 | feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) 335 | # residual connection 336 | hidden_states = hidden_states + feed_forward_hidden_states 337 | # hidden_states = hidden_states + self.adapter_ln(self.adapter_mlp(hidden_states)) 338 | 339 | outputs = [hidden_states] + outputs 340 | return outputs # hidden_states, present, (attentions, cross_attentions) 341 | 342 | 343 | class GPT2PreTrainedModel(PreTrainedModel): 344 | """ 345 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 346 | models. 347 | """ 348 | 349 | config_class = GPT2Config 350 | load_tf_weights = load_tf_weights_in_gpt2 351 | base_model_prefix = "transformer" 352 | 353 | def __init__(self, *inputs, **kwargs): 354 | super().__init__(*inputs, **kwargs) 355 | 356 | def _init_weights(self, module): 357 | """Initialize the weights.""" 358 | if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): 359 | # Slightly different from the TF version which uses truncated_normal for initialization 360 | # cf https://github.com/pytorch/pytorch/pull/5617 361 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 362 | if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: 363 | module.bias.data.zero_() 364 | elif isinstance(module, nn.LayerNorm): 365 | module.bias.data.zero_() 366 | module.weight.data.fill_(1.0) 367 | # module.weight.data.fill_(.01) # KL: Adapter change 368 | 369 | 370 | @dataclass 371 | class GPT2DoubleHeadsModelOutput(ModelOutput): 372 | """ 373 | Base class for outputs of models predicting if two sentences are consecutive or not. 374 | Args: 375 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): 376 | Language modeling loss. 377 | mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided): 378 | Multiple choice classification loss. 379 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): 380 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 381 | mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): 382 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). 383 | past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): 384 | List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, 385 | batch_size, num_heads, sequence_length, embed_size_per_head)`). 386 | Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see 387 | :obj:`past_key_values` input) to speed up sequential decoding. 388 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): 389 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 390 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 391 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 392 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 393 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, 394 | sequence_length, sequence_length)`. 395 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 396 | heads. 397 | """ 398 | 399 | loss: Optional[torch.FloatTensor] = None 400 | mc_loss: Optional[torch.FloatTensor] = None 401 | logits: torch.FloatTensor = None 402 | mc_logits: torch.FloatTensor = None 403 | past_key_values: Optional[List[torch.FloatTensor]] = None 404 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 405 | attentions: Optional[Tuple[torch.FloatTensor]] = None 406 | 407 | 408 | GPT2_START_DOCSTRING = r""" 409 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 410 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 411 | pruning heads etc.) 412 | This model is also a PyTorch `torch.nn.Module `__ 413 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 414 | general usage and behavior. 415 | Parameters: 416 | config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model. 417 | Initializing with a config file does not load the weights associated with the model, only the 418 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 419 | weights. 420 | """ 421 | 422 | GPT2_INPUTS_DOCSTRING = r""" 423 | Args: 424 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): 425 | :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else 426 | ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input 427 | sequence tokens in the vocabulary. 428 | If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be 429 | passed as ``input_ids``. 430 | Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See 431 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 432 | details. 433 | `What are input IDs? <../glossary.html#input-ids>`__ 434 | past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): 435 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 436 | :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which 437 | have their past given to this model should not be passed as ``input_ids`` as they have already been 438 | computed. 439 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 440 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 441 | - 1 for tokens that are **not masked**, 442 | - 0 for tokens that are **masked**. 443 | `What are attention masks? <../glossary.html#attention-mask>`__ 444 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`): 445 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 446 | 1]``: 447 | - 0 corresponds to a `sentence A` token, 448 | - 1 corresponds to a `sentence B` token. 449 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 450 | position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 451 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, 452 | config.max_position_embeddings - 1]``. 453 | `What are position IDs? <../glossary.html#position-ids>`_ 454 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): 455 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: 456 | - 1 indicates the head is **not masked**, 457 | - 0 indicates the head is **masked**. 458 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 459 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 460 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 461 | vectors than the model's internal embedding lookup matrix. 462 | If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see 463 | :obj:`past_key_values`). 464 | use_cache (:obj:`bool`, `optional`): 465 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 466 | decoding (see :obj:`past_key_values`). 467 | output_attentions (:obj:`bool`, `optional`): 468 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 469 | tensors for more detail. 470 | output_hidden_states (:obj:`bool`, `optional`): 471 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 472 | more detail. 473 | return_dict (:obj:`bool`, `optional`): 474 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 475 | """ 476 | PARALLELIZE_DOCSTRING = r""" 477 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 478 | it will evenly distribute blocks across all devices. 479 | Args: 480 | device_map (:obj:`Dict[int, list]`, optional, defaults to None): 481 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 482 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 483 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the 484 | following number of attention modules: 485 | - gpt2: 12 486 | - gpt2-medium: 24 487 | - gpt2-large: 36 488 | - gpt2-xl: 48 489 | Example:: 490 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: 491 | model = GPT2LMHeadModel.from_pretrained('gpt2-xl') 492 | device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8], 493 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], 494 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], 495 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]} 496 | model.parallelize(device_map) 497 | """ 498 | DEPARALLELIZE_DOCSTRING = r""" 499 | Moves the model to cpu from a model parallel state. 500 | Example:: 501 | # On a 4 GPU machine with gpt2-large: 502 | model = GPT2LMHeadModel.from_pretrained('gpt2-large') 503 | device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7], 504 | 1: [8, 9, 10, 11, 12, 13, 14, 15], 505 | 2: [16, 17, 18, 19, 20, 21, 22, 23], 506 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]} 507 | model.parallelize(device_map) # Splits the model across several devices 508 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 509 | """ 510 | 511 | 512 | @add_start_docstrings( 513 | "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", 514 | GPT2_START_DOCSTRING, 515 | ) 516 | class GPT2Model(GPT2PreTrainedModel): 517 | def __init__(self, config): 518 | super().__init__(config) 519 | 520 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 521 | # self.wpe = nn.Embedding(config.n_positions, config.n_embd) 522 | self.drop = nn.Dropout(config.embd_pdrop) 523 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) 524 | self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 525 | 526 | self.init_weights() 527 | # Model parallel 528 | self.model_parallel = False 529 | self.device_map = None 530 | 531 | self.use_layers = None 532 | 533 | def set_layers(self, num_layers): 534 | assert 1 <= num_layers <= len(self.h) 535 | if num_layers is not None: 536 | num_layers -= 1 537 | self.use_layers = num_layers 538 | 539 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 540 | def parallelize(self, device_map=None): 541 | # Check validity of device_map 542 | self.device_map = ( 543 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 544 | ) 545 | assert_device_map(self.device_map, len(self.h)) 546 | self.model_parallel = True 547 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 548 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 549 | self.wte = self.wte.to(self.first_device) 550 | self.wpe = self.wpe.to(self.first_device) 551 | # Load onto devices 552 | for k, v in self.device_map.items(): 553 | for block in v: 554 | cuda_device = "cuda:" + str(k) 555 | self.h[block] = self.h[block].to(cuda_device) 556 | # ln_f to last 557 | self.ln_f = self.ln_f.to(self.last_device) 558 | 559 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 560 | def deparallelize(self): 561 | self.model_parallel = False 562 | self.device_map = None 563 | self.first_device = "cpu" 564 | self.last_device = "cpu" 565 | self.wte = self.wte.to("cpu") 566 | self.wpe = self.wpe.to("cpu") 567 | for index in range(len(self.h)): 568 | self.h[index] = self.h[index].to("cpu") 569 | self.ln_f = self.ln_f.to("cpu") 570 | torch.cuda.empty_cache() 571 | 572 | def get_input_embeddings(self): 573 | return self.wte 574 | 575 | def set_input_embeddings(self, new_embeddings): 576 | self.wte = new_embeddings 577 | 578 | def _prune_heads(self, heads_to_prune): 579 | """ 580 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 581 | """ 582 | for layer, heads in heads_to_prune.items(): 583 | self.h[layer].attn.prune_heads(heads) 584 | 585 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 586 | @add_code_sample_docstrings( 587 | tokenizer_class=_TOKENIZER_FOR_DOC, 588 | checkpoint="gpt2", 589 | output_type=BaseModelOutputWithPastAndCrossAttentions, 590 | config_class=_CONFIG_FOR_DOC, 591 | ) 592 | def forward( 593 | self, 594 | input_ids=None, 595 | past_key_values=None, 596 | attention_mask=None, 597 | token_type_ids=None, 598 | position_ids=None, 599 | head_mask=None, 600 | inputs_embeds=None, 601 | encoder_hidden_states=None, 602 | encoder_attention_mask=None, 603 | use_cache=None, 604 | output_attentions=None, 605 | output_hidden_states=None, 606 | return_dict=None, 607 | ): 608 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 609 | output_hidden_states = ( 610 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 611 | ) 612 | use_cache = use_cache if use_cache is not None else self.config.use_cache 613 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 614 | 615 | if input_ids is not None and inputs_embeds is not None: 616 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 617 | elif input_ids is not None: 618 | input_shape = input_ids.size() 619 | input_ids = input_ids.view(-1, input_shape[-1]) 620 | batch_size = input_ids.shape[0] 621 | elif inputs_embeds is not None: 622 | input_shape = inputs_embeds.size()[:-1] 623 | batch_size = inputs_embeds.shape[0] 624 | else: 625 | raise ValueError("You have to specify either input_ids or inputs_embeds") 626 | 627 | if token_type_ids is not None: 628 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 629 | if position_ids is not None: 630 | position_ids = position_ids.view(-1, input_shape[-1]) 631 | 632 | if past_key_values is None: 633 | past_length = 0 634 | past_key_values = [None] * len(self.h) 635 | else: 636 | past_length = past_key_values[0][0].size(-2) 637 | if position_ids is None: 638 | device = input_ids.device if input_ids is not None else inputs_embeds.device 639 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 640 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 641 | 642 | # Attention mask. 643 | if attention_mask is not None: 644 | assert batch_size > 0, "batch_size has to be defined and > 0" 645 | attention_mask = attention_mask.view(batch_size, -1) 646 | # We create a 3D attention mask from a 2D tensor mask. 647 | # Sizes are [batch_size, 1, 1, to_seq_length] 648 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 649 | # this attention mask is more simple than the triangular masking of causal attention 650 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 651 | attention_mask = attention_mask[:, None, None, :] 652 | 653 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 654 | # masked positions, this operation will create a tensor which is 0.0 for 655 | # positions we want to attend and -10000.0 for masked positions. 656 | # Since we are adding it to the raw scores before the softmax, this is 657 | # effectively the same as removing these entirely. 658 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 659 | attention_mask = (1.0 - attention_mask) * -10000.0 660 | 661 | # If a 2D ou 3D attention mask is provided for the cross-attention 662 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 663 | if self.config.add_cross_attention and encoder_hidden_states is not None: 664 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 665 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 666 | if encoder_attention_mask is None: 667 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 668 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 669 | else: 670 | encoder_attention_mask = None 671 | 672 | # Prepare head mask if needed 673 | # 1.0 in head_mask indicate we keep the head 674 | # attention_probs has shape bsz x n_heads x N x N 675 | # head_mask has shape n_layer x batch x n_heads x N x N 676 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 677 | 678 | if inputs_embeds is None: 679 | inputs_embeds = self.wte(input_ids) 680 | # position_embeds = self.wpe(position_ids) 681 | hidden_states = inputs_embeds # + position_embeds 682 | 683 | if token_type_ids is not None: 684 | token_type_embeds = self.wte(token_type_ids) 685 | hidden_states = hidden_states + token_type_embeds 686 | 687 | hidden_states = self.drop(hidden_states) 688 | 689 | output_shape = input_shape + (hidden_states.size(-1),) 690 | 691 | presents = () if use_cache else None 692 | all_self_attentions = () if output_attentions else None 693 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 694 | all_hidden_states = () if output_hidden_states else None 695 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 696 | 697 | if self.use_layers is not None and i >= self.use_layers: 698 | break 699 | 700 | # Model parallel 701 | if self.model_parallel: 702 | torch.cuda.set_device(hidden_states.device) 703 | # Ensure layer_past is on same device as hidden_states (might not be correct) 704 | if layer_past is not None: 705 | layer_past = layer_past.to(hidden_states.device) 706 | # Ensure that attention_mask is always on the same device as hidden_states 707 | if attention_mask is not None: 708 | attention_mask = attention_mask.to(hidden_states.device) 709 | if isinstance(head_mask, torch.Tensor): 710 | head_mask = head_mask.to(hidden_states.device) 711 | if output_hidden_states: 712 | all_hidden_states = all_hidden_states + (hidden_states,) 713 | 714 | if getattr(self.config, "gradient_checkpointing", False): 715 | 716 | def create_custom_forward(module): 717 | def custom_forward(*inputs): 718 | # checkpointing only works with tuple returns, not with lists 719 | return tuple(output for output in module(*inputs, use_cache, output_attentions)) 720 | 721 | return custom_forward 722 | 723 | outputs = torch.utils.checkpoint.checkpoint( 724 | create_custom_forward(block), 725 | hidden_states, 726 | layer_past, 727 | attention_mask, 728 | head_mask[i], 729 | encoder_hidden_states, 730 | encoder_attention_mask, 731 | ) 732 | else: 733 | outputs = block( 734 | hidden_states, 735 | layer_past=layer_past, 736 | attention_mask=attention_mask, 737 | head_mask=head_mask[i], 738 | encoder_hidden_states=encoder_hidden_states, 739 | encoder_attention_mask=encoder_attention_mask, 740 | use_cache=use_cache, 741 | output_attentions=output_attentions, 742 | ) 743 | 744 | hidden_states, present = outputs[:2] 745 | if use_cache is True: 746 | presents = presents + (present,) 747 | 748 | if output_attentions: 749 | all_self_attentions = all_self_attentions + (outputs[2],) 750 | if self.config.add_cross_attention: 751 | all_cross_attentions = all_cross_attentions + (outputs[3],) 752 | 753 | # Model Parallel: If it's the last layer for that device, put things on the next device 754 | if self.model_parallel: 755 | for k, v in self.device_map.items(): 756 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 757 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 758 | 759 | hidden_states = self.ln_f(hidden_states) 760 | 761 | hidden_states = hidden_states.view(*output_shape) 762 | # Add last hidden state 763 | if output_hidden_states: 764 | all_hidden_states = all_hidden_states + (hidden_states,) 765 | 766 | if not return_dict: 767 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 768 | 769 | return BaseModelOutputWithPastAndCrossAttentions( 770 | last_hidden_state=hidden_states, 771 | past_key_values=presents, 772 | hidden_states=all_hidden_states, 773 | attentions=all_self_attentions, 774 | cross_attentions=all_cross_attentions, 775 | ) 776 | -------------------------------------------------------------------------------- /gym/decision_transformer/training/act_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from decision_transformer.training.trainer import Trainer 5 | 6 | 7 | class ActTrainer(Trainer): 8 | 9 | def train_step(self): 10 | states, actions, rewards, dones, rtg, _, attention_mask = self.get_batch(self.batch_size) 11 | state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards) 12 | 13 | state_preds, action_preds, reward_preds = self.model.forward( 14 | states, actions, rewards, attention_mask=attention_mask, target_return=rtg[:,0], 15 | ) 16 | 17 | act_dim = action_preds.shape[2] 18 | action_preds = action_preds.reshape(-1, act_dim) 19 | action_target = action_target[:,-1].reshape(-1, act_dim) 20 | 21 | loss = self.loss_fn( 22 | state_preds, action_preds, reward_preds, 23 | state_target, action_target, reward_target, 24 | ) 25 | self.optimizer.zero_grad() 26 | loss.backward() 27 | self.optimizer.step() 28 | 29 | return loss.detach().cpu().item() 30 | -------------------------------------------------------------------------------- /gym/decision_transformer/training/seq_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from decision_transformer.training.trainer import Trainer 5 | 6 | 7 | class SequenceTrainer(Trainer): 8 | 9 | def train_step(self): 10 | states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size) 11 | action_target = torch.clone(actions) 12 | 13 | state_preds, action_preds, reward_preds = self.model.forward( 14 | states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask, 15 | ) 16 | 17 | act_dim = action_preds.shape[2] 18 | action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 19 | action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 20 | 21 | loss = self.loss_fn( 22 | None, action_preds, None, 23 | None, action_target, None, 24 | ) 25 | 26 | self.optimizer.zero_grad() 27 | loss.backward() 28 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25) 29 | self.optimizer.step() 30 | 31 | with torch.no_grad(): 32 | self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item() 33 | 34 | return loss.detach().cpu().item() 35 | -------------------------------------------------------------------------------- /gym/decision_transformer/training/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import time 5 | 6 | 7 | class Trainer: 8 | 9 | def __init__(self, model, optimizer, batch_size, get_batch, loss_fn, scheduler=None, eval_fns=None): 10 | self.model = model 11 | self.optimizer = optimizer 12 | self.batch_size = batch_size 13 | self.get_batch = get_batch 14 | self.loss_fn = loss_fn 15 | self.scheduler = scheduler 16 | self.eval_fns = [] if eval_fns is None else eval_fns 17 | self.diagnostics = dict() 18 | 19 | self.start_time = time.time() 20 | 21 | def train_iteration(self, num_steps, iter_num=0, print_logs=False): 22 | 23 | train_losses = [] 24 | logs = dict() 25 | 26 | train_start = time.time() 27 | 28 | self.model.train() 29 | for _ in range(num_steps): 30 | train_loss = self.train_step() 31 | train_losses.append(train_loss) 32 | if self.scheduler is not None: 33 | self.scheduler.step() 34 | 35 | logs['time/training'] = time.time() - train_start 36 | 37 | eval_start = time.time() 38 | 39 | self.model.eval() 40 | for eval_fn in self.eval_fns: 41 | outputs = eval_fn(self.model) 42 | for k, v in outputs.items(): 43 | logs[f'evaluation/{k}'] = v 44 | 45 | logs['time/total'] = time.time() - self.start_time 46 | logs['time/evaluation'] = time.time() - eval_start 47 | logs['training/train_loss_mean'] = np.mean(train_losses) 48 | logs['training/train_loss_std'] = np.std(train_losses) 49 | 50 | for k in self.diagnostics: 51 | logs[k] = self.diagnostics[k] 52 | 53 | if print_logs: 54 | print('=' * 80) 55 | print(f'Iteration {iter_num}') 56 | for k, v in logs.items(): 57 | print(f'{k}: {v}') 58 | 59 | return logs 60 | 61 | def train_step(self): 62 | states, actions, rewards, dones, attention_mask, returns = self.get_batch(self.batch_size) 63 | state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards) 64 | 65 | state_preds, action_preds, reward_preds = self.model.forward( 66 | states, actions, rewards, masks=None, attention_mask=attention_mask, target_return=returns, 67 | ) 68 | 69 | # note: currently indexing & masking is not fully correct 70 | loss = self.loss_fn( 71 | state_preds, action_preds, reward_preds, 72 | state_target[:,1:], action_target, reward_target[:,1:], 73 | ) 74 | self.optimizer.zero_grad() 75 | loss.backward() 76 | self.optimizer.step() 77 | 78 | return loss.detach().cpu().item() 79 | -------------------------------------------------------------------------------- /gym/experiment.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import wandb 5 | 6 | import argparse 7 | import pickle 8 | import random 9 | import sys 10 | 11 | from decision_transformer.evaluation.evaluate_episodes import evaluate_episode, evaluate_episode_rtg 12 | from decision_transformer.models.decision_transformer import DecisionTransformer 13 | from decision_transformer.models.mlp_bc import MLPBCModel 14 | from decision_transformer.training.act_trainer import ActTrainer 15 | from decision_transformer.training.seq_trainer import SequenceTrainer 16 | 17 | 18 | def discount_cumsum(x, gamma): 19 | discount_cumsum = np.zeros_like(x) 20 | discount_cumsum[-1] = x[-1] 21 | for t in reversed(range(x.shape[0]-1)): 22 | discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] 23 | return discount_cumsum 24 | 25 | 26 | def experiment( 27 | exp_prefix, 28 | variant, 29 | ): 30 | device = variant.get('device', 'cuda') 31 | log_to_wandb = variant.get('log_to_wandb', False) 32 | 33 | env_name, dataset = variant['env'], variant['dataset'] 34 | model_type = variant['model_type'] 35 | group_name = f'{exp_prefix}-{env_name}-{dataset}' 36 | exp_prefix = f'{group_name}-{random.randint(int(1e5), int(1e6) - 1)}' 37 | 38 | if env_name == 'hopper': 39 | env = gym.make('Hopper-v3') 40 | max_ep_len = 1000 41 | env_targets = [3600, 1800] # evaluation conditioning targets 42 | scale = 1000. # normalization for rewards/returns 43 | elif env_name == 'halfcheetah': 44 | env = gym.make('HalfCheetah-v3') 45 | max_ep_len = 1000 46 | env_targets = [12000, 6000] 47 | scale = 1000. 48 | elif env_name == 'walker2d': 49 | env = gym.make('Walker2d-v3') 50 | max_ep_len = 1000 51 | env_targets = [5000, 2500] 52 | scale = 1000. 53 | elif env_name == 'reacher2d': 54 | from decision_transformer.envs.reacher_2d import Reacher2dEnv 55 | env = Reacher2dEnv() 56 | max_ep_len = 100 57 | env_targets = [76, 40] 58 | scale = 10. 59 | else: 60 | raise NotImplementedError 61 | 62 | if model_type == 'bc': 63 | env_targets = env_targets[:1] # since BC ignores target, no need for different evaluations 64 | 65 | state_dim = env.observation_space.shape[0] 66 | act_dim = env.action_space.shape[0] 67 | 68 | # load dataset 69 | dataset_path = f'data/{env_name}-{dataset}-v2.pkl' 70 | with open(dataset_path, 'rb') as f: 71 | trajectories = pickle.load(f) 72 | 73 | # save all path information into separate lists 74 | mode = variant.get('mode', 'normal') 75 | states, traj_lens, returns = [], [], [] 76 | for path in trajectories: 77 | if mode == 'delayed': # delayed: all rewards moved to end of trajectory 78 | path['rewards'][-1] = path['rewards'].sum() 79 | path['rewards'][:-1] = 0. 80 | states.append(path['observations']) 81 | traj_lens.append(len(path['observations'])) 82 | returns.append(path['rewards'].sum()) 83 | traj_lens, returns = np.array(traj_lens), np.array(returns) 84 | 85 | # used for input normalization 86 | states = np.concatenate(states, axis=0) 87 | state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 88 | 89 | num_timesteps = sum(traj_lens) 90 | 91 | print('=' * 50) 92 | print(f'Starting new experiment: {env_name} {dataset}') 93 | print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found') 94 | print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}') 95 | print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}') 96 | print('=' * 50) 97 | 98 | K = variant['K'] 99 | batch_size = variant['batch_size'] 100 | num_eval_episodes = variant['num_eval_episodes'] 101 | pct_traj = variant.get('pct_traj', 1.) 102 | 103 | # only train on top pct_traj trajectories (for %BC experiment) 104 | num_timesteps = max(int(pct_traj*num_timesteps), 1) 105 | sorted_inds = np.argsort(returns) # lowest to highest 106 | num_trajectories = 1 107 | timesteps = traj_lens[sorted_inds[-1]] 108 | ind = len(trajectories) - 2 109 | while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps: 110 | timesteps += traj_lens[sorted_inds[ind]] 111 | num_trajectories += 1 112 | ind -= 1 113 | sorted_inds = sorted_inds[-num_trajectories:] 114 | 115 | # used to reweight sampling so we sample according to timesteps instead of trajectories 116 | p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds]) 117 | 118 | def get_batch(batch_size=256, max_len=K): 119 | batch_inds = np.random.choice( 120 | np.arange(num_trajectories), 121 | size=batch_size, 122 | replace=True, 123 | p=p_sample, # reweights so we sample according to timesteps 124 | ) 125 | 126 | s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], [] 127 | for i in range(batch_size): 128 | traj = trajectories[int(sorted_inds[batch_inds[i]])] 129 | si = random.randint(0, traj['rewards'].shape[0] - 1) 130 | 131 | # get sequences from dataset 132 | s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim)) 133 | a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim)) 134 | r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1)) 135 | if 'terminals' in traj: 136 | d.append(traj['terminals'][si:si + max_len].reshape(1, -1)) 137 | else: 138 | d.append(traj['dones'][si:si + max_len].reshape(1, -1)) 139 | timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1)) 140 | timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len-1 # padding cutoff 141 | rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1)) 142 | if rtg[-1].shape[1] <= s[-1].shape[1]: 143 | rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1) 144 | 145 | # padding and state + reward normalization 146 | tlen = s[-1].shape[1] 147 | s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1) 148 | s[-1] = (s[-1] - state_mean) / state_std 149 | a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1) 150 | r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1) 151 | d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1) 152 | rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale 153 | timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1) 154 | mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1)) 155 | 156 | s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device) 157 | a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device) 158 | r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device) 159 | d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device) 160 | rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device) 161 | timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device) 162 | mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device) 163 | 164 | return s, a, r, d, rtg, timesteps, mask 165 | 166 | def eval_episodes(target_rew): 167 | def fn(model): 168 | returns, lengths = [], [] 169 | for _ in range(num_eval_episodes): 170 | with torch.no_grad(): 171 | if model_type == 'dt': 172 | ret, length = evaluate_episode_rtg( 173 | env, 174 | state_dim, 175 | act_dim, 176 | model, 177 | max_ep_len=max_ep_len, 178 | scale=scale, 179 | target_return=target_rew/scale, 180 | mode=mode, 181 | state_mean=state_mean, 182 | state_std=state_std, 183 | device=device, 184 | ) 185 | else: 186 | ret, length = evaluate_episode( 187 | env, 188 | state_dim, 189 | act_dim, 190 | model, 191 | max_ep_len=max_ep_len, 192 | target_return=target_rew/scale, 193 | mode=mode, 194 | state_mean=state_mean, 195 | state_std=state_std, 196 | device=device, 197 | ) 198 | returns.append(ret) 199 | lengths.append(length) 200 | return { 201 | f'target_{target_rew}_return_mean': np.mean(returns), 202 | f'target_{target_rew}_return_std': np.std(returns), 203 | f'target_{target_rew}_length_mean': np.mean(lengths), 204 | f'target_{target_rew}_length_std': np.std(lengths), 205 | } 206 | return fn 207 | 208 | if model_type == 'dt': 209 | model = DecisionTransformer( 210 | state_dim=state_dim, 211 | act_dim=act_dim, 212 | max_length=K, 213 | max_ep_len=max_ep_len, 214 | hidden_size=variant['embed_dim'], 215 | n_layer=variant['n_layer'], 216 | n_head=variant['n_head'], 217 | n_inner=4*variant['embed_dim'], 218 | activation_function=variant['activation_function'], 219 | n_positions=1024, 220 | resid_pdrop=variant['dropout'], 221 | attn_pdrop=variant['dropout'], 222 | ) 223 | elif model_type == 'bc': 224 | model = MLPBCModel( 225 | state_dim=state_dim, 226 | act_dim=act_dim, 227 | max_length=K, 228 | hidden_size=variant['embed_dim'], 229 | n_layer=variant['n_layer'], 230 | ) 231 | else: 232 | raise NotImplementedError 233 | 234 | model = model.to(device=device) 235 | 236 | warmup_steps = variant['warmup_steps'] 237 | optimizer = torch.optim.AdamW( 238 | model.parameters(), 239 | lr=variant['learning_rate'], 240 | weight_decay=variant['weight_decay'], 241 | ) 242 | scheduler = torch.optim.lr_scheduler.LambdaLR( 243 | optimizer, 244 | lambda steps: min((steps+1)/warmup_steps, 1) 245 | ) 246 | 247 | if model_type == 'dt': 248 | trainer = SequenceTrainer( 249 | model=model, 250 | optimizer=optimizer, 251 | batch_size=batch_size, 252 | get_batch=get_batch, 253 | scheduler=scheduler, 254 | loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2), 255 | eval_fns=[eval_episodes(tar) for tar in env_targets], 256 | ) 257 | elif model_type == 'bc': 258 | trainer = ActTrainer( 259 | model=model, 260 | optimizer=optimizer, 261 | batch_size=batch_size, 262 | get_batch=get_batch, 263 | scheduler=scheduler, 264 | loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2), 265 | eval_fns=[eval_episodes(tar) for tar in env_targets], 266 | ) 267 | 268 | if log_to_wandb: 269 | wandb.init( 270 | name=exp_prefix, 271 | group=group_name, 272 | project='decision-transformer', 273 | config=variant 274 | ) 275 | # wandb.watch(model) # wandb has some bug 276 | 277 | for iter in range(variant['max_iters']): 278 | outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True) 279 | if log_to_wandb: 280 | wandb.log(outputs) 281 | 282 | 283 | if __name__ == '__main__': 284 | parser = argparse.ArgumentParser() 285 | parser.add_argument('--env', type=str, default='hopper') 286 | parser.add_argument('--dataset', type=str, default='medium') # medium, medium-replay, medium-expert, expert 287 | parser.add_argument('--mode', type=str, default='normal') # normal for standard setting, delayed for sparse 288 | parser.add_argument('--K', type=int, default=20) 289 | parser.add_argument('--pct_traj', type=float, default=1.) 290 | parser.add_argument('--batch_size', type=int, default=64) 291 | parser.add_argument('--model_type', type=str, default='dt') # dt for decision transformer, bc for behavior cloning 292 | parser.add_argument('--embed_dim', type=int, default=128) 293 | parser.add_argument('--n_layer', type=int, default=3) 294 | parser.add_argument('--n_head', type=int, default=1) 295 | parser.add_argument('--activation_function', type=str, default='relu') 296 | parser.add_argument('--dropout', type=float, default=0.1) 297 | parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4) 298 | parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4) 299 | parser.add_argument('--warmup_steps', type=int, default=10000) 300 | parser.add_argument('--num_eval_episodes', type=int, default=100) 301 | parser.add_argument('--max_iters', type=int, default=10) 302 | parser.add_argument('--num_steps_per_iter', type=int, default=10000) 303 | parser.add_argument('--device', type=str, default='cuda') 304 | parser.add_argument('--log_to_wandb', '-w', type=bool, default=False) 305 | 306 | args = parser.parse_args() 307 | 308 | experiment('gym-experiment', variant=vars(args)) 309 | -------------------------------------------------------------------------------- /gym/readme-gym.md: -------------------------------------------------------------------------------- 1 | 2 | # OpenAI Gym 3 | 4 | ## Installation 5 | 6 | Experiments require MuJoCo. 7 | Follow the instructions in the [mujoco-py repo](https://github.com/openai/mujoco-py) to install. 8 | Then, dependencies can be installed with the following command: 9 | 10 | ``` 11 | conda env create -f conda_env.yml 12 | ``` 13 | 14 | ## Downloading datasets 15 | 16 | Datasets are stored in the `data` directory. 17 | Install the [D4RL repo](https://github.com/rail-berkeley/d4rl), following the instructions there. 18 | Then, run the following script in order to download the datasets and save them in our format: 19 | 20 | ``` 21 | python download_d4rl_datasets.py 22 | ``` 23 | 24 | ## Example usage 25 | 26 | Experiments can be reproduced with the following: 27 | 28 | ``` 29 | python experiment.py --env hopper --dataset medium --model_type dt 30 | ``` 31 | 32 | Adding `-w True` will log results to Weights and Biases. 33 | --------------------------------------------------------------------------------